go项目脚手架
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
2.5 KiB

package middleware
import (
"crypto/rsa"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"slices"
"sorbet/pkg/env"
"sorbet/pkg/ticket"
)
var ticketPublicKey *rsa.PublicKey
type TicketConfig struct {
Skipper Skipper
Anonymously bool
Audiences []string
PublicKey *rsa.PublicKey
TicketFinder ticket.Finder
ClaimsLooker func(c echo.Context, claims *ticket.Claims) error
ErrorHandler func(c echo.Context, err error) error
SuccessHandler func(c echo.Context)
}
func Ticket(anonymously bool, roles ...string) echo.MiddlewareFunc {
return TicketWithConfig(TicketConfig{
Anonymously: anonymously,
TicketFinder: ticket.DefaultFinder,
ClaimsLooker: func(c echo.Context, claims *ticket.Claims) error {
if len(roles) > 0 && slices.Contains(roles, claims.Role) {
return nil
}
return ticket.ErrUnauthorized
},
})
}
func TicketWithConfig(config TicketConfig) echo.MiddlewareFunc {
return config.ToMiddleware()
}
func (t *TicketConfig) ToMiddleware() echo.MiddlewareFunc {
if t.Skipper == nil {
t.Skipper = DefaultSkipper
}
if len(t.Audiences) == 0 {
t.Audiences = append(t.Audiences, "*")
}
if t.TicketFinder == nil {
t.TicketFinder = ticket.DefaultFinder
}
if t.ErrorHandler == nil {
t.ErrorHandler = func(c echo.Context, err error) error {
return err
}
}
if t.ClaimsLooker == nil {
t.ClaimsLooker = func(c echo.Context, claims *ticket.Claims) error {
return nil
}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if t.Skipper(c) {
return next(c)
}
ticketString := t.TicketFinder(c.Request())
if ticketString == "" {
if t.Anonymously {
return next(c)
}
return t.ErrorHandler(c, ticket.ErrNoTicketFound)
}
publicKey := t.PublicKey
if publicKey == nil {
key, err := getTicketPublicKey()
if err != nil {
return err
}
publicKey = key
}
claims, err := ticket.Verify(ticketString, publicKey, t.Audiences...)
if err != nil {
return t.ErrorHandler(c, err)
}
if err = t.ClaimsLooker(c, claims); err != nil {
return t.ErrorHandler(c, err)
}
c.Set("ticket", ticketString)
c.Set("ticket_claims", claims)
return next(c)
}
}
}
func getTicketPublicKey() (*rsa.PublicKey, error) {
if ticketPublicKey != nil {
return ticketPublicKey, nil
}
var err error
source := []byte(env.String("TICKET_PUBLIC_KEY"))
ticketPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source)
return ticketPublicKey, err
}