package util import ( "github.com/labstack/echo/v4" "gorm.io/gorm" "net/url" "reflect" "sorbet/pkg/db" "sorbet/pkg/ioc" "sorbet/pkg/rsp" "strconv" "strings" ) type GetID interface { GetID() any } type ToMap interface { ToMap() map[string]any } type ToEntity interface { ToEntity() any } type ControllerRequest interface { GetID ToMap ToEntity } func ParseQuery[T any](query url.Values, qb *db.QueryBuilder[T]) (page, limit int, err error) { var paginating bool for key, values := range query { switch key { case "sortby": for _, s := range values { if s[0] == '+' { qb.AscentBy(s[1:]) } else if s[0] == '-' { qb.DescentBy(s[1:]) } else { qb.AscentBy(s) } } case "limit", "page": var v int if values[0] != "" { v, err = strconv.Atoi(values[0]) if err != nil { return } } if v <= 0 { return 0, 0, rsp.ErrInternal } if key == "limit" { qb.Limit(v) limit = v } else { paginating = true page = max(v, 1) } default: v := values[0] i := strings.IndexByte(key, '#') if i == -1 { qb.Eq(key, v) continue } switch k, op := key[:i], key[i+1:]; op { case "=": qb.Eq(k, v) case "!=": qb.Neq(k, v) case "<": qb.Lt(k, v) case "<=": qb.Lte(k, v) case ">": qb.Gt(k, v) case ">=": qb.Gte(k, v) case "<>", "><": var less, more any switch len(values) { case 2: less, more = values[0], values[1] case 1: vs := strings.Split(v, ",") if len(vs) != 2 || vs[0] == "" || vs[1] == "" { return 0, 0, rsp.ErrBadParams } less, more = vs[0], vs[1] } if op == "<>" { qb.Between(k, less, more) } else { qb.NotBetween(k, key, more) } case "nil": qb.IsNull(k) case "!nil": qb.NotNull(k) case "~": qb.Like(k, v) case "!~": qb.NotLike(k, v) case "in", "!in": if len(values) == 1 { values = strings.Split(v, ",") } vs := make([]any, len(values)) for i, value := range values { vs[i] = value } if op == "in" { qb.In(k, vs...) } else { qb.NotIn(k, vs...) } default: qb.Eq(key, v) } } } if paginating { return } if limit == 0 { limit = 30 qb.Limit(limit) } qb.Offset((page - 1) * limit) return } func getID(req any) (val any) { defer func() { if recover() != nil { val = nil } }() val = req.(GetID).GetID() rv := reflect.ValueOf(val) if rv.IsZero() || rv.IsNil() { val = nil } return } func getValues(req any) map[string]any { if v, ok := req.(ToMap); ok { return v.ToMap() } return nil } func getEntity[T any](request any) *T { v, ok := request.(ToEntity) if !ok { return nil } ent, ok := v.ToEntity().(*T) if ok { return ent } return nil } // Controller 控制器基类 // // 泛型 [Entity] 表示操作的具体数据; // 泛型 [Upsert] 表示创建或更新时需要的数据。 type Controller[Entity any, Upsert any] struct{} func (ctr *Controller[Entity, Upsert]) RegisterRoutes(path string, r *echo.Group) { r.PUT(path, ctr.Create) r.DELETE(path+"/:id", ctr.Delete) r.POST(path+"/:id", ctr.Update) r.GET(path+"/:id", ctr.Get) r.GET(path, ctr.List) } func (ctr *Controller[Entity, Upsert]) ORM(c echo.Context) (*gorm.DB, error) { orm, err := ioc.Get[gorm.DB]() if err != nil { return nil, err } return orm.WithContext(c.Request().Context()), nil } func (ctr *Controller[Entity, Upsert]) Repository() (*db.Repository[Entity], error) { orm, err := ioc.Get[gorm.DB]() if err != nil { return nil, err } return db.NewRepository[Entity](orm), nil } // Create 创建数据 func (ctr *Controller[Entity, Upsert]) Create(c echo.Context) error { return ctr.upsert(c, true) } // Update 更新数据 func (ctr *Controller[Entity, Upsert]) Update(c echo.Context) error { return ctr.upsert(c, false) } func (ctr *Controller[Entity, Upsert]) upsert(c echo.Context, isCreate bool) error { request, err := Bind[Upsert](c) if err != nil { return err } repo, err := ctr.Repository() if err != nil { return err } id := getID(request) if isCreate != reflect.ValueOf(id).IsZero() { return rsp.BadParams(c, "参数错误") } // 更新数据 if !isCreate { values := getValues(request) if values == nil { return rsp.ErrInternal } err = repo.UpdateByID(c.Request().Context(), id, values) if err == nil { // TODO(hupeh): 返回更新后的实体数据 return rsp.Ok(c, nil) } return err } // 创建数据 group := getEntity[Entity](request) if group == nil { return rsp.ErrInternal } err = repo.Create(c.Request().Context(), group) if err != nil { return rsp.Created(c, group) } return err } // Delete 通过ID删除数据 func (ctr *Controller[Entity, Upsert]) Delete(c echo.Context) error { id, err := BindId(c, true) if err != nil { return err } repo, err := ctr.Repository() if err != nil { return err } err = repo.DeleteByID(c.Request().Context(), id) if err != nil { return err } return rsp.Ok(c, nil) } // Get 通过ID获取数据 func (ctr *Controller[Entity, Upsert]) Get(c echo.Context) error { id, err := BindId(c, true) if err != nil { return err } repo, err := ctr.Repository() if err != nil { return err } entity, err := repo.GetByID(c.Request().Context(), id) if err != nil { return err } return rsp.Ok(c, entity) } // List 获取数据列表 func (ctr *Controller[Entity, Upsert]) List(c echo.Context) error { repo, err := ctr.Repository() if err != nil { return err } qb := repo.NewQueryBuilder(c.Request().Context()) _, _, err = ParseQuery[Entity](c.QueryParams(), qb) if err != nil { return err } // 不是分页查询 if !c.QueryParams().Has("page") { var result []*Entity err = qb.Find(&result) if err != nil { return err } return rsp.Ok(c, result) } // 分页查询 pager, err := qb.Paginate() if err != nil { return err } return rsp.Ok(c, pager) }