You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
148 lines
3.2 KiB
148 lines
3.2 KiB
package middleware
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"ims/app/models"
|
|
"ims/util/db"
|
|
"ims/util/jwt"
|
|
"ims/util/rdb"
|
|
|
|
"github.com/jinzhu/inflection"
|
|
"github.com/redis/go-redis/v9"
|
|
"golang.org/x/sync/singleflight"
|
|
"zestack.dev/log"
|
|
"zestack.dev/slim"
|
|
)
|
|
|
|
var sfg singleflight.Group
|
|
|
|
type JWTExtra struct {
|
|
AccountID uint `redis:"account"`
|
|
EmployeeID uint `redis:"employee"`
|
|
TenantID uint `redis:"tenant"`
|
|
|
|
Account *models.Account
|
|
Employee *models.Employee
|
|
Tenant *models.Tenant
|
|
}
|
|
|
|
func Auth(withTenant, anonymously bool) slim.MiddlewareFunc {
|
|
return jwt.Auth(jwt.AuthConfig{
|
|
Skipper: func(c slim.Context) bool {
|
|
if c.RouteMatchType() == slim.RouteMatchUnknown {
|
|
panic("unknown route")
|
|
}
|
|
switch c.RouteInfo().Name() {
|
|
case "login":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
},
|
|
Anonymously: anonymously,
|
|
Claims: func(c slim.Context, token string, _ *jwt.Claims) error {
|
|
var extra JWTExtra
|
|
if err := extra.load(c, token); err != nil {
|
|
return err
|
|
}
|
|
// 针对租户,如果加载失败,说明没有权限
|
|
if withTenant && (extra.Tenant == nil || extra.Employee == nil) {
|
|
return jwt.ErrForbidden
|
|
}
|
|
|
|
// // 如果租户数据没有初始化,那就只能访问有限的几个页面,
|
|
// // 比如基本设置、初始化等。
|
|
// if withTenant && !tenant.IsReady() {
|
|
// return c.Redirect(http.StatusMovedPermanently, "/tenant/init")
|
|
// }
|
|
|
|
return nil
|
|
},
|
|
})
|
|
}
|
|
|
|
func (e *JWTExtra) load(c slim.Context, token string) error {
|
|
// 加载额外信息
|
|
if err := rdb.Redis().HGetAll(c, "jwt:"+token).Scan(e); err != nil {
|
|
return err
|
|
}
|
|
// 加载用户信息
|
|
if err := e.loadAccount(c); err != nil {
|
|
return err
|
|
}
|
|
// 加载员工信息
|
|
if err := e.loadEmployee(c); err != nil {
|
|
return err
|
|
}
|
|
// 加载租户信息
|
|
if err := e.loadTenant(c); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *JWTExtra) loadAccount(c slim.Context) error {
|
|
var account models.Account
|
|
err := e.loadModel(c, "account", e.AccountID, &account)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
e.Account = &account
|
|
return nil
|
|
}
|
|
|
|
func (e *JWTExtra) loadEmployee(c slim.Context) error {
|
|
if e.EmployeeID == 0 {
|
|
return nil
|
|
}
|
|
var employee models.Employee
|
|
err := e.loadModel(c, "employee", e.EmployeeID, &employee)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
e.Employee = &employee
|
|
return nil
|
|
}
|
|
|
|
func (e *JWTExtra) loadTenant(c slim.Context) error {
|
|
if e.TenantID == 0 {
|
|
return nil
|
|
}
|
|
var tenant models.Tenant
|
|
err := e.loadModel(c, "tenant", e.TenantID, &tenant)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
e.Tenant = &tenant
|
|
return nil
|
|
}
|
|
|
|
func (e *JWTExtra) loadModel(c slim.Context, label string, id uint, val any) error {
|
|
client := rdb.Redis()
|
|
key := fmt.Sprintf("%s:%d", inflection.Plural(label), id)
|
|
_, err, _ := sfg.Do(key, func() (any, error) {
|
|
// 加载 Redis 中的数据
|
|
err := client.HGetAll(c, key).Scan(val)
|
|
if err == nil {
|
|
return nil, nil
|
|
}
|
|
if !errors.Is(err, redis.Nil) {
|
|
return nil, err
|
|
}
|
|
// 加载数据库中的数据
|
|
err = db.Engine().First(val, id).Error
|
|
if err == nil && val != nil {
|
|
err = client.HSet(c, key, val).Err()
|
|
if err != nil {
|
|
log.Error("failed to cache "+label, "error", err, "id", id)
|
|
}
|
|
}
|
|
return nil, err
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.Set("jwt:"+label, val)
|
|
return nil
|
|
}
|
|
|