tables.go (2612B)
1 package schema 2 3 import ( 4 "fmt" 5 "reflect" 6 "sync" 7 ) 8 9 type tableInProgress struct { 10 table *Table 11 12 init1Once sync.Once 13 init2Once sync.Once 14 } 15 16 func newTableInProgress(table *Table) *tableInProgress { 17 return &tableInProgress{ 18 table: table, 19 } 20 } 21 22 func (inp *tableInProgress) init1() bool { 23 var inited bool 24 inp.init1Once.Do(func() { 25 inp.table.init1() 26 inited = true 27 }) 28 return inited 29 } 30 31 func (inp *tableInProgress) init2() bool { 32 var inited bool 33 inp.init2Once.Do(func() { 34 inp.table.init2() 35 inited = true 36 }) 37 return inited 38 } 39 40 type Tables struct { 41 dialect Dialect 42 tables sync.Map 43 44 mu sync.RWMutex 45 inProgress map[reflect.Type]*tableInProgress 46 } 47 48 func NewTables(dialect Dialect) *Tables { 49 return &Tables{ 50 dialect: dialect, 51 inProgress: make(map[reflect.Type]*tableInProgress), 52 } 53 } 54 55 func (t *Tables) Register(models ...interface{}) { 56 for _, model := range models { 57 _ = t.Get(reflect.TypeOf(model).Elem()) 58 } 59 } 60 61 func (t *Tables) Get(typ reflect.Type) *Table { 62 return t.table(typ, false) 63 } 64 65 func (t *Tables) Ref(typ reflect.Type) *Table { 66 return t.table(typ, true) 67 } 68 69 func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table { 70 typ = indirectType(typ) 71 if typ.Kind() != reflect.Struct { 72 panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) 73 } 74 75 if v, ok := t.tables.Load(typ); ok { 76 return v.(*Table) 77 } 78 79 t.mu.Lock() 80 81 if v, ok := t.tables.Load(typ); ok { 82 t.mu.Unlock() 83 return v.(*Table) 84 } 85 86 var table *Table 87 88 inProgress := t.inProgress[typ] 89 if inProgress == nil { 90 table = newTable(t.dialect, typ) 91 inProgress = newTableInProgress(table) 92 t.inProgress[typ] = inProgress 93 } else { 94 table = inProgress.table 95 } 96 97 t.mu.Unlock() 98 99 inProgress.init1() 100 if allowInProgress { 101 return table 102 } 103 104 if !inProgress.init2() { 105 return table 106 } 107 108 t.mu.Lock() 109 delete(t.inProgress, typ) 110 t.tables.Store(typ, table) 111 t.mu.Unlock() 112 113 t.dialect.OnTable(table) 114 115 for _, field := range table.FieldMap { 116 if field.UserSQLType == "" { 117 field.UserSQLType = field.DiscoveredSQLType 118 } 119 if field.CreateTableSQLType == "" { 120 field.CreateTableSQLType = field.UserSQLType 121 } 122 } 123 124 return table 125 } 126 127 func (t *Tables) ByModel(name string) *Table { 128 var found *Table 129 t.tables.Range(func(key, value interface{}) bool { 130 t := value.(*Table) 131 if t.TypeName == name { 132 found = t 133 return false 134 } 135 return true 136 }) 137 return found 138 } 139 140 func (t *Tables) ByName(name string) *Table { 141 var found *Table 142 t.tables.Range(func(key, value interface{}) bool { 143 t := value.(*Table) 144 if t.Name == name { 145 found = t 146 return false 147 } 148 return true 149 }) 150 return found 151 }