gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

migrator.go (9493B)


      1 package migrate
      2 
      3 import (
      4 	"context"
      5 	"errors"
      6 	"fmt"
      7 	"os"
      8 	"path/filepath"
      9 	"regexp"
     10 	"time"
     11 
     12 	"github.com/uptrace/bun"
     13 )
     14 
     15 type MigratorOption func(m *Migrator)
     16 
     17 func WithTableName(table string) MigratorOption {
     18 	return func(m *Migrator) {
     19 		m.table = table
     20 	}
     21 }
     22 
     23 func WithLocksTableName(table string) MigratorOption {
     24 	return func(m *Migrator) {
     25 		m.locksTable = table
     26 	}
     27 }
     28 
     29 // WithMarkAppliedOnSuccess sets the migrator to only mark migrations as applied/unapplied
     30 // when their up/down is successful
     31 func WithMarkAppliedOnSuccess(enabled bool) MigratorOption {
     32 	return func(m *Migrator) {
     33 		m.markAppliedOnSuccess = enabled
     34 	}
     35 }
     36 
     37 type Migrator struct {
     38 	db         *bun.DB
     39 	migrations *Migrations
     40 
     41 	ms MigrationSlice
     42 
     43 	table                string
     44 	locksTable           string
     45 	markAppliedOnSuccess bool
     46 }
     47 
     48 func NewMigrator(db *bun.DB, migrations *Migrations, opts ...MigratorOption) *Migrator {
     49 	m := &Migrator{
     50 		db:         db,
     51 		migrations: migrations,
     52 
     53 		ms: migrations.ms,
     54 
     55 		table:      "bun_migrations",
     56 		locksTable: "bun_migration_locks",
     57 	}
     58 	for _, opt := range opts {
     59 		opt(m)
     60 	}
     61 	return m
     62 }
     63 
     64 func (m *Migrator) DB() *bun.DB {
     65 	return m.db
     66 }
     67 
     68 // MigrationsWithStatus returns migrations with status in ascending order.
     69 func (m *Migrator) MigrationsWithStatus(ctx context.Context) (MigrationSlice, error) {
     70 	sorted, _, err := m.migrationsWithStatus(ctx)
     71 	return sorted, err
     72 }
     73 
     74 func (m *Migrator) migrationsWithStatus(ctx context.Context) (MigrationSlice, int64, error) {
     75 	sorted := m.migrations.Sorted()
     76 
     77 	applied, err := m.AppliedMigrations(ctx)
     78 	if err != nil {
     79 		return nil, 0, err
     80 	}
     81 
     82 	appliedMap := migrationMap(applied)
     83 	for i := range sorted {
     84 		m1 := &sorted[i]
     85 		if m2, ok := appliedMap[m1.Name]; ok {
     86 			m1.ID = m2.ID
     87 			m1.GroupID = m2.GroupID
     88 			m1.MigratedAt = m2.MigratedAt
     89 		}
     90 	}
     91 
     92 	return sorted, applied.LastGroupID(), nil
     93 }
     94 
     95 func (m *Migrator) Init(ctx context.Context) error {
     96 	if _, err := m.db.NewCreateTable().
     97 		Model((*Migration)(nil)).
     98 		ModelTableExpr(m.table).
     99 		IfNotExists().
    100 		Exec(ctx); err != nil {
    101 		return err
    102 	}
    103 	if _, err := m.db.NewCreateTable().
    104 		Model((*migrationLock)(nil)).
    105 		ModelTableExpr(m.locksTable).
    106 		IfNotExists().
    107 		Exec(ctx); err != nil {
    108 		return err
    109 	}
    110 	return nil
    111 }
    112 
    113 func (m *Migrator) Reset(ctx context.Context) error {
    114 	if _, err := m.db.NewDropTable().
    115 		Model((*Migration)(nil)).
    116 		ModelTableExpr(m.table).
    117 		IfExists().
    118 		Exec(ctx); err != nil {
    119 		return err
    120 	}
    121 	if _, err := m.db.NewDropTable().
    122 		Model((*migrationLock)(nil)).
    123 		ModelTableExpr(m.locksTable).
    124 		IfExists().
    125 		Exec(ctx); err != nil {
    126 		return err
    127 	}
    128 	return m.Init(ctx)
    129 }
    130 
    131 // Migrate runs unapplied migrations. If a migration fails, migrate immediately exits.
    132 func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
    133 	cfg := newMigrationConfig(opts)
    134 
    135 	if err := m.validate(); err != nil {
    136 		return nil, err
    137 	}
    138 
    139 	migrations, lastGroupID, err := m.migrationsWithStatus(ctx)
    140 	if err != nil {
    141 		return nil, err
    142 	}
    143 	migrations = migrations.Unapplied()
    144 
    145 	group := new(MigrationGroup)
    146 	if len(migrations) == 0 {
    147 		return group, nil
    148 	}
    149 	group.ID = lastGroupID + 1
    150 
    151 	for i := range migrations {
    152 		migration := &migrations[i]
    153 		migration.GroupID = group.ID
    154 
    155 		if !m.markAppliedOnSuccess {
    156 			if err := m.MarkApplied(ctx, migration); err != nil {
    157 				return group, err
    158 			}
    159 		}
    160 
    161 		group.Migrations = migrations[:i+1]
    162 
    163 		if !cfg.nop && migration.Up != nil {
    164 			if err := migration.Up(ctx, m.db); err != nil {
    165 				return group, err
    166 			}
    167 		}
    168 
    169 		if m.markAppliedOnSuccess {
    170 			if err := m.MarkApplied(ctx, migration); err != nil {
    171 				return group, err
    172 			}
    173 		}
    174 	}
    175 
    176 	return group, nil
    177 }
    178 
    179 func (m *Migrator) Rollback(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) {
    180 	cfg := newMigrationConfig(opts)
    181 
    182 	if err := m.validate(); err != nil {
    183 		return nil, err
    184 	}
    185 
    186 	migrations, err := m.MigrationsWithStatus(ctx)
    187 	if err != nil {
    188 		return nil, err
    189 	}
    190 
    191 	lastGroup := migrations.LastGroup()
    192 
    193 	for i := len(lastGroup.Migrations) - 1; i >= 0; i-- {
    194 		migration := &lastGroup.Migrations[i]
    195 
    196 		if !m.markAppliedOnSuccess {
    197 			if err := m.MarkUnapplied(ctx, migration); err != nil {
    198 				return lastGroup, err
    199 			}
    200 		}
    201 
    202 		if !cfg.nop && migration.Down != nil {
    203 			if err := migration.Down(ctx, m.db); err != nil {
    204 				return lastGroup, err
    205 			}
    206 		}
    207 
    208 		if m.markAppliedOnSuccess {
    209 			if err := m.MarkUnapplied(ctx, migration); err != nil {
    210 				return lastGroup, err
    211 			}
    212 		}
    213 	}
    214 
    215 	return lastGroup, nil
    216 }
    217 
    218 type goMigrationConfig struct {
    219 	packageName string
    220 	goTemplate  string
    221 }
    222 
    223 type GoMigrationOption func(cfg *goMigrationConfig)
    224 
    225 func WithPackageName(name string) GoMigrationOption {
    226 	return func(cfg *goMigrationConfig) {
    227 		cfg.packageName = name
    228 	}
    229 }
    230 
    231 func WithGoTemplate(template string) GoMigrationOption {
    232 	return func(cfg *goMigrationConfig) {
    233 		cfg.goTemplate = template
    234 	}
    235 }
    236 
    237 // CreateGoMigration creates a Go migration file.
    238 func (m *Migrator) CreateGoMigration(
    239 	ctx context.Context, name string, opts ...GoMigrationOption,
    240 ) (*MigrationFile, error) {
    241 	cfg := &goMigrationConfig{
    242 		packageName: "migrations",
    243 		goTemplate:  goTemplate,
    244 	}
    245 	for _, opt := range opts {
    246 		opt(cfg)
    247 	}
    248 
    249 	name, err := m.genMigrationName(name)
    250 	if err != nil {
    251 		return nil, err
    252 	}
    253 
    254 	fname := name + ".go"
    255 	fpath := filepath.Join(m.migrations.getDirectory(), fname)
    256 	content := fmt.Sprintf(cfg.goTemplate, cfg.packageName)
    257 
    258 	if err := os.WriteFile(fpath, []byte(content), 0o644); err != nil {
    259 		return nil, err
    260 	}
    261 
    262 	mf := &MigrationFile{
    263 		Name:    fname,
    264 		Path:    fpath,
    265 		Content: content,
    266 	}
    267 	return mf, nil
    268 }
    269 
    270 // CreateSQLMigrations creates an up and down SQL migration files.
    271 func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) {
    272 	name, err := m.genMigrationName(name)
    273 	if err != nil {
    274 		return nil, err
    275 	}
    276 
    277 	up, err := m.createSQL(ctx, name+".up.sql")
    278 	if err != nil {
    279 		return nil, err
    280 	}
    281 
    282 	down, err := m.createSQL(ctx, name+".down.sql")
    283 	if err != nil {
    284 		return nil, err
    285 	}
    286 
    287 	return []*MigrationFile{up, down}, nil
    288 }
    289 
    290 func (m *Migrator) createSQL(ctx context.Context, fname string) (*MigrationFile, error) {
    291 	fpath := filepath.Join(m.migrations.getDirectory(), fname)
    292 
    293 	if err := os.WriteFile(fpath, []byte(sqlTemplate), 0o644); err != nil {
    294 		return nil, err
    295 	}
    296 
    297 	mf := &MigrationFile{
    298 		Name:    fname,
    299 		Path:    fpath,
    300 		Content: goTemplate,
    301 	}
    302 	return mf, nil
    303 }
    304 
    305 var nameRE = regexp.MustCompile(`^[0-9a-z_\-]+$`)
    306 
    307 func (m *Migrator) genMigrationName(name string) (string, error) {
    308 	const timeFormat = "20060102150405"
    309 
    310 	if name == "" {
    311 		return "", errors.New("migrate: migration name can't be empty")
    312 	}
    313 	if !nameRE.MatchString(name) {
    314 		return "", fmt.Errorf("migrate: invalid migration name: %q", name)
    315 	}
    316 
    317 	version := time.Now().UTC().Format(timeFormat)
    318 	return fmt.Sprintf("%s_%s", version, name), nil
    319 }
    320 
    321 // MarkApplied marks the migration as applied (completed).
    322 func (m *Migrator) MarkApplied(ctx context.Context, migration *Migration) error {
    323 	_, err := m.db.NewInsert().Model(migration).
    324 		ModelTableExpr(m.table).
    325 		Exec(ctx)
    326 	return err
    327 }
    328 
    329 // MarkUnapplied marks the migration as unapplied (new).
    330 func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) error {
    331 	_, err := m.db.NewDelete().
    332 		Model(migration).
    333 		ModelTableExpr(m.table).
    334 		Where("id = ?", migration.ID).
    335 		Exec(ctx)
    336 	return err
    337 }
    338 
    339 func (m *Migrator) TruncateTable(ctx context.Context) error {
    340 	_, err := m.db.NewTruncateTable().TableExpr(m.table).Exec(ctx)
    341 	return err
    342 }
    343 
    344 // MissingMigrations returns applied migrations that can no longer be found.
    345 func (m *Migrator) MissingMigrations(ctx context.Context) (MigrationSlice, error) {
    346 	applied, err := m.AppliedMigrations(ctx)
    347 	if err != nil {
    348 		return nil, err
    349 	}
    350 
    351 	existing := migrationMap(m.migrations.ms)
    352 	for i := len(applied) - 1; i >= 0; i-- {
    353 		m := &applied[i]
    354 		if _, ok := existing[m.Name]; ok {
    355 			applied = append(applied[:i], applied[i+1:]...)
    356 		}
    357 	}
    358 
    359 	return applied, nil
    360 }
    361 
    362 // AppliedMigrations selects applied (applied) migrations in descending order.
    363 func (m *Migrator) AppliedMigrations(ctx context.Context) (MigrationSlice, error) {
    364 	var ms MigrationSlice
    365 	if err := m.db.NewSelect().
    366 		ColumnExpr("*").
    367 		Model(&ms).
    368 		ModelTableExpr(m.table).
    369 		Scan(ctx); err != nil {
    370 		return nil, err
    371 	}
    372 	return ms, nil
    373 }
    374 
    375 func (m *Migrator) formattedTableName(db *bun.DB) string {
    376 	return db.Formatter().FormatQuery(m.table)
    377 }
    378 
    379 func (m *Migrator) validate() error {
    380 	if len(m.ms) == 0 {
    381 		return errors.New("migrate: there are no migrations")
    382 	}
    383 	return nil
    384 }
    385 
    386 //------------------------------------------------------------------------------
    387 
    388 type migrationLock struct {
    389 	ID        int64  `bun:",pk,autoincrement"`
    390 	TableName string `bun:",unique"`
    391 }
    392 
    393 func (m *Migrator) Lock(ctx context.Context) error {
    394 	lock := &migrationLock{
    395 		TableName: m.formattedTableName(m.db),
    396 	}
    397 	if _, err := m.db.NewInsert().
    398 		Model(lock).
    399 		ModelTableExpr(m.locksTable).
    400 		Exec(ctx); err != nil {
    401 		return fmt.Errorf("migrate: migrations table is already locked (%w)", err)
    402 	}
    403 	return nil
    404 }
    405 
    406 func (m *Migrator) Unlock(ctx context.Context) error {
    407 	tableName := m.formattedTableName(m.db)
    408 	_, err := m.db.NewDelete().
    409 		Model((*migrationLock)(nil)).
    410 		ModelTableExpr(m.locksTable).
    411 		Where("? = ?", bun.Ident("table_name"), tableName).
    412 		Exec(ctx)
    413 	return err
    414 }
    415 
    416 func migrationMap(ms MigrationSlice) map[string]*Migration {
    417 	mp := make(map[string]*Migration)
    418 	for i := range ms {
    419 		m := &ms[i]
    420 		mp[m.Name] = m
    421 	}
    422 	return mp
    423 }