You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
180 lines
4.4 KiB
180 lines
4.4 KiB
package ticket
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"errors"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/rs/xid"
|
|
"net/http"
|
|
"slices"
|
|
"sorbet/pkg/env"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
ErrUnauthorized = errors.New("ticket is unauthorized")
|
|
ErrInvalidTicket = errors.New("invalid ticket")
|
|
ErrTicketExpired = errors.New("ticket is expired")
|
|
ErrNoTicketFound = errors.New("no ticket found")
|
|
ErrBadAudience = errors.New("bad audience")
|
|
)
|
|
|
|
type Claims struct {
|
|
ID string `json:"jti,omitempty"` // ticket编号
|
|
UID uint `json:"uid"` // 用户编号
|
|
Role string `json:"role"` // 账号类型
|
|
Issuer string `json:"iss,omitempty"` // 签发人
|
|
Subject string `json:"sub,omitempty"` // 主题
|
|
Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众
|
|
ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间
|
|
NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间
|
|
IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间
|
|
}
|
|
|
|
// GetExpirationTime implements the Claims interface.
|
|
func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) {
|
|
return c.ExpiresAt, nil
|
|
}
|
|
|
|
// GetNotBefore implements the Claims interface.
|
|
func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) {
|
|
return c.NotBefore, nil
|
|
}
|
|
|
|
// GetIssuedAt implements the Claims interface.
|
|
func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) {
|
|
return c.IssuedAt, nil
|
|
}
|
|
|
|
// GetAudience implements the Claims interface.
|
|
func (c *Claims) GetAudience() (jwt.ClaimStrings, error) {
|
|
return c.Audience, nil
|
|
}
|
|
|
|
// GetIssuer implements the Claims interface.
|
|
func (c *Claims) GetIssuer() (string, error) {
|
|
return c.Issuer, nil
|
|
}
|
|
|
|
// GetSubject implements the Claims interface.
|
|
func (c *Claims) GetSubject() (string, error) {
|
|
return c.Subject, nil
|
|
}
|
|
|
|
// Verify 验证令牌
|
|
func Verify(ticket string, key *rsa.PublicKey, audiences ...string) (*Claims, error) {
|
|
decodeString, err := base64.StdEncoding.DecodeString(ticket)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
claims := new(Claims)
|
|
token, err := jwt.ParseWithClaims(string(decodeString), claims, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
|
|
return nil, ErrInvalidTicket
|
|
}
|
|
return key, nil
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, jwt.ErrTokenExpired) {
|
|
return nil, ErrTicketExpired
|
|
}
|
|
return nil, err
|
|
}
|
|
if !token.Valid {
|
|
return nil, ErrInvalidTicket
|
|
}
|
|
var pass bool
|
|
for _, audience := range audiences {
|
|
for _, allow := range claims.Audience {
|
|
if audience == allow {
|
|
pass = true
|
|
goto NEXT
|
|
}
|
|
}
|
|
}
|
|
NEXT:
|
|
if !pass && slices.Contains(claims.Audience, "*") {
|
|
pass = true
|
|
}
|
|
if !pass {
|
|
return claims, ErrBadAudience
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
// Create 创建令牌
|
|
//
|
|
// 如果参数
|
|
// claims
|
|
// 中未给出过期时间,将默认1小时过期
|
|
// 如果未给出令牌编号,则自动生成。
|
|
func Create(claims *Claims) (string, error) {
|
|
if claims.ID == "" {
|
|
claims.ID = xid.New().String()
|
|
}
|
|
if claims.Issuer == "" {
|
|
claims.Issuer = env.String("TICKET_ISSUER")
|
|
}
|
|
if claims.Subject == "" {
|
|
claims.Issuer = env.String("TICKET_SUBJECT")
|
|
}
|
|
if claims.Audience == nil {
|
|
claims.Audience = env.List("TICKET_AUDIENCE")
|
|
}
|
|
if claims.ExpiresAt == nil {
|
|
ttl := env.Duration("TICKET_TTL", time.Hour)
|
|
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl))
|
|
}
|
|
source := []byte(env.String("TICKET_PRIVATE_KEY"))
|
|
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
signedString, err := token.SignedString(privateKey)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base64.StdEncoding.EncodeToString([]byte(signedString)), nil
|
|
}
|
|
|
|
type Finder func(r *http.Request) string
|
|
|
|
func DefaultFinder(r *http.Request) string {
|
|
if s := FromHeader(r); s != "" {
|
|
return s
|
|
}
|
|
if s := FromCookie(r); s != "" {
|
|
return s
|
|
}
|
|
return FromQuery(r)
|
|
}
|
|
|
|
func FromCookie(r *http.Request) string {
|
|
cookie, err := r.Cookie("ticket")
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return cookie.Value
|
|
}
|
|
|
|
func FromHeader(r *http.Request) string {
|
|
bearer := r.Header.Get("Authorization")
|
|
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
|
|
return bearer[7:]
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func FromQuery(r *http.Request, keys ...string) string {
|
|
q := r.URL.Query()
|
|
for _, key := range keys {
|
|
s := q.Get(key)
|
|
if s != "" {
|
|
return s
|
|
}
|
|
}
|
|
return q.Get("ticket")
|
|
}
|
|
|