migrations.go (3304B)
1 package migrate 2 3 import ( 4 "errors" 5 "fmt" 6 "io/fs" 7 "os" 8 "path/filepath" 9 "regexp" 10 "runtime" 11 "strings" 12 ) 13 14 type MigrationsOption func(m *Migrations) 15 16 func WithMigrationsDirectory(directory string) MigrationsOption { 17 return func(m *Migrations) { 18 m.explicitDirectory = directory 19 } 20 } 21 22 type Migrations struct { 23 ms MigrationSlice 24 25 explicitDirectory string 26 implicitDirectory string 27 } 28 29 func NewMigrations(opts ...MigrationsOption) *Migrations { 30 m := new(Migrations) 31 for _, opt := range opts { 32 opt(m) 33 } 34 m.implicitDirectory = filepath.Dir(migrationFile()) 35 return m 36 } 37 38 func (m *Migrations) Sorted() MigrationSlice { 39 migrations := make(MigrationSlice, len(m.ms)) 40 copy(migrations, m.ms) 41 sortAsc(migrations) 42 return migrations 43 } 44 45 func (m *Migrations) MustRegister(up, down MigrationFunc) { 46 if err := m.Register(up, down); err != nil { 47 panic(err) 48 } 49 } 50 51 func (m *Migrations) Register(up, down MigrationFunc) error { 52 fpath := migrationFile() 53 name, comment, err := extractMigrationName(fpath) 54 if err != nil { 55 return err 56 } 57 58 m.Add(Migration{ 59 Name: name, 60 Comment: comment, 61 Up: up, 62 Down: down, 63 }) 64 65 return nil 66 } 67 68 func (m *Migrations) Add(migration Migration) { 69 if migration.Name == "" { 70 panic("migration name is required") 71 } 72 m.ms = append(m.ms, migration) 73 } 74 75 func (m *Migrations) DiscoverCaller() error { 76 dir := filepath.Dir(migrationFile()) 77 return m.Discover(os.DirFS(dir)) 78 } 79 80 func (m *Migrations) Discover(fsys fs.FS) error { 81 return fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error { 82 if err != nil { 83 return err 84 } 85 if d.IsDir() { 86 return nil 87 } 88 89 if !strings.HasSuffix(path, ".up.sql") && !strings.HasSuffix(path, ".down.sql") { 90 return nil 91 } 92 93 name, comment, err := extractMigrationName(path) 94 if err != nil { 95 return err 96 } 97 98 migration := m.getOrCreateMigration(name) 99 if err != nil { 100 return err 101 } 102 103 migration.Comment = comment 104 migrationFunc := NewSQLMigrationFunc(fsys, path) 105 106 if strings.HasSuffix(path, ".up.sql") { 107 migration.Up = migrationFunc 108 return nil 109 } 110 if strings.HasSuffix(path, ".down.sql") { 111 migration.Down = migrationFunc 112 return nil 113 } 114 115 return errors.New("migrate: not reached") 116 }) 117 } 118 119 func (m *Migrations) getOrCreateMigration(name string) *Migration { 120 for i := range m.ms { 121 m := &m.ms[i] 122 if m.Name == name { 123 return m 124 } 125 } 126 127 m.ms = append(m.ms, Migration{Name: name}) 128 return &m.ms[len(m.ms)-1] 129 } 130 131 func (m *Migrations) getDirectory() string { 132 if m.explicitDirectory != "" { 133 return m.explicitDirectory 134 } 135 if m.implicitDirectory != "" { 136 return m.implicitDirectory 137 } 138 return filepath.Dir(migrationFile()) 139 } 140 141 func migrationFile() string { 142 const depth = 32 143 var pcs [depth]uintptr 144 n := runtime.Callers(1, pcs[:]) 145 frames := runtime.CallersFrames(pcs[:n]) 146 147 for { 148 f, ok := frames.Next() 149 if !ok { 150 break 151 } 152 if !strings.Contains(f.Function, "/bun/migrate.") { 153 return f.File 154 } 155 } 156 157 return "" 158 } 159 160 var fnameRE = regexp.MustCompile(`^(\d{1,14})_([0-9a-z_\-]+)\.`) 161 162 func extractMigrationName(fpath string) (string, string, error) { 163 fname := filepath.Base(fpath) 164 165 matches := fnameRE.FindStringSubmatch(fname) 166 if matches == nil { 167 return "", "", fmt.Errorf("migrate: unsupported migration name format: %q", fname) 168 } 169 170 return matches[1], matches[2], nil 171 }