parent
59ae7e72f0
commit
adb24eaf5d
@ -0,0 +1,45 @@ |
|||||||
|
package ticket |
||||||
|
|
||||||
|
import "github.com/golang-jwt/jwt/v5" |
||||||
|
|
||||||
|
type Claims struct { |
||||||
|
ID string `json:"jti,omitempty"` // ticket编号
|
||||||
|
UID uint `json:"uid"` // 用户编号
|
||||||
|
Role string `json:"role"` // 账号类型
|
||||||
|
Issuer string `json:"iss,omitempty"` // 签发人
|
||||||
|
Subject string `json:"sub,omitempty"` // 主题
|
||||||
|
Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众
|
||||||
|
ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间
|
||||||
|
NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间
|
||||||
|
IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间
|
||||||
|
} |
||||||
|
|
||||||
|
// GetExpirationTime implements the Claims interface.
|
||||||
|
func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) { |
||||||
|
return c.ExpiresAt, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetNotBefore implements the Claims interface.
|
||||||
|
func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) { |
||||||
|
return c.NotBefore, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetIssuedAt implements the Claims interface.
|
||||||
|
func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) { |
||||||
|
return c.IssuedAt, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetAudience implements the Claims interface.
|
||||||
|
func (c *Claims) GetAudience() (jwt.ClaimStrings, error) { |
||||||
|
return c.Audience, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetIssuer implements the Claims interface.
|
||||||
|
func (c *Claims) GetIssuer() (string, error) { |
||||||
|
return c.Issuer, nil |
||||||
|
} |
||||||
|
|
||||||
|
// GetSubject implements the Claims interface.
|
||||||
|
func (c *Claims) GetSubject() (string, error) { |
||||||
|
return c.Subject, nil |
||||||
|
} |
@ -0,0 +1,78 @@ |
|||||||
|
package ticket |
||||||
|
|
||||||
|
import ( |
||||||
|
"crypto/rsa" |
||||||
|
"encoding/base64" |
||||||
|
"errors" |
||||||
|
"github.com/golang-jwt/jwt/v5" |
||||||
|
"github.com/rs/xid" |
||||||
|
"slices" |
||||||
|
"time" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
ErrUnauthorized = errors.New("ticket is unauthorized") |
||||||
|
ErrInvalidTicket = errors.New("invalid ticket") |
||||||
|
ErrTicketExpired = errors.New("ticket is expired") |
||||||
|
ErrNoTicketFound = errors.New("no ticket found") |
||||||
|
ErrBadAudience = errors.New("bad audience") |
||||||
|
) |
||||||
|
|
||||||
|
// Create 创建令牌
|
||||||
|
// 如果参数 claims 中未给出过期时间,将默认1小时过期
|
||||||
|
// 如果未给出令牌编号,则自动生成。
|
||||||
|
func Create(claims *Claims, key *rsa.PrivateKey) (string, error) { |
||||||
|
if claims.ID == "" { |
||||||
|
claims.ID = xid.New().String() |
||||||
|
} |
||||||
|
if claims.ExpiresAt == nil { |
||||||
|
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour)) |
||||||
|
} |
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) |
||||||
|
signedString, err := token.SignedString(key) |
||||||
|
if err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
return base64.StdEncoding.EncodeToString([]byte(signedString)), nil |
||||||
|
} |
||||||
|
|
||||||
|
// Verify 验证令牌
|
||||||
|
func Verify(ticket string, key *rsa.PublicKey, audiences ...string) (*Claims, error) { |
||||||
|
decodeString, err := base64.StdEncoding.DecodeString(ticket) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
claims := new(Claims) |
||||||
|
token, err := jwt.ParseWithClaims(string(decodeString), claims, func(token *jwt.Token) (interface{}, error) { |
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { |
||||||
|
return nil, ErrInvalidTicket |
||||||
|
} |
||||||
|
return key, nil |
||||||
|
}) |
||||||
|
if err != nil { |
||||||
|
if errors.Is(err, jwt.ErrTokenExpired) { |
||||||
|
return nil, ErrTicketExpired |
||||||
|
} |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if !token.Valid { |
||||||
|
return nil, ErrInvalidTicket |
||||||
|
} |
||||||
|
var pass bool |
||||||
|
for _, audience := range audiences { |
||||||
|
for _, allow := range claims.Audience { |
||||||
|
if audience == allow { |
||||||
|
pass = true |
||||||
|
goto NEXT |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
NEXT: |
||||||
|
if !pass && slices.Contains(claims.Audience, "*") { |
||||||
|
pass = true |
||||||
|
} |
||||||
|
if !pass { |
||||||
|
return claims, ErrBadAudience |
||||||
|
} |
||||||
|
return claims, nil |
||||||
|
} |
@ -0,0 +1,45 @@ |
|||||||
|
package ticket |
||||||
|
|
||||||
|
import ( |
||||||
|
"net/http" |
||||||
|
"strings" |
||||||
|
) |
||||||
|
|
||||||
|
type Finder func(r *http.Request) string |
||||||
|
|
||||||
|
func DefaultFinder(r *http.Request) string { |
||||||
|
if s := FromHeader(r); s != "" { |
||||||
|
return s |
||||||
|
} |
||||||
|
if s := FromCookie(r); s != "" { |
||||||
|
return s |
||||||
|
} |
||||||
|
return FromQuery(r) |
||||||
|
} |
||||||
|
|
||||||
|
func FromCookie(r *http.Request) string { |
||||||
|
cookie, err := r.Cookie("ticket") |
||||||
|
if err != nil { |
||||||
|
return "" |
||||||
|
} |
||||||
|
return cookie.Value |
||||||
|
} |
||||||
|
|
||||||
|
func FromHeader(r *http.Request) string { |
||||||
|
bearer := r.Header.Get("Authorization") |
||||||
|
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { |
||||||
|
return bearer[7:] |
||||||
|
} |
||||||
|
return "" |
||||||
|
} |
||||||
|
|
||||||
|
func FromQuery(r *http.Request, keys ...string) string { |
||||||
|
q := r.URL.Query() |
||||||
|
for _, key := range keys { |
||||||
|
s := q.Get(key) |
||||||
|
if s != "" { |
||||||
|
return s |
||||||
|
} |
||||||
|
} |
||||||
|
return q.Get("ticket") |
||||||
|
} |
Loading…
Reference in new issue