go语言依赖注入实现

  • A+

最近做项目中,生成对象还是使用比较原始的New和简单工厂的方式,使用过程中感觉不太爽快(依赖紧密,有点改动就比较麻烦),还是比较喜欢使用依赖注入的方式。

然后网上没有找到比较好用的依赖注入包,就自己动手写了一个,也不要求啥,能用就会,把我从繁琐的New方法中解脱出来。

先说一下简单实现原理

  1. 通过反射读取对象的依赖(golang是通过tag实现)
  2. 在容器中查找有无该对象实例
  3. 如果有该对象实例或者创建对象的工厂方法,则注入对象或使用工厂创建对象并注入
  4. 如果无该对象实例,则报错

需要注意的地方:

1、注入的对象首字母需要大写,小写的话,在go中代表私有,通过反射无法修改值

2、go反射无法通过读取配置文件信息动态创建对象

 

首先,介绍一下项目层次结构

go语言依赖注入实现

 

 

 主要解决:数据库-》仓储(读写分离)-》服务-》控制器 这几层的依赖注入问题

数据库,我这里为了简化数据库细节,采用模拟数据的办法来实现,实际项目中是需要读取真是数据库的,代码如下

//准备用户数据,实际开发一般从数据库读取
var users []entities.UserEntity

func init() {
    users = append(users, entities.UserEntity{ID: 1, Name: "小明", NickName: "无敌", Gender: 1, Age: 13, Tel: "18886588086", Address: "中国,广东,深圳"})
    users = append(users, entities.UserEntity{ID: 2, Name: "小红", NickName: "傻妞", Gender: 0, Age: 13, Tel: "1888658809", Address: "中国,广东,广州"})
}

type MockDB struct {
    Host  string
    User  string
    Pwd   string
    Alias string
}

func (db *MockDB) Connect() bool {
    return true
}

func (db *MockDB) Users() []entities.UserEntity {
    return users
}

func (db *MockDB) Close() {

}

数据仓储,为了实现读写分离,分离了两个接口,例如user仓储分为i_user_reader和i_user_repository,其中i_user_repository包含i_user_reader(即继承了i_user_reader)

接口定义如下:

type IUserReader interface {
    GetUsers() []dtos.UserDto
    GetUser(id int64) *dtos.UserDto
    GetMaxUserId() int64
}

type IUserRepository interface {
    IUserReader
    AddUser(user *inputs.UserInput) error
    UpdateUserNickName(id int64, nickName string) error
}

仓储实现如下:

user_read

type UserRead struct {
    ReadDb *db.MockDB `inject:"MockDBRead"`
}

func (r *UserRead) GetUsers() []dtos.UserDto {
    if r.ReadDb.Connect() {
        users := r.ReadDb.Users()
        var list []dtos.UserDto
        for _, user := range users {
            list = append(list, dtos.UserDto{ID: user.ID, Name: user.Name, NickName: user.NickName, Gender: user.Gender, Age: user.Age, Tel: user.Tel, Address: user.Address})
        }
        return list
    }
    return nil
}

func (r *UserRead) GetUser(id int64) *dtos.UserDto {
    if r.ReadDb.Connect() {
        users := r.ReadDb.Users()
        for _, user := range users {
            if user.ID == id {
                return &dtos.UserDto{ID: user.ID, Name: user.Name, NickName: user.NickName, Gender: user.Gender, Age: user.Age, Tel: user.Tel, Address: user.Address}
            }
        }
        return &dtos.UserDto{}
    }
    return nil
}

func (r *UserRead) GetMaxUserId() int64 {
    var maxId int64
    if r.ReadDb.Connect() {
        users := r.ReadDb.Users()
        for _, user := range users {
            if user.ID > maxId {
                maxId = user.ID
            }
        }
    }
    return maxId
}
UserRepository:
type UserRepository struct {
    UserRead
    WriteDb *db.MockDB `inject:"MockDBWrite"`
}

func (w *UserRepository) AddUser(user *inputs.UserInput) error {
    model := entities.UserEntity{}
    model.ID = w.GetMaxUserId() + 1
    model.Name = user.Name
    model.NickName = user.NickName
    model.Gender = user.Gender
    model.Age = user.Age
    model.Address = user.Address
    if w.ReadDb.Connect() {
        users := w.ReadDb.Users()
        users = append(users, model)
    }
    return nil
}

func (w *UserRepository) UpdateUserNickName(id int64, nickName string) error {
    user := w.GetUser(id)
    if user.ID > 0 {
        user.NickName = nickName
        return nil
    } else {
        return errors.New("未找到用户信息")
    }
}

注意,user_read依赖注入的是读db:ReadDB,user_repository依赖注入的是写db:WriteDB

 

服务的接口和实现

i_user_service:

type IUserService interface {
    GetUsers() []dtos.UserDto
    GetUser(id int64) *dtos.UserDto
    AddUser(user *inputs.UserInput) error
}

user_service:

type UserService struct {
    UserRepository repositories.IUserRepository `inject:"UserRepository"`
}

