package middleware import ( "crypto/rsa" "github.com/golang-jwt/jwt/v5" "github.com/labstack/echo/v4" "slices" "sorbet/pkg/env" "sorbet/pkg/ticket" ) var ticketPublicKey *rsa.PublicKey type TicketConfig struct { Skipper Skipper Anonymously bool Audiences []string PublicKey *rsa.PublicKey TicketFinder ticket.Finder ClaimsLooker func(c echo.Context, claims *ticket.Claims) error ErrorHandler func(c echo.Context, err error) error SuccessHandler func(c echo.Context) } func Ticket(anonymously bool, roles ...string) echo.MiddlewareFunc { return TicketWithConfig(TicketConfig{ Anonymously: anonymously, TicketFinder: ticket.DefaultFinder, ClaimsLooker: func(c echo.Context, claims *ticket.Claims) error { if len(roles) > 0 && slices.Contains(roles, claims.Role) { return nil } return ticket.ErrUnauthorized }, }) } func TicketWithConfig(config TicketConfig) echo.MiddlewareFunc { return config.ToMiddleware() } func (t *TicketConfig) ToMiddleware() echo.MiddlewareFunc { if t.Skipper == nil { t.Skipper = DefaultSkipper } if len(t.Audiences) == 0 { t.Audiences = append(t.Audiences, "*") } if t.TicketFinder == nil { t.TicketFinder = ticket.DefaultFinder } if t.ErrorHandler == nil { t.ErrorHandler = func(c echo.Context, err error) error { return err } } if t.ClaimsLooker == nil { t.ClaimsLooker = func(c echo.Context, claims *ticket.Claims) error { return nil } } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if t.Skipper(c) { return next(c) } ticketString := t.TicketFinder(c.Request()) if ticketString == "" { if t.Anonymously { return next(c) } return t.ErrorHandler(c, ticket.ErrNoTicketFound) } publicKey := t.PublicKey if publicKey == nil { key, err := getTicketPublicKey() if err != nil { return err } publicKey = key } claims, err := ticket.Verify(ticketString, publicKey, t.Audiences...) if err != nil { return t.ErrorHandler(c, err) } if err = t.ClaimsLooker(c, claims); err != nil { return t.ErrorHandler(c, err) } c.Set("ticket", ticketString) c.Set("ticket_claims", claims) return next(c) } } } func getTicketPublicKey() (*rsa.PublicKey, error) { if ticketPublicKey != nil { return ticketPublicKey, nil } var err error source := []byte(env.String("TICKET_PUBLIC_KEY")) ticketPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source) return ticketPublicKey, err }