You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
2.3 KiB
95 lines
2.3 KiB
package ticket
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/labstack/echo/v4"
|
|
"slices"
|
|
"sorbet/pkg/env"
|
|
)
|
|
|
|
type Config struct {
|
|
Skipper func(c echo.Context) bool
|
|
Anonymously bool // 是否允许匿名访问
|
|
Audiences []string
|
|
PublicKey *rsa.PublicKey
|
|
TicketKey string
|
|
TicketClaimsKey string
|
|
TicketFinder Finder
|
|
ClaimsLooker func(c echo.Context, claims *Claims) error
|
|
ErrorHandler func(c echo.Context, err error) error
|
|
SuccessHandler func(c echo.Context)
|
|
}
|
|
|
|
func Middleware(anonymously bool, roles ...string) echo.MiddlewareFunc {
|
|
return WithConfig(Config{
|
|
Anonymously: anonymously,
|
|
TicketFinder: DefaultFinder,
|
|
ClaimsLooker: func(c echo.Context, claims *Claims) error {
|
|
if len(roles) > 0 && slices.Contains(roles, claims.Role) {
|
|
return nil
|
|
}
|
|
return ErrUnauthorized
|
|
},
|
|
})
|
|
}
|
|
|
|
func WithConfig(config Config) echo.MiddlewareFunc {
|
|
return config.ToMiddleware()
|
|
}
|
|
|
|
func (t Config) ToMiddleware() echo.MiddlewareFunc {
|
|
if t.Skipper == nil {
|
|
t.Skipper = func(c echo.Context) bool { return false }
|
|
}
|
|
if len(t.Audiences) == 0 {
|
|
t.Audiences = append(t.Audiences, "*")
|
|
}
|
|
if t.TicketFinder == nil {
|
|
t.TicketFinder = DefaultFinder
|
|
}
|
|
if t.ErrorHandler == nil {
|
|
t.ErrorHandler = func(c echo.Context, err error) error { return err }
|
|
}
|
|
if t.PublicKey == nil {
|
|
var err error
|
|
source := []byte(env.String("TICKET_PUBLIC_KEY"))
|
|
t.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
if t.TicketKey == "" {
|
|
t.TicketKey = "ticket"
|
|
}
|
|
if t.TicketClaimsKey == "" {
|
|
t.TicketClaimsKey = "ticket_claims"
|
|
}
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
|
return func(c echo.Context) error {
|
|
if t.Skipper(c) {
|
|
return next(c)
|
|
}
|
|
ticketString := t.TicketFinder(c.Request())
|
|
if ticketString == "" {
|
|
if t.Anonymously {
|
|
return next(c)
|
|
}
|
|
return t.ErrorHandler(c, ErrNoTicketFound)
|
|
}
|
|
claims, err := Verify(ticketString, t.PublicKey, t.Audiences...)
|
|
if err != nil {
|
|
return t.ErrorHandler(c, err)
|
|
}
|
|
if t.ClaimsLooker != nil {
|
|
err = t.ClaimsLooker(c, claims)
|
|
if err != nil {
|
|
return t.ErrorHandler(c, err)
|
|
}
|
|
}
|
|
c.Set(t.TicketKey, ticketString)
|
|
c.Set(t.TicketClaimsKey, claims)
|
|
return next(c)
|
|
}
|
|
}
|
|
}
|
|
|