feat: rsp 库支持 ticket

main
熊二 1 year ago
parent 2aadccd5b8
commit a8d03f5a8f
  1. 4
      internal/init.go
  2. 104
      internal/middleware/ticket.go
  3. 38
      internal/util/ticket.go
  4. 63
      pkg/rsp/error.go
  5. 130
      pkg/rsp/respond_utils.go
  6. 45
      pkg/ticket/claims.go
  7. 95
      pkg/ticket/middleware.go
  8. 134
      pkg/ticket/ticket.go
  9. 45
      pkg/ticket/util.go

@ -67,7 +67,7 @@ func Start() error {
if err != nil { if err != nil {
return err return err
} }
token, err := util.CreateTicket(&ticket.Claims{ token, err := ticket.Create(&ticket.Claims{
UID: company.ID, UID: company.ID,
Role: "system", Role: "system",
Issuer: "chshs", Issuer: "chshs",
@ -80,7 +80,7 @@ func Start() error {
"token": token, "token": token,
}) })
}) })
e.Group("", middleware.Ticket(false, "system")).GET("/u", func(c echo.Context) error { e.Group("", ticket.Middleware(false, "system")).GET("/u", func(c echo.Context) error {
return rsp.Ok(c, echo.Map{ return rsp.Ok(c, echo.Map{
"ticket": c.Get("ticket"), "ticket": c.Get("ticket"),
"claims": c.Get("ticket_claims").(*ticket.Claims), "claims": c.Get("ticket_claims").(*ticket.Claims),

@ -1,104 +0,0 @@
package middleware
import (
"crypto/rsa"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"slices"
"sorbet/pkg/env"
"sorbet/pkg/ticket"
)
var ticketPublicKey *rsa.PublicKey
type TicketConfig struct {
Skipper Skipper
Anonymously bool
Audiences []string
PublicKey *rsa.PublicKey
TicketFinder ticket.Finder
ClaimsLooker func(c echo.Context, claims *ticket.Claims) error
ErrorHandler func(c echo.Context, err error) error
SuccessHandler func(c echo.Context)
}
func Ticket(anonymously bool, roles ...string) echo.MiddlewareFunc {
return TicketWithConfig(TicketConfig{
Anonymously: anonymously,
TicketFinder: ticket.DefaultFinder,
ClaimsLooker: func(c echo.Context, claims *ticket.Claims) error {
if len(roles) > 0 && slices.Contains(roles, claims.Role) {
return nil
}
return ticket.ErrUnauthorized
},
})
}
func TicketWithConfig(config TicketConfig) echo.MiddlewareFunc {
return config.ToMiddleware()
}
func (t *TicketConfig) ToMiddleware() echo.MiddlewareFunc {
if t.Skipper == nil {
t.Skipper = DefaultSkipper
}
if len(t.Audiences) == 0 {
t.Audiences = append(t.Audiences, "*")
}
if t.TicketFinder == nil {
t.TicketFinder = ticket.DefaultFinder
}
if t.ErrorHandler == nil {
t.ErrorHandler = func(c echo.Context, err error) error {
return err
}
}
if t.ClaimsLooker == nil {
t.ClaimsLooker = func(c echo.Context, claims *ticket.Claims) error {
return nil
}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if t.Skipper(c) {
return next(c)
}
ticketString := t.TicketFinder(c.Request())
if ticketString == "" {
if t.Anonymously {
return next(c)
}
return t.ErrorHandler(c, ticket.ErrNoTicketFound)
}
publicKey := t.PublicKey
if publicKey == nil {
key, err := getTicketPublicKey()
if err != nil {
return err
}
publicKey = key
}
claims, err := ticket.Verify(ticketString, publicKey, t.Audiences...)
if err != nil {
return t.ErrorHandler(c, err)
}
if err = t.ClaimsLooker(c, claims); err != nil {
return t.ErrorHandler(c, err)
}
c.Set("ticket", ticketString)
c.Set("ticket_claims", claims)
return next(c)
}
}
}
func getTicketPublicKey() (*rsa.PublicKey, error) {
if ticketPublicKey != nil {
return ticketPublicKey, nil
}
var err error
source := []byte(env.String("TICKET_PUBLIC_KEY"))
ticketPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source)
return ticketPublicKey, err
}

@ -1,38 +0,0 @@
package util
import (
"github.com/golang-jwt/jwt/v5"
"github.com/rs/xid"
"sorbet/pkg/env"
"sorbet/pkg/ticket"
"time"
)
func CreateTicket(claims *ticket.Claims) (string, error) {
if claims.ID == "" {
claims.ID = xid.New().String()
}
if claims.Issuer == "" {
claims.Issuer = env.String("TICKET_ISSUER")
}
if claims.Subject == "" {
claims.Issuer = env.String("TICKET_SUBJECT")
}
if claims.Audience == nil {
claims.Audience = env.List("TICKET_AUDIENCE")
}
if claims.ExpiresAt == nil {
ttl := env.Duration("TICKET_TTL", time.Hour)
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl))
}
source := []byte(env.String("TICKET_PRIVATE_KEY"))
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source)
if err != nil {
return "", err
}
signedString, err := ticket.Create(claims, privateKey)
if err != nil {
return "", err
}
return signedString, nil
}

@ -2,73 +2,75 @@ package rsp
import ( import (
"fmt" "fmt"
"github.com/labstack/echo/v4" "net/http"
"strings" "strings"
) )
var ( var (
// ErrOK 表示没有任何错误。 // ErrOK 表示没有任何错误。
// 对应 HTTP 响应状态码为 500。 // 对应 HTTP 响应状态码为 500。
ErrOK = NewError(0, "") ErrOK = NewError(http.StatusOK, 0, "OK")
// ErrInternal 客户端请求有效,但服务器处理时发生了意外。 // ErrInternal 客户端请求有效,但服务器处理时发生了意外。
// 对应 HTTP 响应状态码为 500。 // 对应 HTTP 响应状态码为 500。
ErrInternal = NewError(-100, "系统内部错误") ErrInternal = NewError(http.StatusInternalServerError, -100, "系统内部错误")
// ErrServiceUnavailable 服务器无法处理请求,一般用于网站维护状态。 // ErrServiceUnavailable 服务器无法处理请求,一般用于网站维护状态。
// 对应 HTTP 响应状态码为 503。 // 对应 HTTP 响应状态码为 503。
ErrServiceUnavailable = NewError(-101, "Service Unavailable") ErrServiceUnavailable = NewError(http.StatusServiceUnavailable, -101, "服务不可用")
// ErrUnauthorized 用户未提供身份验证凭据,或者没有通过身份验证。 // ErrUnauthorized 用户未提供身份验证凭据,或者没有通过身份验证。
// 响应的 HTTP 状态码为 401。 // 响应的 HTTP 状态码为 401。
ErrUnauthorized = NewError(-102, "unauthorized") ErrUnauthorized = NewError(http.StatusUnauthorized, -102, "身份验证失败")
// ErrForbidden 用户通过了身份验证,但是不具有访问资源所需的权限。 // ErrForbidden 用户通过了身份验证,但是不具有访问资源所需的权限。
// 响应的 HTTP 状态码为 403。 // 响应的 HTTP 状态码为 403。
ErrForbidden = NewError(-103, "Forbidden") ErrForbidden = NewError(http.StatusForbidden, -103, "不具有访问资源所需的权限")
// ErrGone 所请求的资源已从这个地址转移,不再可用。 // ErrGone 所请求的资源已从这个地址转移,不再可用。
// 响应的 HTTP 状态码为 410。 // 响应的 HTTP 状态码为 410。
ErrGone = NewError(-104, "Gone") ErrGone = NewError(http.StatusGone, -104, "所请求的资源不存在")
// ErrUnsupportedMediaType 客户端要求的返回格式不支持。 // ErrUnsupportedMediaType 客户端要求的返回格式不支持。
// 比如,API 只能返回 JSON 格式,但是客户端要求返回 XML 格式。 // 比如,API 只能返回 JSON 格式,但是客户端要求返回 XML 格式。
// 响应的 HTTP 状态码为 415。 // 响应的 HTTP 状态码为 415。
ErrUnsupportedMediaType = NewError(-105, "Unsupported Media Type") ErrUnsupportedMediaType = NewError(http.StatusUnsupportedMediaType, -105, "请求的数据格式错误")
// ErrUnprocessableEntity 无法处理客户端上传的附件,导致请求失败。 // ErrUnprocessableEntity 无法处理客户端上传的附件,导致请求失败。
// 响应的 HTTP 状态码为 422。 // 响应的 HTTP 状态码为 422。
ErrUnprocessableEntity = NewError(-106, "Unprocessable Entity") ErrUnprocessableEntity = NewError(http.StatusUnprocessableEntity, -106, "上传了不被支持的附件")
// ErrTooManyRequests 客户端的请求次数超过限额。 // ErrTooManyRequests 客户端的请求次数超过限额。
// 响应的 HTTP 状态码为 422 // 响应的 HTTP 状态码为 429
ErrTooManyRequests = NewError(-107, "Too Many Requests") ErrTooManyRequests = NewError(http.StatusTooManyRequests, -107, "请求次数超过限额")
// ErrSeeOther 表示需要参考另一个 URL 才能完成接收的请求操作, // ErrSeeOther 表示需要参考另一个 URL 才能完成接收的请求操作,
// 当请求方式使用 POST、PUT 和 DELETE 时,对应的 HTTP 状态码为 303, // 当请求方式使用 POST、PUT 和 DELETE 时,对应的 HTTP 状态码为 303,
// 其它的请求方式在大多数情况下应该使用 400 状态码。 // 其它的请求方式在大多数情况下应该使用 400 状态码。
ErrSeeOther = NewError(-108, "see other") ErrSeeOther = NewError(http.StatusSeeOther, -108, "需要更进一步才能完成操作")
// ErrBadRequest 服务器不理解客户端的请求。 // ErrBadRequest 服务器不理解客户端的请求。
// 对应 HTTP 状态码为 404 // 对应 HTTP 状态码为 400
ErrBadRequest = NewError(-109, "bad request") ErrBadRequest = NewError(http.StatusBadRequest, -109, "请求错误")
// ErrBadParams 客户端提交的参数不符合要求 // ErrBadParams 客户端提交的参数不符合要求
// 对应 HTTP 状态码为 400。 // 对应 HTTP 状态码为 400。
ErrBadParams = NewError(-110, "参数错误") ErrBadParams = NewError(http.StatusBadRequest, -110, "参数错误")
// ErrRecordNotFound 访问的数据不存在 // ErrRecordNotFound 访问的数据不存在
// 对应 HTTP 状态码为 404。 // 对应 HTTP 状态码为 404。
ErrRecordNotFound = NewError(-111, "record not found") ErrRecordNotFound = NewError(http.StatusNotFound, -111, "访问的数据不存在")
) )
type Error struct { type Error struct {
code int internal error
text string status int // HTTP 状态码
code int // 请求错误码
text string // 响应提示消息
} }
func NewError(code int, text string) *Error { func NewError(status, code int, text string) *Error {
return &Error{code, text} return &Error{nil, status, code, text}
} }
func (e *Error) Code() int { func (e *Error) Code() int {
@ -79,10 +81,24 @@ func (e *Error) Text() string {
return e.text return e.text
} }
func (e *Error) WithInternal(err error) *Error {
c := *e
c.internal = err
return &c
}
func (e *Error) WithStatus(status int) *Error {
c := *e
c.status = status
return &c
}
func (e *Error) WithText(text ...string) *Error { func (e *Error) WithText(text ...string) *Error {
for _, s := range text { for _, s := range text {
if s != "" { if s != "" {
return NewError(e.code, s) c := *e
c.text = s
return &c
} }
} }
return e return e
@ -97,11 +113,6 @@ func (e *Error) AsProblem(label string) *Problem {
} }
} }
func (e *Error) AsHttpError(code int) *echo.HTTPError {
he := echo.NewHTTPError(code, e.text)
return he.WithInternal(e)
}
func (e *Error) String() string { func (e *Error) String() string {
return strings.TrimSpace(fmt.Sprintf("%d %s", e.code, e.text)) return strings.TrimSpace(fmt.Sprintf("%d %s", e.code, e.text))
} }

@ -8,6 +8,7 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"net/http" "net/http"
"runtime" "runtime"
"sorbet/pkg/ticket"
"sorbet/pkg/v" "sorbet/pkg/v"
) )
@ -52,7 +53,7 @@ func (r *RespondData[T]) RespondValue() any {
} }
type response struct { type response struct {
code int status int
headers map[string]string headers map[string]string
cookies []*http.Cookie cookies []*http.Cookie
err error err error
@ -62,9 +63,9 @@ type response struct {
type Option func(o *response) type Option func(o *response)
func StatusCode(code int) Option { func StatusCode(status int) Option {
return func(o *response) { return func(o *response) {
o.code = code o.status = status
} }
} }
@ -103,57 +104,99 @@ func Data(data any) Option {
} }
} }
func respond(c echo.Context, o *response) error { func (r *response) result(c echo.Context) (m map[string]any, status int) {
defer func() { status = r.status
if o.err != nil { m = map[string]any{
c.Logger().Error(o.err)
}
}()
// 返回的数据
m := map[string]any{
"code": nil, "code": nil,
"success": false, "success": false,
"message": o.message, "message": r.message,
} }
var success bool isDebug := c.Echo().Debug
if err, ok := o.err.(*v.Errors); ok { var err error
if ee, ok := r.err.(*v.Errors); ok {
pb := &Problem{} pb := &Problem{}
for _, e := range err.All() { for _, e := range ee.All() {
pb.AddSubproblem(ErrBadParams.WithText(e.Error()).AsProblem(e.Field())) pb.AddSubproblem(ErrBadParams.WithText(e.Error()).AsProblem(e.Field()))
} }
m["code"] = ErrBadParams.code m["code"] = ErrBadParams.code
m["message"] = ErrBadParams.text m["message"] = ErrBadParams.text
m["problems"] = pb.Problems m["problems"] = pb.Problems
} else if ex, yes := o.err.(*Error); yes { if status < 400 {
status = ErrBadParams.status
}
} else if ex, yes := r.err.(*Error); yes {
err = ex.internal
m["code"] = ex.code m["code"] = ex.code
m["success"] = errors.Is(ex, ErrOK)
m["message"] = ex.text m["message"] = ex.text
success = errors.Is(ex, ErrOK) if status < 400 {
} else if pb, okay := o.err.(*Problem); okay { status = ex.status
}
} else if pb, okay := r.err.(*Problem); okay {
m["code"] = pb.Code m["code"] = pb.Code
m["message"] = pb.Message m["message"] = pb.Message
m["problems"] = pb.Problems m["problems"] = pb.Problems
} else if o.err != nil { if status < 400 {
m["code"] = ErrInternal.code status = ErrInternal.status
m["message"] = o.err.Error() }
} else { } else if r.err == nil {
success = true
m["code"] = 0 m["code"] = 0
if o.data != nil { m["success"] = true
if val, ok := o.data.(RespondValuer); ok { if r.data != nil {
if val, ok := r.data.(RespondValuer); ok {
m["data"] = val.RespondValue() m["data"] = val.RespondValue()
} else { } else {
m["data"] = o.data m["data"] = r.data
}
}
} else {
err = r.err
switch {
case errors.Is(r.err, ticket.ErrNoTicketFound),
errors.Is(r.err, ticket.ErrInvalidTicket),
errors.Is(r.err, ticket.ErrBadAudience),
errors.Is(r.err, ticket.ErrTicketExpired),
errors.Is(r.err, ticket.ErrUnauthorized):
m["code"] = ErrUnauthorized.Code()
if isDebug {
m["message"] = r.err.Error()
} else {
m["message"] = ErrUnauthorized.Code()
}
if status < 400 {
status = ErrUnauthorized.status
}
default:
var he *echo.HTTPError
if errors.As(err, &he) {
status = he.Code
err = he.Internal
}
m["code"] = ErrInternal.Code()
m["message"] = r.err.Error()
if status < 400 {
status = ErrInternal.status
} }
} }
} }
m["success"] = success if c.Echo().Debug && m["success"] != true {
if !success && c.Echo().Debug { if err != nil {
m["error"] = relevantCaller() m["error"] = err.Error()
}
m["stack"] = relevantCaller()
} }
if m["message"] == "" { if m["message"] == "" {
m["message"] = http.StatusText(o.code) m["message"] = http.StatusText(status)
} }
return
}
func respond(c echo.Context, o *response) error {
defer func() {
if o.err != nil {
c.Logger().Error(o.err)
}
}()
// 如果已经输出过,就忽略 // 如果已经输出过,就忽略
if c.Response().Committed { if c.Response().Committed {
return nil return nil
@ -171,10 +214,11 @@ func respond(c echo.Context, o *response) error {
c.SetCookie(cookie) c.SetCookie(cookie)
} }
} }
m, status := o.result(c)
// HEAD 请求没有结果 // HEAD 请求没有结果
r := c.Request() r := c.Request()
if r.Method == http.MethodHead { if r.Method == http.MethodHead {
return c.NoContent(o.code) return c.NoContent(status)
} }
// 根据报头响应不同的格式 // 根据报头响应不同的格式
accept := r.Header.Get(echo.HeaderAccept) accept := r.Header.Get(echo.HeaderAccept)
@ -183,28 +227,28 @@ func respond(c echo.Context, o *response) error {
if html, err := HtmlMarshaller(m); err != nil { if html, err := HtmlMarshaller(m); err != nil {
return err return err
} else { } else {
return c.HTML(o.code, html) return c.HTML(status, html)
} }
case "json": case "json":
return c.JSON(o.code, m) return c.JSON(status, m)
case "jsonp": case "jsonp":
qs := c.Request().URL.Query() qs := c.Request().URL.Query()
for _, name := range JsonpCallbacks { for _, name := range JsonpCallbacks {
if cb := qs.Get(name); cb != "" { if cb := qs.Get(name); cb != "" {
return c.JSONP(o.code, cb, m) return c.JSONP(status, cb, m)
} }
} }
return c.JSONP(o.code, DefaultJsonpCallback, m) return c.JSONP(status, DefaultJsonpCallback, m)
case "xml": case "xml":
return c.XML(o.code, m) return c.XML(status, m)
case "text", "text/*": case "text", "text/*":
if text, err := TextMarshaller(m); err != nil { if text, err := TextMarshaller(m); err != nil {
return err return err
} else { } else {
return c.String(o.code, text) return c.String(status, text)
} }
} }
return c.JSON(o.code, m) return c.JSON(status, m)
} }
func relevantCaller() []string { func relevantCaller() []string {
@ -222,7 +266,7 @@ func relevantCaller() []string {
} }
func Respond(c echo.Context, opts ...Option) error { func Respond(c echo.Context, opts ...Option) error {
o := response{code: http.StatusOK} o := response{status: http.StatusOK}
for _, option := range opts { for _, option := range opts {
option(&o) option(&o)
} }
@ -239,15 +283,11 @@ func Created(c echo.Context, data any) error {
// Fail 响应一个错误 // Fail 响应一个错误
func Fail(c echo.Context, err error, opts ...Option) error { func Fail(c echo.Context, err error, opts ...Option) error {
o := response{code: http.StatusInternalServerError} o := response{status: http.StatusInternalServerError}
for _, option := range opts { for _, option := range opts {
option(&o) option(&o)
} }
o.err = err o.err = err
var he *echo.HTTPError
if errors.As(err, &he) {
o.code = he.Code
}
return respond(c, &o) return respond(c, &o)
} }

@ -1,45 +0,0 @@
package ticket
import "github.com/golang-jwt/jwt/v5"
type Claims struct {
ID string `json:"jti,omitempty"` // ticket编号
UID uint `json:"uid"` // 用户编号
Role string `json:"role"` // 账号类型
Issuer string `json:"iss,omitempty"` // 签发人
Subject string `json:"sub,omitempty"` // 主题
Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众
ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间
NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间
IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间
}
// GetExpirationTime implements the Claims interface.
func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) {
return c.ExpiresAt, nil
}
// GetNotBefore implements the Claims interface.
func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) {
return c.NotBefore, nil
}
// GetIssuedAt implements the Claims interface.
func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) {
return c.IssuedAt, nil
}
// GetAudience implements the Claims interface.
func (c *Claims) GetAudience() (jwt.ClaimStrings, error) {
return c.Audience, nil
}
// GetIssuer implements the Claims interface.
func (c *Claims) GetIssuer() (string, error) {
return c.Issuer, nil
}
// GetSubject implements the Claims interface.
func (c *Claims) GetSubject() (string, error) {
return c.Subject, nil
}

@ -0,0 +1,95 @@
package ticket
import (
"crypto/rsa"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"slices"
"sorbet/pkg/env"
)
type Config struct {
Skipper func(c echo.Context) bool
Anonymously bool // 是否允许匿名访问
Audiences []string
PublicKey *rsa.PublicKey
TicketKey string
TicketClaimsKey string
TicketFinder Finder
ClaimsLooker func(c echo.Context, claims *Claims) error
ErrorHandler func(c echo.Context, err error) error
SuccessHandler func(c echo.Context)
}
func Middleware(anonymously bool, roles ...string) echo.MiddlewareFunc {
return WithConfig(Config{
Anonymously: anonymously,
TicketFinder: DefaultFinder,
ClaimsLooker: func(c echo.Context, claims *Claims) error {
if len(roles) > 0 && slices.Contains(roles, claims.Role) {
return nil
}
return ErrUnauthorized
},
})
}
func WithConfig(config Config) echo.MiddlewareFunc {
return config.ToMiddleware()
}
func (t Config) ToMiddleware() echo.MiddlewareFunc {
if t.Skipper == nil {
t.Skipper = func(c echo.Context) bool { return false }
}
if len(t.Audiences) == 0 {
t.Audiences = append(t.Audiences, "*")
}
if t.TicketFinder == nil {
t.TicketFinder = DefaultFinder
}
if t.ErrorHandler == nil {
t.ErrorHandler = func(c echo.Context, err error) error { return err }
}
if t.PublicKey == nil {
var err error
source := []byte(env.String("TICKET_PUBLIC_KEY"))
t.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source)
if err != nil {
panic(err)
}
}
if t.TicketKey == "" {
t.TicketKey = "ticket"
}
if t.TicketClaimsKey == "" {
t.TicketClaimsKey = "ticket_claims"
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if t.Skipper(c) {
return next(c)
}
ticketString := t.TicketFinder(c.Request())
if ticketString == "" {
if t.Anonymously {
return next(c)
}
return t.ErrorHandler(c, ErrNoTicketFound)
}
claims, err := Verify(ticketString, t.PublicKey, t.Audiences...)
if err != nil {
return t.ErrorHandler(c, err)
}
if t.ClaimsLooker != nil {
err = t.ClaimsLooker(c, claims)
if err != nil {
return t.ErrorHandler(c, err)
}
}
c.Set(t.TicketKey, ticketString)
c.Set(t.TicketClaimsKey, claims)
return next(c)
}
}
}

