You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
229 lines
5.2 KiB
229 lines
5.2 KiB
1 year ago
|
package ioc
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
)
|
||
|
|
||
|
type binding struct {
|
||
|
name string
|
||
|
typ reflect.Type
|
||
|
resolver any
|
||
|
shared bool
|
||
|
}
|
||
|
|
||
|
func (b *binding) make(c *Container) (reflect.Value, error) {
|
||
|
if v, exists := c.instances[b.typ][b.name]; exists {
|
||
|
return v, nil
|
||
|
}
|
||
|
val, err := c.Invoke(b.resolver)
|
||
|
if err != nil {
|
||
|
return reflect.Value{}, err
|
||
|
}
|
||
|
rv := val[0]
|
||
|
if len(val) == 2 {
|
||
|
err = val[1].Interface().(error)
|
||
|
if err != nil {
|
||
|
return reflect.Value{}, err
|
||
|
}
|
||
|
}
|
||
|
if b.shared {
|
||
|
if _, exists := c.instances[b.typ]; !exists {
|
||
|
c.instances[b.typ] = make(map[string]reflect.Value)
|
||
|
}
|
||
|
c.instances[b.typ][b.name] = rv
|
||
|
}
|
||
|
return rv, nil
|
||
|
}
|
||
|
|
||
|
func (b *binding) mustMake(c *Container) reflect.Value {
|
||
|
val, err := b.make(c)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return val
|
||
|
}
|
||
|
|
||
|
type Container struct {
|
||
|
// 注册的工厂函数
|
||
|
factories map[reflect.Type]map[string]*binding
|
||
|
// 注册的共享实例
|
||
|
instances map[reflect.Type]map[string]reflect.Value
|
||
|
parent *Container
|
||
|
}
|
||
|
|
||
|
func New() *Container {
|
||
|
return &Container{
|
||
|
factories: make(map[reflect.Type]map[string]*binding),
|
||
|
instances: make(map[reflect.Type]map[string]reflect.Value),
|
||
|
parent: nil,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Fork 分支
|
||
|
func (c *Container) Fork() *Container {
|
||
|
ioc := New()
|
||
|
ioc.parent = c
|
||
|
return ioc
|
||
|
}
|
||
|
|
||
|
// Bind 绑定值到容器,有效类型:
|
||
|
// - 接口的具体实现值
|
||
|
// - 结构体的实例
|
||
|
// - 类型的值(尽量不要使用原始类型,而应该使用元素类型的变体)
|
||
|
func (c *Container) Bind(instance any) {
|
||
|
c.NamedBind("", instance)
|
||
|
}
|
||
|
|
||
|
// NamedBind 绑定具名值到容器
|
||
|
func (c *Container) NamedBind(name string, instance any) {
|
||
|
//typ := InterfaceOf(instance)
|
||
|
typ := reflect.TypeOf(instance)
|
||
|
if _, ok := c.instances[typ]; !ok {
|
||
|
c.instances[typ] = make(map[string]reflect.Value)
|
||
|
}
|
||
|
c.instances[typ][name] = reflect.ValueOf(instance)
|
||
|
}
|
||
|
|
||
|
// Factory 绑定工厂函数
|
||
|
func (c *Container) Factory(factory any, shared ...bool) error {
|
||
|
return c.NamedFactory("", factory, shared...)
|
||
|
}
|
||
|
|
||
|
// NamedFactory 绑定具名工厂函数
|
||
|
func (c *Container) NamedFactory(name string, factory any, shared ...bool) error {
|
||
|
reflectedFactory := reflect.TypeOf(factory)
|
||
|
if reflectedFactory.Kind() != reflect.Func {
|
||
|
return errors.New("container: the factory must be a function")
|
||
|
}
|
||
|
if returnCount := reflectedFactory.NumOut(); returnCount == 0 || returnCount > 2 {
|
||
|
return errors.New("container: factory function signature is invalid - it must return abstract, or abstract and error")
|
||
|
}
|
||
|
// TODO(hupeh): 验证第二个参数必须是 error 接口
|
||
|
concreteType := reflectedFactory.Out(0)
|
||
|
for i := 0; i < reflectedFactory.NumIn(); i++ {
|
||
|
// 循环依赖
|
||
|
if reflectedFactory.In(i) == concreteType {
|
||
|
return fmt.Errorf("container: factory function signature is invalid - depends on abstract it returns")
|
||
|
}
|
||
|
}
|
||
|
if _, exists := c.factories[concreteType]; !exists {
|
||
|
c.factories[concreteType] = make(map[string]*binding)
|
||
|
}
|
||
|
bd := &binding{
|
||
|
name: name,
|
||
|
typ: concreteType,
|
||
|
resolver: factory,
|
||
|
shared: false,
|
||
|
}
|
||
|
for _, b := range shared {
|
||
|
bd.shared = b
|
||
|
}
|
||
|
c.factories[concreteType][name] = bd
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Resolve 完成的注入
|
||
|
func (c *Container) Resolve(i any) error {
|
||
|
v := reflect.ValueOf(i)
|
||
|
for v.Kind() == reflect.Ptr {
|
||
|
v = v.Elem()
|
||
|
}
|
||
|
if v.Kind() != reflect.Struct {
|
||
|
return errors.New("must given a struct")
|
||
|
}
|
||
|
t := v.Type()
|
||
|
for i := 0; i < v.NumField(); i++ {
|
||
|
f := v.Field(i)
|
||
|
structField := t.Field(i)
|
||
|
inject, willInject := structField.Tag.Lookup("inject")
|
||
|
if !f.CanSet() {
|
||
|
if willInject {
|
||
|
return fmt.Errorf("container: cannot make %v field", t.Field(i).Name)
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
ft := f.Type()
|
||
|
fv := c.NamedGet(inject, ft)
|
||
|
if !fv.IsValid() {
|
||
|
return fmt.Errorf("value not found for type %v", ft)
|
||
|
}
|
||
|
f.Set(fv)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Get 获取指定类型的值
|
||
|
func (c *Container) Get(t reflect.Type) reflect.Value {
|
||
|
return c.NamedGet("", t)
|
||
|
}
|
||
|
|
||
|
// NamedGet 通过注入的名称获取指定类型的值
|
||
|
func (c *Container) NamedGet(name string, t reflect.Type) reflect.Value {
|
||
|
val, exists := c.instances[t][name]
|
||
|
|
||
|
if exists && val.IsValid() {
|
||
|
return val
|
||
|
}
|
||
|
|
||
|
if factory, exists := c.factories[t][name]; exists {
|
||
|
val = factory.mustMake(c)
|
||
|
}
|
||
|
|
||
|
if val.IsValid() || t.Kind() != reflect.Interface {
|
||
|
goto RESULT
|
||
|
}
|
||
|
|
||
|
// 使用共享值里面该接口的实现者
|
||
|
for k, v := range c.instances {
|
||
|
if k.Implements(t) {
|
||
|
for _, value := range v {
|
||
|
if value.IsValid() {
|
||
|
val = value
|
||
|
goto RESULT
|
||
|
}
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// 使用工厂函数里面该接口的实现者
|
||
|
for k, v := range c.factories {
|
||
|
if k.Implements(t) {
|
||
|
for _, bd := range v {
|
||
|
if x := bd.mustMake(c); x.IsValid() {
|
||
|
val = x
|
||
|
goto RESULT
|
||
|
}
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
RESULT:
|
||
|
if !val.IsValid() && c.parent != nil {
|
||
|
val = c.parent.NamedGet(name, t)
|
||
|
}
|
||
|
return val
|
||
|
}
|
||
|
|
||
|
// Invoke 执行函数
|
||
|
func (c *Container) Invoke(f any) ([]reflect.Value, error) {
|
||
|
t := reflect.TypeOf(f)
|
||
|
if t.Kind() != reflect.Func {
|
||
|
return nil, errors.New("container: invalid function")
|
||
|
}
|
||
|
|
||
|
var in = make([]reflect.Value, t.NumIn()) //Panic if t is not kind of Func
|
||
|
for i := 0; i < t.NumIn(); i++ {
|
||
|
argType := t.In(i)
|
||
|
val := c.Get(argType)
|
||
|
if !val.IsValid() {
|
||
|
return nil, fmt.Errorf("value not found for type %v", argType)
|
||
|
}
|
||
|
in[i] = val
|
||
|
}
|
||
|
return reflect.ValueOf(f).Call(in), nil
|
||
|
}
|