From bb44a788855535aa61ebfe01960b6aede652f0bd Mon Sep 17 00:00:00 2001 From: hupeh Date: Thu, 12 Oct 2023 16:07:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=8F=96=E6=B6=88=E5=8F=8D=E8=BD=AC?= =?UTF-8?q?=E6=8E=A7=E5=88=B6=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/init.go | 19 --- internal/repositories/ioc.go | 27 ----- pkg/app/controller.go | 20 ++- pkg/ioc/container.go | 228 ----------------------------------- pkg/ioc/ioc.go | 96 --------------- pkg/ioc/util.go | 19 --- 6 files changed, 9 insertions(+), 400 deletions(-) delete mode 100644 internal/repositories/ioc.go delete mode 100644 pkg/ioc/container.go delete mode 100644 pkg/ioc/ioc.go delete mode 100644 pkg/ioc/util.go diff --git a/internal/init.go b/internal/init.go index 1948a77..4ac4fcb 100644 --- a/internal/init.go +++ b/internal/init.go @@ -11,17 +11,12 @@ import ( "sorbet/internal/util" "sorbet/pkg/db" "sorbet/pkg/env" - "sorbet/pkg/ioc" "sorbet/pkg/log" "sorbet/pkg/rsp" "sorbet/pkg/ticket" ) func Init() error { - ioc.Bind(db.DB()) // 注入数据库操作 - ioc.Bind(log.Default()) // 注入日志操作 - repositories.Init() // 注入数据仓库操作 - // 同步数据库结构 if err := syncEntities(); err != nil { if !errors.Is(err, db.ErrNoCodeFirst) { @@ -31,7 +26,6 @@ func Init() error { log.Error("同步数据表结构需要开启 [DB_CODE_FIRST],在生产模式下请务必关闭。") } } - return nil } @@ -72,16 +66,6 @@ func Start() error { e.Use(middleware.Recover()) e.Use(middleware.CORS()) e.Use(middleware.Logger) - e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - db := ioc.MustGet[gorm.DB]().WithContext(c.Request().Context()) - ci := ioc.Fork() - ci.Bind(db) - c.Set("db", db) - c.Set("ioc", ci) - return next(c) - } - }) e.GET("/", func(c echo.Context) error { repo := repositories.NewCompanyRepository(c.Get("db").(*gorm.DB)) company, err := repo.GetByID(c.Request().Context(), 1) @@ -101,15 +85,12 @@ func Start() error { "token": token, }) }) - e.Group("", middleware.Ticket(false, "system")).GET("/u", func(c echo.Context) error { return rsp.Ok(c, echo.Map{ "ticket": c.Get("ticket"), "claims": c.Get("ticket_claims").(*ticket.Claims), }) }) - e.Logger.Fatal(e.Start(":1323")) - return nil } diff --git a/internal/repositories/ioc.go b/internal/repositories/ioc.go deleted file mode 100644 index 81b3fa7..0000000 --- a/internal/repositories/ioc.go +++ /dev/null @@ -1,27 +0,0 @@ -package repositories - -import ( - "sorbet/pkg/ioc" -) - -func Init() { - ioc.MustFactory(NewCompanyRepository) - ioc.MustFactory(NewCompanyDepartmentRepository) - ioc.MustFactory(NewCompanyStaffRepository) - ioc.MustFactory(NewConfigRepository) - ioc.MustFactory(NewConfigGroupRepository) - ioc.MustFactory(NewFeatureRepository) - ioc.MustFactory(NewFeatureCategoryRepository) - ioc.MustFactory(NewFeatureConfigRepository) - ioc.MustFactory(NewFeatureContentRepository) - ioc.MustFactory(NewFeatureContentChapterRepository) - ioc.MustFactory(NewFeatureContentDetailRepository) - ioc.MustFactory(NewResourceRepository) - ioc.MustFactory(NewResourceCategoryRepository) - ioc.MustFactory(NewSystemLogRepository) - ioc.MustFactory(NewSystemMenuRepository) - ioc.MustFactory(NewSystemPermissionRepository) - ioc.MustFactory(NewSystemRoleRepository) - ioc.MustFactory(NewSystemRolePowerRepository) - ioc.MustFactory(NewSystemUserRepository) -} diff --git a/pkg/app/controller.go b/pkg/app/controller.go index 8e5b19c..76f6c1d 100644 --- a/pkg/app/controller.go +++ b/pkg/app/controller.go @@ -5,7 +5,6 @@ import ( "gorm.io/gorm" "reflect" "sorbet/pkg/db" - "sorbet/pkg/ioc" "sorbet/pkg/rsp" ) @@ -72,16 +71,15 @@ func (ctr *Controller[Entity, Upsert]) RegisterRoutes(path string, r *echo.Group // ORM 获取 gorm.DB 实例 func (ctr *Controller[Entity, Upsert]) ORM(c echo.Context) (*gorm.DB, error) { - orm, err := ioc.Get[gorm.DB]() - if err != nil { - return nil, err + if orm, ok := c.Get("orm").(*gorm.DB); ok { + return orm, nil } - return orm.WithContext(c.Request().Context()), nil + return db.DB().WithContext(c.Request().Context()), nil } // Repository 获取 Repository 实例 -func (ctr *Controller[Entity, Upsert]) Repository() (*db.Repository[Entity], error) { - orm, err := ioc.Get[gorm.DB]() +func (ctr *Controller[Entity, Upsert]) Repository(c echo.Context) (*db.Repository[Entity], error) { + orm, err := ctr.ORM(c) if err != nil { return nil, err } @@ -103,7 +101,7 @@ func (ctr *Controller[Entity, Upsert]) upsert(c echo.Context, isCreate bool) err if err != nil { return err } - repo, err := ctr.Repository() + repo, err := ctr.Repository(c) if err != nil { return err } @@ -142,7 +140,7 @@ func (ctr *Controller[Entity, Upsert]) Delete(c echo.Context) error { if err != nil { return err } - repo, err := ctr.Repository() + repo, err := ctr.Repository(c) if err != nil { return err } @@ -159,7 +157,7 @@ func (ctr *Controller[Entity, Upsert]) Get(c echo.Context) error { if err != nil { return err } - repo, err := ctr.Repository() + repo, err := ctr.Repository(c) if err != nil { return err } @@ -172,7 +170,7 @@ func (ctr *Controller[Entity, Upsert]) Get(c echo.Context) error { // List 获取数据列表 func (ctr *Controller[Entity, Upsert]) List(c echo.Context) error { - repo, err := ctr.Repository() + repo, err := ctr.Repository(c) if err != nil { return err } diff --git a/pkg/ioc/container.go b/pkg/ioc/container.go deleted file mode 100644 index dc0128c..0000000 --- a/pkg/ioc/container.go +++ /dev/null @@ -1,228 +0,0 @@ -package ioc - -import ( - "errors" - "fmt" - "reflect" -) - -type binding struct { - name string - typ reflect.Type - resolver any - shared bool -} - -func (b *binding) make(c *Container) (reflect.Value, error) { - if v, exists := c.instances[b.typ][b.name]; exists { - return v, nil - } - val, err := c.Invoke(b.resolver) - if err != nil { - return reflect.Value{}, err - } - rv := val[0] - if len(val) == 2 { - err = val[1].Interface().(error) - if err != nil { - return reflect.Value{}, err - } - } - if b.shared { - if _, exists := c.instances[b.typ]; !exists { - c.instances[b.typ] = make(map[string]reflect.Value) - } - c.instances[b.typ][b.name] = rv - } - return rv, nil -} - -func (b *binding) mustMake(c *Container) reflect.Value { - val, err := b.make(c) - if err != nil { - panic(err) - } - return val -} - -type Container struct { - // 注册的工厂函数 - factories map[reflect.Type]map[string]*binding - // 注册的共享实例 - instances map[reflect.Type]map[string]reflect.Value - parent *Container -} - -func New() *Container { - return &Container{ - factories: make(map[reflect.Type]map[string]*binding), - instances: make(map[reflect.Type]map[string]reflect.Value), - parent: nil, - } -} - -// Fork 分支 -func (c *Container) Fork() *Container { - ioc := New() - ioc.parent = c - return ioc -} - -// Bind 绑定值到容器,有效类型: -// - 接口的具体实现值 -// - 结构体的实例 -// - 类型的值(尽量不要使用原始类型,而应该使用元素类型的变体) -func (c *Container) Bind(instance any) { - c.NamedBind("", instance) -} - -// NamedBind 绑定具名值到容器 -func (c *Container) NamedBind(name string, instance any) { - //typ := InterfaceOf(instance) - typ := reflect.TypeOf(instance) - if _, ok := c.instances[typ]; !ok { - c.instances[typ] = make(map[string]reflect.Value) - } - c.instances[typ][name] = reflect.ValueOf(instance) -} - -// Factory 绑定工厂函数 -func (c *Container) Factory(factory any, shared ...bool) error { - return c.NamedFactory("", factory, shared...) -} - -// NamedFactory 绑定具名工厂函数 -func (c *Container) NamedFactory(name string, factory any, shared ...bool) error { - reflectedFactory := reflect.TypeOf(factory) - if reflectedFactory.Kind() != reflect.Func { - return errors.New("container: the factory must be a function") - } - if returnCount := reflectedFactory.NumOut(); returnCount == 0 || returnCount > 2 { - return errors.New("container: factory function signature is invalid - it must return abstract, or abstract and error") - } - // TODO(hupeh): 验证第二个参数必须是 error 接口 - concreteType := reflectedFactory.Out(0) - for i := 0; i < reflectedFactory.NumIn(); i++ { - // 循环依赖 - if reflectedFactory.In(i) == concreteType { - return fmt.Errorf("container: factory function signature is invalid - depends on abstract it returns") - } - } - if _, exists := c.factories[concreteType]; !exists { - c.factories[concreteType] = make(map[string]*binding) - } - bd := &binding{ - name: name, - typ: concreteType, - resolver: factory, - shared: false, - } - for _, b := range shared { - bd.shared = b - } - c.factories[concreteType][name] = bd - return nil -} - -// Resolve 完成的注入 -func (c *Container) Resolve(i any) error { - v := reflect.ValueOf(i) - for v.Kind() == reflect.Ptr { - v = v.Elem() - } - if v.Kind() != reflect.Struct { - return errors.New("must given a struct") - } - t := v.Type() - for i := 0; i < v.NumField(); i++ { - f := v.Field(i) - structField := t.Field(i) - inject, willInject := structField.Tag.Lookup("inject") - if !f.CanSet() { - if willInject { - return fmt.Errorf("container: cannot make %v field", t.Field(i).Name) - } - continue - } - ft := f.Type() - fv := c.NamedGet(inject, ft) - if !fv.IsValid() { - return fmt.Errorf("value not found for type %v", ft) - } - f.Set(fv) - } - return nil -} - -// Get 获取指定类型的值 -func (c *Container) Get(t reflect.Type) reflect.Value { - return c.NamedGet("", t) -} - -// NamedGet 通过注入的名称获取指定类型的值 -func (c *Container) NamedGet(name string, t reflect.Type) reflect.Value { - val, exists := c.instances[t][name] - - if exists && val.IsValid() { - return val - } - - if factory, exists := c.factories[t][name]; exists { - val = factory.mustMake(c) - } - - if val.IsValid() || t.Kind() != reflect.Interface { - goto RESULT - } - - // 使用共享值里面该接口的实现者 - for k, v := range c.instances { - if k.Implements(t) { - for _, value := range v { - if value.IsValid() { - val = value - goto RESULT - } - } - break - } - } - - // 使用工厂函数里面该接口的实现者 - for k, v := range c.factories { - if k.Implements(t) { - for _, bd := range v { - if x := bd.mustMake(c); x.IsValid() { - val = x - goto RESULT - } - } - break - } - } - -RESULT: - if !val.IsValid() && c.parent != nil { - val = c.parent.NamedGet(name, t) - } - return val -} - -// Invoke 执行函数 -func (c *Container) Invoke(f any) ([]reflect.Value, error) { - t := reflect.TypeOf(f) - if t.Kind() != reflect.Func { - return nil, errors.New("container: invalid function") - } - - var in = make([]reflect.Value, t.NumIn()) //Panic if t is not kind of Func - for i := 0; i < t.NumIn(); i++ { - argType := t.In(i) - val := c.Get(argType) - if !val.IsValid() { - return nil, fmt.Errorf("value not found for type %v", argType) - } - in[i] = val - } - return reflect.ValueOf(f).Call(in), nil -} diff --git a/pkg/ioc/ioc.go b/pkg/ioc/ioc.go deleted file mode 100644 index 8b71253..0000000 --- a/pkg/ioc/ioc.go +++ /dev/null @@ -1,96 +0,0 @@ -package ioc - -import ( - "errors" - "reflect" -) - -var ( - ErrValueNotFound = errors.New("ioc: value not found") -) - -var global = New() - -// Fork 分支 -func Fork() *Container { - return global.Fork() -} - -// Bind 绑定值到容器,有效类型: -// -// - 接口的具体实现值 -// - 结构体的实例 -// - 类型的值(尽量不要使用原始类型,而应该使用元素类型的变体) -func Bind(instance any) { - global.Bind(instance) -} - -// NamedBind 绑定具名值到容器 -func NamedBind(name string, instance any) { - global.NamedBind(name, instance) -} - -// Factory 绑定工厂函数 -func Factory(factory any, shared ...bool) error { - return global.Factory(factory, shared...) -} - -func MustFactory(factory any, shared ...bool) { - err := Factory(factory, shared...) - if err != nil { - panic(err) - } -} - -// NamedFactory 绑定具名工厂函数 -func NamedFactory(name string, factory any, shared ...bool) error { - return global.NamedFactory(name, factory, shared...) -} - -func MustNamedFactory(name string, factory any, shared ...bool) { - err := NamedFactory(name, factory, shared...) - if err != nil { - panic(err) - } -} - -// Resolve 完成的注入 -func Resolve(i any) error { - return global.Resolve(i) -} - -// Get 获取指定类型的值 -func Get[T any]() (*T, error) { - return NamedGet[T]("") -} - -func MustGet[T any]() *T { - return MustNamedGet[T]("") -} - -// NamedGet 通过注入的名称获取指定类型的值 -func NamedGet[T any](name string) (*T, error) { - var abs T - t := reflect.TypeOf(&abs) - v := global.NamedGet(name, t) - if !v.IsValid() { - return nil, ErrValueNotFound - } - if x, ok := v.Interface().(*T); ok { - return x, nil - } - return nil, ErrValueNotFound -} - -func MustNamedGet[T any](name string) *T { - v, err := NamedGet[T](name) - if err != nil { - panic(err) - } - return v -} - -// Invoke 执行函数 -func Invoke(f any) ([]reflect.Value, error) { - return global.Invoke(f) -} diff --git a/pkg/ioc/util.go b/pkg/ioc/util.go deleted file mode 100644 index 28fb498..0000000 --- a/pkg/ioc/util.go +++ /dev/null @@ -1,19 +0,0 @@ -package ioc - -import "reflect" - -// InterfaceOf dereferences a pointer to an Interface type. -// It panics if value is not a pointer to an interface. -func InterfaceOf(value interface{}) reflect.Type { - t := reflect.TypeOf(value) - - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - - if t.Kind() != reflect.Interface { - panic("the value is not a pointer to an interface. (*MyInterface)(nil)") - } - - return t -}