feat: rsp 库支持 ticket

main
熊二 1 year ago
parent 2aadccd5b8
commit a8d03f5a8f
  1. 4
      internal/init.go
  2. 104
      internal/middleware/ticket.go
  3. 38
      internal/util/ticket.go
  4. 63
      pkg/rsp/error.go
  5. 130
      pkg/rsp/respond_utils.go
  6. 45
      pkg/ticket/claims.go
  7. 95
      pkg/ticket/middleware.go
  8. 134
      pkg/ticket/ticket.go
  9. 45
      pkg/ticket/util.go

@ -67,7 +67,7 @@ func Start() error {
if err != nil {
return err
}
token, err := util.CreateTicket(&ticket.Claims{
token, err := ticket.Create(&ticket.Claims{
UID: company.ID,
Role: "system",
Issuer: "chshs",
@ -80,7 +80,7 @@ func Start() error {
"token": token,
})
})
e.Group("", middleware.Ticket(false, "system")).GET("/u", func(c echo.Context) error {
e.Group("", ticket.Middleware(false, "system")).GET("/u", func(c echo.Context) error {
return rsp.Ok(c, echo.Map{
"ticket": c.Get("ticket"),
"claims": c.Get("ticket_claims").(*ticket.Claims),

@ -1,104 +0,0 @@
package middleware
import (
"crypto/rsa"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"slices"
"sorbet/pkg/env"
"sorbet/pkg/ticket"
)
var ticketPublicKey *rsa.PublicKey
type TicketConfig struct {
Skipper Skipper
Anonymously bool
Audiences []string
PublicKey *rsa.PublicKey
TicketFinder ticket.Finder
ClaimsLooker func(c echo.Context, claims *ticket.Claims) error
ErrorHandler func(c echo.Context, err error) error
SuccessHandler func(c echo.Context)
}
func Ticket(anonymously bool, roles ...string) echo.MiddlewareFunc {
return TicketWithConfig(TicketConfig{
Anonymously: anonymously,
TicketFinder: ticket.DefaultFinder,
ClaimsLooker: func(c echo.Context, claims *ticket.Claims) error {
if len(roles) > 0 && slices.Contains(roles, claims.Role) {
return nil
}
return ticket.ErrUnauthorized
},
})
}
func TicketWithConfig(config TicketConfig) echo.MiddlewareFunc {
return config.ToMiddleware()
}
func (t *TicketConfig) ToMiddleware() echo.MiddlewareFunc {
if t.Skipper == nil {
t.Skipper = DefaultSkipper
}
if len(t.Audiences) == 0 {
t.Audiences = append(t.Audiences, "*")
}
if t.TicketFinder == nil {
t.TicketFinder = ticket.DefaultFinder
}
if t.ErrorHandler == nil {
t.ErrorHandler = func(c echo.Context, err error) error {
return err
}
}
if t.ClaimsLooker == nil {
t.ClaimsLooker = func(c echo.Context, claims *ticket.Claims) error {
return nil
}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if t.Skipper(c) {
return next(c)
}
ticketString := t.TicketFinder(c.Request())
if ticketString == "" {
if t.Anonymously {
return next(c)
}
return t.ErrorHandler(c, ticket.ErrNoTicketFound)
}
publicKey := t.PublicKey
if publicKey == nil {
key, err := getTicketPublicKey()
if err != nil {
return err
}
publicKey = key
}
claims, err := ticket.Verify(ticketString, publicKey, t.Audiences...)
if err != nil {
return t.ErrorHandler(c, err)
}
if err = t.ClaimsLooker(c, claims); err != nil {
return t.ErrorHandler(c, err)
}
c.Set("ticket", ticketString)
c.Set("ticket_claims", claims)
return next(c)
}
}
}
func getTicketPublicKey() (*rsa.PublicKey, error) {
if ticketPublicKey != nil {
return ticketPublicKey, nil
}
var err error
source := []byte(env.String("TICKET_PUBLIC_KEY"))
ticketPublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source)
return ticketPublicKey, err
}

@ -1,38 +0,0 @@
package util
import (
"github.com/golang-jwt/jwt/v5"
"github.com/rs/xid"
"sorbet/pkg/env"
"sorbet/pkg/ticket"
"time"
)
func CreateTicket(claims *ticket.Claims) (string, error) {
if claims.ID == "" {
claims.ID = xid.New().String()
}
if claims.Issuer == "" {
claims.Issuer = env.String("TICKET_ISSUER")
}
if claims.Subject == "" {
claims.Issuer = env.String("TICKET_SUBJECT")
}
if claims.Audience == nil {
claims.Audience = env.List("TICKET_AUDIENCE")
}
if claims.ExpiresAt == nil {
ttl := env.Duration("TICKET_TTL", time.Hour)
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl))
}
source := []byte(env.String("TICKET_PRIVATE_KEY"))
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source)
if err != nil {
return "", err
}
signedString, err := ticket.Create(claims, privateKey)
if err != nil {
return "", err
}
return signedString, nil
}

@ -2,73 +2,75 @@ package rsp
import (
"fmt"
"github.com/labstack/echo/v4"
"net/http"
"strings"
)
var (
// ErrOK 表示没有任何错误。
// 对应 HTTP 响应状态码为 500。
ErrOK = NewError(0, "")
ErrOK = NewError(http.StatusOK, 0, "OK")
// ErrInternal 客户端请求有效,但服务器处理时发生了意外。
// 对应 HTTP 响应状态码为 500。
ErrInternal = NewError(-100, "系统内部错误")
ErrInternal = NewError(http.StatusInternalServerError, -100, "系统内部错误")
// ErrServiceUnavailable 服务器无法处理请求,一般用于网站维护状态。
// 对应 HTTP 响应状态码为 503。
ErrServiceUnavailable = NewError(-101, "Service Unavailable")
ErrServiceUnavailable = NewError(http.StatusServiceUnavailable, -101, "服务不可用")
// ErrUnauthorized 用户未提供身份验证凭据,或者没有通过身份验证。
// 响应的 HTTP 状态码为 401。
ErrUnauthorized = NewError(-102, "unauthorized")
ErrUnauthorized = NewError(http.StatusUnauthorized, -102, "身份验证失败")
// ErrForbidden 用户通过了身份验证,但是不具有访问资源所需的权限。
// 响应的 HTTP 状态码为 403。
ErrForbidden = NewError(-103, "Forbidden")
ErrForbidden = NewError(http.StatusForbidden, -103, "不具有访问资源所需的权限")
// ErrGone 所请求的资源已从这个地址转移,不再可用。
// 响应的 HTTP 状态码为 410。
ErrGone = NewError(-104, "Gone")
ErrGone = NewError(http.StatusGone, -104, "所请求的资源不存在")
// ErrUnsupportedMediaType 客户端要求的返回格式不支持。
// 比如,API 只能返回 JSON 格式,但是客户端要求返回 XML 格式。
// 响应的 HTTP 状态码为 415。
ErrUnsupportedMediaType = NewError(-105, "Unsupported Media Type")
ErrUnsupportedMediaType = NewError(http.StatusUnsupportedMediaType, -105, "请求的数据格式错误")
// ErrUnprocessableEntity 无法处理客户端上传的附件,导致请求失败。
// 响应的 HTTP 状态码为 422。
ErrUnprocessableEntity = NewError(-106, "Unprocessable Entity")
ErrUnprocessableEntity = NewError(http.StatusUnprocessableEntity, -106, "上传了不被支持的附件")
// ErrTooManyRequests 客户端的请求次数超过限额。
// 响应的 HTTP 状态码为 422
ErrTooManyRequests = NewError(-107, "Too Many Requests")
// 响应的 HTTP 状态码为 429
ErrTooManyRequests = NewError(http.StatusTooManyRequests, -107, "请求次数超过限额")
// ErrSeeOther 表示需要参考另一个 URL 才能完成接收的请求操作,
// 当请求方式使用 POST、PUT 和 DELETE 时,对应的 HTTP 状态码为 303,
// 其它的请求方式在大多数情况下应该使用 400 状态码。
ErrSeeOther = NewError(-108, "see other")
ErrSeeOther = NewError(http.StatusSeeOther, -108, "需要更进一步才能完成操作")
// ErrBadRequest 服务器不理解客户端的请求。
// 对应 HTTP 状态码为 404
ErrBadRequest = NewError(-109, "bad request")
// 对应 HTTP 状态码为 400
ErrBadRequest = NewError(http.StatusBadRequest, -109, "请求错误")
// ErrBadParams 客户端提交的参数不符合要求
// 对应 HTTP 状态码为 400。
ErrBadParams = NewError(-110, "参数错误")
ErrBadParams = NewError(http.StatusBadRequest, -110, "参数错误")
// ErrRecordNotFound 访问的数据不存在
// 对应 HTTP 状态码为 404。
ErrRecordNotFound = NewError(-111, "record not found")
ErrRecordNotFound = NewError(http.StatusNotFound, -111, "访问的数据不存在")
)
type Error struct {
code int
text string
internal error
status int // HTTP 状态码
code int // 请求错误码
text string // 响应提示消息
}
func NewError(code int, text string) *Error {
return &Error{code, text}
func NewError(status, code int, text string) *Error {
return &Error{nil, status, code, text}
}
func (e *Error) Code() int {
@ -79,10 +81,24 @@ func (e *Error) Text() string {
return e.text
}
func (e *Error) WithInternal(err error) *Error {
c := *e
c.internal = err
return &c
}
func (e *Error) WithStatus(status int) *Error {
c := *e
c.status = status
return &c
}
func (e *Error) WithText(text ...string) *Error {
for _, s := range text {
if s != "" {
return NewError(e.code, s)
c := *e
c.text = s
return &c
}
}
return e
@ -97,11 +113,6 @@ func (e *Error) AsProblem(label string) *Problem {
}
}
func (e *Error) AsHttpError(code int) *echo.HTTPError {
he := echo.NewHTTPError(code, e.text)
return he.WithInternal(e)
}
func (e *Error) String() string {
return strings.TrimSpace(fmt.Sprintf("%d %s", e.code, e.text))
}

@ -8,6 +8,7 @@ import (
"github.com/labstack/echo/v4"
"net/http"
"runtime"
"sorbet/pkg/ticket"
"sorbet/pkg/v"
)
@ -52,7 +53,7 @@ func (r *RespondData[T]) RespondValue() any {
}
type response struct {
code int
status int
headers map[string]string
cookies []*http.Cookie
err error
@ -62,9 +63,9 @@ type response struct {
type Option func(o *response)
func StatusCode(code int) Option {
func StatusCode(status int) Option {
return func(o *response) {
o.code = code
o.status = status
}
}
@ -103,57 +104,99 @@ func Data(data any) Option {
}
}
func respond(c echo.Context, o *response) error {
defer func() {
if o.err != nil {
c.Logger().Error(o.err)
}
}()
// 返回的数据
m := map[string]any{
func (r *response) result(c echo.Context) (m map[string]any, status int) {
status = r.status
m = map[string]any{
"code": nil,
"success": false,
"message": o.message,
"message": r.message,
}
var success bool
if err, ok := o.err.(*v.Errors); ok {
isDebug := c.Echo().Debug
var err error
if ee, ok := r.err.(*v.Errors); ok {
pb := &Problem{}
for _, e := range err.All() {
for _, e := range ee.All() {
pb.AddSubproblem(ErrBadParams.WithText(e.Error()).AsProblem(e.Field()))
}
m["code"] = ErrBadParams.code
m["message"] = ErrBadParams.text
m["problems"] = pb.Problems
} else if ex, yes := o.err.(*Error); yes {
if status < 400 {
status = ErrBadParams.status
}
} else if ex, yes := r.err.(*Error); yes {
err = ex.internal
m["code"] = ex.code
m["success"] = errors.Is(ex, ErrOK)
m["message"] = ex.text
success = errors.Is(ex, ErrOK)
} else if pb, okay := o.err.(*Problem); okay {
if status < 400 {
status = ex.status
}
} else if pb, okay := r.err.(*Problem); okay {
m["code"] = pb.Code
m["message"] = pb.Message
m["problems"] = pb.Problems
} else if o.err != nil {
m["code"] = ErrInternal.code
m["message"] = o.err.Error()
} else {
success = true
if status < 400 {
status = ErrInternal.status
}
} else if r.err == nil {
m["code"] = 0
if o.data != nil {
if val, ok := o.data.(RespondValuer); ok {
m["success"] = true
if r.data != nil {
if val, ok := r.data.(RespondValuer); ok {
m["data"] = val.RespondValue()
} else {
m["data"] = o.data
m["data"] = r.data
}
}
} else {
err = r.err
switch {
case errors.Is(r.err, ticket.ErrNoTicketFound),
errors.Is(r.err, ticket.ErrInvalidTicket),
errors.Is(r.err, ticket.ErrBadAudience),
errors.Is(r.err, ticket.ErrTicketExpired),
errors.Is(r.err, ticket.ErrUnauthorized):
m["code"] = ErrUnauthorized.Code()
if isDebug {
m["message"] = r.err.Error()
} else {
m["message"] = ErrUnauthorized.Code()
}
if status < 400 {
status = ErrUnauthorized.status
}
default:
var he *echo.HTTPError
if errors.As(err, &he) {
status = he.Code
err = he.Internal
}
m["success"] = success
if !success && c.Echo().Debug {
m["error"] = relevantCaller()
m["code"] = ErrInternal.Code()
m["message"] = r.err.Error()
if status < 400 {
status = ErrInternal.status
}
}
}
if c.Echo().Debug && m["success"] != true {
if err != nil {
m["error"] = err.Error()
}
m["stack"] = relevantCaller()
}
if m["message"] == "" {
m["message"] = http.StatusText(o.code)
m["message"] = http.StatusText(status)
}
return
}
func respond(c echo.Context, o *response) error {
defer func() {
if o.err != nil {
c.Logger().Error(o.err)
}
}()
// 如果已经输出过,就忽略
if c.Response().Committed {
return nil
@ -171,10 +214,11 @@ func respond(c echo.Context, o *response) error {
c.SetCookie(cookie)
}
}
m, status := o.result(c)
// HEAD 请求没有结果
r := c.Request()
if r.Method == http.MethodHead {
return c.NoContent(o.code)
return c.NoContent(status)
}
// 根据报头响应不同的格式
accept := r.Header.Get(echo.HeaderAccept)
@ -183,28 +227,28 @@ func respond(c echo.Context, o *response) error {
if html, err := HtmlMarshaller(m); err != nil {
return err
} else {
return c.HTML(o.code, html)
return c.HTML(status, html)
}
case "json":
return c.JSON(o.code, m)
return c.JSON(status, m)
case "jsonp":
qs := c.Request().URL.Query()
for _, name := range JsonpCallbacks {
if cb := qs.Get(name); cb != "" {
return c.JSONP(o.code, cb, m)
return c.JSONP(status, cb, m)
}
}
return c.JSONP(o.code, DefaultJsonpCallback, m)
return c.JSONP(status, DefaultJsonpCallback, m)
case "xml":
return c.XML(o.code, m)
return c.XML(status, m)
case "text", "text/*":
if text, err := TextMarshaller(m); err != nil {
return err
} else {
return c.String(o.code, text)
return c.String(status, text)
}
}
return c.JSON(o.code, m)
return c.JSON(status, m)
}
func relevantCaller() []string {
@ -222,7 +266,7 @@ func relevantCaller() []string {
}
func Respond(c echo.Context, opts ...Option) error {
o := response{code: http.StatusOK}
o := response{status: http.StatusOK}
for _, option := range opts {
option(&o)
}
@ -239,15 +283,11 @@ func Created(c echo.Context, data any) error {
// Fail 响应一个错误
func Fail(c echo.Context, err error, opts ...Option) error {
o := response{code: http.StatusInternalServerError}
o := response{status: http.StatusInternalServerError}
for _, option := range opts {
option(&o)
}
o.err = err
var he *echo.HTTPError
if errors.As(err, &he) {
o.code = he.Code
}
return respond(c, &o)
}

@ -1,45 +0,0 @@
package ticket
import "github.com/golang-jwt/jwt/v5"
type Claims struct {
ID string `json:"jti,omitempty"` // ticket编号
UID uint `json:"uid"` // 用户编号
Role string `json:"role"` // 账号类型
Issuer string `json:"iss,omitempty"` // 签发人
Subject string `json:"sub,omitempty"` // 主题
Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众
ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间
NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间
IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间
}
// GetExpirationTime implements the Claims interface.
func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) {
return c.ExpiresAt, nil
}
// GetNotBefore implements the Claims interface.
func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) {
return c.NotBefore, nil
}
// GetIssuedAt implements the Claims interface.
func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) {
return c.IssuedAt, nil
}
// GetAudience implements the Claims interface.
func (c *Claims) GetAudience() (jwt.ClaimStrings, error) {
return c.Audience, nil
}
// GetIssuer implements the Claims interface.
func (c *Claims) GetIssuer() (string, error) {
return c.Issuer, nil
}
// GetSubject implements the Claims interface.
func (c *Claims) GetSubject() (string, error) {
return c.Subject, nil
}

@ -0,0 +1,95 @@
package ticket
import (
"crypto/rsa"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"slices"
"sorbet/pkg/env"
)
type Config struct {
Skipper func(c echo.Context) bool
Anonymously bool // 是否允许匿名访问
Audiences []string
PublicKey *rsa.PublicKey
TicketKey string
TicketClaimsKey string
TicketFinder Finder
ClaimsLooker func(c echo.Context, claims *Claims) error
ErrorHandler func(c echo.Context, err error) error
SuccessHandler func(c echo.Context)
}
func Middleware(anonymously bool, roles ...string) echo.MiddlewareFunc {
return WithConfig(Config{
Anonymously: anonymously,
TicketFinder: DefaultFinder,
ClaimsLooker: func(c echo.Context, claims *Claims) error {
if len(roles) > 0 && slices.Contains(roles, claims.Role) {
return nil
}
return ErrUnauthorized
},
})
}
func WithConfig(config Config) echo.MiddlewareFunc {
return config.ToMiddleware()
}
func (t Config) ToMiddleware() echo.MiddlewareFunc {
if t.Skipper == nil {
t.Skipper = func(c echo.Context) bool { return false }
}
if len(t.Audiences) == 0 {
t.Audiences = append(t.Audiences, "*")
}
if t.TicketFinder == nil {
t.TicketFinder = DefaultFinder
}
if t.ErrorHandler == nil {
t.ErrorHandler = func(c echo.Context, err error) error { return err }
}
if t.PublicKey == nil {
var err error
source := []byte(env.String("TICKET_PUBLIC_KEY"))
t.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(source)
if err != nil {
panic(err)
}
}
if t.TicketKey == "" {
t.TicketKey = "ticket"
}
if t.TicketClaimsKey == "" {
t.TicketClaimsKey = "ticket_claims"
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if t.Skipper(c) {
return next(c)
}
ticketString := t.TicketFinder(c.Request())
if ticketString == "" {
if t.Anonymously {
return next(c)
}
return t.ErrorHandler(c, ErrNoTicketFound)
}
claims, err := Verify(ticketString, t.PublicKey, t.Audiences...)
if err != nil {
return t.ErrorHandler(c, err)
}
if t.ClaimsLooker != nil {
err = t.ClaimsLooker(c, claims)
if err != nil {
return t.ErrorHandler(c, err)
}
}
c.Set(t.TicketKey, ticketString)
c.Set(t.TicketClaimsKey, claims)
return next(c)
}
}
}

@ -6,7 +6,10 @@ import (
"errors"
"github.com/golang-jwt/jwt/v5"
"github.com/rs/xid"
"net/http"
"slices"
"sorbet/pkg/env"
"strings"
"time"
)
@ -18,22 +21,46 @@ var (
ErrBadAudience = errors.New("bad audience")
)
// Create 创建令牌
// 如果参数 claims 中未给出过期时间,将默认1小时过期
// 如果未给出令牌编号,则自动生成。
func Create(claims *Claims, key *rsa.PrivateKey) (string, error) {
if claims.ID == "" {
claims.ID = xid.New().String()
}
if claims.ExpiresAt == nil {
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(time.Hour))
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedString, err := token.SignedString(key)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString([]byte(signedString)), nil
type Claims struct {
ID string `json:"jti,omitempty"` // ticket编号
UID uint `json:"uid"` // 用户编号
Role string `json:"role"` // 账号类型
Issuer string `json:"iss,omitempty"` // 签发人
Subject string `json:"sub,omitempty"` // 主题
Audience jwt.ClaimStrings `json:"aud,omitempty"` // 受众
ExpiresAt *jwt.NumericDate `json:"exp,omitempty"` // 过期时间
NotBefore *jwt.NumericDate `json:"nbf,omitempty"` // 生效时间
IssuedAt *jwt.NumericDate `json:"iat,omitempty"` // 签发时间
}
// GetExpirationTime implements the Claims interface.
func (c *Claims) GetExpirationTime() (*jwt.NumericDate, error) {
return c.ExpiresAt, nil
}
// GetNotBefore implements the Claims interface.
func (c *Claims) GetNotBefore() (*jwt.NumericDate, error) {
return c.NotBefore, nil
}
// GetIssuedAt implements the Claims interface.
func (c *Claims) GetIssuedAt() (*jwt.NumericDate, error) {
return c.IssuedAt, nil
}
// GetAudience implements the Claims interface.
func (c *Claims) GetAudience() (jwt.ClaimStrings, error) {
return c.Audience, nil
}
// GetIssuer implements the Claims interface.
func (c *Claims) GetIssuer() (string, error) {
return c.Issuer, nil
}
// GetSubject implements the Claims interface.
func (c *Claims) GetSubject() (string, error) {
return c.Subject, nil
}
// Verify 验证令牌
@ -76,3 +103,78 @@ NEXT:
}
return claims, nil
}
// Create 创建令牌
//
// 如果参数
// claims
// 中未给出过期时间,将默认1小时过期
// 如果未给出令牌编号,则自动生成。
func Create(claims *Claims) (string, error) {
if claims.ID == "" {
claims.ID = xid.New().String()
}
if claims.Issuer == "" {
claims.Issuer = env.String("TICKET_ISSUER")
}
if claims.Subject == "" {
claims.Issuer = env.String("TICKET_SUBJECT")
}
if claims.Audience == nil {
claims.Audience = env.List("TICKET_AUDIENCE")
}
if claims.ExpiresAt == nil {
ttl := env.Duration("TICKET_TTL", time.Hour)
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(ttl))
}
source := []byte(env.String("TICKET_PRIVATE_KEY"))
privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(source)
if err != nil {
return "", err
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedString, err := token.SignedString(privateKey)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString([]byte(signedString)), nil
}
type Finder func(r *http.Request) string
func DefaultFinder(r *http.Request) string {
if s := FromHeader(r); s != "" {
return s
}
if s := FromCookie(r); s != "" {
return s
}
return FromQuery(r)
}
func FromCookie(r *http.Request) string {
cookie, err := r.Cookie("ticket")
if err != nil {
return ""
}
return cookie.Value
}
func FromHeader(r *http.Request) string {
bearer := r.Header.Get("Authorization")
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
return bearer[7:]
}
return ""
}
func FromQuery(r *http.Request, keys ...string) string {
q := r.URL.Query()
for _, key := range keys {
s := q.Get(key)
if s != "" {
return s
}
}
return q.Get("ticket")
}

@ -1,45 +0,0 @@
package ticket
import (
"net/http"
"strings"
)
type Finder func(r *http.Request) string
func DefaultFinder(r *http.Request) string {
if s := FromHeader(r); s != "" {
return s
}
if s := FromCookie(r); s != "" {
return s
}
return FromQuery(r)
}
func FromCookie(r *http.Request) string {
cookie, err := r.Cookie("ticket")
if err != nil {
return ""
}
return cookie.Value
}
func FromHeader(r *http.Request) string {
bearer := r.Header.Get("Authorization")
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
return bearer[7:]
}
return ""
}
func FromQuery(r *http.Request, keys ...string) string {
q := r.URL.Query()
for _, key := range keys {
s := q.Get(key)
if s != "" {
return s
}
}
return q.Get("ticket")
}
Loading…
Cancel
Save