migrator.go (9493B)
1 package migrate 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "os" 8 "path/filepath" 9 "regexp" 10 "time" 11 12 "github.com/uptrace/bun" 13 ) 14 15 type MigratorOption func(m *Migrator) 16 17 func WithTableName(table string) MigratorOption { 18 return func(m *Migrator) { 19 m.table = table 20 } 21 } 22 23 func WithLocksTableName(table string) MigratorOption { 24 return func(m *Migrator) { 25 m.locksTable = table 26 } 27 } 28 29 // WithMarkAppliedOnSuccess sets the migrator to only mark migrations as applied/unapplied 30 // when their up/down is successful 31 func WithMarkAppliedOnSuccess(enabled bool) MigratorOption { 32 return func(m *Migrator) { 33 m.markAppliedOnSuccess = enabled 34 } 35 } 36 37 type Migrator struct { 38 db *bun.DB 39 migrations *Migrations 40 41 ms MigrationSlice 42 43 table string 44 locksTable string 45 markAppliedOnSuccess bool 46 } 47 48 func NewMigrator(db *bun.DB, migrations *Migrations, opts ...MigratorOption) *Migrator { 49 m := &Migrator{ 50 db: db, 51 migrations: migrations, 52 53 ms: migrations.ms, 54 55 table: "bun_migrations", 56 locksTable: "bun_migration_locks", 57 } 58 for _, opt := range opts { 59 opt(m) 60 } 61 return m 62 } 63 64 func (m *Migrator) DB() *bun.DB { 65 return m.db 66 } 67 68 // MigrationsWithStatus returns migrations with status in ascending order. 69 func (m *Migrator) MigrationsWithStatus(ctx context.Context) (MigrationSlice, error) { 70 sorted, _, err := m.migrationsWithStatus(ctx) 71 return sorted, err 72 } 73 74 func (m *Migrator) migrationsWithStatus(ctx context.Context) (MigrationSlice, int64, error) { 75 sorted := m.migrations.Sorted() 76 77 applied, err := m.AppliedMigrations(ctx) 78 if err != nil { 79 return nil, 0, err 80 } 81 82 appliedMap := migrationMap(applied) 83 for i := range sorted { 84 m1 := &sorted[i] 85 if m2, ok := appliedMap[m1.Name]; ok { 86 m1.ID = m2.ID 87 m1.GroupID = m2.GroupID 88 m1.MigratedAt = m2.MigratedAt 89 } 90 } 91 92 return sorted, applied.LastGroupID(), nil 93 } 94 95 func (m *Migrator) Init(ctx context.Context) error { 96 if _, err := m.db.NewCreateTable(). 97 Model((*Migration)(nil)). 98 ModelTableExpr(m.table). 99 IfNotExists(). 100 Exec(ctx); err != nil { 101 return err 102 } 103 if _, err := m.db.NewCreateTable(). 104 Model((*migrationLock)(nil)). 105 ModelTableExpr(m.locksTable). 106 IfNotExists(). 107 Exec(ctx); err != nil { 108 return err 109 } 110 return nil 111 } 112 113 func (m *Migrator) Reset(ctx context.Context) error { 114 if _, err := m.db.NewDropTable(). 115 Model((*Migration)(nil)). 116 ModelTableExpr(m.table). 117 IfExists(). 118 Exec(ctx); err != nil { 119 return err 120 } 121 if _, err := m.db.NewDropTable(). 122 Model((*migrationLock)(nil)). 123 ModelTableExpr(m.locksTable). 124 IfExists(). 125 Exec(ctx); err != nil { 126 return err 127 } 128 return m.Init(ctx) 129 } 130 131 // Migrate runs unapplied migrations. If a migration fails, migrate immediately exits. 132 func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) { 133 cfg := newMigrationConfig(opts) 134 135 if err := m.validate(); err != nil { 136 return nil, err 137 } 138 139 migrations, lastGroupID, err := m.migrationsWithStatus(ctx) 140 if err != nil { 141 return nil, err 142 } 143 migrations = migrations.Unapplied() 144 145 group := new(MigrationGroup) 146 if len(migrations) == 0 { 147 return group, nil 148 } 149 group.ID = lastGroupID + 1 150 151 for i := range migrations { 152 migration := &migrations[i] 153 migration.GroupID = group.ID 154 155 if !m.markAppliedOnSuccess { 156 if err := m.MarkApplied(ctx, migration); err != nil { 157 return group, err 158 } 159 } 160 161 group.Migrations = migrations[:i+1] 162 163 if !cfg.nop && migration.Up != nil { 164 if err := migration.Up(ctx, m.db); err != nil { 165 return group, err 166 } 167 } 168 169 if m.markAppliedOnSuccess { 170 if err := m.MarkApplied(ctx, migration); err != nil { 171 return group, err 172 } 173 } 174 } 175 176 return group, nil 177 } 178 179 func (m *Migrator) Rollback(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) { 180 cfg := newMigrationConfig(opts) 181 182 if err := m.validate(); err != nil { 183 return nil, err 184 } 185 186 migrations, err := m.MigrationsWithStatus(ctx) 187 if err != nil { 188 return nil, err 189 } 190 191 lastGroup := migrations.LastGroup() 192 193 for i := len(lastGroup.Migrations) - 1; i >= 0; i-- { 194 migration := &lastGroup.Migrations[i] 195 196 if !m.markAppliedOnSuccess { 197 if err := m.MarkUnapplied(ctx, migration); err != nil { 198 return lastGroup, err 199 } 200 } 201 202 if !cfg.nop && migration.Down != nil { 203 if err := migration.Down(ctx, m.db); err != nil { 204 return lastGroup, err 205 } 206 } 207 208 if m.markAppliedOnSuccess { 209 if err := m.MarkUnapplied(ctx, migration); err != nil { 210 return lastGroup, err 211 } 212 } 213 } 214 215 return lastGroup, nil 216 } 217 218 type goMigrationConfig struct { 219 packageName string 220 goTemplate string 221 } 222 223 type GoMigrationOption func(cfg *goMigrationConfig) 224 225 func WithPackageName(name string) GoMigrationOption { 226 return func(cfg *goMigrationConfig) { 227 cfg.packageName = name 228 } 229 } 230 231 func WithGoTemplate(template string) GoMigrationOption { 232 return func(cfg *goMigrationConfig) { 233 cfg.goTemplate = template 234 } 235 } 236 237 // CreateGoMigration creates a Go migration file. 238 func (m *Migrator) CreateGoMigration( 239 ctx context.Context, name string, opts ...GoMigrationOption, 240 ) (*MigrationFile, error) { 241 cfg := &goMigrationConfig{ 242 packageName: "migrations", 243 goTemplate: goTemplate, 244 } 245 for _, opt := range opts { 246 opt(cfg) 247 } 248 249 name, err := m.genMigrationName(name) 250 if err != nil { 251 return nil, err 252 } 253 254 fname := name + ".go" 255 fpath := filepath.Join(m.migrations.getDirectory(), fname) 256 content := fmt.Sprintf(cfg.goTemplate, cfg.packageName) 257 258 if err := os.WriteFile(fpath, []byte(content), 0o644); err != nil { 259 return nil, err 260 } 261 262 mf := &MigrationFile{ 263 Name: fname, 264 Path: fpath, 265 Content: content, 266 } 267 return mf, nil 268 } 269 270 // CreateSQLMigrations creates an up and down SQL migration files. 271 func (m *Migrator) CreateSQLMigrations(ctx context.Context, name string) ([]*MigrationFile, error) { 272 name, err := m.genMigrationName(name) 273 if err != nil { 274 return nil, err 275 } 276 277 up, err := m.createSQL(ctx, name+".up.sql") 278 if err != nil { 279 return nil, err 280 } 281 282 down, err := m.createSQL(ctx, name+".down.sql") 283 if err != nil { 284 return nil, err 285 } 286 287 return []*MigrationFile{up, down}, nil 288 } 289 290 func (m *Migrator) createSQL(ctx context.Context, fname string) (*MigrationFile, error) { 291 fpath := filepath.Join(m.migrations.getDirectory(), fname) 292 293 if err := os.WriteFile(fpath, []byte(sqlTemplate), 0o644); err != nil { 294 return nil, err 295 } 296 297 mf := &MigrationFile{ 298 Name: fname, 299 Path: fpath, 300 Content: goTemplate, 301 } 302 return mf, nil 303 } 304 305 var nameRE = regexp.MustCompile(`^[0-9a-z_\-]+$`) 306 307 func (m *Migrator) genMigrationName(name string) (string, error) { 308 const timeFormat = "20060102150405" 309 310 if name == "" { 311 return "", errors.New("migrate: migration name can't be empty") 312 } 313 if !nameRE.MatchString(name) { 314 return "", fmt.Errorf("migrate: invalid migration name: %q", name) 315 } 316 317 version := time.Now().UTC().Format(timeFormat) 318 return fmt.Sprintf("%s_%s", version, name), nil 319 } 320 321 // MarkApplied marks the migration as applied (completed). 322 func (m *Migrator) MarkApplied(ctx context.Context, migration *Migration) error { 323 _, err := m.db.NewInsert().Model(migration). 324 ModelTableExpr(m.table). 325 Exec(ctx) 326 return err 327 } 328 329 // MarkUnapplied marks the migration as unapplied (new). 330 func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) error { 331 _, err := m.db.NewDelete(). 332 Model(migration). 333 ModelTableExpr(m.table). 334 Where("id = ?", migration.ID). 335 Exec(ctx) 336 return err 337 } 338 339 func (m *Migrator) TruncateTable(ctx context.Context) error { 340 _, err := m.db.NewTruncateTable().TableExpr(m.table).Exec(ctx) 341 return err 342 } 343 344 // MissingMigrations returns applied migrations that can no longer be found. 345 func (m *Migrator) MissingMigrations(ctx context.Context) (MigrationSlice, error) { 346 applied, err := m.AppliedMigrations(ctx) 347 if err != nil { 348 return nil, err 349 } 350 351 existing := migrationMap(m.migrations.ms) 352 for i := len(applied) - 1; i >= 0; i-- { 353 m := &applied[i] 354 if _, ok := existing[m.Name]; ok { 355 applied = append(applied[:i], applied[i+1:]...) 356 } 357 } 358 359 return applied, nil 360 } 361 362 // AppliedMigrations selects applied (applied) migrations in descending order. 363 func (m *Migrator) AppliedMigrations(ctx context.Context) (MigrationSlice, error) { 364 var ms MigrationSlice 365 if err := m.db.NewSelect(). 366 ColumnExpr("*"). 367 Model(&ms). 368 ModelTableExpr(m.table). 369 Scan(ctx); err != nil { 370 return nil, err 371 } 372 return ms, nil 373 } 374 375 func (m *Migrator) formattedTableName(db *bun.DB) string { 376 return db.Formatter().FormatQuery(m.table) 377 } 378 379 func (m *Migrator) validate() error { 380 if len(m.ms) == 0 { 381 return errors.New("migrate: there are no migrations") 382 } 383 return nil 384 } 385 386 //------------------------------------------------------------------------------ 387 388 type migrationLock struct { 389 ID int64 `bun:",pk,autoincrement"` 390 TableName string `bun:",unique"` 391 } 392 393 func (m *Migrator) Lock(ctx context.Context) error { 394 lock := &migrationLock{ 395 TableName: m.formattedTableName(m.db), 396 } 397 if _, err := m.db.NewInsert(). 398 Model(lock). 399 ModelTableExpr(m.locksTable). 400 Exec(ctx); err != nil { 401 return fmt.Errorf("migrate: migrations table is already locked (%w)", err) 402 } 403 return nil 404 } 405 406 func (m *Migrator) Unlock(ctx context.Context) error { 407 tableName := m.formattedTableName(m.db) 408 _, err := m.db.NewDelete(). 409 Model((*migrationLock)(nil)). 410 ModelTableExpr(m.locksTable). 411 Where("? = ?", bun.Ident("table_name"), tableName). 412 Exec(ctx) 413 return err 414 } 415 416 func migrationMap(ms MigrationSlice) map[string]*Migration { 417 mp := make(map[string]*Migration) 418 for i := range ms { 419 m := &ms[i] 420 mp[m.Name] = m 421 } 422 return mp 423 }