diff --git a/app/db.go b/app/db.go index 9f54c74..7df8365 100644 --- a/app/db.go +++ b/app/db.go @@ -5,6 +5,7 @@ import ( "gorm.io/gorm" "math" "strings" + "time" ) // DB 用户数据操作 @@ -29,6 +30,14 @@ type User struct { Password string `json:"-"` // 登录密码 } +// UserToken 用户令牌 +type UserToken struct { + Code string `gorm:"primarykey"` // 主键 + AccessToken string // 授权令牌 + RefreshToken string // 刷新令牌 + CreatedAt time.Time // 创建时间 +} + // Goods 商品 type Goods struct { gorm.Model diff --git a/app/fns.go b/app/fns.go index a420c4e..95079d0 100644 --- a/app/fns.go +++ b/app/fns.go @@ -1,6 +1,15 @@ package app -import "golang.org/x/crypto/bcrypt" +import ( + "errors" + "github.com/go-chi/jwtauth/v5" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/rs/xid" + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" + "net/http" + "time" +) func HashPassword(password string) (string, error) { bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14) @@ -11,3 +20,87 @@ func CheckPasswordHash(password, hash string) bool { err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) return err == nil } + +func GenerateAuthToken(r *Request, uid uint) (*UserToken, error) { + code := xid.New().String() + + rawToken, err := jwt.NewBuilder(). + //Audience(). + Expiration(time.Now().Add(time.Hour*24)). + Issuer(r.URL.Hostname()). + IssuedAt(time.Now()). + //JwtID(). + //NotBefore(). + //Subject(). + Claim("uid", uid). + Claim("code", code). + Build() + if err != nil { + return nil, err + } + + claims, err := rawToken.AsMap(r.Context()) + if err != nil { + return nil, err + } + + _, accessToken, err := tokenAuth.Encode(claims) + if err != nil { + return nil, NewError(1, "生成授权令牌失败") + } + + if err = rawToken.Set(jwt.ExpirationKey, time.Now().Add(time.Hour*24*30)); err != nil { + return nil, err + } + if claims, err = rawToken.AsMap(r.Context()); err != nil { + return nil, err + } + _, refreshToken, err := tokenAuth.Encode(claims) + if err != nil { + return nil, NewError(1, "生成刷新令牌失败") + } + + return &UserToken{ + Code: code, + AccessToken: accessToken, + RefreshToken: refreshToken, + CreatedAt: time.Now(), + }, nil +} + +func AuthInfo(r *Request) (*UserToken, uint, error) { + token, _, err := jwtauth.FromContext(r.Context()) + if err != nil { + return nil, 0, err + } + + ex := NewError(403, "错误令牌") + ex.Status = http.StatusForbidden + + code, ok := token.Get("code") + if !ok { + return nil, 0, nil + } + if _, ok = code.(string); !ok { + return nil, 0, ex + } + + uidUint, ok := token.Get("uid") + if !ok { + return nil, 0, ex + } + uid, ok := uidUint.(uint) + if !ok { + return nil, 0, ex + } + + var ut UserToken + if err = DB.First(&ut, "code = ?", code).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, 0, ex + } + return nil, 0, err + } + + return &ut, uid, nil +} diff --git a/app/rts.go b/app/rts.go index 52a366a..0b1a8a1 100644 --- a/app/rts.go +++ b/app/rts.go @@ -190,7 +190,6 @@ func ListUser(w *ResponseWriter, r *Request) { func Login(w *ResponseWriter, r *Request) { var phoneNumber string var password string - var token string var ok bool // 提交的手机号码 @@ -215,35 +214,71 @@ func Login(w *ResponseWriter, r *Request) { return } - // 提交的设备码 - if token, ok = r.Get("token"); !ok || len(token) == 0 { - w.Error(NewError(2, "缺少设备码")) + // 查询用户是否存在 + var user User + if err := DB.First(&user, "phone_number = ?", phoneNumber).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + err = NewError(1, "手机号码或密码错误") + } + w.Error(err) return } - // 查询用户是否存在 - var user User - err := DB.First(&user, "phone_number = ?", phoneNumber).Error - if errors.Is(err, gorm.ErrRecordNotFound) { + // 验证密码 + if !CheckPasswordHash(password, user.Password) { w.Error(NewError(1, "手机号码或密码错误")) - } else if err != nil { + return + } + + ut, err := GenerateAuthToken(r, user.ID) + if err == nil { + err = DB.Create(&ut).Error + } + if err != nil { w.Error(err) - } else { - if !CheckPasswordHash(password, user.Password) { - w.Error(NewError(1, "手机号码或密码错误")) - return - } - _, tokenString, err := tokenAuth.Encode(map[string]any{"uid": user.ID, "tkn": token}) - if err != nil { - LogError(err) - w.Error(NewError(1, "登录失败")) - return + return + } + + w.Ok(map[string]any{ + "user": user, + "access_token": ut.AccessToken, + "refresh_token": ut.RefreshToken, + }, "登录成功") +} + +// RefreshToken 刷新授权令牌 +func RefreshToken(w *ResponseWriter, r *Request) { + // 获取刷新令牌信息 + ut, uid, err := AuthInfo(r) + if err != nil { + w.Error(err) + return + } + + // 生成新的令牌 + ut2, err := GenerateAuthToken(r, uid) + if err != nil { + w.Error(err) + return + } + + // 删除旧的令牌,保持新的令牌 + err = DB.Transaction(func(tx *gorm.DB) error { + if err := tx.Delete(&ut).Error; err != nil { + return err } - w.Ok(map[string]any{ - "user": user, - "token": tokenString, - }, "登录成功") + return tx.Create(&ut2).Error + }) + + if err != nil { + w.Error(err) + return } + + w.Ok(map[string]any{ + "access_token": ut.AccessToken, + "refresh_token": ut.RefreshToken, + }, "刷新令牌成功") } // CreateGoods 创建商品 diff --git a/go.mod b/go.mod index 3523056..6d1e469 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,8 @@ go 1.19 require ( github.com/go-chi/chi/v5 v5.0.8 github.com/go-chi/jwtauth/v5 v5.1.0 + github.com/rs/xid v1.4.0 + golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f gorm.io/driver/sqlite v1.4.3 gorm.io/gorm v1.24.2 ) @@ -21,5 +23,4 @@ require ( github.com/lestrrat-go/jwx/v2 v2.0.6 // indirect github.com/lestrrat-go/option v1.0.0 // indirect github.com/mattn/go-sqlite3 v1.14.15 // indirect - golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f // indirect ) diff --git a/go.sum b/go.sum index 42360dc..7f54fac 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,8 @@ github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOj github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY= +github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=