func (s *UserService) AddUser(user *inputs.UserInput) error {
    return s.UserRepository.AddUser(user)
}

func (s *UserService) GetUsers() []dtos.UserDto {
    return s.UserRepository.GetUsers()
}

func (s *UserService) GetUser(id int64) *dtos.UserDto {
    return s.UserRepository.GetUser(id)
}

UserService依赖注入UserRepository,另外,项目中,特意把仓储接口定义和服务放在同一层,是为了让服务只依赖仓储接口,不依赖仓储具体实现。这算是设计模式原则的依赖倒置原则的体现吧。

控制器实现:

type UserController struct {
    UserService user.IUserService `inject:"UserService"`
}

func (ctrl *UserController) GetUsers(ctx *gin.Context) {
    users := ctrl.UserService.GetUsers()
    Ok(Response{Code: Success, Msg: "获取用户成功!", Data: users}, ctx)
}

func (ctrl *UserController) GetUser(ctx *gin.Context) {
    idStr := ctx.Param("id")
    id, err := strconv.ParseInt(idStr, 10, 64)
    if err != nil {
        BadRequestError("id参数格式错误", ctx)
        return
    }
    users := ctrl.UserService.GetUser(id)
    Ok(Response{Code: Success, Msg: "获取用户成功!", Data: users}, ctx)
}

func (ctrl *UserController) AddUser(ctx *gin.Context) {
    input := inputs.UserInput{}
    err := ctx.ShouldBindJSON(&input)
    if err != nil {
        BadRequestError("参数错误", ctx)
        return
    }
    err = ctrl.UserService.AddUser(&input)
    if err != nil {
        Ok(Response{Code: Failed, Msg: err.Error()}, ctx)
        return
    }
    Ok(Response{Code: Success, Msg: "添加用户成功!"}, ctx)
}

UserController依赖注入UserService

接下来是实现依赖注入的核心代码,容器的实现

Container:

var injectTagName = "inject" //依赖注入tag名

//生命周期
// singleton:单例 单一实例,每次使用都是该实例
// transient:瞬时实例,每次使用都创建新的实例
type Container struct {
    sync.Mutex
    singletons map[string]interface{}
    transients map[string]factory
}

type factory = func() (interface{}, error)

//注册单例对象
func (c *Container) SetSingleton(name string, singleton interface{}) {
    c.Lock()
    c.singletons[name] = singleton
    c.Unlock()
}

func (c *Container) GetSingleton(name string) interface{} {
    return c.singletons[name]
}

//注册瞬时实例创建工厂方法
func (c *Container) SetTransient(name string, factory factory) {
    c.Lock()
    c.transients[name] = factory
    c.Unlock()
}

func (c *Container) GetTransient(name string) interface{} {
    factory := c.transients[name]
    instance, _ := factory()
    return instance
}

//注入实例
func (c *Container) Entry(instance interface{}) error {
    err := c.entryValue(reflect.ValueOf(instance))
    if err != nil {
        return err
    }
    return nil
}

func (c *Container) entryValue(value reflect.Value) error {
    if value.Kind() != reflect.Ptr {
        return errors.New("必须为指针")
    }
    elemType, elemValue := value.Type().Elem(), value.Elem()
    for i := 0; i < elemType.NumField(); i++ {
        if !elemValue.Field(i).CanSet() { //不可设置 跳过
            continue
        }

        fieldType := elemType.Field(i)
        if fieldType.Anonymous {
            //fmt.Println(fieldType.Name + "是匿名字段")
            item := reflect.New(elemValue.Field(i).Type())
            c.entryValue(item) //递归注入
            elemValue.Field(i).Set(item.Elem())
        } else {
            if elemValue.Field(i).IsZero() { //零值才注入
                //fmt.Println(elemValue.Field(i).Interface())
                //fmt.Println(fieldType.Name)
                tag := fieldType.Tag.Get(injectTagName)
                injectInstance, err := c.getInstance(tag)
                if err != nil {
                    return err
                }
                c.entryValue(reflect.ValueOf(injectInstance)) //递归注入

                elemValue.Field(i).Set(reflect.ValueOf(injectInstance))
            } else {
                fmt.Println(fieldType.Name)
            }
        }
    }
    return nil
}

func (c *Container) getInstance(tag string) (interface{}, error) {
    var injectName string
    tags := strings.Split(tag, ",")
    if len(tags) == 0 {
        injectName = ""
    } else {
        injectName = tags[0]
    }

    if c.isTransient(tag) {
        factory, ok := c.transients[injectName]
        if !ok {
            return nil, errors.New("transient factory not found")
        } else {
            return factory()
        }
    } else { //默认单例
        instance, ok := c.singletons[injectName]
        if !ok || instance == nil {
            return nil, errors.New(injectName + " dependency not found")
        } else {
            return instance, nil
        }
    }
}

// transient:瞬时实例,每次使用都创建新的实例
func (c *Container) isTransient(tag string) bool {
    tags := strings.Split(tag, ",")
    for _, name := range tags {
        if name == "transient" {
            return true
        }
    }
    return false
}

