feat: 取消反转控制功能

main
熊二 1 year ago
parent 34d8a15f06
commit bb44a78885
  1. 19
      internal/init.go
  2. 27
      internal/repositories/ioc.go
  3. 20
      pkg/app/controller.go
  4. 228
      pkg/ioc/container.go
  5. 96
      pkg/ioc/ioc.go
  6. 19
      pkg/ioc/util.go

@ -11,17 +11,12 @@ import (
"sorbet/internal/util"
"sorbet/pkg/db"
"sorbet/pkg/env"
"sorbet/pkg/ioc"
"sorbet/pkg/log"
"sorbet/pkg/rsp"
"sorbet/pkg/ticket"
)
func Init() error {
ioc.Bind(db.DB()) // 注入数据库操作
ioc.Bind(log.Default()) // 注入日志操作
repositories.Init() // 注入数据仓库操作
// 同步数据库结构
if err := syncEntities(); err != nil {
if !errors.Is(err, db.ErrNoCodeFirst) {
@ -31,7 +26,6 @@ func Init() error {
log.Error("同步数据表结构需要开启 [DB_CODE_FIRST],在生产模式下请务必关闭。")
}
}
return nil
}
@ -72,16 +66,6 @@ func Start() error {
e.Use(middleware.Recover())
e.Use(middleware.CORS())
e.Use(middleware.Logger)
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
db := ioc.MustGet[gorm.DB]().WithContext(c.Request().Context())
ci := ioc.Fork()
ci.Bind(db)
c.Set("db", db)
c.Set("ioc", ci)
return next(c)
}
})
e.GET("/", func(c echo.Context) error {
repo := repositories.NewCompanyRepository(c.Get("db").(*gorm.DB))
company, err := repo.GetByID(c.Request().Context(), 1)
@ -101,15 +85,12 @@ func Start() error {
"token": token,
})
})
e.Group("", middleware.Ticket(false, "system")).GET("/u", func(c echo.Context) error {
return rsp.Ok(c, echo.Map{
"ticket": c.Get("ticket"),
"claims": c.Get("ticket_claims").(*ticket.Claims),
})
})
e.Logger.Fatal(e.Start(":1323"))
return nil
}

@ -1,27 +0,0 @@
package repositories
import (
"sorbet/pkg/ioc"
)
func Init() {
ioc.MustFactory(NewCompanyRepository)
ioc.MustFactory(NewCompanyDepartmentRepository)
ioc.MustFactory(NewCompanyStaffRepository)
ioc.MustFactory(NewConfigRepository)
ioc.MustFactory(NewConfigGroupRepository)
ioc.MustFactory(NewFeatureRepository)
ioc.MustFactory(NewFeatureCategoryRepository)
ioc.MustFactory(NewFeatureConfigRepository)
ioc.MustFactory(NewFeatureContentRepository)
ioc.MustFactory(NewFeatureContentChapterRepository)
ioc.MustFactory(NewFeatureContentDetailRepository)
ioc.MustFactory(NewResourceRepository)
ioc.MustFactory(NewResourceCategoryRepository)
ioc.MustFactory(NewSystemLogRepository)
ioc.MustFactory(NewSystemMenuRepository)
ioc.MustFactory(NewSystemPermissionRepository)
ioc.MustFactory(NewSystemRoleRepository)
ioc.MustFactory(NewSystemRolePowerRepository)
ioc.MustFactory(NewSystemUserRepository)
}

