gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

migration.go (5589B)


      1 package migrate
      2 
      3 import (
      4 	"bufio"
      5 	"bytes"
      6 	"context"
      7 	"fmt"
      8 	"io"
      9 	"io/fs"
     10 	"sort"
     11 	"strings"
     12 	"time"
     13 
     14 	"github.com/uptrace/bun"
     15 )
     16 
     17 type Migration struct {
     18 	bun.BaseModel
     19 
     20 	ID         int64 `bun:",pk,autoincrement"`
     21 	Name       string
     22 	Comment    string `bun:"-"`
     23 	GroupID    int64
     24 	MigratedAt time.Time `bun:",notnull,nullzero,default:current_timestamp"`
     25 
     26 	Up   MigrationFunc `bun:"-"`
     27 	Down MigrationFunc `bun:"-"`
     28 }
     29 
     30 func (m Migration) String() string {
     31 	return fmt.Sprintf("%s_%s", m.Name, m.Comment)
     32 }
     33 
     34 func (m Migration) IsApplied() bool {
     35 	return m.ID > 0
     36 }
     37 
     38 type MigrationFunc func(ctx context.Context, db *bun.DB) error
     39 
     40 func NewSQLMigrationFunc(fsys fs.FS, name string) MigrationFunc {
     41 	return func(ctx context.Context, db *bun.DB) error {
     42 		f, err := fsys.Open(name)
     43 		if err != nil {
     44 			return err
     45 		}
     46 
     47 		isTx := strings.HasSuffix(name, ".tx.up.sql") || strings.HasSuffix(name, ".tx.down.sql")
     48 		return Exec(ctx, db, f, isTx)
     49 	}
     50 }
     51 
     52 // Exec reads and executes the SQL migration in the f.
     53 func Exec(ctx context.Context, db *bun.DB, f io.Reader, isTx bool) error {
     54 	scanner := bufio.NewScanner(f)
     55 	var queries []string
     56 
     57 	var query []byte
     58 	for scanner.Scan() {
     59 		b := scanner.Bytes()
     60 
     61 		const prefix = "--bun:"
     62 		if bytes.HasPrefix(b, []byte(prefix)) {
     63 			b = b[len(prefix):]
     64 			if bytes.Equal(b, []byte("split")) {
     65 				queries = append(queries, string(query))
     66 				query = query[:0]
     67 				continue
     68 			}
     69 			return fmt.Errorf("bun: unknown directive: %q", b)
     70 		}
     71 
     72 		query = append(query, b...)
     73 		query = append(query, '\n')
     74 	}
     75 
     76 	if len(query) > 0 {
     77 		queries = append(queries, string(query))
     78 	}
     79 	if err := scanner.Err(); err != nil {
     80 		return err
     81 	}
     82 
     83 	var idb bun.IConn
     84 
     85 	if isTx {
     86 		tx, err := db.BeginTx(ctx, nil)
     87 		if err != nil {
     88 			return err
     89 		}
     90 		idb = tx
     91 	} else {
     92 		conn, err := db.Conn(ctx)
     93 		if err != nil {
     94 			return err
     95 		}
     96 		idb = conn
     97 	}
     98 
     99 	var retErr error
    100 	var execErr error
    101 
    102 	defer func() {
    103 		if tx, ok := idb.(bun.Tx); ok {
    104 			if execErr != nil {
    105 				retErr = tx.Rollback()
    106 			} else {
    107 				retErr = tx.Commit()
    108 			}
    109 			return
    110 		}
    111 
    112 		if conn, ok := idb.(bun.Conn); ok {
    113 			retErr = conn.Close()
    114 			return
    115 		}
    116 
    117 		panic("not reached")
    118 	}()
    119 
    120 	for _, q := range queries {
    121 		_, execErr = idb.ExecContext(ctx, q)
    122 		if execErr != nil {
    123 			return execErr
    124 		}
    125 	}
    126 
    127 	return retErr
    128 }
    129 
    130 const goTemplate = `package %s
    131 
    132 import (
    133 	"context"
    134 	"fmt"
    135 
    136 	"github.com/uptrace/bun"
    137 )
    138 
    139 func init() {
    140 	Migrations.MustRegister(func(ctx context.Context, db *bun.DB) error {
    141 		fmt.Print(" [up migration] ")
    142 		return nil
    143 	}, func(ctx context.Context, db *bun.DB) error {
    144 		fmt.Print(" [down migration] ")
    145 		return nil
    146 	})
    147 }
    148 `
    149 
    150 const sqlTemplate = `SET statement_timeout = 0;
    151 
    152 --bun:split
    153 
    154 SELECT 1
    155 
    156 --bun:split
    157 
    158 SELECT 2
    159 `
    160 
    161 //------------------------------------------------------------------------------
    162 
    163 type MigrationSlice []Migration
    164 
    165 func (ms MigrationSlice) String() string {
    166 	if len(ms) == 0 {
    167 		return "empty"
    168 	}
    169 
    170 	if len(ms) > 5 {
    171 		return fmt.Sprintf("%d migrations (%s ... %s)", len(ms), ms[0].Name, ms[len(ms)-1].Name)
    172 	}
    173 
    174 	var sb strings.Builder
    175 
    176 	for i := range ms {
    177 		if i > 0 {
    178 			sb.WriteString(", ")
    179 		}
    180 		sb.WriteString(ms[i].String())
    181 	}
    182 
    183 	return sb.String()
    184 }
    185 
    186 // Applied returns applied migrations in descending order
    187 // (the order is important and is used in Rollback).
    188 func (ms MigrationSlice) Applied() MigrationSlice {
    189 	var applied MigrationSlice
    190 	for i := range ms {
    191 		if ms[i].IsApplied() {
    192 			applied = append(applied, ms[i])
    193 		}
    194 	}
    195 	sortDesc(applied)
    196 	return applied
    197 }
    198 
    199 // Unapplied returns unapplied migrations in ascending order
    200 // (the order is important and is used in Migrate).
    201 func (ms MigrationSlice) Unapplied() MigrationSlice {
    202 	var unapplied MigrationSlice
    203 	for i := range ms {
    204 		if !ms[i].IsApplied() {
    205 			unapplied = append(unapplied, ms[i])
    206 		}
    207 	}
    208 	sortAsc(unapplied)
    209 	return unapplied
    210 }
    211 
    212 // LastGroupID returns the last applied migration group id.
    213 // The id is 0 when there are no migration groups.
    214 func (ms MigrationSlice) LastGroupID() int64 {
    215 	var lastGroupID int64
    216 	for i := range ms {
    217 		groupID := ms[i].GroupID
    218 		if groupID > lastGroupID {
    219 			lastGroupID = groupID
    220 		}
    221 	}
    222 	return lastGroupID
    223 }
    224 
    225 // LastGroup returns the last applied migration group.
    226 func (ms MigrationSlice) LastGroup() *MigrationGroup {
    227 	group := &MigrationGroup{
    228 		ID: ms.LastGroupID(),
    229 	}
    230 	if group.ID == 0 {
    231 		return group
    232 	}
    233 	for i := range ms {
    234 		if ms[i].GroupID == group.ID {
    235 			group.Migrations = append(group.Migrations, ms[i])
    236 		}
    237 	}
    238 	return group
    239 }
    240 
    241 type MigrationGroup struct {
    242 	ID         int64
    243 	Migrations MigrationSlice
    244 }
    245 
    246 func (g MigrationGroup) IsZero() bool {
    247 	return g.ID == 0 && len(g.Migrations) == 0
    248 }
    249 
    250 func (g MigrationGroup) String() string {
    251 	if g.IsZero() {
    252 		return "nil"
    253 	}
    254 	return fmt.Sprintf("group #%d (%s)", g.ID, g.Migrations)
    255 }
    256 
    257 type MigrationFile struct {
    258 	Name    string
    259 	Path    string
    260 	Content string
    261 }
    262 
    263 //------------------------------------------------------------------------------
    264 
    265 type migrationConfig struct {
    266 	nop bool
    267 }
    268 
    269 func newMigrationConfig(opts []MigrationOption) *migrationConfig {
    270 	cfg := new(migrationConfig)
    271 	for _, opt := range opts {
    272 		opt(cfg)
    273 	}
    274 	return cfg
    275 }
    276 
    277 type MigrationOption func(cfg *migrationConfig)
    278 
    279 func WithNopMigration() MigrationOption {
    280 	return func(cfg *migrationConfig) {
    281 		cfg.nop = true
    282 	}
    283 }
    284 
    285 //------------------------------------------------------------------------------
    286 
    287 func sortAsc(ms MigrationSlice) {
    288 	sort.Slice(ms, func(i, j int) bool {
    289 		return ms[i].Name < ms[j].Name
    290 	})
    291 }
    292 
    293 func sortDesc(ms MigrationSlice) {
    294 	sort.Slice(ms, func(i, j int) bool {
    295 		return ms[i].Name > ms[j].Name
    296 	})
    297 }