parent
2aadccd5b8
commit
a8d03f5a8f
@ -1,104 +0,0 @@ |
||||
package middleware |
||||
|
||||
import ( |
||||
"crypto/rsa" |
||||
"github.com/golang-jwt/jwt/v5" |
||||
"github.com/labstack/echo/v4" |
||||
"slices" |
||||
"sorbet/pkg/env" |
||||
"sorbet/pkg/ticket" |
||||
) |
||||
|
||||
var ticketPublicKey *rsa.PublicKey |
||||
|
||||
type TicketConfig struct { |
||||
Skipper Skipper |
||||
Anonymously bool |
||||
Audiences []string |
||||
PublicKey *rsa.PublicKey |
||||
TicketFinder ticket.Finder |
||||
ClaimsLooker func(c echo.Context, claims *ticket.Claims) error |
||||
ErrorHandler func(c echo.Context, err error) error |
||||
SuccessHandler func(c echo.Context) |
||||
} |
||||
|
||||
func Ticket(anonymously bool, roles ...string) echo.MiddlewareFunc { |
||||
return TicketWithConfig(TicketConfig{ |
||||
Anonymously: anonymously, |
||||
TicketFinder: ticket.DefaultFinder, |
||||
ClaimsLooker: func(c echo.Context, claims *ticket.Claims) error { |
||||
if len(roles) > 0 && slices.Contains(roles, claims.Role) { |
||||
return nil |
||||
} |
||||
return ticket.ErrUnauthorized |
||||
}, |
||||
}) |
||||
} |
||||
|
||||
func TicketWithConfig(config TicketConfig) echo.MiddlewareFunc { |
||||
return config.ToMiddleware() |
||||
} |
||||
|
||||
func (t *TicketConfig) ToMiddleware() echo.MiddlewareFunc { |
||||
if t.Skipper == nil { |
||||
t.Skipper = DefaultSkipper |
||||
} |
||||
if len(t.Audiences) == 0 { |
||||
t.Audiences = append(t.Audiences, "*") |
||||
} |
||||
if t.TicketFinder == nil { |
||||
t.TicketFinder = ticket.DefaultFinder |
||||
} |
||||
if t.ErrorHandler == nil { |
||||
t.ErrorHandler = func(c echo.Context, err error) error { |
||||
return err |
||||
} |
||||
} |
||||
if t.ClaimsLooker == nil { |
||||
t.ClaimsLooker = func(c echo.Context, claims *ticket.Claims) error { |
||||
return nil |
||||
} |
||||
} |
||||
return func(next echo.HandlerFunc) echo.HandlerFunc { |
||||
return func(c echo.Context) error { |
||||
if t.Skipper(c) { |
||||
return next(c) |
||||
} |
||||
ticketString := t.TicketFinder(c.Request()) |
||||
if ticketString == "" { |
||||
if t.Anonymously { |
||||
return next(c) |
||||
} |
||||
return t.ErrorHandler(c, ticket.ErrNoTicketFound) |
||||
} |
||||
publicKey := t.PublicKey |
||||
if publicKey == nil { |
||||
key, err := getTicketPublicKey() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
publicKey = key |
||||
} |
||||
claims, err := ticket.Verify(ticketString, publicKey, t.Audiences...) |
||||
if err != nil { |
||||
return t.ErrorHandler(c, err) |
||||
} |
||||
if err = t.ClaimsLooker(c, claims); err != nil { |
||||
return t.ErrorHandler(c, err) |
||||
} |
||||
c.Set("ticket", ticketString) |
||||
c.Set("ticket_claims", claims) |
||||
return next(c) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func getTicketPublicKey() (*rsa.PublicKey, error) { |
||||
if ticketPublicKey != nil { |
||||
return ticketPublicKey, nil |
||||
} |
||||
var err error |
||||
source := []byte(env.String("TICKET_PUBLIC_KEY")) |
||||
ticketPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source) |
||||
return ticketPublicKey, err |
||||
} |
@ -1,38 +0,0 @@ |
||||
package util |
||||
|
||||
import ( |
||||
"github.com/golang-jwt/jwt/v5" |
||||
"github.com/rs/xid" |
||||
"sorbet/pkg/env" |
||||
"sorbet/pkg/ticket" |
||||
"time" |
||||
) |
||||
|
||||
func CreateTicket(claims *ticket.Claims) (string, error) { |
||||
if claims.ID == "" { |
||||
claims.ID = xid.New().String() |
||||
} |
||||
if claims.Issuer == "" { |
||||
claims.Issuer = env.String("TICKET_ISSUER") |
||||
} |
||||
if claims.Subject == "" { |
||||
claims.Issuer = env.String("TICKET_SUBJECT") |
||||
} |
||||
if claims.Audience == nil { |
||||
claims.Audience = env.List("TICKET_AUDIENCE") |
||||
} |
||||
if claims.ExpiresAt == nil { |
||||
ttl := env.Duration("TICKET_TTL", time.Hour) |
||||
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl)) |
||||
} |
||||
source := []byte(env.String("TICKET_PRIVATE_KEY")) |
||||
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
signedString, err := ticket.Create(claims, privateKey) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
return signedString, nil |
||||
} |
@ -1,45 +0,0 @@ |
||||
package ticket |
||||
|
||||
import "github.com/golang-jwt/jwt/v5" |
||||
|
||||
type Claims struct { |
||||
ID string `json:"jti,omitempty"` // ticket编号
|
||||
UID uint `json:"uid"` // 用户编号
|
||||
Role string `json:"role"` // 账号类型
|
||||
Issuer string `json:"iss,omitempty"` // 签发人
|
||||
Subject string `json:"sub,omitempty"` // 主题
|
||||
Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众
|
||||
ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间
|
||||
NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间
|
||||
IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间
|
||||
} |
||||
|
||||
// GetExpirationTime implements the Claims interface.
|
||||
func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) { |
||||
return c.ExpiresAt, nil |
||||
} |
||||
|
||||
// GetNotBefore implements the Claims interface.
|
||||
func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) { |
||||
return c.NotBefore, nil |
||||
} |
||||
|
||||
// GetIssuedAt implements the Claims interface.
|
||||
func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) { |
||||
return c.IssuedAt, nil |
||||
} |
||||
|
||||
// GetAudience implements the Claims interface.
|
||||
func (c *Claims) GetAudience() (jwt.ClaimStrings, error) { |
||||
return c.Audience, nil |
||||
} |
||||
|
||||
// GetIssuer implements the Claims interface.
|
||||
func (c *Claims) GetIssuer() (string, error) { |
||||
return c.Issuer, nil |
||||
} |
||||
|
||||
// GetSubject implements the Claims interface.
|
||||
func (c *Claims) GetSubject() (string, error) { |
||||
return c.Subject, nil |
||||
} |
@ -0,0 +1,95 @@ |
||||
package ticket |
||||
|
||||
import ( |
||||
"crypto/rsa" |
||||
"github.com/golang-jwt/jwt/v5" |
||||
"github.com/labstack/echo/v4" |
||||
"slices" |
||||
"sorbet/pkg/env" |
||||
) |
||||
|
||||
type Config struct { |
||||
Skipper func(c echo.Context) bool |
||||
Anonymously bool // 是否允许匿名访问
|
||||
Audiences []string |
||||
PublicKey *rsa.PublicKey |
||||
TicketKey string |
||||
TicketClaimsKey string |
||||
TicketFinder Finder |
||||
ClaimsLooker func(c echo.Context, claims *Claims) error |
||||
ErrorHandler func(c echo.Context, err error) error |
||||
SuccessHandler func(c echo.Context) |
||||
} |
||||
|
||||
func Middleware(anonymously bool, roles ...string) echo.MiddlewareFunc { |
||||
return WithConfig(Config{ |
||||
Anonymously: anonymously, |
||||
TicketFinder: DefaultFinder, |
||||
ClaimsLooker: func(c echo.Context, claims *Claims) error { |
||||
if len(roles) > 0 && slices.Contains(roles, claims.Role) { |
||||
return nil |
||||
} |
||||
return ErrUnauthorized |
||||
}, |
||||
}) |
||||
} |
||||
|
||||
func WithConfig(config Config) echo.MiddlewareFunc { |
||||
return config.ToMiddleware() |
||||
} |
||||
|
||||
func (t Config) ToMiddleware() echo.MiddlewareFunc { |
||||
if t.Skipper == nil { |
||||
t.Skipper = func(c echo.Context) bool { return false } |
||||
} |
||||
if len(t.Audiences) == 0 { |
||||
t.Audiences = append(t.Audiences, "*") |
||||
} |
||||
if t.TicketFinder == nil { |
||||
t.TicketFinder = DefaultFinder |
||||
} |
||||
if t.ErrorHandler == nil { |
||||
t.ErrorHandler = func(c echo.Context, err error) error { return err } |
||||
} |
||||
if t.PublicKey == nil { |
||||
var err error |
||||
source := []byte(env.String("TICKET_PUBLIC_KEY")) |
||||
t.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
} |
||||
if t.TicketKey == "" { |
||||
t.TicketKey = "ticket" |
||||
} |
||||
if t.TicketClaimsKey == "" { |
||||
t.TicketClaimsKey = "ticket_claims" |
||||
} |
||||
return func(next echo.HandlerFunc) echo.HandlerFunc { |
||||
return func(c echo.Context) error { |
||||
if t.Skipper(c) { |
||||
return next(c) |
||||
} |
||||
ticketString := t.TicketFinder(c.Request()) |
||||
if ticketString == "" { |
||||
if t.Anonymously { |
||||
return next(c) |
||||
} |
||||
return t.ErrorHandler(c, ErrNoTicketFound) |
||||
} |
||||
claims, err := Verify(ticketString, t.PublicKey, t.Audiences...) |
||||
if err != nil { |
||||
return t.ErrorHandler(c, err) |
||||
} |
||||
if t.ClaimsLooker != nil { |
||||
err = t.ClaimsLooker(c, claims) |
||||
if err != nil { |
||||
return t.ErrorHandler(c, err) |
||||
} |
||||
} |
||||
c.Set(t.TicketKey, ticketString) |
||||
c.Set(t.TicketClaimsKey, claims) |
||||
return next(c) |
||||
} |
||||
} |
||||
} |
@ -1,45 +0,0 @@ |
||||
package ticket |
||||
|
||||
import ( |
||||
"net/http" |
||||
"strings" |
||||
) |
||||
|
||||
type Finder func(r *http.Request) string |
||||
|
||||
func DefaultFinder(r *http.Request) string { |
||||
if s := FromHeader(r); s != "" { |
||||
return s |
||||
} |
||||
if s := FromCookie(r); s != "" { |
||||
return s |
||||
} |
||||
return FromQuery(r) |
||||
} |
||||
|
||||
func FromCookie(r *http.Request) string { |
||||
cookie, err := r.Cookie("ticket") |
||||
if err != nil { |
||||
return "" |
||||
} |
||||
return cookie.Value |
||||
} |
||||
|
||||
func FromHeader(r *http.Request) string { |
||||
bearer := r.Header.Get("Authorization") |
||||
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { |
||||
return bearer[7:] |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
func FromQuery(r *http.Request, keys ...string) string { |
||||
q := r.URL.Query() |
||||
for _, key := range keys { |
||||
s := q.Get(key) |
||||
if s != "" { |
||||
return s |
||||
} |
||||
} |
||||
return q.Get("ticket") |
||||
} |
Loading…
Reference in new issue