go项目脚手架
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
sorbet/pkg/ioc/container.go

228 lines
5.2 KiB

package ioc
import (
"errors"
"fmt"
"reflect"
)
type binding struct {
name string
typ reflect.Type
resolver any
shared bool
}
func (b *binding) make(c *Container) (reflect.Value, error) {
if v, exists := c.instances[b.typ][b.name]; exists {
return v, nil
}
val, err := c.Invoke(b.resolver)
if err != nil {
return reflect.Value{}, err
}
rv := val[0]
if len(val) == 2 {
err = val[1].Interface().(error)
if err != nil {
return reflect.Value{}, err
}
}
if b.shared {
if _, exists := c.instances[b.typ]; !exists {
c.instances[b.typ] = make(map[string]reflect.Value)
}
c.instances[b.typ][b.name] = rv
}
return rv, nil
}
func (b *binding) mustMake(c *Container) reflect.Value {
val, err := b.make(c)
if err != nil {
panic(err)
}
return val
}
type Container struct {
// 注册的工厂函数
factories map[reflect.Type]map[string]*binding
// 注册的共享实例
instances map[reflect.Type]map[string]reflect.Value
parent *Container
}
func New() *Container {
return &Container{
factories: make(map[reflect.Type]map[string]*binding),
instances: make(map[reflect.Type]map[string]reflect.Value),
parent: nil,
}
}
// Fork 分支
func (c *Container) Fork() *Container {
ioc := New()
ioc.parent = c
return ioc
}
// Bind 绑定值到容器,有效类型:
// - 接口的具体实现值
// - 结构体的实例
// - 类型的值(尽量不要使用原始类型,而应该使用元素类型的变体)
func (c *Container) Bind(instance any) {
c.NamedBind("", instance)
}
// NamedBind 绑定具名值到容器
func (c *Container) NamedBind(name string, instance any) {
//typ := InterfaceOf(instance)
typ := reflect.TypeOf(instance)
if _, ok := c.instances[typ]; !ok {
c.instances[typ] = make(map[string]reflect.Value)
}
c.instances[typ][name] = reflect.ValueOf(instance)
}
// Factory 绑定工厂函数
func (c *Container) Factory(factory any, shared ...bool) error {
return c.NamedFactory("", factory, shared...)
}
// NamedFactory 绑定具名工厂函数
func (c *Container) NamedFactory(name string, factory any, shared ...bool) error {
reflectedFactory := reflect.TypeOf(factory)
if reflectedFactory.Kind() != reflect.Func {
return errors.New("container: the factory must be a function")
}
if returnCount := reflectedFactory.NumOut(); returnCount == 0 || returnCount > 2 {
return errors.New("container: factory function signature is invalid - it must return abstract, or abstract and error")
}
// TODO(hupeh): 验证第二个参数必须是 error 接口
concreteType := reflectedFactory.Out(0)
for i := 0; i < reflectedFactory.NumIn(); i++ {
// 循环依赖
if reflectedFactory.In(i) == concreteType {
return fmt.Errorf("container: factory function signature is invalid - depends on abstract it returns")
}
}
if _, exists := c.factories[concreteType]; !exists {
c.factories[concreteType] = make(map[string]*binding)
}
bd := &binding{
name: name,
typ: concreteType,
resolver: factory,
shared: false,
}
for _, b := range shared {
bd.shared = b
}
c.factories[concreteType][name] = bd
return nil
}
// Resolve 完成的注入
func (c *Container) Resolve(i any) error {
v := reflect.ValueOf(i)
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return errors.New("must given a struct")
}
t := v.Type()
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
structField := t.Field(i)
inject, willInject := structField.Tag.Lookup("inject")
if !f.CanSet() {
if willInject {
return fmt.Errorf("container: cannot make %v field", t.Field(i).Name)
}
continue
}
ft := f.Type()
fv := c.NamedGet(inject, ft)
if !fv.IsValid() {
return fmt.Errorf("value not found for type %v", ft)
}
f.Set(fv)
}
return nil
}
// Get 获取指定类型的值
func (c *Container) Get(t reflect.Type) reflect.Value {
return c.NamedGet("", t)
}
// NamedGet 通过注入的名称获取指定类型的值
func (c *Container) NamedGet(name string, t reflect.Type) reflect.Value {
val, exists := c.instances[t][name]
if exists && val.IsValid() {
return val
}
if factory, exists := c.factories[t][name]; exists {
val = factory.mustMake(c)
}
if val.IsValid() || t.Kind() != reflect.Interface {
goto RESULT
}
// 使用共享值里面该接口的实现者
for k, v := range c.instances {
if k.Implements(t) {
for _, value := range v {
if value.IsValid() {
val = value
goto RESULT
}
}
break
}
}
// 使用工厂函数里面该接口的实现者
for k, v := range c.factories {
if k.Implements(t) {
for _, bd := range v {
if x := bd.mustMake(c); x.IsValid() {
val = x
goto RESULT
}
}
break
}
}
RESULT:
if !val.IsValid() && c.parent != nil {
val = c.parent.NamedGet(name, t)
}
return val
}
// Invoke 执行函数
func (c *Container) Invoke(f any) ([]reflect.Value, error) {
t := reflect.TypeOf(f)
if t.Kind() != reflect.Func {
return nil, errors.New("container: invalid function")
}
var in = make([]reflect.Value, t.NumIn()) //Panic if t is not kind of Func
for i := 0; i < t.NumIn(); i++ {
argType := t.In(i)
val := c.Get(argType)
if !val.IsValid() {
return nil, fmt.Errorf("value not found for type %v", argType)
}
in[i] = val
}
return reflect.ValueOf(f).Call(in), nil
}