diff --git a/app/db.go b/app/db.go index 1b8bf8b..1be0154 100644 --- a/app/db.go +++ b/app/db.go @@ -34,6 +34,8 @@ type User struct { // UserToken 用户令牌 type UserToken struct { Code string `gorm:"primarykey"` // 主键 + UserID uint // 用户ID + User *User // 关联用户 AccessToken string // 授权令牌 RefreshToken string // 刷新令牌 CreatedAt time.Time // 创建时间 diff --git a/app/fns.go b/app/fns.go index 95079d0..aa62e32 100644 --- a/app/fns.go +++ b/app/fns.go @@ -1,6 +1,7 @@ package app import ( + "context" "errors" "github.com/go-chi/jwtauth/v5" "github.com/lestrrat-go/jwx/v2/jwt" @@ -32,7 +33,6 @@ func GenerateAuthToken(r *Request, uid uint) (*UserToken, error) { //JwtID(). //NotBefore(). //Subject(). - Claim("uid", uid). Claim("code", code). Build() if err != nil { @@ -62,45 +62,63 @@ func GenerateAuthToken(r *Request, uid uint) (*UserToken, error) { return &UserToken{ Code: code, + UserID: uid, AccessToken: accessToken, RefreshToken: refreshToken, CreatedAt: time.Now(), }, nil } -func AuthInfo(r *Request) (*UserToken, uint, error) { +func AuthInfo(r *Request) (*UserToken, error) { token, _, err := jwtauth.FromContext(r.Context()) if err != nil { - return nil, 0, err + return nil, err } - ex := NewError(403, "错误令牌") - ex.Status = http.StatusForbidden - code, ok := token.Get("code") if !ok { - return nil, 0, nil + return nil, ErrInvalidToken } if _, ok = code.(string); !ok { - return nil, 0, ex - } - - uidUint, ok := token.Get("uid") - if !ok { - return nil, 0, ex - } - uid, ok := uidUint.(uint) - if !ok { - return nil, 0, ex + return nil, ErrInvalidToken } var ut UserToken - if err = DB.First(&ut, "code = ?", code).Error; err != nil { + err = DB.Model(&UserToken{}).Preload("User").First(&ut, "code = ?", code).Error + if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, 0, ex + return nil, ErrInvalidToken } - return nil, 0, err + return nil, err } - return &ut, uid, nil + return &ut, nil +} + +var ErrInvalidToken = &Error{ + Status: http.StatusForbidden, + Code: 401, + Message: "错误令牌", +} + +func CheckAuthToken(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token, _, _ := jwtauth.FromContext(r.Context()) + if code, ok := token.Get("code"); !ok { + NewResponseWriter(w).Error(ErrInvalidToken) + } else if codeString, ok := code.(string); !ok { + NewResponseWriter(w).Error(ErrInvalidToken) + } else if _, err := xid.FromString(codeString); err != nil { + NewResponseWriter(w).Error(ErrInvalidToken) + } else { + var ut UserToken + err = DB.First(&ut, "code = ?", codeString).Error + if err != nil { + NewResponseWriter(w).Error(ErrInvalidToken) + } else { + ctx := context.WithValue(r.Context(), "USER_TOKEN", &ut) + next.ServeHTTP(w, r.WithContext(ctx)) + } + } + }) } diff --git a/app/rts.go b/app/rts.go index 0b1a8a1..3e4743b 100644 --- a/app/rts.go +++ b/app/rts.go @@ -78,6 +78,7 @@ func CreateUser(w *ResponseWriter, r *Request) { Name: ui.Name, PhoneNumber: ui.PhoneNumber, Password: hash, + Admin: false, } if err := DB.Create(&user).Error; err != nil { w.Error(err) @@ -249,14 +250,14 @@ func Login(w *ResponseWriter, r *Request) { // RefreshToken 刷新授权令牌 func RefreshToken(w *ResponseWriter, r *Request) { // 获取刷新令牌信息 - ut, uid, err := AuthInfo(r) + ut, err := AuthInfo(r) if err != nil { w.Error(err) return } // 生成新的令牌 - ut2, err := GenerateAuthToken(r, uid) + ut2, err := GenerateAuthToken(r, ut.UserID) if err != nil { w.Error(err) return