Compare commits
No commits in common. '95ed5bbdd4e7a2d6af94e7858901062439af71e3' and '2aadccd5b87051a8de1d89028d89a101b70b0db5' have entirely different histories.
95ed5bbdd4
...
2aadccd5b8
@ -1,11 +0,0 @@ |
|||||||
package commands |
|
||||||
|
|
||||||
import ( |
|
||||||
"github.com/urfave/cli/v2" |
|
||||||
) |
|
||||||
|
|
||||||
func CreateEntity(c *cli.Context) error { |
|
||||||
args := c.Args() |
|
||||||
|
|
||||||
args.Get(0) |
|
||||||
} |
|
@ -0,0 +1,104 @@ |
|||||||
|
package middleware |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/rsa" |
||||||
|
"github.com/golang-jwt/jwt/v5" |
||||||
|
"github.com/labstack/echo/v4" |
||||||
|
"slices" |
||||||
|
"sorbet/pkg/env" |
||||||
|
"sorbet/pkg/ticket" |
||||||
|
) |
||||||
|
|
||||||
|
var ticketPublicKey *rsa.PublicKey |
||||||
|
|
||||||
|
type TicketConfig struct { |
||||||
|
Skipper Skipper |
||||||
|
Anonymously bool |
||||||
|
Audiences []string |
||||||
|
PublicKey *rsa.PublicKey |
||||||
|
TicketFinder ticket.Finder |
||||||
|
ClaimsLooker func(c echo.Context, claims *ticket.Claims) error |
||||||
|
ErrorHandler func(c echo.Context, err error) error |
||||||
|
SuccessHandler func(c echo.Context) |
||||||
|
} |
||||||
|
|
||||||
|
func Ticket(anonymously bool, roles ...string) echo.MiddlewareFunc { |
||||||
|
return TicketWithConfig(TicketConfig{ |
||||||
|
Anonymously: anonymously, |
||||||
|
TicketFinder: ticket.DefaultFinder, |
||||||
|
ClaimsLooker: func(c echo.Context, claims *ticket.Claims) error { |
||||||
|
if len(roles) > 0 && slices.Contains(roles, claims.Role) { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return ticket.ErrUnauthorized |
||||||
|
}, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func TicketWithConfig(config TicketConfig) echo.MiddlewareFunc { |
||||||
|
return config.ToMiddleware() |
||||||
|
} |
||||||
|
|
||||||
|
func (t *TicketConfig) ToMiddleware() echo.MiddlewareFunc { |
||||||
|
if t.Skipper == nil { |
||||||
|
t.Skipper = DefaultSkipper |
||||||
|
} |
||||||
|
if len(t.Audiences) == 0 { |
||||||
|
t.Audiences = append(t.Audiences, "*") |
||||||
|
} |
||||||
|
if t.TicketFinder == nil { |
||||||
|
t.TicketFinder = ticket.DefaultFinder |
||||||
|
} |
||||||
|
if t.ErrorHandler == nil { |
||||||
|
t.ErrorHandler = func(c echo.Context, err error) error { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
if t.ClaimsLooker == nil { |
||||||
|
t.ClaimsLooker = func(c echo.Context, claims *ticket.Claims) error { |
||||||
|
return nil |
||||||
|
} |
||||||
|
} |
||||||
|
return func(next echo.HandlerFunc) echo.HandlerFunc { |
||||||
|
return func(c echo.Context) error { |
||||||
|
if t.Skipper(c) { |
||||||
|
return next(c) |
||||||
|
} |
||||||
|
ticketString := t.TicketFinder(c.Request()) |
||||||
|
if ticketString == "" { |
||||||
|
if t.Anonymously { |
||||||
|
return next(c) |
||||||
|
} |
||||||
|
return t.ErrorHandler(c, ticket.ErrNoTicketFound) |
||||||
|
} |
||||||
|
publicKey := t.PublicKey |
||||||
|
if publicKey == nil { |
||||||
|
key, err := getTicketPublicKey() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
publicKey = key |
||||||
|
} |
||||||
|
claims, err := ticket.Verify(ticketString, publicKey, t.Audiences...) |
||||||
|
if err != nil { |
||||||
|
return t.ErrorHandler(c, err) |
||||||
|
} |
||||||
|
if err = t.ClaimsLooker(c, claims); err != nil { |
||||||
|
return t.ErrorHandler(c, err) |
||||||
|
} |
||||||
|
c.Set("ticket", ticketString) |
||||||
|
c.Set("ticket_claims", claims) |
||||||
|
return next(c) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func getTicketPublicKey() (*rsa.PublicKey, error) { |
||||||
|
if ticketPublicKey != nil { |
||||||
|
return ticketPublicKey, nil |
||||||
|
} |
||||||
|
var err error |
||||||
|
source := []byte(env.String("TICKET_PUBLIC_KEY")) |
||||||
|
ticketPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source) |
||||||
|
return ticketPublicKey, err |
||||||
|
} |
@ -0,0 +1,38 @@ |
|||||||
|
package util |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/golang-jwt/jwt/v5" |
||||||
|
"github.com/rs/xid" |
||||||
|
"sorbet/pkg/env" |
||||||
|
"sorbet/pkg/ticket" |
||||||
|
"time" |
||||||
|
) |
||||||
|
|
||||||
|
func CreateTicket(claims *ticket.Claims) (string, error) { |
||||||
|
if claims.ID == "" { |
||||||
|
claims.ID = xid.New().String() |
||||||
|
} |
||||||
|
if claims.Issuer == "" { |
||||||
|
claims.Issuer = env.String("TICKET_ISSUER") |
||||||
|
} |
||||||
|
if claims.Subject == "" { |
||||||
|
claims.Issuer = env.String("TICKET_SUBJECT") |
||||||
|
} |
||||||
|
if claims.Audience == nil { |
||||||
|
claims.Audience = env.List("TICKET_AUDIENCE") |
||||||
|
} |
||||||
|
if claims.ExpiresAt == nil { |
||||||
|
ttl := env.Duration("TICKET_TTL", time.Hour) |
||||||
|
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl)) |
||||||
|
} |
||||||
|
source := []byte(env.String("TICKET_PRIVATE_KEY")) |
||||||
|
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source) |
||||||
|
if err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
signedString, err := ticket.Create(claims, privateKey) |
||||||
|
if err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
return signedString, nil |
||||||
|
} |
@ -0,0 +1,45 @@ |
|||||||
|
package ticket |
||||||
|
|
||||||
|
import "github.com/golang-jwt/jwt/v5" |
||||||
|
|
||||||
|
type Claims struct { |
||||||
|
ID string `json:"jti,omitempty"` // ticket编号
|
||||||
|
UID uint `json:"uid"` // 用户编号
|
||||||
|
Role string `json:"role"` // 账号类型
|
||||||
|
Issuer string `json:"iss,omitempty"` // 签发人
|
||||||
|
Subject string `json:"sub,omitempty"` // 主题
|
||||||
|
Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众
|
||||||
|
ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间
|
||||||
|
NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间
|
||||||
|
IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间
|
||||||
|
} |
||||||
|
|
||||||
|
// GetExpirationTime implements the Claims interface.
|
||||||
|
func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) { |
||||||
|
return c.ExpiresAt, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetNotBefore implements the Claims interface.
|
||||||
|
func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) { |
||||||
|
return c.NotBefore, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetIssuedAt implements the Claims interface.
|
||||||
|
func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) { |
||||||
|
return c.IssuedAt, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetAudience implements the Claims interface.
|
||||||
|
func (c *Claims) GetAudience() (jwt.ClaimStrings, error) { |
||||||
|
return c.Audience, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetIssuer implements the Claims interface.
|
||||||
|
func (c *Claims) GetIssuer() (string, error) { |
||||||
|
return c.Issuer, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetSubject implements the Claims interface.
|
||||||
|
func (c *Claims) GetSubject() (string, error) { |
||||||
|
return c.Subject, nil |
||||||
|
} |
@ -1,95 +0,0 @@ |
|||||||
package ticket |
|
||||||
|
|
||||||
import ( |
|
||||||
"crypto/rsa" |
|
||||||
"github.com/golang-jwt/jwt/v5" |
|
||||||
"github.com/labstack/echo/v4" |
|
||||||
"slices" |
|
||||||
"sorbet/pkg/env" |
|
||||||
) |
|
||||||
|
|
||||||
type Config struct { |
|
||||||
Skipper func(c echo.Context) bool |
|
||||||
Anonymously bool // 是否允许匿名访问
|
|
||||||
Audiences []string |
|
||||||
PublicKey *rsa.PublicKey |
|
||||||
TicketKey string |
|
||||||
TicketClaimsKey string |
|
||||||
TicketFinder Finder |
|
||||||
ClaimsLooker func(c echo.Context, claims *Claims) error |
|
||||||
ErrorHandler func(c echo.Context, err error) error |
|
||||||
SuccessHandler func(c echo.Context) |
|
||||||
} |
|
||||||
|
|
||||||
func Middleware(anonymously bool, roles ...string) echo.MiddlewareFunc { |
|
||||||
return WithConfig(Config{ |
|
||||||
Anonymously: anonymously, |
|
||||||
TicketFinder: DefaultFinder, |
|
||||||
ClaimsLooker: func(c echo.Context, claims *Claims) error { |
|
||||||
if len(roles) > 0 && slices.Contains(roles, claims.Role) { |
|
||||||
return nil |
|
||||||
} |
|
||||||
return ErrUnauthorized |
|
||||||
}, |
|
||||||
}) |
|
||||||
} |
|
||||||
|
|
||||||
func WithConfig(config Config) echo.MiddlewareFunc { |
|
||||||
return config.ToMiddleware() |
|
||||||
} |
|
||||||
|
|
||||||
func (t Config) ToMiddleware() echo.MiddlewareFunc { |
|
||||||
if t.Skipper == nil { |
|
||||||
t.Skipper = func(c echo.Context) bool { return false } |
|
||||||
} |
|
||||||
if len(t.Audiences) == 0 { |
|
||||||
t.Audiences = append(t.Audiences, "*") |
|
||||||
} |
|
||||||
if t.TicketFinder == nil { |
|
||||||
t.TicketFinder = DefaultFinder |
|
||||||
} |
|
||||||
if t.ErrorHandler == nil { |
|
||||||
t.ErrorHandler = func(c echo.Context, err error) error { return err } |
|
||||||
} |
|
||||||
if t.PublicKey == nil { |
|
||||||
var err error |
|
||||||
source := []byte(env.String("TICKET_PUBLIC_KEY")) |
|
||||||
t.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source) |
|
||||||
if err != nil { |
|
||||||
panic(err) |
|
||||||
} |
|
||||||
} |
|
||||||
if t.TicketKey == "" { |
|
||||||
t.TicketKey = "ticket" |
|
||||||
} |
|
||||||
if t.TicketClaimsKey == "" { |
|
||||||
t.TicketClaimsKey = "ticket_claims" |
|
||||||
} |
|
||||||
return func(next echo.HandlerFunc) echo.HandlerFunc { |
|
||||||
return func(c echo.Context) error { |
|
||||||
if t.Skipper(c) { |
|
||||||
return next(c) |
|
||||||
} |
|
||||||
ticketString := t.TicketFinder(c.Request()) |
|
||||||
if ticketString == "" { |
|
||||||
if t.Anonymously { |
|
||||||
return next(c) |
|
||||||
} |
|
||||||
return t.ErrorHandler(c, ErrNoTicketFound) |
|
||||||
} |
|
||||||
claims, err := Verify(ticketString, t.PublicKey, t.Audiences...) |
|
||||||
if err != nil { |
|
||||||
return t.ErrorHandler(c, err) |
|
||||||
} |
|
||||||
if t.ClaimsLooker != nil { |
|
||||||
err = t.ClaimsLooker(c, claims) |
|
||||||
if err != nil { |
|
||||||
return t.ErrorHandler(c, err) |
|
||||||
} |
|
||||||
} |
|
||||||
c.Set(t.TicketKey, ticketString) |
|
||||||
c.Set(t.TicketClaimsKey, claims) |
|
||||||
return next(c) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
@ -0,0 +1,45 @@ |
|||||||
|
package ticket |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
"strings" |
||||||
|
) |
||||||
|
|
||||||
|
type Finder func(r *http.Request) string |
||||||
|
|
||||||
|
func DefaultFinder(r *http.Request) string { |
||||||
|
if s := FromHeader(r); s != "" { |
||||||
|
return s |
||||||
|
} |
||||||
|
if s := FromCookie(r); s != "" { |
||||||
|
return s |
||||||
|
} |
||||||
|
return FromQuery(r) |
||||||
|
} |
||||||
|
|
||||||
|
func FromCookie(r *http.Request) string { |
||||||
|
cookie, err := r.Cookie("ticket") |
||||||
|
if err != nil { |
||||||
|
return "" |
||||||
|
} |
||||||
|
return cookie.Value |
||||||
|
} |
||||||
|
|
||||||
|
func FromHeader(r *http.Request) string { |
||||||
|
bearer := r.Header.Get("Authorization") |
||||||
|
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { |
||||||
|
return bearer[7:] |
||||||
|
} |
||||||
|
return "" |
||||||
|
} |
||||||
|
|
||||||
|
func FromQuery(r *http.Request, keys ...string) string { |
||||||
|
q := r.URL.Query() |
||||||
|
for _, key := range keys { |
||||||
|
s := q.Get(key) |
||||||
|
if s != "" { |
||||||
|
return s |
||||||
|
} |
||||||
|
} |
||||||
|
return q.Get("ticket") |
||||||
|
} |
Loading…
Reference in new issue