From a8d03f5a8f1e2fac65805b1552e7f2df4e8142aa Mon Sep 17 00:00:00 2001 From: hupeh Date: Thu, 12 Oct 2023 19:08:59 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20rsp=20=E5=BA=93=E6=94=AF=E6=8C=81=20tic?= =?UTF-8?q?ket?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/init.go | 4 +- internal/middleware/ticket.go | 104 -------------------------- internal/util/ticket.go | 38 ---------- pkg/rsp/error.go | 63 +++++++++------- pkg/rsp/respond_utils.go | 130 +++++++++++++++++++++------------ pkg/ticket/claims.go | 45 ------------ pkg/ticket/middleware.go | 95 ++++++++++++++++++++++++ pkg/ticket/ticket.go | 134 ++++++++++++++++++++++++++++++---- pkg/ticket/util.go | 45 ------------ 9 files changed, 337 insertions(+), 321 deletions(-) delete mode 100644 internal/middleware/ticket.go delete mode 100644 internal/util/ticket.go delete mode 100644 pkg/ticket/claims.go create mode 100644 pkg/ticket/middleware.go delete mode 100644 pkg/ticket/util.go diff --git a/internal/init.go b/internal/init.go index 2363a91..5dbefe4 100644 --- a/internal/init.go +++ b/internal/init.go @@ -67,7 +67,7 @@ func Start() error { if err != nil { return err } - token, err := util.CreateTicket(&ticket.Claims{ + token, err := ticket.Create(&ticket.Claims{ UID: company.ID, Role: "system", Issuer: "chshs", @@ -80,7 +80,7 @@ func Start() error { "token": token, }) }) - e.Group("", middleware.Ticket(false, "system")).GET("/u", func(c echo.Context) error { + e.Group("", ticket.Middleware(false, "system")).GET("/u", func(c echo.Context) error { return rsp.Ok(c, echo.Map{ "ticket": c.Get("ticket"), "claims": c.Get("ticket_claims").(*ticket.Claims), diff --git a/internal/middleware/ticket.go b/internal/middleware/ticket.go deleted file mode 100644 index 9a2fc60..0000000 --- a/internal/middleware/ticket.go +++ /dev/null @@ -1,104 +0,0 @@ -package middleware - -import ( - "crypto/rsa" - "github.com/golang-jwt/jwt/v5" - "github.com/labstack/echo/v4" - "slices" - "sorbet/pkg/env" - "sorbet/pkg/ticket" -) - -var ticketPublicKey *rsa.PublicKey - -type TicketConfig struct { - Skipper Skipper - Anonymously bool - Audiences []string - PublicKey *rsa.PublicKey - TicketFinder ticket.Finder - ClaimsLooker func(c echo.Context, claims *ticket.Claims) error - ErrorHandler func(c echo.Context, err error) error - SuccessHandler func(c echo.Context) -} - -func Ticket(anonymously bool, roles ...string) echo.MiddlewareFunc { - return TicketWithConfig(TicketConfig{ - Anonymously: anonymously, - TicketFinder: ticket.DefaultFinder, - ClaimsLooker: func(c echo.Context, claims *ticket.Claims) error { - if len(roles) > 0 && slices.Contains(roles, claims.Role) { - return nil - } - return ticket.ErrUnauthorized - }, - }) -} - -func TicketWithConfig(config TicketConfig) echo.MiddlewareFunc { - return config.ToMiddleware() -} - -func (t *TicketConfig) ToMiddleware() echo.MiddlewareFunc { - if t.Skipper == nil { - t.Skipper = DefaultSkipper - } - if len(t.Audiences) == 0 { - t.Audiences = append(t.Audiences, "*") - } - if t.TicketFinder == nil { - t.TicketFinder = ticket.DefaultFinder - } - if t.ErrorHandler == nil { - t.ErrorHandler = func(c echo.Context, err error) error { - return err - } - } - if t.ClaimsLooker == nil { - t.ClaimsLooker = func(c echo.Context, claims *ticket.Claims) error { - return nil - } - } - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - if t.Skipper(c) { - return next(c) - } - ticketString := t.TicketFinder(c.Request()) - if ticketString == "" { - if t.Anonymously { - return next(c) - } - return t.ErrorHandler(c, ticket.ErrNoTicketFound) - } - publicKey := t.PublicKey - if publicKey == nil { - key, err := getTicketPublicKey() - if err != nil { - return err - } - publicKey = key - } - claims, err := ticket.Verify(ticketString, publicKey, t.Audiences...) - if err != nil { - return t.ErrorHandler(c, err) - } - if err = t.ClaimsLooker(c, claims); err != nil { - return t.ErrorHandler(c, err) - } - c.Set("ticket", ticketString) - c.Set("ticket_claims", claims) - return next(c) - } - } -} - -func getTicketPublicKey() (*rsa.PublicKey, error) { - if ticketPublicKey != nil { - return ticketPublicKey, nil - } - var err error - source := []byte(env.String("TICKET_PUBLIC_KEY")) - ticketPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source) - return ticketPublicKey, err -} diff --git a/internal/util/ticket.go b/internal/util/ticket.go deleted file mode 100644 index ff178d3..0000000 --- a/internal/util/ticket.go +++ /dev/null @@ -1,38 +0,0 @@ -package util - -import ( - "github.com/golang-jwt/jwt/v5" - "github.com/rs/xid" - "sorbet/pkg/env" - "sorbet/pkg/ticket" - "time" -) - -func CreateTicket(claims *ticket.Claims) (string, error) { - if claims.ID == "" { - claims.ID = xid.New().String() - } - if claims.Issuer == "" { - claims.Issuer = env.String("TICKET_ISSUER") - } - if claims.Subject == "" { - claims.Issuer = env.String("TICKET_SUBJECT") - } - if claims.Audience == nil { - claims.Audience = env.List("TICKET_AUDIENCE") - } - if claims.ExpiresAt == nil { - ttl := env.Duration("TICKET_TTL", time.Hour) - claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl)) - } - source := []byte(env.String("TICKET_PRIVATE_KEY")) - privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source) - if err != nil { - return "", err - } - signedString, err := ticket.Create(claims, privateKey) - if err != nil { - return "", err - } - return signedString, nil -} diff --git a/pkg/rsp/error.go b/pkg/rsp/error.go index 4e6ebf7..dc4fa59 100644 --- a/pkg/rsp/error.go +++ b/pkg/rsp/error.go @@ -2,73 +2,75 @@ package rsp import ( "fmt" - "github.com/labstack/echo/v4" + "net/http" "strings" ) var ( // ErrOK 表示没有任何错误。 // 对应 HTTP 响应状态码为 500。 - ErrOK = NewError(0, "") + ErrOK = NewError(http.StatusOK, 0, "OK") // ErrInternal 客户端请求有效,但服务器处理时发生了意外。 // 对应 HTTP 响应状态码为 500。 - ErrInternal = NewError(-100, "系统内部错误") + ErrInternal = NewError(http.StatusInternalServerError, -100, "系统内部错误") // ErrServiceUnavailable 服务器无法处理请求,一般用于网站维护状态。 // 对应 HTTP 响应状态码为 503。 - ErrServiceUnavailable = NewError(-101, "Service Unavailable") + ErrServiceUnavailable = NewError(http.StatusServiceUnavailable, -101, "服务不可用") // ErrUnauthorized 用户未提供身份验证凭据,或者没有通过身份验证。 // 响应的 HTTP 状态码为 401。 - ErrUnauthorized = NewError(-102, "unauthorized") + ErrUnauthorized = NewError(http.StatusUnauthorized, -102, "身份验证失败") // ErrForbidden 用户通过了身份验证,但是不具有访问资源所需的权限。 // 响应的 HTTP 状态码为 403。 - ErrForbidden = NewError(-103, "Forbidden") + ErrForbidden = NewError(http.StatusForbidden, -103, "不具有访问资源所需的权限") // ErrGone 所请求的资源已从这个地址转移,不再可用。 // 响应的 HTTP 状态码为 410。 - ErrGone = NewError(-104, "Gone") + ErrGone = NewError(http.StatusGone, -104, "所请求的资源不存在") // ErrUnsupportedMediaType 客户端要求的返回格式不支持。 // 比如,API 只能返回 JSON 格式,但是客户端要求返回 XML 格式。 // 响应的 HTTP 状态码为 415。 - ErrUnsupportedMediaType = NewError(-105, "Unsupported Media Type") + ErrUnsupportedMediaType = NewError(http.StatusUnsupportedMediaType, -105, "请求的数据格式错误") // ErrUnprocessableEntity 无法处理客户端上传的附件,导致请求失败。 // 响应的 HTTP 状态码为 422。 - ErrUnprocessableEntity = NewError(-106, "Unprocessable Entity") + ErrUnprocessableEntity = NewError(http.StatusUnprocessableEntity, -106, "上传了不被支持的附件") // ErrTooManyRequests 客户端的请求次数超过限额。 - // 响应的 HTTP 状态码为 422。 - ErrTooManyRequests = NewError(-107, "Too Many Requests") + // 响应的 HTTP 状态码为 429。 + ErrTooManyRequests = NewError(http.StatusTooManyRequests, -107, "请求次数超过限额") // ErrSeeOther 表示需要参考另一个 URL 才能完成接收的请求操作, // 当请求方式使用 POST、PUT 和 DELETE 时,对应的 HTTP 状态码为 303, // 其它的请求方式在大多数情况下应该使用 400 状态码。 - ErrSeeOther = NewError(-108, "see other") + ErrSeeOther = NewError(http.StatusSeeOther, -108, "需要更进一步才能完成操作") // ErrBadRequest 服务器不理解客户端的请求。 - // 对应 HTTP 状态码为 404。 - ErrBadRequest = NewError(-109, "bad request") + // 对应 HTTP 状态码为 400。 + ErrBadRequest = NewError(http.StatusBadRequest, -109, "请求错误") // ErrBadParams 客户端提交的参数不符合要求 // 对应 HTTP 状态码为 400。 - ErrBadParams = NewError(-110, "参数错误") + ErrBadParams = NewError(http.StatusBadRequest, -110, "参数错误") // ErrRecordNotFound 访问的数据不存在 // 对应 HTTP 状态码为 404。 - ErrRecordNotFound = NewError(-111, "record not found") + ErrRecordNotFound = NewError(http.StatusNotFound, -111, "访问的数据不存在") ) type Error struct { - code int - text string + internal error + status int // HTTP 状态码 + code int // 请求错误码 + text string // 响应提示消息 } -func NewError(code int, text string) *Error { - return &Error{code, text} +func NewError(status, code int, text string) *Error { + return &Error{nil, status, code, text} } func (e *Error) Code() int { @@ -79,10 +81,24 @@ func (e *Error) Text() string { return e.text } +func (e *Error) WithInternal(err error) *Error { + c := *e + c.internal = err + return &c +} + +func (e *Error) WithStatus(status int) *Error { + c := *e + c.status = status + return &c +} + func (e *Error) WithText(text ...string) *Error { for _, s := range text { if s != "" { - return NewError(e.code, s) + c := *e + c.text = s + return &c } } return e @@ -97,11 +113,6 @@ func (e *Error) AsProblem(label string) *Problem { } } -func (e *Error) AsHttpError(code int) *echo.HTTPError { - he := echo.NewHTTPError(code, e.text) - return he.WithInternal(e) -} - func (e *Error) String() string { return strings.TrimSpace(fmt.Sprintf("%d %s", e.code, e.text)) } diff --git a/pkg/rsp/respond_utils.go b/pkg/rsp/respond_utils.go index 16bc188..7bacf96 100644 --- a/pkg/rsp/respond_utils.go +++ b/pkg/rsp/respond_utils.go @@ -8,6 +8,7 @@ import ( "github.com/labstack/echo/v4" "net/http" "runtime" + "sorbet/pkg/ticket" "sorbet/pkg/v" ) @@ -52,7 +53,7 @@ func (r *RespondData[T]) RespondValue() any { } type response struct { - code int + status int headers map[string]string cookies []*http.Cookie err error @@ -62,9 +63,9 @@ type response struct { type Option func(o *response) -func StatusCode(code int) Option { +func StatusCode(status int) Option { return func(o *response) { - o.code = code + o.status = status } } @@ -103,57 +104,99 @@ func Data(data any) Option { } } -func respond(c echo.Context, o *response) error { - defer func() { - if o.err != nil { - c.Logger().Error(o.err) - } - }() - - // 返回的数据 - m := map[string]any{ +func (r *response) result(c echo.Context) (m map[string]any, status int) { + status = r.status + m = map[string]any{ "code": nil, "success": false, - "message": o.message, + "message": r.message, } - var success bool - if err, ok := o.err.(*v.Errors); ok { + isDebug := c.Echo().Debug + var err error + if ee, ok := r.err.(*v.Errors); ok { pb := &Problem{} - for _, e := range err.All() { + for _, e := range ee.All() { pb.AddSubproblem(ErrBadParams.WithText(e.Error()).AsProblem(e.Field())) } m["code"] = ErrBadParams.code m["message"] = ErrBadParams.text m["problems"] = pb.Problems - } else if ex, yes := o.err.(*Error); yes { + if status < 400 { + status = ErrBadParams.status + } + } else if ex, yes := r.err.(*Error); yes { + err = ex.internal m["code"] = ex.code + m["success"] = errors.Is(ex, ErrOK) m["message"] = ex.text - success = errors.Is(ex, ErrOK) - } else if pb, okay := o.err.(*Problem); okay { + if status < 400 { + status = ex.status + } + } else if pb, okay := r.err.(*Problem); okay { m["code"] = pb.Code m["message"] = pb.Message m["problems"] = pb.Problems - } else if o.err != nil { - m["code"] = ErrInternal.code - m["message"] = o.err.Error() - } else { - success = true + if status < 400 { + status = ErrInternal.status + } + } else if r.err == nil { m["code"] = 0 - if o.data != nil { - if val, ok := o.data.(RespondValuer); ok { + m["success"] = true + if r.data != nil { + if val, ok := r.data.(RespondValuer); ok { m["data"] = val.RespondValue() } else { - m["data"] = o.data + m["data"] = r.data + } + } + } else { + err = r.err + switch { + case errors.Is(r.err, ticket.ErrNoTicketFound), + errors.Is(r.err, ticket.ErrInvalidTicket), + errors.Is(r.err, ticket.ErrBadAudience), + errors.Is(r.err, ticket.ErrTicketExpired), + errors.Is(r.err, ticket.ErrUnauthorized): + m["code"] = ErrUnauthorized.Code() + if isDebug { + m["message"] = r.err.Error() + } else { + m["message"] = ErrUnauthorized.Code() + } + if status < 400 { + status = ErrUnauthorized.status + } + default: + var he *echo.HTTPError + if errors.As(err, &he) { + status = he.Code + err = he.Internal + } + m["code"] = ErrInternal.Code() + m["message"] = r.err.Error() + if status < 400 { + status = ErrInternal.status } } } - m["success"] = success - if !success && c.Echo().Debug { - m["error"] = relevantCaller() + if c.Echo().Debug && m["success"] != true { + if err != nil { + m["error"] = err.Error() + } + m["stack"] = relevantCaller() } if m["message"] == "" { - m["message"] = http.StatusText(o.code) + m["message"] = http.StatusText(status) } + return +} + +func respond(c echo.Context, o *response) error { + defer func() { + if o.err != nil { + c.Logger().Error(o.err) + } + }() // 如果已经输出过,就忽略 if c.Response().Committed { return nil @@ -171,10 +214,11 @@ func respond(c echo.Context, o *response) error { c.SetCookie(cookie) } } + m, status := o.result(c) // HEAD 请求没有结果 r := c.Request() if r.Method == http.MethodHead { - return c.NoContent(o.code) + return c.NoContent(status) } // 根据报头响应不同的格式 accept := r.Header.Get(echo.HeaderAccept) @@ -183,28 +227,28 @@ func respond(c echo.Context, o *response) error { if html, err := HtmlMarshaller(m); err != nil { return err } else { - return c.HTML(o.code, html) + return c.HTML(status, html) } case "json": - return c.JSON(o.code, m) + return c.JSON(status, m) case "jsonp": qs := c.Request().URL.Query() for _, name := range JsonpCallbacks { if cb := qs.Get(name); cb != "" { - return c.JSONP(o.code, cb, m) + return c.JSONP(status, cb, m) } } - return c.JSONP(o.code, DefaultJsonpCallback, m) + return c.JSONP(status, DefaultJsonpCallback, m) case "xml": - return c.XML(o.code, m) + return c.XML(status, m) case "text", "text/*": if text, err := TextMarshaller(m); err != nil { return err } else { - return c.String(o.code, text) + return c.String(status, text) } } - return c.JSON(o.code, m) + return c.JSON(status, m) } func relevantCaller() []string { @@ -222,7 +266,7 @@ func relevantCaller() []string { } func Respond(c echo.Context, opts ...Option) error { - o := response{code: http.StatusOK} + o := response{status: http.StatusOK} for _, option := range opts { option(&o) } @@ -239,15 +283,11 @@ func Created(c echo.Context, data any) error { // Fail 响应一个错误 func Fail(c echo.Context, err error, opts ...Option) error { - o := response{code: http.StatusInternalServerError} + o := response{status: http.StatusInternalServerError} for _, option := range opts { option(&o) } o.err = err - var he *echo.HTTPError - if errors.As(err, &he) { - o.code = he.Code - } return respond(c, &o) } diff --git a/pkg/ticket/claims.go b/pkg/ticket/claims.go deleted file mode 100644 index 6bd5b8f..0000000 --- a/pkg/ticket/claims.go +++ /dev/null @@ -1,45 +0,0 @@ -package ticket - -import "github.com/golang-jwt/jwt/v5" - -type Claims struct { - ID string `json:"jti,omitempty"` // ticket编号 - UID uint `json:"uid"` // 用户编号 - Role string `json:"role"` // 账号类型 - Issuer string `json:"iss,omitempty"` // 签发人 - Subject string `json:"sub,omitempty"` // 主题 - Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众 - ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间 - NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间 - IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间 -} - -// GetExpirationTime implements the Claims interface. -func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) { - return c.ExpiresAt, nil -} - -// GetNotBefore implements the Claims interface. -func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) { - return c.NotBefore, nil -} - -// GetIssuedAt implements the Claims interface. -func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) { - return c.IssuedAt, nil -} - -// GetAudience implements the Claims interface. -func (c *Claims) GetAudience() (jwt.ClaimStrings, error) { - return c.Audience, nil -} - -// GetIssuer implements the Claims interface. -func (c *Claims) GetIssuer() (string, error) { - return c.Issuer, nil -} - -// GetSubject implements the Claims interface. -func (c *Claims) GetSubject() (string, error) { - return c.Subject, nil -} diff --git a/pkg/ticket/middleware.go b/pkg/ticket/middleware.go new file mode 100644 index 0000000..080406e --- /dev/null +++ b/pkg/ticket/middleware.go @@ -0,0 +1,95 @@ +package ticket + +import ( + "crypto/rsa" + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo/v4" + "slices" + "sorbet/pkg/env" +) + +type Config struct { + Skipper func(c echo.Context) bool + Anonymously bool // 是否允许匿名访问 + Audiences []string + PublicKey *rsa.PublicKey + TicketKey string + TicketClaimsKey string + TicketFinder Finder + ClaimsLooker func(c echo.Context, claims *Claims) error + ErrorHandler func(c echo.Context, err error) error + SuccessHandler func(c echo.Context) +} + +func Middleware(anonymously bool, roles ...string) echo.MiddlewareFunc { + return WithConfig(Config{ + Anonymously: anonymously, + TicketFinder: DefaultFinder, + ClaimsLooker: func(c echo.Context, claims *Claims) error { + if len(roles) > 0 && slices.Contains(roles, claims.Role) { + return nil + } + return ErrUnauthorized + }, + }) +} + +func WithConfig(config Config) echo.MiddlewareFunc { + return config.ToMiddleware() +} + +func (t Config) ToMiddleware() echo.MiddlewareFunc { + if t.Skipper == nil { + t.Skipper = func(c echo.Context) bool { return false } + } + if len(t.Audiences) == 0 { + t.Audiences = append(t.Audiences, "*") + } + if t.TicketFinder == nil { + t.TicketFinder = DefaultFinder + } + if t.ErrorHandler == nil { + t.ErrorHandler = func(c echo.Context, err error) error { return err } + } + if t.PublicKey == nil { + var err error + source := []byte(env.String("TICKET_PUBLIC_KEY")) + t.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source) + if err != nil { + panic(err) + } + } + if t.TicketKey == "" { + t.TicketKey = "ticket" + } + if t.TicketClaimsKey == "" { + t.TicketClaimsKey = "ticket_claims" + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if t.Skipper(c) { + return next(c) + } + ticketString := t.TicketFinder(c.Request()) + if ticketString == "" { + if t.Anonymously { + return next(c) + } + return t.ErrorHandler(c, ErrNoTicketFound) + } + claims, err := Verify(ticketString, t.PublicKey, t.Audiences...) + if err != nil { + return t.ErrorHandler(c, err) + } + if t.ClaimsLooker != nil { + err = t.ClaimsLooker(c, claims) + if err != nil { + return t.ErrorHandler(c, err) + } + } + c.Set(t.TicketKey, ticketString) + c.Set(t.TicketClaimsKey, claims) + return next(c) + } + } +} diff --git a/pkg/ticket/ticket.go b/pkg/ticket/ticket.go index 4b64f79..db62062 100644 --- a/pkg/ticket/ticket.go +++ b/pkg/ticket/ticket.go @@ -6,7 +6,10 @@ import ( "errors" "github.com/golang-jwt/jwt/v5" "github.com/rs/xid" + "net/http" "slices" + "sorbet/pkg/env" + "strings" "time" ) @@ -18,22 +21,46 @@ var ( ErrBadAudience = errors.New("bad audience") ) -// Create 创建令牌 -// 如果参数 claims 中未给出过期时间,将默认1小时过期 -// 如果未给出令牌编号,则自动生成。 -func Create(claims *Claims, key *rsa.PrivateKey) (string, error) { - if claims.ID == "" { - claims.ID = xid.New().String() - } - if claims.ExpiresAt == nil { - claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour)) - } - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - signedString, err := token.SignedString(key) - if err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString([]byte(signedString)), nil +type Claims struct { + ID string `json:"jti,omitempty"` // ticket编号 + UID uint `json:"uid"` // 用户编号 + Role string `json:"role"` // 账号类型 + Issuer string `json:"iss,omitempty"` // 签发人 + Subject string `json:"sub,omitempty"` // 主题 + Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众 + ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间 + NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间 + IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间 +} + +// GetExpirationTime implements the Claims interface. +func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) { + return c.ExpiresAt, nil +} + +// GetNotBefore implements the Claims interface. +func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) { + return c.NotBefore, nil +} + +// GetIssuedAt implements the Claims interface. +func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) { + return c.IssuedAt, nil +} + +// GetAudience implements the Claims interface. +func (c *Claims) GetAudience() (jwt.ClaimStrings, error) { + return c.Audience, nil +} + +// GetIssuer implements the Claims interface. +func (c *Claims) GetIssuer() (string, error) { + return c.Issuer, nil +} + +// GetSubject implements the Claims interface. +func (c *Claims) GetSubject() (string, error) { + return c.Subject, nil } // Verify 验证令牌 @@ -76,3 +103,78 @@ NEXT: } return claims, nil } + +// Create 创建令牌 +// +// 如果参数 +// claims +// 中未给出过期时间,将默认1小时过期 +// 如果未给出令牌编号,则自动生成。 +func Create(claims *Claims) (string, error) { + if claims.ID == "" { + claims.ID = xid.New().String() + } + if claims.Issuer == "" { + claims.Issuer = env.String("TICKET_ISSUER") + } + if claims.Subject == "" { + claims.Issuer = env.String("TICKET_SUBJECT") + } + if claims.Audience == nil { + claims.Audience = env.List("TICKET_AUDIENCE") + } + if claims.ExpiresAt == nil { + ttl := env.Duration("TICKET_TTL", time.Hour) + claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl)) + } + source := []byte(env.String("TICKET_PRIVATE_KEY")) + privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source) + if err != nil { + return "", err + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signedString, err := token.SignedString(privateKey) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString([]byte(signedString)), nil +} + +type Finder func(r *http.Request) string + +func DefaultFinder(r *http.Request) string { + if s := FromHeader(r); s != "" { + return s + } + if s := FromCookie(r); s != "" { + return s + } + return FromQuery(r) +} + +func FromCookie(r *http.Request) string { + cookie, err := r.Cookie("ticket") + if err != nil { + return "" + } + return cookie.Value +} + +func FromHeader(r *http.Request) string { + bearer := r.Header.Get("Authorization") + if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { + return bearer[7:] + } + return "" +} + +func FromQuery(r *http.Request, keys ...string) string { + q := r.URL.Query() + for _, key := range keys { + s := q.Get(key) + if s != "" { + return s + } + } + return q.Get("ticket") +} diff --git a/pkg/ticket/util.go b/pkg/ticket/util.go deleted file mode 100644 index c6d3ff0..0000000 --- a/pkg/ticket/util.go +++ /dev/null @@ -1,45 +0,0 @@ -package ticket - -import ( - "net/http" - "strings" -) - -type Finder func(r *http.Request) string - -func DefaultFinder(r *http.Request) string { - if s := FromHeader(r); s != "" { - return s - } - if s := FromCookie(r); s != "" { - return s - } - return FromQuery(r) -} - -func FromCookie(r *http.Request) string { - cookie, err := r.Cookie("ticket") - if err != nil { - return "" - } - return cookie.Value -} - -func FromHeader(r *http.Request) string { - bearer := r.Header.Get("Authorization") - if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" { - return bearer[7:] - } - return "" -} - -func FromQuery(r *http.Request, keys ...string) string { - q := r.URL.Query() - for _, key := range keys { - s := q.Get(key) - if s != "" { - return s - } - } - return q.Get("ticket") -}