go项目脚手架
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
sorbet/pkg/ticket/ticket.go

180 lines
4.4 KiB

package ticket
import (
"crypto/rsa"
"encoding/base64"
"errors"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/xid"
"net/http"
"slices"
"sorbet/pkg/env"
"strings"
"time"
)
var (
ErrUnauthorized = errors.New("ticket is unauthorized")
ErrInvalidTicket = errors.New("invalid ticket")
ErrTicketExpired = errors.New("ticket is expired")
ErrNoTicketFound = errors.New("no ticket found")
ErrBadAudience = errors.New("bad audience")
)
type Claims struct {
ID string `json:"jti,omitempty"` // ticket编号
UID uint `json:"uid"` // 用户编号
Role string `json:"role"` // 账号类型
Issuer string `json:"iss,omitempty"` // 签发人
Subject string `json:"sub,omitempty"` // 主题
Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众
ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间
NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间
IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间
}
// GetExpirationTime implements the Claims interface.
func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) {
return c.ExpiresAt, nil
}
// GetNotBefore implements the Claims interface.
func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) {
return c.NotBefore, nil
}
// GetIssuedAt implements the Claims interface.
func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) {
return c.IssuedAt, nil
}
// GetAudience implements the Claims interface.
func (c *Claims) GetAudience() (jwt.ClaimStrings, error) {
return c.Audience, nil
}
// GetIssuer implements the Claims interface.
func (c *Claims) GetIssuer() (string, error) {
return c.Issuer, nil
}
// GetSubject implements the Claims interface.
func (c *Claims) GetSubject() (string, error) {
return c.Subject, nil
}
// Verify 验证令牌
func Verify(ticket string, key *rsa.PublicKey, audiences ...string) (*Claims, error) {
decodeString, err := base64.StdEncoding.DecodeString(ticket)
if err != nil {
return nil, err
}
claims := new(Claims)
token, err := jwt.ParseWithClaims(string(decodeString), claims, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, ErrInvalidTicket
}
return key, nil
})
if err != nil {
if errors.Is(err, jwt.ErrTokenExpired) {
return nil, ErrTicketExpired
}
return nil, err
}
if !token.Valid {
return nil, ErrInvalidTicket
}
var pass bool
for _, audience := range audiences {
for _, allow := range claims.Audience {
if audience == allow {
pass = true
goto NEXT
}
}
}
NEXT:
if !pass && slices.Contains(claims.Audience, "*") {
pass = true
}
if !pass {
return claims, ErrBadAudience
}
return claims, nil
}
// Create 创建令牌
//
// 如果参数
// claims
// 中未给出过期时间,将默认1小时过期
// 如果未给出令牌编号,则自动生成。
func Create(claims *Claims) (string, error) {
if claims.ID == "" {
claims.ID = xid.New().String()
}
if claims.Issuer == "" {
claims.Issuer = env.String("TICKET_ISSUER")
}
if claims.Subject == "" {
claims.Issuer = env.String("TICKET_SUBJECT")
}
if claims.Audience == nil {
claims.Audience = env.List("TICKET_AUDIENCE")
}
if claims.ExpiresAt == nil {
ttl := env.Duration("TICKET_TTL", time.Hour)
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl))
}
source := []byte(env.String("TICKET_PRIVATE_KEY"))
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source)
if err != nil {
return "", err
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedString, err := token.SignedString(privateKey)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString([]byte(signedString)), nil
}
type Finder func(r *http.Request) string
func DefaultFinder(r *http.Request) string {
if s := FromHeader(r); s != "" {
return s
}
if s := FromCookie(r); s != "" {
return s
}
return FromQuery(r)
}
func FromCookie(r *http.Request) string {
cookie, err := r.Cookie("ticket")
if err != nil {
return ""
}
return cookie.Value
}
func FromHeader(r *http.Request) string {
bearer := r.Header.Get("Authorization")
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
return bearer[7:]
}
return ""
}
func FromQuery(r *http.Request, keys ...string) string {
q := r.URL.Query()
for _, key := range keys {
s := q.Get(key)
if s != "" {
return s
}
}
return q.Get("ticket")
}