parent
2aadccd5b8
commit
a8d03f5a8f
@ -1,104 +0,0 @@ |
|||||||
package middleware |
|
||||||
|
|
||||||
import ( |
|
||||||
"crypto/rsa" |
|
||||||
"github.com/golang-jwt/jwt/v5" |
|
||||||
"github.com/labstack/echo/v4" |
|
||||||
"slices" |
|
||||||
"sorbet/pkg/env" |
|
||||||
"sorbet/pkg/ticket" |
|
||||||
) |
|
||||||
|
|
||||||
var ticketPublicKey *rsa.PublicKey |
|
||||||
|
|
||||||
type TicketConfig struct { |
|
||||||
Skipper Skipper |
|
||||||
Anonymously bool |
|
||||||
Audiences []string |
|
||||||
PublicKey *rsa.PublicKey |
|
||||||
TicketFinder ticket.Finder |
|
||||||
ClaimsLooker func(c echo.Context, claims *ticket.Claims) error |
|
||||||
ErrorHandler func(c echo.Context, err error) error |
|
||||||
SuccessHandler func(c echo.Context) |
|
||||||
} |
|
||||||
|
|
||||||
func Ticket(anonymously bool, roles ...string) echo.MiddlewareFunc { |
|
||||||
return TicketWithConfig(TicketConfig{ |
|
||||||
Anonymously: anonymously, |
|
||||||
TicketFinder: ticket.DefaultFinder, |
|
||||||
ClaimsLooker: func(c echo.Context, claims *ticket.Claims) error { |
|
||||||
if len(roles) > 0 && slices.Contains(roles, claims.Role) { |
|
||||||
return nil |
|
||||||
} |
|
||||||
return ticket.ErrUnauthorized |
|
||||||
}, |
|
||||||
}) |
|
||||||
} |
|
||||||
|
|
||||||
func TicketWithConfig(config TicketConfig) echo.MiddlewareFunc { |
|
||||||
return config.ToMiddleware() |
|
||||||
} |
|
||||||
|
|
||||||
func (t *TicketConfig) ToMiddleware() echo.MiddlewareFunc { |
|
||||||
if t.Skipper == nil { |
|
||||||
t.Skipper = DefaultSkipper |
|
||||||
} |
|
||||||
if len(t.Audiences) == 0 { |
|
||||||
t.Audiences = append(t.Audiences, "*") |
|
||||||
} |
|
||||||
if t.TicketFinder == nil { |
|
||||||
t.TicketFinder = ticket.DefaultFinder |
|
||||||
} |
|
||||||
if t.ErrorHandler == nil { |
|
||||||
t.ErrorHandler = func(c echo.Context, err error) error { |
|
||||||
return err |
|
||||||
} |
|
||||||
} |
|
||||||
if t.ClaimsLooker == nil { |
|
||||||
t.ClaimsLooker = func(c echo.Context, claims *ticket.Claims) error { |
|
||||||
return nil |
|
||||||
} |
|
||||||
} |
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc { |
|
||||||
return func(c echo.Context) error { |
|
||||||
if t.Skipper(c) { |
|
||||||
return next(c) |
|
||||||
} |
|
||||||
ticketString := t.TicketFinder(c.Request()) |
|
||||||
if ticketString == "" { |
|
||||||
if t.Anonymously { |
|
||||||
return next(c) |
|
||||||
} |
|
||||||
return t.ErrorHandler(c, ticket.ErrNoTicketFound) |
|
||||||
} |
|
||||||
publicKey := t.PublicKey |
|
||||||
if publicKey == nil { |
|
||||||
key, err := getTicketPublicKey() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
publicKey = key |
|
||||||
} |
|
||||||
claims, err := ticket.Verify(ticketString, publicKey, t.Audiences...) |
|
||||||
if err != nil { |
|
||||||
return t.ErrorHandler(c, err) |
|
||||||
} |
|
||||||
if err = t.ClaimsLooker(c, claims); err != nil { |
|
||||||
return t.ErrorHandler(c, err) |
|
||||||
} |
|
||||||
c.Set("ticket", ticketString) |
|
||||||
c.Set("ticket_claims", claims) |
|
||||||
return next(c) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func getTicketPublicKey() (*rsa.PublicKey, error) { |
|
||||||
if ticketPublicKey != nil { |
|
||||||
return ticketPublicKey, nil |
|
||||||
} |
|
||||||
var err error |
|
||||||
source := []byte(env.String("TICKET_PUBLIC_KEY")) |
|
||||||
ticketPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source) |
|
||||||
return ticketPublicKey, err |
|
||||||
} |
|
@ -1,38 +0,0 @@ |
|||||||
package util |
|
||||||
|
|
||||||
import ( |
|
||||||
"github.com/golang-jwt/jwt/v5" |
|
||||||
"github.com/rs/xid" |
|
||||||
"sorbet/pkg/env" |
|
||||||
"sorbet/pkg/ticket" |
|
||||||
"time" |
|
||||||
) |
|
||||||
|
|
||||||
func CreateTicket(claims *ticket.Claims) (string, error) { |
|
||||||
if claims.ID == "" { |
|
||||||
claims.ID = xid.New().String() |
|
||||||
} |
|
||||||
if claims.Issuer == "" { |
|
||||||
claims.Issuer = env.String("TICKET_ISSUER") |
|
||||||
} |
|
||||||
if claims.Subject == "" { |
|
||||||
claims.Issuer = env.String("TICKET_SUBJECT") |
|
||||||
} |
|
||||||
if claims.Audience == nil { |
|
||||||
claims.Audience = env.List("TICKET_AUDIENCE") |
|
||||||
} |
|
||||||
if claims.ExpiresAt == nil { |
|
||||||
ttl := env.Duration("TICKET_TTL", time.Hour) |
|
||||||
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl)) |
|
||||||
} |
|
||||||
source := []byte(env.String("TICKET_PRIVATE_KEY")) |
|
||||||
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source) |
|
||||||
if err != nil { |
|
||||||
return "", err |
|
||||||
} |
|
||||||
signedString, err := ticket.Create(claims, privateKey) |
|
||||||
if err != nil { |
|
||||||
return "", err |
|
||||||
} |
|
||||||
return signedString, nil |
|
||||||
} |
|
@ -1,45 +0,0 @@ |
|||||||
package ticket |
|
||||||
|
|
||||||
import "github.com/golang-jwt/jwt/v5" |
|
||||||
|
|
||||||
type Claims struct { |
|
||||||
ID string `json:"jti,omitempty"` // ticket编号
|
|
||||||
UID uint `json:"uid"` // 用户编号
|
|
||||||
Role string `json:"role"` // 账号类型
|
|
||||||
Issuer string `json:"iss,omitempty"` // 签发人
|
|
||||||
Subject string `json:"sub,omitempty"` // 主题
|
|
||||||
Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众
|
|
||||||
ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间
|
|
||||||
NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间
|
|
||||||
IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间
|
|
||||||
} |
|
||||||
|
|
||||||
// GetExpirationTime implements the Claims interface.
|
|
||||||
func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) { |
|
||||||
return c.ExpiresAt, nil |
|
||||||
} |
|
||||||
|
|
||||||
// GetNotBefore implements the Claims interface.
|
|
||||||
func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) { |
|
||||||
return c.NotBefore, nil |
|
||||||
} |
|
||||||
|
|
||||||
// GetIssuedAt implements the Claims interface.
|
|
||||||
func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) { |
|
||||||
return c.IssuedAt, nil |
|
||||||
} |
|
||||||
|
|
||||||
// GetAudience implements the Claims interface.
|
|
||||||
func (c *Claims) GetAudience() (jwt.ClaimStrings, error) { |
|
||||||
return c.Audience, nil |
|
||||||
} |
|
||||||
|
|
||||||
// GetIssuer implements the Claims interface.
|
|
||||||
func (c *Claims) GetIssuer() (string, error) { |
|
||||||
return c.Issuer, nil |
|
||||||
} |
|
||||||
|
|
||||||
// GetSubject implements the Claims interface.
|
|
||||||
func (c *Claims) GetSubject() (string, error) { |
|
||||||
return c.Subject, nil |
|
||||||
} |
|
@ -0,0 +1,95 @@ |
|||||||
|
package ticket |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/rsa" |
||||||
|
"github.com/golang-jwt/jwt/v5" |
||||||
|
"github.com/labstack/echo/v4" |
||||||
|
"slices" |
||||||
|
"sorbet/pkg/env" |
||||||
|
) |
||||||
|
|
||||||
|
type Config struct { |
||||||
|
Skipper func(c echo.Context) bool |
||||||
|
Anonymously bool // 是否允许匿名访问
|
||||||
|
Audiences []string |
||||||
|
PublicKey *rsa.PublicKey |
||||||
|
TicketKey string |
||||||
|
TicketClaimsKey string |
||||||
|
TicketFinder Finder |
||||||
|
ClaimsLooker func(c echo.Context, claims *Claims) error |
||||||
|
ErrorHandler func(c echo.Context, err error) error |
||||||
|
SuccessHandler func(c echo.Context) |
||||||
|
} |
||||||
|
|
||||||
|
func Middleware(anonymously bool, roles ...string) echo.MiddlewareFunc { |
||||||
|
return WithConfig(Config{ |
||||||
|
Anonymously: anonymously, |
||||||
|
TicketFinder: DefaultFinder, |
||||||
|
ClaimsLooker: func(c echo.Context, claims *Claims) error { |
||||||
|
if len(roles) > 0 && slices.Contains(roles, claims.Role) { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return ErrUnauthorized |
||||||
|
}, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func WithConfig(config Config) echo.MiddlewareFunc { |
||||||
|
return config.ToMiddleware() |
||||||
|
} |
||||||
|
|
||||||
|
func (t Config) ToMiddleware() echo.MiddlewareFunc { |
||||||
|
if t.Skipper == nil { |
||||||
|
t.Skipper = func(c echo.Context) bool { return false } |
||||||
|
} |
||||||
|
if len(t.Audiences) == 0 { |
||||||
|
t.Audiences = append(t.Audiences, "*") |
||||||
|
} |
||||||
|
if t.TicketFinder == nil { |
||||||
|
t.TicketFinder = DefaultFinder |
||||||
|
} |
||||||
|
if t.ErrorHandler == nil { |
||||||
|
t.ErrorHandler = func(c echo.Context, err error) error { return err } |
||||||
|
} |
||||||
|
if t.PublicKey == nil { |
||||||
|
var err error |
||||||
|
source := []byte(env.String("TICKET_PUBLIC_KEY")) |
||||||
|
t.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source) |
||||||
|
if err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
||||||
|
if t.TicketKey == "" { |
||||||
|
t.TicketKey = "ticket" |
||||||
|
} |
||||||
|
if t.TicketClaimsKey == "" { |
||||||
|
t.TicketClaimsKey = "ticket_claims" |
||||||
|
} |
||||||
|
return func(next echo.HandlerFunc) echo.HandlerFunc { |
||||||
|
return func(c echo.Context) error { |
||||||
|
if t.Skipper(c) { |
||||||
|
return next(c) |
||||||
|
} |
||||||
|
ticketString := t.TicketFinder(c.Request()) |
||||||
|
if ticketString == "" { |
||||||
|
if t.Anonymously { |
||||||
|
return next(c) |
||||||
|
} |
||||||
|
return t.ErrorHandler(c, ErrNoTicketFound) |
||||||
|
} |
||||||
|
claims, err := Verify(ticketString, t.PublicKey, t.Audiences...) |
||||||
|
if err != nil { |
||||||
|
return t.ErrorHandler(c, err) |
||||||
|
} |
||||||
|
if t.ClaimsLooker != nil { |
||||||
|
err = t.ClaimsLooker(c, claims) |
||||||
|
if err != nil { |
||||||
|
return t.ErrorHandler(c, err) |
||||||
|
} |
||||||
|
} |
||||||
|
c.Set(t.TicketKey, ticketString) |
||||||
|
c.Set(t.TicketClaimsKey, claims) |
||||||
|
return next(c) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
@ -1,45 +0,0 @@ |
|||||||
package ticket |
|
||||||
|
|
||||||
import ( |
|
||||||
"net/http" |
|
||||||
"strings" |
|
||||||
) |
|
||||||
|
|
||||||
type Finder func(r *http.Request) string |
|
||||||
|
|
||||||
func DefaultFinder(r *http.Request) string { |
|
||||||
if s := FromHeader(r); s != "" { |
|
||||||
return s |
|
||||||
} |
|
||||||
if s := FromCookie(r); s != "" { |
|
||||||
return s |
|
||||||
} |
|
||||||
return FromQuery(r) |
|
||||||
} |
|
||||||
|
|
||||||
func FromCookie(r *http.Request) string { |
|
||||||
cookie, err := r.Cookie("ticket") |
|
||||||
if err != nil { |
|
||||||
return "" |
|
||||||
} |
|
||||||
return cookie.Value |
|
||||||
} |
|
||||||
|
|
||||||
func FromHeader(r *http.Request) string { |
|
||||||
bearer := r.Header.Get("Authorization") |
|
||||||
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { |
|
||||||
return bearer[7:] |
|
||||||
} |
|
||||||
return "" |
|
||||||
} |
|
||||||
|
|
||||||
func FromQuery(r *http.Request, keys ...string) string { |
|
||||||
q := r.URL.Query() |
|
||||||
for _, key := range keys { |
|
||||||
s := q.Get(key) |
|
||||||
if s != "" { |
|
||||||
return s |
|
||||||
} |
|
||||||
} |
|
||||||
return q.Get("ticket") |
|
||||||
} |
|
Loading…
Reference in new issue