package ticket import ( "crypto/rsa" "encoding/base64" "errors" "github.com/golang-jwt/jwt/v5" "github.com/rs/xid" "net/http" "slices" "sorbet/pkg/env" "strings" "time" ) var ( ErrUnauthorized = errors.New("ticket is unauthorized") ErrInvalidTicket = errors.New("invalid ticket") ErrTicketExpired = errors.New("ticket is expired") ErrNoTicketFound = errors.New("no ticket found") ErrBadAudience = errors.New("bad audience") ) type Claims struct { ID string `json:"jti,omitempty"` // ticket编号 UID uint `json:"uid"` // 用户编号 Role string `json:"role"` // 账号类型 Issuer string `json:"iss,omitempty"` // 签发人 Subject string `json:"sub,omitempty"` // 主题 Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众 ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间 NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间 IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间 } // GetExpirationTime implements the Claims interface. func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) { return c.ExpiresAt, nil } // GetNotBefore implements the Claims interface. func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) { return c.NotBefore, nil } // GetIssuedAt implements the Claims interface. func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) { return c.IssuedAt, nil } // GetAudience implements the Claims interface. func (c *Claims) GetAudience() (jwt.ClaimStrings, error) { return c.Audience, nil } // GetIssuer implements the Claims interface. func (c *Claims) GetIssuer() (string, error) { return c.Issuer, nil } // GetSubject implements the Claims interface. func (c *Claims) GetSubject() (string, error) { return c.Subject, nil } // Verify 验证令牌 func Verify(ticket string, key *rsa.PublicKey, audiences ...string) (*Claims, error) { decodeString, err := base64.StdEncoding.DecodeString(ticket) if err != nil { return nil, err } claims := new(Claims) token, err := jwt.ParseWithClaims(string(decodeString), claims, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, ErrInvalidTicket } return key, nil }) if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { return nil, ErrTicketExpired } return nil, err } if !token.Valid { return nil, ErrInvalidTicket } var pass bool for _, audience := range audiences { for _, allow := range claims.Audience { if audience == allow { pass = true goto NEXT } } } NEXT: if !pass && slices.Contains(claims.Audience, "*") { pass = true } if !pass { return claims, ErrBadAudience } return claims, nil } // Create 创建令牌 // // 如果参数 // claims // 中未给出过期时间,将默认1小时过期 // 如果未给出令牌编号,则自动生成。 func Create(claims *Claims) (string, error) { if claims.ID == "" { claims.ID = xid.New().String() } if claims.Issuer == "" { claims.Issuer = env.String("TICKET_ISSUER") } if claims.Subject == "" { claims.Issuer = env.String("TICKET_SUBJECT") } if claims.Audience == nil { claims.Audience = env.List("TICKET_AUDIENCE") } if claims.ExpiresAt == nil { ttl := env.Duration("TICKET_TTL", time.Hour) claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl)) } source := []byte(env.String("TICKET_PRIVATE_KEY")) privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source) if err != nil { return "", err } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) signedString, err := token.SignedString(privateKey) if err != nil { return "", err } return base64.StdEncoding.EncodeToString([]byte(signedString)), nil } type Finder func(r *http.Request) string func DefaultFinder(r *http.Request) string { if s := FromHeader(r); s != "" { return s } if s := FromCookie(r); s != "" { return s } return FromQuery(r) } func FromCookie(r *http.Request) string { cookie, err := r.Cookie("ticket") if err != nil { return "" } return cookie.Value } func FromHeader(r *http.Request) string { bearer := r.Header.Get("Authorization") if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { return bearer[7:] } return "" } func FromQuery(r *http.Request, keys ...string) string { q := r.URL.Query() for _, key := range keys { s := q.Get(key) if s != "" { return s } } return q.Get("ticket") }