package ticket import ( "crypto/rsa" "encoding/base64" "errors" "github.com/golang-jwt/jwt/v5" "github.com/rs/xid" "slices" "time" ) var ( ErrUnauthorized = errors.New("ticket is unauthorized") ErrInvalidTicket = errors.New("invalid ticket") ErrTicketExpired = errors.New("ticket is expired") ErrNoTicketFound = errors.New("no ticket found") ErrBadAudience = errors.New("bad audience") ) // Create 创建令牌 // 如果参数 claims 中未给出过期时间,将默认1小时过期 // 如果未给出令牌编号,则自动生成。 func Create(claims *Claims, key *rsa.PrivateKey) (string, error) { if claims.ID == "" { claims.ID = xid.New().String() } if claims.ExpiresAt == nil { claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour)) } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) signedString, err := token.SignedString(key) if err != nil { return "", err } return base64.StdEncoding.EncodeToString([]byte(signedString)), nil } // Verify 验证令牌 func Verify(ticket string, key *rsa.PublicKey, audiences ...string) (*Claims, error) { decodeString, err := base64.StdEncoding.DecodeString(ticket) if err != nil { return nil, err } claims := new(Claims) token, err := jwt.ParseWithClaims(string(decodeString), claims, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { return nil, ErrInvalidTicket } return key, nil }) if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { return nil, ErrTicketExpired } return nil, err } if !token.Valid { return nil, ErrInvalidTicket } var pass bool for _, audience := range audiences { for _, allow := range claims.Audience { if audience == allow { pass = true goto NEXT } } } NEXT: if !pass && slices.Contains(claims.Audience, "*") { pass = true } if !pass { return claims, ErrBadAudience } return claims, nil }