query_table_create.go (9145B)
1 package bun 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "sort" 8 "strconv" 9 "strings" 10 11 "github.com/uptrace/bun/dialect/feature" 12 "github.com/uptrace/bun/dialect/sqltype" 13 "github.com/uptrace/bun/internal" 14 "github.com/uptrace/bun/schema" 15 ) 16 17 type CreateTableQuery struct { 18 baseQuery 19 20 temp bool 21 ifNotExists bool 22 23 // varchar changes the default length for VARCHAR columns. 24 // Because some dialects require that length is always specified for VARCHAR type, 25 // we will use the exact user-defined type if length is set explicitly, as in `bun:",type:varchar(5)"`, 26 // but assume the new default length when it's omitted, e.g. `bun:",type:varchar"`. 27 varchar int 28 29 fks []schema.QueryWithArgs 30 partitionBy schema.QueryWithArgs 31 tablespace schema.QueryWithArgs 32 } 33 34 var _ Query = (*CreateTableQuery)(nil) 35 36 func NewCreateTableQuery(db *DB) *CreateTableQuery { 37 q := &CreateTableQuery{ 38 baseQuery: baseQuery{ 39 db: db, 40 conn: db.DB, 41 }, 42 varchar: db.Dialect().DefaultVarcharLen(), 43 } 44 return q 45 } 46 47 func (q *CreateTableQuery) Conn(db IConn) *CreateTableQuery { 48 q.setConn(db) 49 return q 50 } 51 52 func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery { 53 q.setModel(model) 54 return q 55 } 56 57 func (q *CreateTableQuery) Err(err error) *CreateTableQuery { 58 q.setErr(err) 59 return q 60 } 61 62 // ------------------------------------------------------------------------------ 63 64 func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery { 65 for _, table := range tables { 66 q.addTable(schema.UnsafeIdent(table)) 67 } 68 return q 69 } 70 71 func (q *CreateTableQuery) TableExpr(query string, args ...interface{}) *CreateTableQuery { 72 q.addTable(schema.SafeQuery(query, args)) 73 return q 74 } 75 76 func (q *CreateTableQuery) ModelTableExpr(query string, args ...interface{}) *CreateTableQuery { 77 q.modelTableName = schema.SafeQuery(query, args) 78 return q 79 } 80 81 func (q *CreateTableQuery) ColumnExpr(query string, args ...interface{}) *CreateTableQuery { 82 q.addColumn(schema.SafeQuery(query, args)) 83 return q 84 } 85 86 // ------------------------------------------------------------------------------ 87 88 func (q *CreateTableQuery) Temp() *CreateTableQuery { 89 q.temp = true 90 return q 91 } 92 93 func (q *CreateTableQuery) IfNotExists() *CreateTableQuery { 94 q.ifNotExists = true 95 return q 96 } 97 98 // Varchar sets default length for VARCHAR columns. 99 func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery { 100 if n <= 0 { 101 q.setErr(fmt.Errorf("bun: illegal VARCHAR length: %d", n)) 102 return q 103 } 104 q.varchar = n 105 return q 106 } 107 108 func (q *CreateTableQuery) ForeignKey(query string, args ...interface{}) *CreateTableQuery { 109 q.fks = append(q.fks, schema.SafeQuery(query, args)) 110 return q 111 } 112 113 func (q *CreateTableQuery) PartitionBy(query string, args ...interface{}) *CreateTableQuery { 114 q.partitionBy = schema.SafeQuery(query, args) 115 return q 116 } 117 118 func (q *CreateTableQuery) TableSpace(tablespace string) *CreateTableQuery { 119 q.tablespace = schema.UnsafeIdent(tablespace) 120 return q 121 } 122 123 func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery { 124 for _, relation := range q.tableModel.Table().Relations { 125 if relation.Type == schema.ManyToManyRelation || 126 relation.Type == schema.HasManyRelation { 127 continue 128 } 129 130 q = q.ForeignKey("(?) REFERENCES ? (?) ? ?", 131 Safe(appendColumns(nil, "", relation.BaseFields)), 132 relation.JoinTable.SQLName, 133 Safe(appendColumns(nil, "", relation.JoinFields)), 134 Safe(relation.OnUpdate), 135 Safe(relation.OnDelete), 136 ) 137 } 138 return q 139 } 140 141 // ------------------------------------------------------------------------------ 142 143 func (q *CreateTableQuery) Operation() string { 144 return "CREATE TABLE" 145 } 146 147 func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 148 if q.err != nil { 149 return nil, q.err 150 } 151 if q.table == nil { 152 return nil, errNilModel 153 } 154 155 b = append(b, "CREATE "...) 156 if q.temp { 157 b = append(b, "TEMP "...) 158 } 159 b = append(b, "TABLE "...) 160 if q.ifNotExists && fmter.Dialect().Features().Has(feature.TableNotExists) { 161 b = append(b, "IF NOT EXISTS "...) 162 } 163 b, err = q.appendFirstTable(fmter, b) 164 if err != nil { 165 return nil, err 166 } 167 168 b = append(b, " ("...) 169 170 for i, field := range q.table.Fields { 171 if i > 0 { 172 b = append(b, ", "...) 173 } 174 175 b = append(b, field.SQLName...) 176 b = append(b, " "...) 177 b = q.appendSQLType(b, field) 178 if field.NotNull { 179 b = append(b, " NOT NULL"...) 180 } 181 if field.AutoIncrement { 182 switch { 183 case fmter.Dialect().Features().Has(feature.AutoIncrement): 184 b = append(b, " AUTO_INCREMENT"...) 185 case fmter.Dialect().Features().Has(feature.Identity): 186 b = append(b, " IDENTITY"...) 187 } 188 } 189 if field.Identity { 190 if fmter.Dialect().Features().Has(feature.GeneratedIdentity) { 191 b = append(b, " GENERATED BY DEFAULT AS IDENTITY"...) 192 } 193 } 194 if field.SQLDefault != "" { 195 b = append(b, " DEFAULT "...) 196 b = append(b, field.SQLDefault...) 197 } 198 } 199 200 for i, col := range q.columns { 201 // Only pre-pend the comma if we are on subsequent iterations, or if there were fields/columns appended before 202 // this. This way if we are only appending custom column expressions we will not produce a syntax error with a 203 // leading comma. 204 if i > 0 || len(q.table.Fields) > 0 { 205 b = append(b, ", "...) 206 } 207 b, err = col.AppendQuery(fmter, b) 208 if err != nil { 209 return nil, err 210 } 211 } 212 213 b = q.appendPKConstraint(b, q.table.PKs) 214 b = q.appendUniqueConstraints(fmter, b) 215 b, err = q.appendFKConstraints(fmter, b) 216 if err != nil { 217 return nil, err 218 } 219 220 b = append(b, ")"...) 221 222 if !q.partitionBy.IsZero() { 223 b = append(b, " PARTITION BY "...) 224 b, err = q.partitionBy.AppendQuery(fmter, b) 225 if err != nil { 226 return nil, err 227 } 228 } 229 230 if !q.tablespace.IsZero() { 231 b = append(b, " TABLESPACE "...) 232 b, err = q.tablespace.AppendQuery(fmter, b) 233 if err != nil { 234 return nil, err 235 } 236 } 237 238 return b, nil 239 } 240 241 func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte { 242 // Most of the time these two will match, but for the cases where DiscoveredSQLType is dialect-specific, 243 // e.g. pgdialect would change sqltype.SmallInt to pgTypeSmallSerial for columns that have `bun:",autoincrement"` 244 if !strings.EqualFold(field.CreateTableSQLType, field.DiscoveredSQLType) { 245 return append(b, field.CreateTableSQLType...) 246 } 247 248 // For all common SQL types except VARCHAR, both UserDefinedSQLType and DiscoveredSQLType specify the correct type, 249 // and we needn't modify it. For VARCHAR columns, we will stop to check if a valid length has been set in .Varchar(int). 250 if !strings.EqualFold(field.CreateTableSQLType, sqltype.VarChar) || q.varchar <= 0 { 251 return append(b, field.CreateTableSQLType...) 252 } 253 254 b = append(b, sqltype.VarChar...) 255 b = append(b, "("...) 256 b = strconv.AppendInt(b, int64(q.varchar), 10) 257 b = append(b, ")"...) 258 return b 259 } 260 261 func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte { 262 unique := q.table.Unique 263 264 keys := make([]string, 0, len(unique)) 265 for key := range unique { 266 keys = append(keys, key) 267 } 268 sort.Strings(keys) 269 270 for _, key := range keys { 271 if key == "" { 272 for _, field := range unique[key] { 273 b = q.appendUniqueConstraint(fmter, b, key, field) 274 } 275 continue 276 } 277 b = q.appendUniqueConstraint(fmter, b, key, unique[key]...) 278 } 279 280 return b 281 } 282 283 func (q *CreateTableQuery) appendUniqueConstraint( 284 fmter schema.Formatter, b []byte, name string, fields ...*schema.Field, 285 ) []byte { 286 if name != "" { 287 b = append(b, ", CONSTRAINT "...) 288 b = fmter.AppendIdent(b, name) 289 } else { 290 b = append(b, ","...) 291 } 292 b = append(b, " UNIQUE ("...) 293 b = appendColumns(b, "", fields) 294 b = append(b, ")"...) 295 return b 296 } 297 298 func (q *CreateTableQuery) appendFKConstraints( 299 fmter schema.Formatter, b []byte, 300 ) (_ []byte, err error) { 301 for _, fk := range q.fks { 302 b = append(b, ", FOREIGN KEY "...) 303 b, err = fk.AppendQuery(fmter, b) 304 if err != nil { 305 return nil, err 306 } 307 } 308 return b, nil 309 } 310 311 func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []byte { 312 if len(pks) == 0 { 313 return b 314 } 315 316 b = append(b, ", PRIMARY KEY ("...) 317 b = appendColumns(b, "", pks) 318 b = append(b, ")"...) 319 return b 320 } 321 322 // ------------------------------------------------------------------------------ 323 324 func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { 325 if err := q.beforeCreateTableHook(ctx); err != nil { 326 return nil, err 327 } 328 329 queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) 330 if err != nil { 331 return nil, err 332 } 333 334 query := internal.String(queryBytes) 335 336 res, err := q.exec(ctx, q, query) 337 if err != nil { 338 return nil, err 339 } 340 341 if q.table != nil { 342 if err := q.afterCreateTableHook(ctx); err != nil { 343 return nil, err 344 } 345 } 346 347 return res, nil 348 } 349 350 func (q *CreateTableQuery) beforeCreateTableHook(ctx context.Context) error { 351 if hook, ok := q.table.ZeroIface.(BeforeCreateTableHook); ok { 352 if err := hook.BeforeCreateTable(ctx, q); err != nil { 353 return err 354 } 355 } 356 return nil 357 } 358 359 func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error { 360 if hook, ok := q.table.ZeroIface.(AfterCreateTableHook); ok { 361 if err := hook.AfterCreateTable(ctx, q); err != nil { 362 return err 363 } 364 } 365 return nil 366 }