You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
2.5 KiB
104 lines
2.5 KiB
package middleware
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/labstack/echo/v4"
|
|
"slices"
|
|
"sorbet/pkg/env"
|
|
"sorbet/pkg/ticket"
|
|
)
|
|
|
|
var ticketPublicKey *rsa.PublicKey
|
|
|
|
type TicketConfig struct {
|
|
Skipper Skipper
|
|
Anonymously bool
|
|
Audiences []string
|
|
PublicKey *rsa.PublicKey
|
|
TicketFinder ticket.Finder
|
|
ClaimsLooker func(c echo.Context, claims *ticket.Claims) error
|
|
ErrorHandler func(c echo.Context, err error) error
|
|
SuccessHandler func(c echo.Context)
|
|
}
|
|
|
|
func Ticket(anonymously bool, roles ...string) echo.MiddlewareFunc {
|
|
return TicketWithConfig(TicketConfig{
|
|
Anonymously: anonymously,
|
|
TicketFinder: ticket.DefaultFinder,
|
|
ClaimsLooker: func(c echo.Context, claims *ticket.Claims) error {
|
|
if len(roles) > 0 && slices.Contains(roles, claims.Role) {
|
|
return nil
|
|
}
|
|
return ticket.ErrUnauthorized
|
|
},
|
|
})
|
|
}
|
|
|
|
func TicketWithConfig(config TicketConfig) echo.MiddlewareFunc {
|
|
return config.ToMiddleware()
|
|
}
|
|
|
|
func (t *TicketConfig) ToMiddleware() echo.MiddlewareFunc {
|
|
if t.Skipper == nil {
|
|
t.Skipper = DefaultSkipper
|
|
}
|
|
if len(t.Audiences) == 0 {
|
|
t.Audiences = append(t.Audiences, "*")
|
|
}
|
|
if t.TicketFinder == nil {
|
|
t.TicketFinder = ticket.DefaultFinder
|
|
}
|
|
if t.ErrorHandler == nil {
|
|
t.ErrorHandler = func(c echo.Context, err error) error {
|
|
return err
|
|
}
|
|
}
|
|
if t.ClaimsLooker == nil {
|
|
t.ClaimsLooker = func(c echo.Context, claims *ticket.Claims) error {
|
|
return nil
|
|
}
|
|
}
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if t.Skipper(c) {
|
|
return next(c)
|
|
}
|
|
ticketString := t.TicketFinder(c.Request())
|
|
if ticketString == "" {
|
|
if t.Anonymously {
|
|
return next(c)
|
|
}
|
|
return t.ErrorHandler(c, ticket.ErrNoTicketFound)
|
|
}
|
|
publicKey := t.PublicKey
|
|
if publicKey == nil {
|
|
key, err := getTicketPublicKey()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
publicKey = key
|
|
}
|
|
claims, err := ticket.Verify(ticketString, publicKey, t.Audiences...)
|
|
if err != nil {
|
|
return t.ErrorHandler(c, err)
|
|
}
|
|
if err = t.ClaimsLooker(c, claims); err != nil {
|
|
return t.ErrorHandler(c, err)
|
|
}
|
|
c.Set("ticket", ticketString)
|
|
c.Set("ticket_claims", claims)
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func getTicketPublicKey() (*rsa.PublicKey, error) {
|
|
if ticketPublicKey != nil {
|
|
return ticketPublicKey, nil
|
|
}
|
|
var err error
|
|
source := []byte(env.String("TICKET_PUBLIC_KEY"))
|
|
ticketPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source)
|
|
return ticketPublicKey, err
|
|
}
|
|
|