You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
215 lines
4.5 KiB
215 lines
4.5 KiB
package crud
|
|
|
|
import (
|
|
"github.com/labstack/echo/v4"
|
|
"gorm.io/gorm"
|
|
"reflect"
|
|
"sorbet/pkg/db"
|
|
"sorbet/pkg/rsp"
|
|
)
|
|
|
|
// Upsertable 数据创建和更新需要接口
|
|
// 用于控制器 Create、Update 方法
|
|
type Upsertable interface {
|
|
// GetID 用于 Controller.Update 方法,
|
|
// 定位被更新的数据编号
|
|
GetID() any
|
|
// ToMap 用于 Controller.Update 方法,
|
|
// 返回被更新的数据,支持零值
|
|
ToMap() map[string]any
|
|
// ToEntity 用于 Controller.Create 方法,
|
|
// 返回被创建的实体的数据
|
|
ToEntity() any
|
|
}
|
|
|
|
func getID(req any) (val any) {
|
|
defer func() {
|
|
if recover() != nil {
|
|
val = nil
|
|
}
|
|
}()
|
|
val = req.(Upsertable).GetID()
|
|
rv := reflect.ValueOf(val)
|
|
if rv.IsZero() || rv.IsNil() {
|
|
val = nil
|
|
}
|
|
return
|
|
}
|
|
|
|
func getValues(req any) map[string]any {
|
|
if v, ok := req.(Upsertable); ok {
|
|
return v.ToMap()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getEntity[T any](request any) *T {
|
|
v, ok := request.(Upsertable)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
if ent, ok := v.ToEntity().(*T); ok {
|
|
return ent
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Controller 控制器基类
|
|
//
|
|
// 泛型 [Entity] 表示操作的具体数据;
|
|
// 泛型 [Upsert] 表示创建或更新时需要的数据。
|
|
type Controller[Entity any, Upsert any] struct{}
|
|
|
|
// RegisterRoutes 注册路由
|
|
func (ctr *Controller[Entity, Upsert]) RegisterRoutes(path string, r *echo.Group) {
|
|
r.GET(path, ctr.List)
|
|
r.GET(path+"/:id", ctr.Get)
|
|
r.PUT(path, ctr.Create)
|
|
r.DELETE(path+"/:id", ctr.Delete)
|
|
r.POST(path+"/:id", ctr.Update)
|
|
}
|
|
|
|
// ORM 获取 gorm.DB 实例
|
|
func (ctr *Controller[Entity, Upsert]) ORM(c echo.Context) (*gorm.DB, error) {
|
|
if orm, ok := c.Get("orm").(*gorm.DB); ok {
|
|
return orm, nil
|
|
}
|
|
return db.DB().WithContext(c.Request().Context()), nil
|
|
}
|
|
|
|
func (ctr *Controller[Entity, Upsert]) MustORM(c echo.Context) *gorm.DB {
|
|
orm, err := ctr.ORM(c)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return orm
|
|
}
|
|
|
|
// Repository 获取 Repository 实例
|
|
func (ctr *Controller[Entity, Upsert]) Repository(c echo.Context) (*db.Repository[Entity], error) {
|
|
orm, err := ctr.ORM(c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return db.NewRepository[Entity](orm), nil
|
|
}
|
|
|
|
func (ctr *Controller[Entity, Upsert]) MustRepository(c echo.Context) *db.Repository[Entity] {
|
|
repository, err := ctr.Repository(c)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return repository
|
|
}
|
|
|
|
// Create 创建数据
|
|
func (ctr *Controller[Entity, Upsert]) Create(c echo.Context) error {
|
|
return ctr.upsert(c, true)
|
|
}
|
|
|
|
// Update 更新数据
|
|
func (ctr *Controller[Entity, Upsert]) Update(c echo.Context) error {
|
|
return ctr.upsert(c, false)
|
|
}
|
|
|
|
func (ctr *Controller[Entity, Upsert]) upsert(c echo.Context, isCreate bool) error {
|
|
request, err := Bind[Upsert](c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
repo, err := ctr.Repository(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
id := getID(request)
|
|
if isCreate != reflect.ValueOf(id).IsZero() {
|
|
return rsp.BadParams(c, "参数错误")
|
|
}
|
|
// 更新数据
|
|
if !isCreate {
|
|
values := getValues(request)
|
|
if values == nil {
|
|
return rsp.ErrInternal
|
|
}
|
|
err = repo.UpdateByID(c.Request().Context(), id, values)
|
|
if err == nil {
|
|
// TODO(hupeh): 返回更新后的实体数据
|
|
return rsp.Ok(c, nil)
|
|
}
|
|
return err
|
|
}
|
|
// 创建数据
|
|
group := getEntity[Entity](request)
|
|
if group == nil {
|
|
return rsp.ErrInternal
|
|
}
|
|
err = repo.Create(c.Request().Context(), group)
|
|
if err != nil {
|
|
return rsp.Created(c, group)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// Delete 通过ID删除数据
|
|
func (ctr *Controller[Entity, Upsert]) Delete(c echo.Context) error {
|
|
id, err := BindID(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
repo, err := ctr.Repository(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = repo.DeleteByID(c.Request().Context(), id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return rsp.Ok(c, nil)
|
|
}
|
|
|
|
// Get 通过ID获取数据
|
|
func (ctr *Controller[Entity, Upsert]) Get(c echo.Context) error {
|
|
id, err := BindID(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
repo, err := ctr.Repository(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
entity, err := repo.GetByID(c.Request().Context(), id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return rsp.Ok(c, entity)
|
|
}
|
|
|
|
// List 获取数据列表
|
|
func (ctr *Controller[Entity, Upsert]) List(c echo.Context) error {
|
|
repo, err := ctr.Repository(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
qb := repo.NewQueryBuilder(c.Request().Context())
|
|
_, _, err = BindQuery[Entity](c, qb)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 不是分页查询
|
|
if !c.QueryParams().Has("page") {
|
|
var result []*Entity
|
|
err = qb.Find(&result)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return rsp.Ok(c, result)
|
|
}
|
|
|
|
// 分页查询
|
|
pager, err := qb.Paginate()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return rsp.Ok(c, pager)
|
|
}
|
|
|