You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
313 lines
5.8 KiB
313 lines
5.8 KiB
1 year ago
|
package util
|
||
|
|
||
|
import (
|
||
|
"github.com/labstack/echo/v4"
|
||
|
"gorm.io/gorm"
|
||
|
"net/url"
|
||
|
"reflect"
|
||
|
"sorbet/pkg/db"
|
||
|
"sorbet/pkg/ioc"
|
||
|
"sorbet/pkg/rsp"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
)
|
||
|
|
||
|
type GetID interface {
|
||
|
GetID() any
|
||
|
}
|
||
|
|
||
|
type ToMap interface {
|
||
|
ToMap() map[string]any
|
||
|
}
|
||
|
|
||
|
type ToEntity interface {
|
||
|
ToEntity() any
|
||
|
}
|
||
|
|
||
|
type ControllerRequest interface {
|
||
|
GetID
|
||
|
ToMap
|
||
|
ToEntity
|
||
|
}
|
||
|
|
||
|
func ParseQuery[T any](query url.Values, qb *db.QueryBuilder[T]) (page, limit int, err error) {
|
||
|
var paginating bool
|
||
|
for key, values := range query {
|
||
|
switch key {
|
||
|
case "sortby":
|
||
|
for _, s := range values {
|
||
|
if s[0] == '+' {
|
||
|
qb.AscentBy(s[1:])
|
||
|
} else if s[0] == '-' {
|
||
|
qb.DescentBy(s[1:])
|
||
|
} else {
|
||
|
qb.AscentBy(s)
|
||
|
}
|
||
|
}
|
||
|
case "limit", "page":
|
||
|
var v int
|
||
|
if values[0] != "" {
|
||
|
v, err = strconv.Atoi(values[0])
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
if v <= 0 {
|
||
|
return 0, 0, rsp.ErrInternal
|
||
|
}
|
||
|
if key == "limit" {
|
||
|
qb.Limit(v)
|
||
|
limit = v
|
||
|
} else {
|
||
|
paginating = true
|
||
|
page = max(v, 1)
|
||
|
}
|
||
|
default:
|
||
|
v := values[0]
|
||
|
i := strings.IndexByte(key, '#')
|
||
|
if i == -1 {
|
||
|
qb.Eq(key, v)
|
||
|
continue
|
||
|
}
|
||
|
switch k, op := key[:i], key[i+1:]; op {
|
||
|
case "=":
|
||
|
qb.Eq(k, v)
|
||
|
case "!=":
|
||
|
qb.Neq(k, v)
|
||
|
case "<":
|
||
|
qb.Lt(k, v)
|
||
|
case "<=":
|
||
|
qb.Lte(k, v)
|
||
|
case ">":
|
||
|
qb.Gt(k, v)
|
||
|
case ">=":
|
||
|
qb.Gte(k, v)
|
||
|
case "<>", "><":
|
||
|
var less, more any
|
||
|
switch len(values) {
|
||
|
case 2:
|
||
|
less, more = values[0], values[1]
|
||
|
case 1:
|
||
|
vs := strings.Split(v, ",")
|
||
|
if len(vs) != 2 || vs[0] == "" || vs[1] == "" {
|
||
|
return 0, 0, rsp.ErrBadParams
|
||
|
}
|
||
|
less, more = vs[0], vs[1]
|
||
|
}
|
||
|
if op == "<>" {
|
||
|
qb.Between(k, less, more)
|
||
|
} else {
|
||
|
qb.NotBetween(k, key, more)
|
||
|
}
|
||
|
case "nil":
|
||
|
qb.IsNull(k)
|
||
|
case "!nil":
|
||
|
qb.NotNull(k)
|
||
|
case "~":
|
||
|
qb.Like(k, v)
|
||
|
case "!~":
|
||
|
qb.NotLike(k, v)
|
||
|
case "in", "!in":
|
||
|
if len(values) == 1 {
|
||
|
values = strings.Split(v, ",")
|
||
|
}
|
||
|
vs := make([]any, len(values))
|
||
|
for i, value := range values {
|
||
|
vs[i] = value
|
||
|
}
|
||
|
if op == "in" {
|
||
|
qb.In(k, vs...)
|
||
|
} else {
|
||
|
qb.NotIn(k, vs...)
|
||
|
}
|
||
|
default:
|
||
|
qb.Eq(key, v)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if paginating {
|
||
|
return
|
||
|
}
|
||
|
if limit == 0 {
|
||
|
limit = 30
|
||
|
qb.Limit(limit)
|
||
|
}
|
||
|
qb.Offset((page - 1) * limit)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func getID(req any) (val any) {
|
||
|
defer func() {
|
||
|
if recover() != nil {
|
||
|
val = nil
|
||
|
}
|
||
|
}()
|
||
|
val = req.(GetID).GetID()
|
||
|
rv := reflect.ValueOf(val)
|
||
|
if rv.IsZero() || rv.IsNil() {
|
||
|
val = nil
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func getValues(req any) map[string]any {
|
||
|
if v, ok := req.(ToMap); ok {
|
||
|
return v.ToMap()
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func getEntity[T any](request any) *T {
|
||
|
v, ok := request.(ToEntity)
|
||
|
if !ok {
|
||
|
return nil
|
||
|
}
|
||
|
ent, ok := v.ToEntity().(*T)
|
||
|
if ok {
|
||
|
return ent
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Controller 控制器基类
|
||
|
//
|
||
|
// 泛型 [Entity] 表示操作的具体数据;
|
||
|
// 泛型 [Upsert] 表示创建或更新时需要的数据。
|
||
|
type Controller[Entity any, Upsert any] struct{}
|
||
|
|
||
|
func (ctr *Controller[Entity, Upsert]) RegisterRoutes(path string, r *echo.Group) {
|
||
|
r.PUT(path, ctr.Create)
|
||
|
r.DELETE(path+"/:id", ctr.Delete)
|
||
|
r.POST(path+"/:id", ctr.Update)
|
||
|
r.GET(path+"/:id", ctr.Get)
|
||
|
r.GET(path, ctr.List)
|
||
|
}
|
||
|
|
||
|
func (ctr *Controller[Entity, Upsert]) ORM(c echo.Context) (*gorm.DB, error) {
|
||
|
orm, err := ioc.Get[gorm.DB]()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return orm.WithContext(c.Request().Context()), nil
|
||
|
}
|
||
|
|
||
|
func (ctr *Controller[Entity, Upsert]) Repository() (*db.Repository[Entity], error) {
|
||
|
orm, err := ioc.Get[gorm.DB]()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return db.NewRepository[Entity](orm), nil
|
||
|
}
|
||
|
|
||
|
// Create 创建数据
|
||
|
func (ctr *Controller[Entity, Upsert]) Create(c echo.Context) error {
|
||
|
return ctr.upsert(c, true)
|
||
|
}
|
||
|
|
||
|
// Update 更新数据
|
||
|
func (ctr *Controller[Entity, Upsert]) Update(c echo.Context) error {
|
||
|
return ctr.upsert(c, false)
|
||
|
}
|
||
|
|
||
|
func (ctr *Controller[Entity, Upsert]) upsert(c echo.Context, isCreate bool) error {
|
||
|
request, err := Bind[Upsert](c)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
repo, err := ctr.Repository()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
id := getID(request)
|
||
|
if isCreate != reflect.ValueOf(id).IsZero() {
|
||
|
return rsp.BadParams(c, "参数错误")
|
||
|
}
|
||
|
// 更新数据
|
||
|
if !isCreate {
|
||
|
values := getValues(request)
|
||
|
if values == nil {
|
||
|
return rsp.ErrInternal
|
||
|
}
|
||
|
err = repo.UpdateByID(c.Request().Context(), id, values)
|
||
|
if err == nil {
|
||
|
// TODO(hupeh): 返回更新后的实体数据
|
||
|
return rsp.Ok(c, nil)
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
// 创建数据
|
||
|
group := getEntity[Entity](request)
|
||
|
if group == nil {
|
||
|
return rsp.ErrInternal
|
||
|
}
|
||
|
err = repo.Create(c.Request().Context(), group)
|
||
|
if err != nil {
|
||
|
return rsp.Created(c, group)
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Delete 通过ID删除数据
|
||
|
func (ctr *Controller[Entity, Upsert]) Delete(c echo.Context) error {
|
||
|
id, err := BindId(c, true)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
repo, err := ctr.Repository()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
err = repo.DeleteByID(c.Request().Context(), id)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return rsp.Ok(c, nil)
|
||
|
}
|
||
|
|
||
|
// Get 通过ID获取数据
|
||
|
func (ctr *Controller[Entity, Upsert]) Get(c echo.Context) error {
|
||
|
id, err := BindId(c, true)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
repo, err := ctr.Repository()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
entity, err := repo.GetByID(c.Request().Context(), id)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return rsp.Ok(c, entity)
|
||
|
}
|
||
|
|
||
|
// List 获取数据列表
|
||
|
func (ctr *Controller[Entity, Upsert]) List(c echo.Context) error {
|
||
|
repo, err := ctr.Repository()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
qb := repo.NewQueryBuilder(c.Request().Context())
|
||
|
_, _, err = ParseQuery[Entity](c.QueryParams(), qb)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// 不是分页查询
|
||
|
if !c.QueryParams().Has("page") {
|
||
|
var result []*Entity
|
||
|
err = qb.Find(&result)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return rsp.Ok(c, result)
|
||
|
}
|
||
|
|
||
|
// 分页查询
|
||
|
pager, err := qb.Paginate()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return rsp.Ok(c, pager)
|
||
|
}
|