go项目脚手架
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
sorbet/pkg/crud/controller.go

215 lines
4.5 KiB

package crud
import (
"github.com/labstack/echo/v4"
"gorm.io/gorm"
"reflect"
"sorbet/pkg/db"
"sorbet/pkg/rsp"
)
// Upsertable 数据创建和更新需要接口
// 用于控制器 Create、Update 方法
type Upsertable interface {
// GetID 用于 Controller.Update 方法,
// 定位被更新的数据编号
GetID() any
// ToMap 用于 Controller.Update 方法,
// 返回被更新的数据,支持零值
ToMap() map[string]any
// ToEntity 用于 Controller.Create 方法,
// 返回被创建的实体的数据
ToEntity() any
}
func getID(req any) (val any) {
defer func() {
if recover() != nil {
val = nil
}
}()
val = req.(Upsertable).GetID()
rv := reflect.ValueOf(val)
if rv.IsZero() || rv.IsNil() {
val = nil
}
return
}
func getValues(req any) map[string]any {
if v, ok := req.(Upsertable); ok {
return v.ToMap()
}
return nil
}
func getEntity[T any](request any) *T {
v, ok := request.(Upsertable)
if !ok {
return nil
}
if ent, ok := v.ToEntity().(*T); ok {
return ent
}
return nil
}
// Controller 控制器基类
//
// 泛型 [Entity] 表示操作的具体数据;
// 泛型 [Upsert] 表示创建或更新时需要的数据。
type Controller[Entity any, Upsert any] struct{}
// RegisterRoutes 注册路由
func (ctr *Controller[Entity, Upsert]) RegisterRoutes(path string, r *echo.Group) {
r.GET(path, ctr.List)
r.GET(path+"/:id", ctr.Get)
r.PUT(path, ctr.Create)
r.DELETE(path+"/:id", ctr.Delete)
r.POST(path+"/:id", ctr.Update)
}
// ORM 获取 gorm.DB 实例
func (ctr *Controller[Entity, Upsert]) ORM(c echo.Context) (*gorm.DB, error) {
if orm, ok := c.Get("orm").(*gorm.DB); ok {
return orm, nil
}
return db.DB().WithContext(c.Request().Context()), nil
}
func (ctr *Controller[Entity, Upsert]) MustORM(c echo.Context) *gorm.DB {
orm, err := ctr.ORM(c)
if err != nil {
panic(err)
}
return orm
}
// Repository 获取 Repository 实例
func (ctr *Controller[Entity, Upsert]) Repository(c echo.Context) (*db.Repository[Entity], error) {
orm, err := ctr.ORM(c)
if err != nil {
return nil, err
}
return db.NewRepository[Entity](orm), nil
}
func (ctr *Controller[Entity, Upsert]) MustRepository(c echo.Context) *db.Repository[Entity] {
repository, err := ctr.Repository(c)
if err != nil {
panic(err)
}
return repository
}
// Create 创建数据
func (ctr *Controller[Entity, Upsert]) Create(c echo.Context) error {
return ctr.upsert(c, true)
}
// Update 更新数据
func (ctr *Controller[Entity, Upsert]) Update(c echo.Context) error {
return ctr.upsert(c, false)
}
func (ctr *Controller[Entity, Upsert]) upsert(c echo.Context, isCreate bool) error {
request, err := Bind[Upsert](c)
if err != nil {
return err
}
repo, err := ctr.Repository(c)
if err != nil {
return err
}
id := getID(request)
if isCreate != reflect.ValueOf(id).IsZero() {
return rsp.BadParams(c, "参数错误")
}
// 更新数据
if !isCreate {
values := getValues(request)
if values == nil {
return rsp.ErrInternal
}
err = repo.UpdateByID(c.Request().Context(), id, values)
if err == nil {
// TODO(hupeh): 返回更新后的实体数据
return rsp.Ok(c, nil)
}
return err
}
// 创建数据
group := getEntity[Entity](request)
if group == nil {
return rsp.ErrInternal
}
err = repo.Create(c.Request().Context(), group)
if err != nil {
return rsp.Created(c, group)
}
return err
}
// Delete 通过ID删除数据
func (ctr *Controller[Entity, Upsert]) Delete(c echo.Context) error {
id, err := BindID(c)
if err != nil {
return err
}
repo, err := ctr.Repository(c)
if err != nil {
return err
}
err = repo.DeleteByID(c.Request().Context(), id)
if err != nil {
return err
}
return rsp.Ok(c, nil)
}
// Get 通过ID获取数据
func (ctr *Controller[Entity, Upsert]) Get(c echo.Context) error {
id, err := BindID(c)
if err != nil {
return err
}
repo, err := ctr.Repository(c)
if err != nil {
return err
}
entity, err := repo.GetByID(c.Request().Context(), id)
if err != nil {
return err
}
return rsp.Ok(c, entity)
}
// List 获取数据列表
func (ctr *Controller[Entity, Upsert]) List(c echo.Context) error {
repo, err := ctr.Repository(c)
if err != nil {
return err
}
qb := repo.NewQueryBuilder(c.Request().Context())
_, _, err = BindQuery[Entity](c, qb)
if err != nil {
return err
}
// 不是分页查询
if !c.QueryParams().Has("page") {
var result []*Entity
err = qb.Find(&result)
if err != nil {
return err
}
return rsp.Ok(c, result)
}
// 分页查询
pager, err := qb.Paginate()
if err != nil {
return err
}
return rsp.Ok(c, pager)
}