diff --git a/pkg/ticket/claims.go b/pkg/ticket/claims.go new file mode 100644 index 0000000..6bd5b8f --- /dev/null +++ b/pkg/ticket/claims.go @@ -0,0 +1,45 @@ +package ticket + +import "github.com/golang-jwt/jwt/v5" + +type Claims struct { + ID string `json:"jti,omitempty"` // ticket编号 + UID uint `json:"uid"` // 用户编号 + Role string `json:"role"` // 账号类型 + Issuer string `json:"iss,omitempty"` // 签发人 + Subject string `json:"sub,omitempty"` // 主题 + Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众 + ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间 + NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间 + IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间 +} + +// GetExpirationTime implements the Claims interface. +func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) { + return c.ExpiresAt, nil +} + +// GetNotBefore implements the Claims interface. +func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) { + return c.NotBefore, nil +} + +// GetIssuedAt implements the Claims interface. +func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) { + return c.IssuedAt, nil +} + +// GetAudience implements the Claims interface. +func (c *Claims) GetAudience() (jwt.ClaimStrings, error) { + return c.Audience, nil +} + +// GetIssuer implements the Claims interface. +func (c *Claims) GetIssuer() (string, error) { + return c.Issuer, nil +} + +// GetSubject implements the Claims interface. +func (c *Claims) GetSubject() (string, error) { + return c.Subject, nil +} diff --git a/pkg/ticket/ticket.go b/pkg/ticket/ticket.go new file mode 100644 index 0000000..4b64f79 --- /dev/null +++ b/pkg/ticket/ticket.go @@ -0,0 +1,78 @@ +package ticket + +import ( + "crypto/rsa" + "encoding/base64" + "errors" + "github.com/golang-jwt/jwt/v5" + "github.com/rs/xid" + "slices" + "time" +) + +var ( + ErrUnauthorized = errors.New("ticket is unauthorized") + ErrInvalidTicket = errors.New("invalid ticket") + ErrTicketExpired = errors.New("ticket is expired") + ErrNoTicketFound = errors.New("no ticket found") + ErrBadAudience = errors.New("bad audience") +) + +// Create 创建令牌 +// 如果参数 claims 中未给出过期时间,将默认1小时过期 +// 如果未给出令牌编号,则自动生成。 +func Create(claims *Claims, key *rsa.PrivateKey) (string, error) { + if claims.ID == "" { + claims.ID = xid.New().String() + } + if claims.ExpiresAt == nil { + claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour)) + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signedString, err := token.SignedString(key) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString([]byte(signedString)), nil +} + +// Verify 验证令牌 +func Verify(ticket string, key *rsa.PublicKey, audiences ...string) (*Claims, error) { + decodeString, err := base64.StdEncoding.DecodeString(ticket) + if err != nil { + return nil, err + } + claims := new(Claims) + token, err := jwt.ParseWithClaims(string(decodeString), claims, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, ErrInvalidTicket + } + return key, nil + }) + if err != nil { + if errors.Is(err, jwt.ErrTokenExpired) { + return nil, ErrTicketExpired + } + return nil, err + } + if !token.Valid { + return nil, ErrInvalidTicket + } + var pass bool + for _, audience := range audiences { + for _, allow := range claims.Audience { + if audience == allow { + pass = true + goto NEXT + } + } + } +NEXT: + if !pass && slices.Contains(claims.Audience, "*") { + pass = true + } + if !pass { + return claims, ErrBadAudience + } + return claims, nil +} diff --git a/pkg/ticket/util.go b/pkg/ticket/util.go new file mode 100644 index 0000000..c6d3ff0 --- /dev/null +++ b/pkg/ticket/util.go @@ -0,0 +1,45 @@ +package ticket + +import ( + "net/http" + "strings" +) + +type Finder func(r *http.Request) string + +func DefaultFinder(r *http.Request) string { + if s := FromHeader(r); s != "" { + return s + } + if s := FromCookie(r); s != "" { + return s + } + return FromQuery(r) +} + +func FromCookie(r *http.Request) string { + cookie, err := r.Cookie("ticket") + if err != nil { + return "" + } + return cookie.Value +} + +func FromHeader(r *http.Request) string { + bearer := r.Header.Get("Authorization") + if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { + return bearer[7:] + } + return "" +} + +func FromQuery(r *http.Request, keys ...string) string { + q := r.URL.Query() + for _, key := range keys { + s := q.Get(key) + if s != "" { + return s + } + } + return q.Get("ticket") +}