gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

migrations.go (3304B)


      1 package migrate
      2 
      3 import (
      4 	"errors"
      5 	"fmt"
      6 	"io/fs"
      7 	"os"
      8 	"path/filepath"
      9 	"regexp"
     10 	"runtime"
     11 	"strings"
     12 )
     13 
     14 type MigrationsOption func(m *Migrations)
     15 
     16 func WithMigrationsDirectory(directory string) MigrationsOption {
     17 	return func(m *Migrations) {
     18 		m.explicitDirectory = directory
     19 	}
     20 }
     21 
     22 type Migrations struct {
     23 	ms MigrationSlice
     24 
     25 	explicitDirectory string
     26 	implicitDirectory string
     27 }
     28 
     29 func NewMigrations(opts ...MigrationsOption) *Migrations {
     30 	m := new(Migrations)
     31 	for _, opt := range opts {
     32 		opt(m)
     33 	}
     34 	m.implicitDirectory = filepath.Dir(migrationFile())
     35 	return m
     36 }
     37 
     38 func (m *Migrations) Sorted() MigrationSlice {
     39 	migrations := make(MigrationSlice, len(m.ms))
     40 	copy(migrations, m.ms)
     41 	sortAsc(migrations)
     42 	return migrations
     43 }
     44 
     45 func (m *Migrations) MustRegister(up, down MigrationFunc) {
     46 	if err := m.Register(up, down); err != nil {
     47 		panic(err)
     48 	}
     49 }
     50 
     51 func (m *Migrations) Register(up, down MigrationFunc) error {
     52 	fpath := migrationFile()
     53 	name, comment, err := extractMigrationName(fpath)
     54 	if err != nil {
     55 		return err
     56 	}
     57 
     58 	m.Add(Migration{
     59 		Name:    name,
     60 		Comment: comment,
     61 		Up:      up,
     62 		Down:    down,
     63 	})
     64 
     65 	return nil
     66 }
     67 
     68 func (m *Migrations) Add(migration Migration) {
     69 	if migration.Name == "" {
     70 		panic("migration name is required")
     71 	}
     72 	m.ms = append(m.ms, migration)
     73 }
     74 
     75 func (m *Migrations) DiscoverCaller() error {
     76 	dir := filepath.Dir(migrationFile())
     77 	return m.Discover(os.DirFS(dir))
     78 }
     79 
     80 func (m *Migrations) Discover(fsys fs.FS) error {
     81 	return fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
     82 		if err != nil {
     83 			return err
     84 		}
     85 		if d.IsDir() {
     86 			return nil
     87 		}
     88 
     89 		if !strings.HasSuffix(path, ".up.sql") && !strings.HasSuffix(path, ".down.sql") {
     90 			return nil
     91 		}
     92 
     93 		name, comment, err := extractMigrationName(path)
     94 		if err != nil {
     95 			return err
     96 		}
     97 
     98 		migration := m.getOrCreateMigration(name)
     99 		if err != nil {
    100 			return err
    101 		}
    102 
    103 		migration.Comment = comment
    104 		migrationFunc := NewSQLMigrationFunc(fsys, path)
    105 
    106 		if strings.HasSuffix(path, ".up.sql") {
    107 			migration.Up = migrationFunc
    108 			return nil
    109 		}
    110 		if strings.HasSuffix(path, ".down.sql") {
    111 			migration.Down = migrationFunc
    112 			return nil
    113 		}
    114 
    115 		return errors.New("migrate: not reached")
    116 	})
    117 }
    118 
    119 func (m *Migrations) getOrCreateMigration(name string) *Migration {
    120 	for i := range m.ms {
    121 		m := &m.ms[i]
    122 		if m.Name == name {
    123 			return m
    124 		}
    125 	}
    126 
    127 	m.ms = append(m.ms, Migration{Name: name})
    128 	return &m.ms[len(m.ms)-1]
    129 }
    130 
    131 func (m *Migrations) getDirectory() string {
    132 	if m.explicitDirectory != "" {
    133 		return m.explicitDirectory
    134 	}
    135 	if m.implicitDirectory != "" {
    136 		return m.implicitDirectory
    137 	}
    138 	return filepath.Dir(migrationFile())
    139 }
    140 
    141 func migrationFile() string {
    142 	const depth = 32
    143 	var pcs [depth]uintptr
    144 	n := runtime.Callers(1, pcs[:])
    145 	frames := runtime.CallersFrames(pcs[:n])
    146 
    147 	for {
    148 		f, ok := frames.Next()
    149 		if !ok {
    150 			break
    151 		}
    152 		if !strings.Contains(f.Function, "/bun/migrate.") {
    153 			return f.File
    154 		}
    155 	}
    156 
    157 	return ""
    158 }
    159 
    160 var fnameRE = regexp.MustCompile(`^(\d{1,14})_([0-9a-z_\-]+)\.`)
    161 
    162 func extractMigrationName(fpath string) (string, string, error) {
    163 	fname := filepath.Base(fpath)
    164 
    165 	matches := fnameRE.FindStringSubmatch(fname)
    166 	if matches == nil {
    167 		return "", "", fmt.Errorf("migrate: unsupported migration name format: %q", fname)
    168 	}
    169 
    170 	return matches[1], matches[2], nil
    171 }