You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
162 lines
3.7 KiB
162 lines
3.7 KiB
package db
|
|
|
|
import (
|
|
"github.com/go-gormigrate/gormigrate/v2"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type Migrator struct {
|
|
Schema string
|
|
TenantId uint
|
|
DB *gorm.DB
|
|
Migrations []*gormigrate.Migration
|
|
Log func(string, ...any)
|
|
}
|
|
|
|
func (m *Migrator) Migrate() error {
|
|
return m.migrate(func(migrator *gormigrate.Gormigrate) error {
|
|
return migrator.Migrate()
|
|
})
|
|
}
|
|
|
|
func (m *Migrator) MigrateTo(migrationID string) error {
|
|
return m.migrate(func(migrator *gormigrate.Gormigrate) error {
|
|
return migrator.MigrateTo(migrationID)
|
|
})
|
|
}
|
|
|
|
func (m *Migrator) migrate(fn func(migrator *gormigrate.Gormigrate) error) error {
|
|
if m.TenantId > 0 {
|
|
m.Log("⏳ migrating tables for tenant %d", m.TenantId)
|
|
} else {
|
|
m.Log("⏳ migrating public tables")
|
|
}
|
|
|
|
if err := CreateSchema(m.DB, m.Schema); err != nil {
|
|
if m.TenantId > 0 {
|
|
m.Log("❌ failed to create schema for tenant %d: %s", m.TenantId, err)
|
|
} else {
|
|
m.Log("❌ failed to create public schema: %s", err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
tx := m.DB.Begin()
|
|
defer func() {
|
|
if tx.Error == nil {
|
|
tx.Commit()
|
|
if m.TenantId > 0 {
|
|
m.Log("✅ private tables migrated for tenant %d", m.TenantId)
|
|
} else {
|
|
m.Log("✅ public tables migrated for all tenants")
|
|
}
|
|
} else {
|
|
tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
unlock, err := LockSchema(tx, m.Schema)
|
|
if err != nil {
|
|
m.Log("❌ failed to acquire advisory lock: %w", err)
|
|
return err
|
|
}
|
|
defer unlock()
|
|
|
|
reset := func() error { return nil }
|
|
if m.TenantId > 0 {
|
|
reset, err = UseTenant(tx, m.TenantId)
|
|
if err != nil {
|
|
m.Log("❌ failed to switch schema for tenant %d: %w", m.TenantId, err)
|
|
return err
|
|
}
|
|
}
|
|
defer reset()
|
|
|
|
err = fn(gormigrate.New(tx, &gormigrate.Options{
|
|
TableName: tx.NamingStrategy.TableName("migrations"),
|
|
}, m.Migrations))
|
|
if err != nil {
|
|
if m.TenantId > 0 {
|
|
m.Log("❌ failed to call migrate for tenant %d: %w", m.TenantId, err)
|
|
} else {
|
|
m.Log("❌ failed to migrate public tables: %w", err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *Migrator) RollbackLast() error {
|
|
return m.rollback(func(migrator *gormigrate.Gormigrate) error {
|
|
return migrator.RollbackLast()
|
|
})
|
|
}
|
|
|
|
func (m *Migrator) RollbackTo(migrationID string) error {
|
|
return m.rollback(func(migrator *gormigrate.Gormigrate) error {
|
|
return migrator.MigrateTo(migrationID)
|
|
})
|
|
}
|
|
|
|
func (m *Migrator) rollback(fn func(migrator *gormigrate.Gormigrate) error) error {
|
|
if m.TenantId > 0 {
|
|
m.Log("⏳ rollbacking tables for tenant %d", m.TenantId)
|
|
} else {
|
|
m.Log("⏳ rollbacking public tables")
|
|
}
|
|
|
|
if err := CreateSchema(m.DB, m.Schema); err != nil {
|
|
if m.TenantId > 0 {
|
|
m.Log("❌ failed to create schema for tenant %d: %s", m.TenantId, err)
|
|
} else {
|
|
m.Log("❌ failed to create public schema: %s", err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
tx := m.DB.Begin()
|
|
defer func() {
|
|
if tx.Error == nil {
|
|
tx.Commit()
|
|
if m.TenantId > 0 {
|
|
m.Log("✅ private tables rollbacked for tenant %d", m.TenantId)
|
|
} else {
|
|
m.Log("✅ public tables rollbacked for all tenants")
|
|
}
|
|
} else {
|
|
tx.Rollback()
|
|
}
|
|
}()
|
|
|
|
unlock, err := LockSchema(tx, m.Schema)
|
|
if err != nil {
|
|
m.Log("❌ failed to acquire advisory lock: %w", err)
|
|
return err
|
|
}
|
|
defer unlock()
|
|
|
|
reset := func() error { return nil }
|
|
if m.TenantId > 0 {
|
|
reset, err = UseTenant(tx, m.TenantId)
|
|
if err != nil {
|
|
m.Log("❌ failed to switch schema for tenant %d: %w", m.TenantId, err)
|
|
return err
|
|
}
|
|
}
|
|
defer reset()
|
|
|
|
err = fn(gormigrate.New(tx, &gormigrate.Options{
|
|
TableName: tx.NamingStrategy.TableName("migrations"),
|
|
}, m.Migrations))
|
|
if err != nil {
|
|
if m.TenantId > 0 {
|
|
m.Log("❌ failed to call rollback for tenant %d: %w", m.TenantId, err)
|
|
} else {
|
|
m.Log("❌ failed to rollback public tables: %w", err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|