You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

148 lines
3.2 KiB

package middleware
import (
"errors"
"fmt"
"ims/app/models"
"ims/util/db"
"ims/util/jwt"
"ims/util/rdb"
"github.com/jinzhu/inflection"
"github.com/redis/go-redis/v9"
"golang.org/x/sync/singleflight"
"zestack.dev/log"
"zestack.dev/slim"
)
var sfg singleflight.Group
type JWTExtra struct {
AccountID uint `redis:"account"`
EmployeeID uint `redis:"employee"`
TenantID uint `redis:"tenant"`
Account *models.Account
Employee *models.Employee
Tenant *models.Tenant
}
func Auth(withTenant, anonymously bool) slim.MiddlewareFunc {
return jwt.Auth(jwt.AuthConfig{
Skipper: func(c slim.Context) bool {
if c.RouteMatchType() == slim.RouteMatchUnknown {
panic("unknown route")
}
switch c.RouteInfo().Name() {
case "login":
return true
default:
return false
}
},
Anonymously: anonymously,
Claims: func(c slim.Context, token string, _ *jwt.Claims) error {
var extra JWTExtra
if err := extra.load(c, token); err != nil {
return err
}
// 针对租户,如果加载失败,说明没有权限
if withTenant && (extra.Tenant == nil || extra.Employee == nil) {
return jwt.ErrForbidden
}
// // 如果租户数据没有初始化,那就只能访问有限的几个页面,
// // 比如基本设置、初始化等。
// if withTenant && !tenant.IsReady() {
// return c.Redirect(http.StatusMovedPermanently, "/tenant/init")
// }
return nil
},
})
}
func (e *JWTExtra) load(c slim.Context, token string) error {
// 加载额外信息
if err := rdb.Redis().HGetAll(c, "jwt:"+token).Scan(e); err != nil {
return err
}
// 加载用户信息
if err := e.loadAccount(c); err != nil {
return err
}
// 加载员工信息
if err := e.loadEmployee(c); err != nil {
return err
}
// 加载租户信息
if err := e.loadTenant(c); err != nil {
return err
}
return nil
}
func (e *JWTExtra) loadAccount(c slim.Context) error {
var account models.Account
err := e.loadModel(c, "account", e.AccountID, &account)
if err != nil {
return err
}
e.Account = &account
return nil
}
func (e *JWTExtra) loadEmployee(c slim.Context) error {
if e.EmployeeID == 0 {
return nil
}
var employee models.Employee
err := e.loadModel(c, "employee", e.EmployeeID, &employee)
if err != nil {
return err
}
e.Employee = &employee
return nil
}
func (e *JWTExtra) loadTenant(c slim.Context) error {
if e.TenantID == 0 {
return nil
}
var tenant models.Tenant
err := e.loadModel(c, "tenant", e.TenantID, &tenant)
if err != nil {
return err
}
e.Tenant = &tenant
return nil
}
func (e *JWTExtra) loadModel(c slim.Context, label string, id uint, val any) error {
client := rdb.Redis()
key := fmt.Sprintf("%s:%d", inflection.Plural(label), id)
_, err, _ := sfg.Do(key, func() (any, error) {
// 加载 Redis 中的数据
err := client.HGetAll(c, key).Scan(val)
if err == nil {
return nil, nil
}
if !errors.Is(err, redis.Nil) {
return nil, err
}
// 加载数据库中的数据
err = db.Engine().First(val, id).Error
if err == nil && val != nil {
err = client.HSet(c, key, val).Err()
if err != nil {
log.Error("failed to cache "+label, "error", err, "id", id)
}
}
return nil, err
})
if err != nil {
return err
}
c.Set("jwt:"+label, val)
return nil
}