go项目脚手架
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
sorbet/internal/util/controller.go

312 lines
5.8 KiB

package util
import (
"github.com/labstack/echo/v4"
"gorm.io/gorm"
"net/url"
"reflect"
"sorbet/pkg/db"
"sorbet/pkg/ioc"
"sorbet/pkg/rsp"
"strconv"
"strings"
)
type GetID interface {
GetID() any
}
type ToMap interface {
ToMap() map[string]any
}
type ToEntity interface {
ToEntity() any
}
type ControllerRequest interface {
GetID
ToMap
ToEntity
}
func ParseQuery[T any](query url.Values, qb *db.QueryBuilder[T]) (page, limit int, err error) {
var paginating bool
for key, values := range query {
switch key {
case "sortby":
for _, s := range values {
if s[0] == '+' {
qb.AscentBy(s[1:])
} else if s[0] == '-' {
qb.DescentBy(s[1:])
} else {
qb.AscentBy(s)
}
}
case "limit", "page":
var v int
if values[0] != "" {
v, err = strconv.Atoi(values[0])
if err != nil {
return
}
}
if v <= 0 {
return 0, 0, rsp.ErrInternal
}
if key == "limit" {
qb.Limit(v)
limit = v
} else {
paginating = true
page = max(v, 1)
}
default:
v := values[0]
i := strings.IndexByte(key, '#')
if i == -1 {
qb.Eq(key, v)
continue
}
switch k, op := key[:i], key[i+1:]; op {
case "=":
qb.Eq(k, v)
case "!=":
qb.Neq(k, v)
case "<":
qb.Lt(k, v)
case "<=":
qb.Lte(k, v)
case ">":
qb.Gt(k, v)
case ">=":
qb.Gte(k, v)
case "<>", "><":
var less, more any
switch len(values) {
case 2:
less, more = values[0], values[1]
case 1:
vs := strings.Split(v, ",")
if len(vs) != 2 || vs[0] == "" || vs[1] == "" {
return 0, 0, rsp.ErrBadParams
}
less, more = vs[0], vs[1]
}
if op == "<>" {
qb.Between(k, less, more)
} else {
qb.NotBetween(k, key, more)
}
case "nil":
qb.IsNull(k)
case "!nil":
qb.NotNull(k)
case "~":
qb.Like(k, v)
case "!~":
qb.NotLike(k, v)
case "in", "!in":
if len(values) == 1 {
values = strings.Split(v, ",")
}
vs := make([]any, len(values))
for i, value := range values {
vs[i] = value
}
if op == "in" {
qb.In(k, vs...)
} else {
qb.NotIn(k, vs...)
}
default:
qb.Eq(key, v)
}
}
}
if paginating {
return
}
if limit == 0 {
limit = 30
qb.Limit(limit)
}
qb.Offset((page - 1) * limit)
return
}
func getID(req any) (val any) {
defer func() {
if recover() != nil {
val = nil
}
}()
val = req.(GetID).GetID()
rv := reflect.ValueOf(val)
if rv.IsZero() || rv.IsNil() {
val = nil
}
return
}
func getValues(req any) map[string]any {
if v, ok := req.(ToMap); ok {
return v.ToMap()
}
return nil
}
func getEntity[T any](request any) *T {
v, ok := request.(ToEntity)
if !ok {
return nil
}
ent, ok := v.ToEntity().(*T)
if ok {
return ent
}
return nil
}
// Controller 控制器基类
//
// 泛型 [Entity] 表示操作的具体数据;
// 泛型 [Upsert] 表示创建或更新时需要的数据。
type Controller[Entity any, Upsert any] struct{}
func (ctr *Controller[Entity, Upsert]) RegisterRoutes(path string, r *echo.Group) {
r.PUT(path, ctr.Create)
r.DELETE(path+"/:id", ctr.Delete)
r.POST(path+"/:id", ctr.Update)
r.GET(path+"/:id", ctr.Get)
r.GET(path, ctr.List)
}
func (ctr *Controller[Entity, Upsert]) ORM(c echo.Context) (*gorm.DB, error) {
orm, err := ioc.Get[gorm.DB]()
if err != nil {
return nil, err
}
return orm.WithContext(c.Request().Context()), nil
}
func (ctr *Controller[Entity, Upsert]) Repository() (*db.Repository[Entity], error) {
orm, err := ioc.Get[gorm.DB]()
if err != nil {
return nil, err
}
return db.NewRepository[Entity](orm), nil
}
// Create 创建数据
func (ctr *Controller[Entity, Upsert]) Create(c echo.Context) error {
return ctr.upsert(c, true)
}
// Update 更新数据
func (ctr *Controller[Entity, Upsert]) Update(c echo.Context) error {
return ctr.upsert(c, false)
}
func (ctr *Controller[Entity, Upsert]) upsert(c echo.Context, isCreate bool) error {
request, err := Bind[Upsert](c)
if err != nil {
return err
}
repo, err := ctr.Repository()
if err != nil {
return err
}
id := getID(request)
if isCreate != reflect.ValueOf(id).IsZero() {
return rsp.BadParams(c, "参数错误")
}
// 更新数据
if !isCreate {
values := getValues(request)
if values == nil {
return rsp.ErrInternal
}
err = repo.UpdateByID(c.Request().Context(), id, values)
if err == nil {
// TODO(hupeh): 返回更新后的实体数据
return rsp.Ok(c, nil)
}
return err
}
// 创建数据
group := getEntity[Entity](request)
if group == nil {
return rsp.ErrInternal
}
err = repo.Create(c.Request().Context(), group)
if err != nil {
return rsp.Created(c, group)
}
return err
}
// Delete 通过ID删除数据
func (ctr *Controller[Entity, Upsert]) Delete(c echo.Context) error {
id, err := BindId(c, true)
if err != nil {
return err
}
repo, err := ctr.Repository()
if err != nil {
return err
}
err = repo.DeleteByID(c.Request().Context(), id)
if err != nil {
return err
}
return rsp.Ok(c, nil)
}
// Get 通过ID获取数据
func (ctr *Controller[Entity, Upsert]) Get(c echo.Context) error {
id, err := BindId(c, true)
if err != nil {
return err
}
repo, err := ctr.Repository()
if err != nil {
return err
}
entity, err := repo.GetByID(c.Request().Context(), id)
if err != nil {
return err
}
return rsp.Ok(c, entity)
}
// List 获取数据列表
func (ctr *Controller[Entity, Upsert]) List(c echo.Context) error {
repo, err := ctr.Repository()
if err != nil {
return err
}
qb := repo.NewQueryBuilder(c.Request().Context())
_, _, err = ParseQuery[Entity](c.QueryParams(), qb)
if err != nil {
return err
}
// 不是分页查询
if !c.QueryParams().Has("page") {
var result []*Entity
err = qb.Find(&result)
if err != nil {
return err
}
return rsp.Ok(c, result)
}
// 分页查询
pager, err := qb.Paginate()
if err != nil {
return err
}
return rsp.Ok(c, pager)
}