go项目脚手架
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
sorbet/pkg/ticket/middleware.go

96 lines
2.3 KiB

package ticket
import (
"crypto/rsa"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"slices"
"sorbet/pkg/env"
)
type Config struct {
Skipper func(c echo.Context) bool
Anonymously bool // 是否允许匿名访问
Audiences []string
PublicKey *rsa.PublicKey
TicketKey string
TicketClaimsKey string
TicketFinder Finder
ClaimsLooker func(c echo.Context, claims *Claims) error
ErrorHandler func(c echo.Context, err error) error
SuccessHandler func(c echo.Context)
}
func Middleware(anonymously bool, roles ...string) echo.MiddlewareFunc {
return WithConfig(Config{
Anonymously: anonymously,
TicketFinder: DefaultFinder,
ClaimsLooker: func(c echo.Context, claims *Claims) error {
if len(roles) > 0 && slices.Contains(roles, claims.Role) {
return nil
}
return ErrUnauthorized
},
})
}
func WithConfig(config Config) echo.MiddlewareFunc {
return config.ToMiddleware()
}
func (t Config) ToMiddleware() echo.MiddlewareFunc {
if t.Skipper == nil {
t.Skipper = func(c echo.Context) bool { return false }
}
if len(t.Audiences) == 0 {
t.Audiences = append(t.Audiences, "*")
}
if t.TicketFinder == nil {
t.TicketFinder = DefaultFinder
}
if t.ErrorHandler == nil {
t.ErrorHandler = func(c echo.Context, err error) error { return err }
}
if t.PublicKey == nil {
var err error
source := []byte(env.String("TICKET_PUBLIC_KEY"))
t.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source)
if err != nil {
panic(err)
}
}
if t.TicketKey == "" {
t.TicketKey = "ticket"
}
if t.TicketClaimsKey == "" {
t.TicketClaimsKey = "ticket_claims"
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if t.Skipper(c) {
return next(c)
}
ticketString := t.TicketFinder(c.Request())
if ticketString == "" {
if t.Anonymously {
return next(c)
}
return t.ErrorHandler(c, ErrNoTicketFound)
}
claims, err := Verify(ticketString, t.PublicKey, t.Audiences...)
if err != nil {
return t.ErrorHandler(c, err)
}
if t.ClaimsLooker != nil {
err = t.ClaimsLooker(c, claims)
if err != nil {
return t.ErrorHandler(c, err)
}
}
c.Set(t.TicketKey, ticketString)
c.Set(t.TicketClaimsKey, claims)
return next(c)
}
}
}