package ioc import ( "errors" "fmt" "reflect" ) type binding struct { name string typ reflect.Type resolver any shared bool } func (b *binding) make(c *Container) (reflect.Value, error) { if v, exists := c.instances[b.typ][b.name]; exists { return v, nil } val, err := c.Invoke(b.resolver) if err != nil { return reflect.Value{}, err } rv := val[0] if len(val) == 2 { err = val[1].Interface().(error) if err != nil { return reflect.Value{}, err } } if b.shared { if _, exists := c.instances[b.typ]; !exists { c.instances[b.typ] = make(map[string]reflect.Value) } c.instances[b.typ][b.name] = rv } return rv, nil } func (b *binding) mustMake(c *Container) reflect.Value { val, err := b.make(c) if err != nil { panic(err) } return val } type Container struct { // 注册的工厂函数 factories map[reflect.Type]map[string]*binding // 注册的共享实例 instances map[reflect.Type]map[string]reflect.Value parent *Container } func New() *Container { return &Container{ factories: make(map[reflect.Type]map[string]*binding), instances: make(map[reflect.Type]map[string]reflect.Value), parent: nil, } } // Fork 分支 func (c *Container) Fork() *Container { ioc := New() ioc.parent = c return ioc } // Bind 绑定值到容器,有效类型: // - 接口的具体实现值 // - 结构体的实例 // - 类型的值(尽量不要使用原始类型,而应该使用元素类型的变体) func (c *Container) Bind(instance any) { c.NamedBind("", instance) } // NamedBind 绑定具名值到容器 func (c *Container) NamedBind(name string, instance any) { //typ := InterfaceOf(instance) typ := reflect.TypeOf(instance) if _, ok := c.instances[typ]; !ok { c.instances[typ] = make(map[string]reflect.Value) } c.instances[typ][name] = reflect.ValueOf(instance) } // Factory 绑定工厂函数 func (c *Container) Factory(factory any, shared ...bool) error { return c.NamedFactory("", factory, shared...) } // NamedFactory 绑定具名工厂函数 func (c *Container) NamedFactory(name string, factory any, shared ...bool) error { reflectedFactory := reflect.TypeOf(factory) if reflectedFactory.Kind() != reflect.Func { return errors.New("container: the factory must be a function") } if returnCount := reflectedFactory.NumOut(); returnCount == 0 || returnCount > 2 { return errors.New("container: factory function signature is invalid - it must return abstract, or abstract and error") } // TODO(hupeh): 验证第二个参数必须是 error 接口 concreteType := reflectedFactory.Out(0) for i := 0; i < reflectedFactory.NumIn(); i++ { // 循环依赖 if reflectedFactory.In(i) == concreteType { return fmt.Errorf("container: factory function signature is invalid - depends on abstract it returns") } } if _, exists := c.factories[concreteType]; !exists { c.factories[concreteType] = make(map[string]*binding) } bd := &binding{ name: name, typ: concreteType, resolver: factory, shared: false, } for _, b := range shared { bd.shared = b } c.factories[concreteType][name] = bd return nil } // Resolve 完成的注入 func (c *Container) Resolve(i any) error { v := reflect.ValueOf(i) for v.Kind() == reflect.Ptr { v = v.Elem() } if v.Kind() != reflect.Struct { return errors.New("must given a struct") } t := v.Type() for i := 0; i < v.NumField(); i++ { f := v.Field(i) structField := t.Field(i) inject, willInject := structField.Tag.Lookup("inject") if !f.CanSet() { if willInject { return fmt.Errorf("container: cannot make %v field", t.Field(i).Name) } continue } ft := f.Type() fv := c.NamedGet(inject, ft) if !fv.IsValid() { return fmt.Errorf("value not found for type %v", ft) } f.Set(fv) } return nil } // Get 获取指定类型的值 func (c *Container) Get(t reflect.Type) reflect.Value { return c.NamedGet("", t) } // NamedGet 通过注入的名称获取指定类型的值 func (c *Container) NamedGet(name string, t reflect.Type) reflect.Value { val, exists := c.instances[t][name] if exists && val.IsValid() { return val } if factory, exists := c.factories[t][name]; exists { val = factory.mustMake(c) } if val.IsValid() || t.Kind() != reflect.Interface { goto RESULT } // 使用共享值里面该接口的实现者 for k, v := range c.instances { if k.Implements(t) { for _, value := range v { if value.IsValid() { val = value goto RESULT } } break } } // 使用工厂函数里面该接口的实现者 for k, v := range c.factories { if k.Implements(t) { for _, bd := range v { if x := bd.mustMake(c); x.IsValid() { val = x goto RESULT } } break } } RESULT: if !val.IsValid() && c.parent != nil { val = c.parent.NamedGet(name, t) } return val } // Invoke 执行函数 func (c *Container) Invoke(f any) ([]reflect.Value, error) { t := reflect.TypeOf(f) if t.Kind() != reflect.Func { return nil, errors.New("container: invalid function") } var in = make([]reflect.Value, t.NumIn()) //Panic if t is not kind of Func for i := 0; i < t.NumIn(); i++ { argType := t.In(i) val := c.Get(argType) if !val.IsValid() { return nil, fmt.Errorf("value not found for type %v", argType) } in[i] = val } return reflect.ValueOf(f).Call(in), nil }