package db import ( "context" "database/sql" "errors" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/driver/sqlserver" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/plugin/optimisticlock" "sorbet/pkg/env" "sync" "time" ) var ( // TODO(hupeh): 使用原子性操作 atomic.Value db *gorm.DB lock sync.RWMutex ErrNoCodeFirst = errors.New("no code first") // 使用东八区时间 // https://cloud.tencent.com/developer/article/1805859 cstZone = time.FixedZone("CST", 8*3600) ) type Version = optimisticlock.Version type SessionConfig = gorm.Session type BaseConfig struct { TimeLocation *time.Location NamingStrategy schema.Namer Logger logger.Interface Plugins map[string]gorm.Plugin TablePrefix string SingularTable bool NameReplacer schema.Replacer IdentifierMaxLength int MaxIdleConns int MaxOpenConns int ConnMaxLifetime time.Duration } type Config struct { BaseConfig Driver string StoreEngine string DSN string } // DB 获取数据库操作实例 func DB() *gorm.DB { lock.RLock() if db != nil { lock.RUnlock() return db } lock.RUnlock() lock.Lock() db = New() lock.Unlock() return db } func WithContext(ctx context.Context) *gorm.DB { return DB().WithContext(ctx) } // SetDB 自定义操作引擎 func SetDB(engine *gorm.DB) { lock.Lock() defer lock.Unlock() db = engine } // New 创建数据库操作引擎,初始化参数来自环境变量 func New() *gorm.DB { engine, err := NewWithConfig(&Config{ BaseConfig: BaseConfig{ TimeLocation: cstZone, TablePrefix: env.String("DB_PREFIX"), SingularTable: env.Bool("DB_SINGULAR_TABLE", false), IdentifierMaxLength: env.Int("DB_IDENTIFIER_MAX_LENGTH", 0), Logger: &dbLogger{200 * time.Millisecond}, MaxIdleConns: env.Int("DB_MAX_IDLE_CONNS", 0), MaxOpenConns: env.Int("DB_MAX_OPEN_CONNS", 0), ConnMaxLifetime: env.Duration("DB_CONN_MAX_LIFETIME", 0), }, Driver: env.String("DB_DRIVER", "sqlite3"), StoreEngine: env.String("DB_STORE_ENGINE", "InnoDB"), DSN: env.String("DB_DSN", "./app.db"), }) if err != nil { panic(err) } return engine } // NewWithConfig 通过配置创建数据库操作引擎 func NewWithConfig(config *Config) (*gorm.DB, error) { var dialector gorm.Dialector switch config.Driver { case "mysql": dialector = mysql.Open(config.DSN) case "pgsql": dialector = postgres.Open(config.DSN) case "sqlite", "sqlite3": dialector = sqlite.Open(config.DSN) case "sqlserver": dialector = sqlserver.Open(config.DSN) default: return nil, errors.New("不支持的数据库驱动:" + config.Driver) } engine, err := NewWithDialector(dialector, &config.BaseConfig) if err != nil { return nil, err } if config.Driver == "mysql" && config.StoreEngine != "" { engine = engine.Set("gorm:table_options", "ENGINE="+config.StoreEngine) } return engine, nil } // NewWithDialector 通过指定的 dialector 创建数据库操作引擎 func NewWithDialector(dialector gorm.Dialector, config *BaseConfig) (*gorm.DB, error) { engine, err := gorm.Open(dialector, &gorm.Config{ NamingStrategy: schema.NamingStrategy{ TablePrefix: config.TablePrefix, SingularTable: config.SingularTable, NameReplacer: config.NameReplacer, NoLowerCase: false, IdentifierMaxLength: config.IdentifierMaxLength, }, Logger: config.Logger, NowFunc: func() time.Time { if config.TimeLocation == nil { return time.Now() } return time.Now().In(config.TimeLocation) }, QueryFields: false, }) if err != nil { return nil, err } rawDB, err := engine.DB() if err != nil { return nil, err } if config.MaxIdleConns > 0 { rawDB.SetMaxIdleConns(config.MaxIdleConns) } if config.MaxOpenConns > 0 { rawDB.SetMaxOpenConns(config.MaxOpenConns) } if config.ConnMaxLifetime > 0 { rawDB.SetConnMaxLifetime(config.ConnMaxLifetime) } return engine, nil } // Sync 同步数据库结构,属于代码优先模式。 // // 在使用该方法之前需要在环境变量中开启 "DB_CODE_FIRST" 选项。 // // 这是非常危险的操作,必须慎之又慎,因为函数将进行如下的同步操作: // * 自动检测和创建表,这个检测是根据表的名字 // * 自动检测和新增表中的字段,这个检测是根据字段名,同时对表中多余的字段给出警告信息 // * 自动检测,创建和删除索引和唯一索引,这个检测是根据索引的一个或多个字段名,而不根据索引名称。因此这里需要注意,如果在一个有大量数据的表中引入新的索引,数据库可能需要一定的时间来建立索引。 // * 自动转换varchar字段类型到text字段类型,自动警告其它字段类型在模型和数据库之间不一致的情况。 // * 自动警告字段的默认值,是否为空信息在模型和数据库之间不匹配的情况 // // 以上这些警告信息需要将日志的显示级别调整为Warn级别才会显示。 func Sync(beans ...any) error { if env.Bool("DB_CODE_FIRST") { return DB().AutoMigrate(beans...) } return ErrNoCodeFirst } // Ping ping 一下数据库连接 func Ping() error { raw, err := DB().DB() if err != nil { return err } return raw.Ping() } // Stats 返回数据库统计信息 func Stats() (*sql.DBStats, error) { raw, err := DB().DB() if err != nil { return nil, err } stats := raw.Stats() return &stats, nil } // Now 这是个工具函数,返回当前时间 func Now() time.Time { return DB().Config.NowFunc() } // Session 会话模式 // // 在该模式下会创建并缓存预编译语句,从而提高后续的调用速度 func Session(config *SessionConfig) *gorm.DB { return DB().Session(config) } // Model 通过模型进行下一步操作 func Model(value any) *gorm.DB { return DB().Model(value) } // Table 通过数据表面进行下一步操作 func Table(name string, args ...any) *gorm.DB { return DB().Table(name, args...) } // Create 通过模型创建记录 // // 使用模型创建一条记录: // // user := User{Name: "Jinzhu", Age: 18, Birthday: time.Now()} // ok, err := db.Create(&user) // 通过数据的指针来创建 // // 我们还可以使用模型切边创建多项记录: // // users := []*User{ // User{Name: "Jinzhu", Age: 18, Birthday: time.Now()}, // User{Name: "Jackson", Age: 19, Birthday: time.Now()}, // } // ok, err := db.Create(users) // 通过 slice 创建多条记录 func Create(value any) (bool, error) { result := DB().Create(value) if err := result.Error; err != nil { return false, err } return result.RowsAffected > 0, nil } // Save 保存模型数据,由以下两点需要注意: // // - 该函数会保存所有的字段,即使字段是零值。 // - 如果模型中的主键值是零值,将会创建该数据。 func Save(value any) (bool, error) { result := DB().Save(value) if err := result.Error; err != nil { return false, err } return result.RowsAffected > 0, nil } func Upsert(bean any, conflict clause.OnConflict) (bool, error) { result := DB().Clauses(conflict).Create(bean) if err := result.Error; err != nil { return false, err } return result.RowsAffected > 0, nil } // Transaction 自动事务管理 // // 如果在 fc 中开启了新的事务,必须确保这个内嵌的事务被提交或被回滚。 func Transaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) error { return DB().Transaction(fc, opts...) } // Begin 开启事务 // // 使用示例 // // tx := db.Begin() // 开始事务 // tx.Create() // 执行一些数据库操作 // tx.Rollback() // 遇到错误时回滚事务 // tx.Commit() // 否则,提交事务 // // 事务一旦开始,就应该使用返回的 tx 对象处理数据 func Begin(opts ...*sql.TxOptions) (tx *gorm.DB) { return DB().Begin(opts...) } // Raw 执行 SQL 查询语句 func Raw(sql string, values ...any) *gorm.DB { return DB().Raw(sql, values...) } // Exec 执行Insert, Update, Delete 等命令的 SQL 语句, // 如果需要查询数据请使用 Query 函数 func Exec(sql string, values ...any) *gorm.DB { return DB().Exec(sql, values...) } func Unscoped() *gorm.DB { return DB().Unscoped() } // Migrator 返回迁移接口 func Migrator() gorm.Migrator { return DB().Migrator() }