migration.go (5589B)
1 package migrate 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "fmt" 8 "io" 9 "io/fs" 10 "sort" 11 "strings" 12 "time" 13 14 "github.com/uptrace/bun" 15 ) 16 17 type Migration struct { 18 bun.BaseModel 19 20 ID int64 `bun:",pk,autoincrement"` 21 Name string 22 Comment string `bun:"-"` 23 GroupID int64 24 MigratedAt time.Time `bun:",notnull,nullzero,default:current_timestamp"` 25 26 Up MigrationFunc `bun:"-"` 27 Down MigrationFunc `bun:"-"` 28 } 29 30 func (m Migration) String() string { 31 return fmt.Sprintf("%s_%s", m.Name, m.Comment) 32 } 33 34 func (m Migration) IsApplied() bool { 35 return m.ID > 0 36 } 37 38 type MigrationFunc func(ctx context.Context, db *bun.DB) error 39 40 func NewSQLMigrationFunc(fsys fs.FS, name string) MigrationFunc { 41 return func(ctx context.Context, db *bun.DB) error { 42 f, err := fsys.Open(name) 43 if err != nil { 44 return err 45 } 46 47 isTx := strings.HasSuffix(name, ".tx.up.sql") || strings.HasSuffix(name, ".tx.down.sql") 48 return Exec(ctx, db, f, isTx) 49 } 50 } 51 52 // Exec reads and executes the SQL migration in the f. 53 func Exec(ctx context.Context, db *bun.DB, f io.Reader, isTx bool) error { 54 scanner := bufio.NewScanner(f) 55 var queries []string 56 57 var query []byte 58 for scanner.Scan() { 59 b := scanner.Bytes() 60 61 const prefix = "--bun:" 62 if bytes.HasPrefix(b, []byte(prefix)) { 63 b = b[len(prefix):] 64 if bytes.Equal(b, []byte("split")) { 65 queries = append(queries, string(query)) 66 query = query[:0] 67 continue 68 } 69 return fmt.Errorf("bun: unknown directive: %q", b) 70 } 71 72 query = append(query, b...) 73 query = append(query, '\n') 74 } 75 76 if len(query) > 0 { 77 queries = append(queries, string(query)) 78 } 79 if err := scanner.Err(); err != nil { 80 return err 81 } 82 83 var idb bun.IConn 84 85 if isTx { 86 tx, err := db.BeginTx(ctx, nil) 87 if err != nil { 88 return err 89 } 90 idb = tx 91 } else { 92 conn, err := db.Conn(ctx) 93 if err != nil { 94 return err 95 } 96 idb = conn 97 } 98 99 var retErr error 100 var execErr error 101 102 defer func() { 103 if tx, ok := idb.(bun.Tx); ok { 104 if execErr != nil { 105 retErr = tx.Rollback() 106 } else { 107 retErr = tx.Commit() 108 } 109 return 110 } 111 112 if conn, ok := idb.(bun.Conn); ok { 113 retErr = conn.Close() 114 return 115 } 116 117 panic("not reached") 118 }() 119 120 for _, q := range queries { 121 _, execErr = idb.ExecContext(ctx, q) 122 if execErr != nil { 123 return execErr 124 } 125 } 126 127 return retErr 128 } 129 130 const goTemplate = `package %s 131 132 import ( 133 "context" 134 "fmt" 135 136 "github.com/uptrace/bun" 137 ) 138 139 func init() { 140 Migrations.MustRegister(func(ctx context.Context, db *bun.DB) error { 141 fmt.Print(" [up migration] ") 142 return nil 143 }, func(ctx context.Context, db *bun.DB) error { 144 fmt.Print(" [down migration] ") 145 return nil 146 }) 147 } 148 ` 149 150 const sqlTemplate = `SET statement_timeout = 0; 151 152 --bun:split 153 154 SELECT 1 155 156 --bun:split 157 158 SELECT 2 159 ` 160 161 //------------------------------------------------------------------------------ 162 163 type MigrationSlice []Migration 164 165 func (ms MigrationSlice) String() string { 166 if len(ms) == 0 { 167 return "empty" 168 } 169 170 if len(ms) > 5 { 171 return fmt.Sprintf("%d migrations (%s ... %s)", len(ms), ms[0].Name, ms[len(ms)-1].Name) 172 } 173 174 var sb strings.Builder 175 176 for i := range ms { 177 if i > 0 { 178 sb.WriteString(", ") 179 } 180 sb.WriteString(ms[i].String()) 181 } 182 183 return sb.String() 184 } 185 186 // Applied returns applied migrations in descending order 187 // (the order is important and is used in Rollback). 188 func (ms MigrationSlice) Applied() MigrationSlice { 189 var applied MigrationSlice 190 for i := range ms { 191 if ms[i].IsApplied() { 192 applied = append(applied, ms[i]) 193 } 194 } 195 sortDesc(applied) 196 return applied 197 } 198 199 // Unapplied returns unapplied migrations in ascending order 200 // (the order is important and is used in Migrate). 201 func (ms MigrationSlice) Unapplied() MigrationSlice { 202 var unapplied MigrationSlice 203 for i := range ms { 204 if !ms[i].IsApplied() { 205 unapplied = append(unapplied, ms[i]) 206 } 207 } 208 sortAsc(unapplied) 209 return unapplied 210 } 211 212 // LastGroupID returns the last applied migration group id. 213 // The id is 0 when there are no migration groups. 214 func (ms MigrationSlice) LastGroupID() int64 { 215 var lastGroupID int64 216 for i := range ms { 217 groupID := ms[i].GroupID 218 if groupID > lastGroupID { 219 lastGroupID = groupID 220 } 221 } 222 return lastGroupID 223 } 224 225 // LastGroup returns the last applied migration group. 226 func (ms MigrationSlice) LastGroup() *MigrationGroup { 227 group := &MigrationGroup{ 228 ID: ms.LastGroupID(), 229 } 230 if group.ID == 0 { 231 return group 232 } 233 for i := range ms { 234 if ms[i].GroupID == group.ID { 235 group.Migrations = append(group.Migrations, ms[i]) 236 } 237 } 238 return group 239 } 240 241 type MigrationGroup struct { 242 ID int64 243 Migrations MigrationSlice 244 } 245 246 func (g MigrationGroup) IsZero() bool { 247 return g.ID == 0 && len(g.Migrations) == 0 248 } 249 250 func (g MigrationGroup) String() string { 251 if g.IsZero() { 252 return "nil" 253 } 254 return fmt.Sprintf("group #%d (%s)", g.ID, g.Migrations) 255 } 256 257 type MigrationFile struct { 258 Name string 259 Path string 260 Content string 261 } 262 263 //------------------------------------------------------------------------------ 264 265 type migrationConfig struct { 266 nop bool 267 } 268 269 func newMigrationConfig(opts []MigrationOption) *migrationConfig { 270 cfg := new(migrationConfig) 271 for _, opt := range opts { 272 opt(cfg) 273 } 274 return cfg 275 } 276 277 type MigrationOption func(cfg *migrationConfig) 278 279 func WithNopMigration() MigrationOption { 280 return func(cfg *migrationConfig) { 281 cfg.nop = true 282 } 283 } 284 285 //------------------------------------------------------------------------------ 286 287 func sortAsc(ms MigrationSlice) { 288 sort.Slice(ms, func(i, j int) bool { 289 return ms[i].Name < ms[j].Name 290 }) 291 } 292 293 func sortDesc(ms MigrationSlice) { 294 sort.Slice(ms, func(i, j int) bool { 295 return ms[i].Name > ms[j].Name 296 }) 297 }