package ticket import ( "crypto/rsa" "github.com/golang-jwt/jwt/v5" "github.com/labstack/echo/v4" "slices" "sorbet/pkg/env" ) type Config struct { Skipper func(c echo.Context) bool Anonymously bool // 是否允许匿名访问 Audiences []string PublicKey *rsa.PublicKey TicketKey string TicketClaimsKey string TicketFinder Finder ClaimsLooker func(c echo.Context, claims *Claims) error ErrorHandler func(c echo.Context, err error) error SuccessHandler func(c echo.Context) } func Middleware(anonymously bool, roles ...string) echo.MiddlewareFunc { return WithConfig(Config{ Anonymously: anonymously, TicketFinder: DefaultFinder, ClaimsLooker: func(c echo.Context, claims *Claims) error { if len(roles) > 0 && slices.Contains(roles, claims.Role) { return nil } return ErrUnauthorized }, }) } func WithConfig(config Config) echo.MiddlewareFunc { return config.ToMiddleware() } func (t Config) ToMiddleware() echo.MiddlewareFunc { if t.Skipper == nil { t.Skipper = func(c echo.Context) bool { return false } } if len(t.Audiences) == 0 { t.Audiences = append(t.Audiences, "*") } if t.TicketFinder == nil { t.TicketFinder = DefaultFinder } if t.ErrorHandler == nil { t.ErrorHandler = func(c echo.Context, err error) error { return err } } if t.PublicKey == nil { var err error source := []byte(env.String("TICKET_PUBLIC_KEY")) t.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source) if err != nil { panic(err) } } if t.TicketKey == "" { t.TicketKey = "ticket" } if t.TicketClaimsKey == "" { t.TicketClaimsKey = "ticket_claims" } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if t.Skipper(c) { return next(c) } ticketString := t.TicketFinder(c.Request()) if ticketString == "" { if t.Anonymously { return next(c) } return t.ErrorHandler(c, ErrNoTicketFound) } claims, err := Verify(ticketString, t.PublicKey, t.Audiences...) if err != nil { return t.ErrorHandler(c, err) } if t.ClaimsLooker != nil { err = t.ClaimsLooker(c, claims) if err != nil { return t.ErrorHandler(c, err) } } c.Set(t.TicketKey, ticketString) c.Set(t.TicketClaimsKey, claims) return next(c) } } }