package db import ( "context" "errors" "ims/util/backoff" "strings" "gorm.io/gorm" "zestack.dev/cast" "zestack.dev/env" ) const ( ctxKey = "ims/util/db:engine" ) var ( // ErrExecSQL is returned when the migrator fails to execute SQL. ErrExecSQL = errors.New("locking: failed to execute SQL") // ErrAcquireLock is returned when the migrator fails to acquire an advisory lock. ErrAcquireLock = errors.New("locking: failed to acquire advisory lock") // ErrReleaseLock is returned when the migrator fails to release an advisory lock. ErrReleaseLock = errors.New("locking: failed to release advisory lock") // ErrSwitchSchema is returned when the migrator fails to switch to a schema. ErrSwitchSchema = errors.New("db: nested schema switching") db *gorm.DB schemaHelper SchemaHelper ) type SchemaHelper interface { PublicSchema() string TenantSchema(tenantId uint) string CurrentSchema(tx *gorm.DB) string UseSchema(tx *gorm.DB, schema string) (reset func() error, err error) LockSchema(tx *gorm.DB, schema string, retry *backoff.Options) (unlock func() error, err error) CreateSchema(tx *gorm.DB, schema string) error DropSchema(tx *gorm.DB, schema string) error } // Engine 获取数据库操作引擎 func Engine() *gorm.DB { return db } func NewContext(ctx context.Context) context.Context { return context.WithValue(ctx, ctxKey, FromContext(ctx)) } func FromContext(ctx context.Context) *gorm.DB { if ctx == nil { return Engine() } engine, ok := ctx.Value(ctxKey).(*gorm.DB) if ok && engine != nil { return engine } return Engine().WithContext(ctx) } // WithContext 派生出基于指定上下文的数据库操作引擎 func WithContext(ctx context.Context) *gorm.DB { return Engine().WithContext(ctx) } func PublicSchema() string { return schemaHelper.PublicSchema() } func TenantSchema(tenantId uint) string { return schemaHelper.TenantSchema(tenantId) } // UseTenant 使用指定的租户 // 切记:应该在事务中使用 func UseTenant(tx *gorm.DB, tenant uint) (reset func() error, err error) { return schemaHelper.UseSchema(tx, schemaHelper.TenantSchema(tenant)) } func CurrentTenant(tx *gorm.DB) (tenant uint, ok bool) { schema := schemaHelper.CurrentSchema(tx) prefix := env.String("DB_TENANT_PREFIX", "tenant_") suffix := env.String("DB_TENANT_SUFFIX") if schema != "" && prefix != "" && strings.HasPrefix(schema, prefix) { schema = strings.TrimPrefix(schema, prefix) } if schema != "" && suffix != "" && strings.HasSuffix(schema, suffix) { schema = strings.TrimSuffix(schema, suffix) } if schema == "" { return } var err error tenant, err = cast.Uint(schema) ok = err == nil && tenant > 0 return } func WithTenant(tx *gorm.DB, tenant uint, fn func(tx *gorm.DB) error) error { return tx.Transaction(func(tx *gorm.DB) error { reset, err := UseTenant(tx, tenant) if err != nil { return err } defer reset() return fn(tx) }) } func LockSchema(tx *gorm.DB, schema string, opts ...backoff.Option) (unlock func() error, err error) { if len(opts) == 0 { return schemaHelper.LockSchema(tx, schema, nil) } var retry *backoff.Options retry.Apply(opts...) return schemaHelper.LockSchema(tx, schema, retry) } func CreateSchema(tx *gorm.DB, schema string) error { return schemaHelper.CreateSchema(tx, schema) } func DropSchema(tx *gorm.DB, schema string) error { return schemaHelper.DropSchema(tx, schema) }