func (c *Container) String() string {
    lines := make([]string, 0, len(c.singletons)+len(c.transients)+2)
    lines = append(lines, "singletons:")
    for key, value := range c.singletons {
        line := fmt.Sprintf("    %s: %x %s", key, c.singletons[key], reflect.TypeOf(value).String())
        lines = append(lines, line)
    }

    lines = append(lines, "transients:")
    for key, value := range c.transients {
        line := fmt.Sprintf("    %s: %x %s", key, c.transients[key], reflect.TypeOf(value).String())
        lines = append(lines, line)
    }
    return strings.Join(lines, "\n")
}

这里使用了两种生命周期的实例:单例和瞬时(其他生命周期,水平有限哈)

简单说下原理,容器主要包含两个map对象,用来存储对象和创建对方方法,然后依赖注入实现,就是通过反射获取tag信息,再去容器map中获取对象,通过反射把获取的对象赋值到字段中。

我这里采用了递归注入的方式,所以本项目中,只用注入UserController对象即可,因为实际项目中多点是有多个Controller对象,所以我这里使用了个简单工厂来创建Controller对象,然后只用注入工厂方法即可

工厂方法实现如下:

type CtrlFactory struct {
    UserCtrl *controllers.UserController `inject:"UserController"`
}

使用容器前,需要先初始化好容器对象,这里使用一个全局对象,然后初始化好需要注入的对象,实现代码如下:

var GContainer = &Container{
    singletons: make(map[string]interface{}),
    transients: make(map[string]factory),
}

func Init() {
    //db
    GContainer.SetSingleton("MockDBRead", &db.MockDB{Host: "192.168.1.12:3036", User: "root", Pwd: "123456", Alias: "Read"})
    GContainer.SetSingleton("MockDBWrite", &db.MockDB{Host: "192.168.1.25:3036", User: "root", Pwd: "123456", Alias: "Write"})

    //仓储
    GContainer.SetSingleton("UserRepository", &user.UserRepository{})

    //服务
    GContainer.SetSingleton("UserService", &userDomain.UserService{})

    //控制器
    GContainer.SetSingleton("UserController", &controllers.UserController{})

    //控制器工厂
    ctlFactory := &CtrlFactory{}
    GContainer.SetSingleton("CtrlFactory", ctlFactory)

    GContainer.Entry(ctlFactory) //注入

    fmt.Println(GContainer.String())
}

依赖注入代码实现讲完了,然后就是具体使用了,使用时,先在main方法中调用容器出事化方法Init() (注意,这里Init特意大写,要和go包的init区分,go包的init是自动调用,这里大写的Init是需要手动调用的,至于为啥呢,注意是可以控制调用时机,go包的init调用顺序有点莫名其妙,特别是包引用复杂的时候),main代码如下:

func main() {
    Init()
    Run()
}

func Init() {
    inject.Init()
}

func Run() {
    router := router.Init()

    s := &http.Server{
        Addr:           ":8080",
        Handler:        router,
        ReadTimeout:    time.Duration(10) * time.Second,
        WriteTimeout:   time.Duration(10) * time.Second,
        MaxHeaderBytes: 1 << 20,
    }
    go func() {
        log.Println("Server Listen at:8080")
        if err := s.ListenAndServe(); err != nil {
            log.Printf("Listen:%s\n", err)
        }
    }()

    quit := make(chan os.Signal)
    signal.Notify(quit, os.Interrupt)
    <-quit

    log.Println("Shutdown Server...")
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()
    if err := s.Shutdown(ctx); err != nil {
        log.Fatal("Server Shutdown:", err)
    }
    log.Println("Server exiting")
}

我这里使用了gin框架来构建http服务

初始化话完毕后,就是在路由中使用controller了,先从容器中获取工厂对象,然后通过go类型推断转化为具体类型,代码如下:

func Init() *gin.Engine {
    // Creates a router without any middleware by default
    r := gin.New()
    r.Use(gin.Logger())
    // Recovery middleware recovers from any panics and writes a 500 if there was one.
    r.Use(gin.Recovery())

    r.GET("/ping", func(c *gin.Context) {
        c.JSON(200, gin.H{
            "message": "pong",
        })
    })

    factory := inject.GContainer.GetSingleton("CtrlFactory")
    ctrlFactory := factory.(*inject.CtrlFactory)

    apiV1 := r.Group("/api/v1")
    //users
    userRg := apiV1.Group("/user")
    {
        userRg.POST("", ctrlFactory.UserCtrl.AddUser)
        userRg.GET("", ctrlFactory.UserCtrl.GetUsers)
        userRg.GET("/:id", ctrlFactory.UserCtrl.GetUser)
    }

    gin.SetMode("debug")
    return r
}

核心代码就是:

factory := inject.GContainer.GetSingleton("CtrlFactory")
ctrlFactory := factory.(*inject.CtrlFactory)

ok,介绍完了。初始弄这个依赖注入可能觉得有点麻烦,但这是一劳永逸的办法,后面有啥增加修改的就比较简单

具体代码放在github上了,有兴趣可以关注一下:https://github.com/marshhu/ma-inject

 

90DIR-CMD