|
|
|
package app
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"errors"
|
|
|
|
"github.com/go-chi/jwtauth/v5"
|
|
|
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
|
|
|
"github.com/rs/xid"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
|
"gorm.io/gorm"
|
|
|
|
"net/http"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
func HashPassword(password string) (string, error) {
|
|
|
|
bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14)
|
|
|
|
return string(bytes), err
|
|
|
|
}
|
|
|
|
|
|
|
|
func CheckPasswordHash(password, hash string) bool {
|
|
|
|
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
|
|
|
return err == nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func GenerateAuthToken(r *Request, uid uint) (*UserToken, error) {
|
|
|
|
code := xid.New().String()
|
|
|
|
|
|
|
|
rawToken, err := jwt.NewBuilder().
|
|
|
|
//Audience().
|
|
|
|
Expiration(time.Now().Add(time.Hour*24)).
|
|
|
|
Issuer(r.URL.Hostname()).
|
|
|
|
IssuedAt(time.Now()).
|
|
|
|
//JwtID().
|
|
|
|
//NotBefore().
|
|
|
|
//Subject().
|
|
|
|
Claim("code", code).
|
|
|
|
Build()
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
claims, err := rawToken.AsMap(r.Context())
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
_, accessToken, err := tokenAuth.Encode(claims)
|
|
|
|
if err != nil {
|
|
|
|
return nil, NewError(1, "生成授权令牌失败")
|
|
|
|
}
|
|
|
|
|
|
|
|
if err = rawToken.Set(jwt.ExpirationKey, time.Now().Add(time.Hour*24*30)); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
if claims, err = rawToken.AsMap(r.Context()); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
_, refreshToken, err := tokenAuth.Encode(claims)
|
|
|
|
if err != nil {
|
|
|
|
return nil, NewError(1, "生成刷新令牌失败")
|
|
|
|
}
|
|
|
|
|
|
|
|
return &UserToken{
|
|
|
|
Code: code,
|
|
|
|
UserID: uid,
|
|
|
|
AccessToken: accessToken,
|
|
|
|
RefreshToken: refreshToken,
|
|
|
|
CreatedAt: time.Now(),
|
|
|
|
}, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func AuthInfo(r *Request) (*UserToken, error) {
|
|
|
|
token, _, err := jwtauth.FromContext(r.Context())
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
code, ok := token.Get("code")
|
|
|
|
if !ok {
|
|
|
|
return nil, ErrInvalidToken
|
|
|
|
}
|
|
|
|
if _, ok = code.(string); !ok {
|
|
|
|
return nil, ErrInvalidToken
|
|
|
|
}
|
|
|
|
|
|
|
|
var ut UserToken
|
|
|
|
err = DB.Model(&UserToken{}).Preload("User").First(&ut, "code = ?", code).Error
|
|
|
|
if err != nil {
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
return nil, ErrInvalidToken
|
|
|
|
}
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &ut, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
var ErrInvalidToken = &Error{
|
|
|
|
Status: http.StatusForbidden,
|
|
|
|
Code: 401,
|
|
|
|
Message: "错误令牌",
|
|
|
|
}
|
|
|
|
|
|
|
|
func UserTokenFromContext(ctx context.Context) (*UserToken, bool) {
|
|
|
|
ut, ok := ctx.Value("USER_TOKEN").(*UserToken)
|
|
|
|
return ut, ok
|
|
|
|
}
|
|
|
|
|
|
|
|
func CheckAuthToken(next http.Handler) http.Handler {
|
|
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
token, _, _ := jwtauth.FromContext(r.Context())
|
|
|
|
if code, ok := token.Get("code"); !ok {
|
|
|
|
NewResponseWriter(w).Error(ErrInvalidToken)
|
|
|
|
} else if codeString, ok := code.(string); !ok {
|
|
|
|
NewResponseWriter(w).Error(ErrInvalidToken)
|
|
|
|
} else if _, err := xid.FromString(codeString); err != nil {
|
|
|
|
NewResponseWriter(w).Error(ErrInvalidToken)
|
|
|
|
} else {
|
|
|
|
var ut UserToken
|
|
|
|
err = DB.Model(&UserToken{}).Preload("User").First(&ut, "code = ?", codeString).Error
|
|
|
|
if err != nil {
|
|
|
|
NewResponseWriter(w).Error(ErrInvalidToken)
|
|
|
|
} else {
|
|
|
|
ctx := context.WithValue(r.Context(), "USER_TOKEN", &ut)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|