@ -6,7 +6,10 @@ import (
"errors" "errors"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/rs/xid" "github.com/rs/xid"
"net/http"
"slices" "slices"
"sorbet/pkg/env"
"strings"
"time" "time"
) )
@ -18,22 +21,46 @@ var (
ErrBadAudience = errors.New("bad audience") ErrBadAudience = errors.New("bad audience")
) )
// Create 创建令牌 type Claims struct {
// 如果参数 claims 中未给出过期时间,将默认1小时过期 ID string `json:"jti,omitempty"` // ticket编号
// 如果未给出令牌编号,则自动生成。 UID uint `json:"uid"` // 用户编号
func Create(claims *Claims, key *rsa.PrivateKey) (string, error) { Role string `json:"role"` // 账号类型
if claims.ID == "" { Issuer string `json:"iss,omitempty"` // 签发人
claims.ID = xid.New().String() Subject string `json:"sub,omitempty"` // 主题
} Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众
if claims.ExpiresAt == nil { ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour)) NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间
} IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) }
signedString, err := token.SignedString(key)
if err != nil { // GetExpirationTime implements the Claims interface.
return "", err func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) {
} return c.ExpiresAt, nil
return base64.StdEncoding.EncodeToString([]byte(signedString)), nil }
// GetNotBefore implements the Claims interface.
func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) {
return c.NotBefore, nil
}
// GetIssuedAt implements the Claims interface.
func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) {
return c.IssuedAt, nil
}
// GetAudience implements the Claims interface.
func (c *Claims) GetAudience() (jwt.ClaimStrings, error) {
return c.Audience, nil
}
// GetIssuer implements the Claims interface.
func (c *Claims) GetIssuer() (string, error) {
return c.Issuer, nil
}
// GetSubject implements the Claims interface.
func (c *Claims) GetSubject() (string, error) {
return c.Subject, nil
} }
// Verify 验证令牌 // Verify 验证令牌
@ -76,3 +103,78 @@ NEXT:
} }
return claims, nil return claims, nil
} }
// Create 创建令牌
//
// 如果参数
// claims
// 中未给出过期时间,将默认1小时过期
// 如果未给出令牌编号,则自动生成。
func Create(claims *Claims) (string, error) {
if claims.ID == "" {
claims.ID = xid.New().String()
}
if claims.Issuer == "" {
claims.Issuer = env.String("TICKET_ISSUER")
}
if claims.Subject == "" {
claims.Issuer = env.String("TICKET_SUBJECT")
}
if claims.Audience == nil {
claims.Audience = env.List("TICKET_AUDIENCE")
}
if claims.ExpiresAt == nil {
ttl := env.Duration("TICKET_TTL", time.Hour)
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl))
}
source := []byte(env.String("TICKET_PRIVATE_KEY"))
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source)
if err != nil {
return "", err
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedString, err := token.SignedString(privateKey)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString([]byte(signedString)), nil
}
type Finder func(r *http.Request) string
func DefaultFinder(r *http.Request) string {
if s := FromHeader(r); s != "" {
return s
}
if s := FromCookie(r); s != "" {
return s
}
return FromQuery(r)
}
func FromCookie(r *http.Request) string {
cookie, err := r.Cookie("ticket")
if err != nil {
return ""
}
return cookie.Value
}
func FromHeader(r *http.Request) string {
bearer := r.Header.Get("Authorization")
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
return bearer[7:]
}
return ""
}
func FromQuery(r *http.Request, keys ...string) string {
q := r.URL.Query()
for _, key := range keys {
s := q.Get(key)
if s != "" {
return s
}
}
return q.Get("ticket")
}

@ -1,45 +0,0 @@
package ticket
import (
"net/http"
"strings"
)
type Finder func(r *http.Request) string
func DefaultFinder(r *http.Request) string {
if s := FromHeader(r); s != "" {
return s
}
if s := FromCookie(r); s != "" {
return s
}
return FromQuery(r)
}
func FromCookie(r *http.Request) string {
cookie, err := r.Cookie("ticket")
if err != nil {
return ""
}
return cookie.Value
}
func FromHeader(r *http.Request) string {
bearer := r.Header.Get("Authorization")
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
return bearer[7:]
}
return ""
}
func FromQuery(r *http.Request, keys ...string) string {
q := r.URL.Query()
for _, key := range keys {
s := q.Get(key)
if s != "" {
return s
}
}
return q.Get("ticket")
}
Loading…
Cancel
Save