|
|
|
@ -1,6 +1,7 @@ |
|
|
|
|
package app |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"context" |
|
|
|
|
"errors" |
|
|
|
|
"github.com/go-chi/jwtauth/v5" |
|
|
|
|
"github.com/lestrrat-go/jwx/v2/jwt" |
|
|
|
@ -32,7 +33,6 @@ func GenerateAuthToken(r *Request, uid uint) (*UserToken, error) { |
|
|
|
|
//JwtID().
|
|
|
|
|
//NotBefore().
|
|
|
|
|
//Subject().
|
|
|
|
|
Claim("uid", uid). |
|
|
|
|
Claim("code", code). |
|
|
|
|
Build() |
|
|
|
|
if err != nil { |
|
|
|
@ -62,45 +62,63 @@ func GenerateAuthToken(r *Request, uid uint) (*UserToken, error) { |
|
|
|
|
|
|
|
|
|
return &UserToken{ |
|
|
|
|
Code: code, |
|
|
|
|
UserID: uid, |
|
|
|
|
AccessToken: accessToken, |
|
|
|
|
RefreshToken: refreshToken, |
|
|
|
|
CreatedAt: time.Now(), |
|
|
|
|
}, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func AuthInfo(r *Request) (*UserToken, uint, error) { |
|
|
|
|
func AuthInfo(r *Request) (*UserToken, error) { |
|
|
|
|
token, _, err := jwtauth.FromContext(r.Context()) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, 0, err |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
ex := NewError(403, "错误令牌") |
|
|
|
|
ex.Status = http.StatusForbidden |
|
|
|
|
|
|
|
|
|
code, ok := token.Get("code") |
|
|
|
|
if !ok { |
|
|
|
|
return nil, 0, nil |
|
|
|
|
return nil, ErrInvalidToken |
|
|
|
|
} |
|
|
|
|
if _, ok = code.(string); !ok { |
|
|
|
|
return nil, 0, ex |
|
|
|
|
return nil, ErrInvalidToken |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
uidUint, ok := token.Get("uid") |
|
|
|
|
if !ok { |
|
|
|
|
return nil, 0, ex |
|
|
|
|
var ut UserToken |
|
|
|
|
err = DB.Model(&UserToken{}).Preload("User").First(&ut, "code = ?", code).Error |
|
|
|
|
if err != nil { |
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) { |
|
|
|
|
return nil, ErrInvalidToken |
|
|
|
|
} |
|
|
|
|
uid, ok := uidUint.(uint) |
|
|
|
|
if !ok { |
|
|
|
|
return nil, 0, ex |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return &ut, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var ErrInvalidToken = &Error{ |
|
|
|
|
Status: http.StatusForbidden, |
|
|
|
|
Code: 401, |
|
|
|
|
Message: "错误令牌", |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func CheckAuthToken(next http.Handler) http.Handler { |
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
token, _, _ := jwtauth.FromContext(r.Context()) |
|
|
|
|
if code, ok := token.Get("code"); !ok { |
|
|
|
|
NewResponseWriter(w).Error(ErrInvalidToken) |
|
|
|
|
} else if codeString, ok := code.(string); !ok { |
|
|
|
|
NewResponseWriter(w).Error(ErrInvalidToken) |
|
|
|
|
} else if _, err := xid.FromString(codeString); err != nil { |
|
|
|
|
NewResponseWriter(w).Error(ErrInvalidToken) |
|
|
|
|
} else { |
|
|
|
|
var ut UserToken |
|
|
|
|
if err = DB.First(&ut, "code = ?", code).Error; err != nil { |
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) { |
|
|
|
|
return nil, 0, ex |
|
|
|
|
err = DB.First(&ut, "code = ?", codeString).Error |
|
|
|
|
if err != nil { |
|
|
|
|
NewResponseWriter(w).Error(ErrInvalidToken) |
|
|
|
|
} else { |
|
|
|
|
ctx := context.WithValue(r.Context(), "USER_TOKEN", &ut) |
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx)) |
|
|
|
|
} |
|
|
|
|
return nil, 0, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return &ut, uid, nil |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|