@ -5,7 +5,6 @@ import (
"gorm.io/gorm"
"reflect"
"sorbet/pkg/db"
"sorbet/pkg/ioc"
"sorbet/pkg/rsp"
)
@ -72,16 +71,15 @@ func (ctr *Controller[Entity, Upsert]) RegisterRoutes(path string, r *echo.Group
// ORM 获取 gorm.DB 实例
func (ctr *Controller[Entity, Upsert]) ORM(c echo.Context) (*gorm.DB, error) {
orm, err := ioc.Get[gorm.DB]()
if err != nil {
return nil, err
if orm, ok := c.Get("orm").(*gorm.DB); ok {
return orm, nil
}
return orm.WithContext(c.Request().Context()), nil
return db.DB().WithContext(c.Request().Context()), nil
}
// Repository 获取 Repository 实例
func (ctr *Controller[Entity, Upsert]) Repository() (*db.Repository[Entity], error) {
orm, err := ioc.Get[gorm.DB]()
func (ctr *Controller[Entity, Upsert]) Repository(c echo.Context) (*db.Repository[Entity], error) {
orm, err := ctr.ORM(c)
if err != nil {
return nil, err
}
@ -103,7 +101,7 @@ func (ctr *Controller[Entity, Upsert]) upsert(c echo.Context, isCreate bool) err
if err != nil {
return err
}
repo, err := ctr.Repository()
repo, err := ctr.Repository(c)
if err != nil {
return err
}
@ -142,7 +140,7 @@ func (ctr *Controller[Entity, Upsert]) Delete(c echo.Context) error {
if err != nil {
return err
}
repo, err := ctr.Repository()
repo, err := ctr.Repository(c)
if err != nil {
return err
}
@ -159,7 +157,7 @@ func (ctr *Controller[Entity, Upsert]) Get(c echo.Context) error {
if err != nil {
return err
}
repo, err := ctr.Repository()
repo, err := ctr.Repository(c)
if err != nil {
return err
}
@ -172,7 +170,7 @@ func (ctr *Controller[Entity, Upsert]) Get(c echo.Context) error {
// List 获取数据列表
func (ctr *Controller[Entity, Upsert]) List(c echo.Context) error {
repo, err := ctr.Repository()
repo, err := ctr.Repository(c)
if err != nil {
return err
}

@ -1,228 +0,0 @@
package ioc
import (
"errors"
"fmt"
"reflect"
)
type binding struct {
name string
typ reflect.Type
resolver any
shared bool
}
func (b *binding) make(c *Container) (reflect.Value, error) {
if v, exists := c.instances[b.typ][b.name]; exists {
return v, nil
}
val, err := c.Invoke(b.resolver)
if err != nil {
return reflect.Value{}, err
}
rv := val[0]
if len(val) == 2 {
err = val[1].Interface().(error)
if err != nil {
return reflect.Value{}, err
}
}
if b.shared {
if _, exists := c.instances[b.typ]; !exists {
c.instances[b.typ] = make(map[string]reflect.Value)
}
c.instances[b.typ][b.name] = rv
}
return rv, nil
}
func (b *binding) mustMake(c *Container) reflect.Value {
val, err := b.make(c)
if err != nil {
panic(err)
}
return val
}
type Container struct {
// 注册的工厂函数
factories map[reflect.Type]map[string]*binding
// 注册的共享实例
instances map[reflect.Type]map[string]reflect.Value
parent *Container
}
func New() *Container {
return &Container{
factories: make(map[reflect.Type]map[string]*binding),
instances: make(map[reflect.Type]map[string]reflect.Value),
parent: nil,
}
}
// Fork 分支
func (c *Container) Fork() *Container {
ioc := New()
ioc.parent = c
return ioc
}
// Bind 绑定值到容器,有效类型:
// - 接口的具体实现值
// - 结构体的实例
// - 类型的值(尽量不要使用原始类型,而应该使用元素类型的变体)
func (c *Container) Bind(instance any) {
c.NamedBind("", instance)
}
// NamedBind 绑定具名值到容器
func (c *Container) NamedBind(name string, instance any) {
//typ := InterfaceOf(instance)
typ := reflect.TypeOf(instance)
if _, ok := c.instances[typ]; !ok {
c.instances[typ] = make(map[string]reflect.Value)
}
c.instances[typ][name] = reflect.ValueOf(instance)
}
// Factory 绑定工厂函数
func (c *Container) Factory(factory any, shared ...bool) error {
return c.NamedFactory("", factory, shared...)
}
// NamedFactory 绑定具名工厂函数
func (c *Container) NamedFactory(name string, factory any, shared ...bool) error {
reflectedFactory := reflect.TypeOf(factory)
if reflectedFactory.Kind() != reflect.Func {
return errors.New("container: the factory must be a function")
}
if returnCount := reflectedFactory.NumOut(); returnCount == 0 || returnCount > 2 {
return errors.New("container: factory function signature is invalid - it must return abstract, or abstract and error")
}
// TODO(hupeh): 验证第二个参数必须是 error 接口
concreteType := reflectedFactory.Out(0)
for i := 0; i < reflectedFactory.NumIn(); i++ {
// 循环依赖
if reflectedFactory.In(i) == concreteType {
return fmt.Errorf("container: factory function signature is invalid - depends on abstract it returns")
}
}
if _, exists := c.factories[concreteType]; !exists {
c.factories[concreteType] = make(map[string]*binding)
}
bd := &binding{
name: name,
typ: concreteType,
resolver: factory,
shared: false,
}
for _, b := range shared {
bd.shared = b
}
c.factories[concreteType][name] = bd
return nil
}
// Resolve 完成的注入
func (c *Container) Resolve(i any) error {
v := reflect.ValueOf(i)
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return errors.New("must given a struct")
}
t := v.Type()
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
structField := t.Field(i)
inject, willInject := structField.Tag.Lookup("inject")
if !f.CanSet() {
if willInject {
return fmt.Errorf("container: cannot make %v field", t.Field(i).Name)
}
continue
}
ft := f.Type()
fv := c.NamedGet(inject, ft)
if !fv.IsValid() {
return fmt.Errorf("value not found for type %v", ft)
}
f.Set(fv)
}
return nil
}
// Get 获取指定类型的值
func (c *Container) Get(t reflect.Type) reflect.Value {
return c.NamedGet("", t)
}
// NamedGet 通过注入的名称获取指定类型的值
func (c *Container) NamedGet(name string, t reflect.Type) reflect.Value {
val, exists := c.instances[t][name]
if exists && val.IsValid() {
return val
}
if factory, exists := c.factories[t][name]; exists {
val = factory.mustMake(c)
}
if val.IsValid() || t.Kind() != reflect.Interface {
goto RESULT
}
// 使用共享值里面该接口的实现者
for k, v := range c.instances {
if k.Implements(t) {
for _, value := range v {
if value.IsValid() {
val = value
goto RESULT
}
}
break
}
}
// 使用工厂函数里面该接口的实现者
for k, v := range c.factories {
if k.Implements(t) {
for _, bd := range v {
if x := bd.mustMake(c); x.IsValid() {
val = x
goto RESULT
}
}
break
}
}
RESULT:
if !val.IsValid() && c.parent != nil {
val = c.parent.NamedGet(name, t)
}
return val
}
// Invoke 执行函数
func (c *Container) Invoke(f any) ([]reflect.Value, error) {
t := reflect.TypeOf(f)
if t.Kind() != reflect.Func {
return nil, errors.New("container: invalid function")
}
var in = make([]reflect.Value, t.NumIn()) //Panic if t is not kind of Func
for i := 0; i < t.NumIn(); i++ {
argType := t.In(i)
val := c.Get(argType)
if !val.IsValid() {
return nil, fmt.Errorf("value not found for type %v", argType)
}
in[i] = val
}
return reflect.ValueOf(f).Call(in), nil
}

@ -1,96 +0,0 @@
package ioc
import (
"errors"
"reflect"
)
var (
ErrValueNotFound = errors.New("ioc: value not found")
)
var global = New()
// Fork 分支
func Fork() *Container {
return global.Fork()
}
// Bind 绑定值到容器,有效类型:
//
// - 接口的具体实现值
// - 结构体的实例
// - 类型的值(尽量不要使用原始类型,而应该使用元素类型的变体)
func Bind(instance any) {
global.Bind(instance)
}
// NamedBind 绑定具名值到容器
func NamedBind(name string, instance any) {
global.NamedBind(name, instance)
}
// Factory 绑定工厂函数
func Factory(factory any, shared ...bool) error {
return global.Factory(factory, shared...)
}
func MustFactory(factory any, shared ...bool) {
err := Factory(factory, shared...)
if err != nil {
panic(err)
}
}
// NamedFactory 绑定具名工厂函数
func NamedFactory(name string, factory any, shared ...bool) error {
return global.NamedFactory(name, factory, shared...)
}
func MustNamedFactory(name string, factory any, shared ...bool) {
err := NamedFactory(name, factory, shared...)
if err != nil {
panic(err)
}
}
// Resolve 完成的注入
func Resolve(i any) error {
return global.Resolve(i)
}
// Get 获取指定类型的值
func Get[T any]() (*T, error) {
return NamedGet[T]("")
}
func MustGet[T any]() *T {
return MustNamedGet[T]("")
}
// NamedGet 通过注入的名称获取指定类型的值
func NamedGet[T any](name string) (*T, error) {
var abs T
t := reflect.TypeOf(&abs)
v := global.NamedGet(name, t)
if !v.IsValid() {
return nil, ErrValueNotFound
}
if x, ok := v.Interface().(*T); ok {
return x, nil
}
return nil, ErrValueNotFound
}
func MustNamedGet[T any](name string) *T {
v, err := NamedGet[T](name)
if err != nil {
panic(err)
}
return v
}
// Invoke 执行函数
func Invoke(f any) ([]reflect.Value, error) {
return global.Invoke(f)
}

@ -1,19 +0,0 @@
package ioc
import "reflect"
// InterfaceOf dereferences a pointer to an Interface type.
// It panics if value is not a pointer to an interface.
func InterfaceOf(value interface{}) reflect.Type {
t := reflect.TypeOf(value)
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Interface {
panic("the value is not a pointer to an interface. (*MyInterface)(nil)")
}
return t
}
Loading…
Cancel
Save