gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

query_table_create.go (9145B)


      1 package bun
      2 
      3 import (
      4 	"context"
      5 	"database/sql"
      6 	"fmt"
      7 	"sort"
      8 	"strconv"
      9 	"strings"
     10 
     11 	"github.com/uptrace/bun/dialect/feature"
     12 	"github.com/uptrace/bun/dialect/sqltype"
     13 	"github.com/uptrace/bun/internal"
     14 	"github.com/uptrace/bun/schema"
     15 )
     16 
     17 type CreateTableQuery struct {
     18 	baseQuery
     19 
     20 	temp        bool
     21 	ifNotExists bool
     22 
     23 	// varchar changes the default length for VARCHAR columns.
     24 	// Because some dialects require that length is always specified for VARCHAR type,
     25 	// we will use the exact user-defined type if length is set explicitly, as in `bun:",type:varchar(5)"`,
     26 	// but assume the new default length when it's omitted, e.g. `bun:",type:varchar"`.
     27 	varchar int
     28 
     29 	fks         []schema.QueryWithArgs
     30 	partitionBy schema.QueryWithArgs
     31 	tablespace  schema.QueryWithArgs
     32 }
     33 
     34 var _ Query = (*CreateTableQuery)(nil)
     35 
     36 func NewCreateTableQuery(db *DB) *CreateTableQuery {
     37 	q := &CreateTableQuery{
     38 		baseQuery: baseQuery{
     39 			db:   db,
     40 			conn: db.DB,
     41 		},
     42 		varchar: db.Dialect().DefaultVarcharLen(),
     43 	}
     44 	return q
     45 }
     46 
     47 func (q *CreateTableQuery) Conn(db IConn) *CreateTableQuery {
     48 	q.setConn(db)
     49 	return q
     50 }
     51 
     52 func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery {
     53 	q.setModel(model)
     54 	return q
     55 }
     56 
     57 func (q *CreateTableQuery) Err(err error) *CreateTableQuery {
     58 	q.setErr(err)
     59 	return q
     60 }
     61 
     62 // ------------------------------------------------------------------------------
     63 
     64 func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery {
     65 	for _, table := range tables {
     66 		q.addTable(schema.UnsafeIdent(table))
     67 	}
     68 	return q
     69 }
     70 
     71 func (q *CreateTableQuery) TableExpr(query string, args ...interface{}) *CreateTableQuery {
     72 	q.addTable(schema.SafeQuery(query, args))
     73 	return q
     74 }
     75 
     76 func (q *CreateTableQuery) ModelTableExpr(query string, args ...interface{}) *CreateTableQuery {
     77 	q.modelTableName = schema.SafeQuery(query, args)
     78 	return q
     79 }
     80 
     81 func (q *CreateTableQuery) ColumnExpr(query string, args ...interface{}) *CreateTableQuery {
     82 	q.addColumn(schema.SafeQuery(query, args))
     83 	return q
     84 }
     85 
     86 // ------------------------------------------------------------------------------
     87 
     88 func (q *CreateTableQuery) Temp() *CreateTableQuery {
     89 	q.temp = true
     90 	return q
     91 }
     92 
     93 func (q *CreateTableQuery) IfNotExists() *CreateTableQuery {
     94 	q.ifNotExists = true
     95 	return q
     96 }
     97 
     98 // Varchar sets default length for VARCHAR columns.
     99 func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery {
    100 	if n <= 0 {
    101 		q.setErr(fmt.Errorf("bun: illegal VARCHAR length: %d", n))
    102 		return q
    103 	}
    104 	q.varchar = n
    105 	return q
    106 }
    107 
    108 func (q *CreateTableQuery) ForeignKey(query string, args ...interface{}) *CreateTableQuery {
    109 	q.fks = append(q.fks, schema.SafeQuery(query, args))
    110 	return q
    111 }
    112 
    113 func (q *CreateTableQuery) PartitionBy(query string, args ...interface{}) *CreateTableQuery {
    114 	q.partitionBy = schema.SafeQuery(query, args)
    115 	return q
    116 }
    117 
    118 func (q *CreateTableQuery) TableSpace(tablespace string) *CreateTableQuery {
    119 	q.tablespace = schema.UnsafeIdent(tablespace)
    120 	return q
    121 }
    122 
    123 func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery {
    124 	for _, relation := range q.tableModel.Table().Relations {
    125 		if relation.Type == schema.ManyToManyRelation ||
    126 			relation.Type == schema.HasManyRelation {
    127 			continue
    128 		}
    129 
    130 		q = q.ForeignKey("(?) REFERENCES ? (?) ? ?",
    131 			Safe(appendColumns(nil, "", relation.BaseFields)),
    132 			relation.JoinTable.SQLName,
    133 			Safe(appendColumns(nil, "", relation.JoinFields)),
    134 			Safe(relation.OnUpdate),
    135 			Safe(relation.OnDelete),
    136 		)
    137 	}
    138 	return q
    139 }
    140 
    141 // ------------------------------------------------------------------------------
    142 
    143 func (q *CreateTableQuery) Operation() string {
    144 	return "CREATE TABLE"
    145 }
    146 
    147 func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
    148 	if q.err != nil {
    149 		return nil, q.err
    150 	}
    151 	if q.table == nil {
    152 		return nil, errNilModel
    153 	}
    154 
    155 	b = append(b, "CREATE "...)
    156 	if q.temp {
    157 		b = append(b, "TEMP "...)
    158 	}
    159 	b = append(b, "TABLE "...)
    160 	if q.ifNotExists && fmter.Dialect().Features().Has(feature.TableNotExists) {
    161 		b = append(b, "IF NOT EXISTS "...)
    162 	}
    163 	b, err = q.appendFirstTable(fmter, b)
    164 	if err != nil {
    165 		return nil, err
    166 	}
    167 
    168 	b = append(b, " ("...)
    169 
    170 	for i, field := range q.table.Fields {
    171 		if i > 0 {
    172 			b = append(b, ", "...)
    173 		}
    174 
    175 		b = append(b, field.SQLName...)
    176 		b = append(b, " "...)
    177 		b = q.appendSQLType(b, field)
    178 		if field.NotNull {
    179 			b = append(b, " NOT NULL"...)
    180 		}
    181 		if field.AutoIncrement {
    182 			switch {
    183 			case fmter.Dialect().Features().Has(feature.AutoIncrement):
    184 				b = append(b, " AUTO_INCREMENT"...)
    185 			case fmter.Dialect().Features().Has(feature.Identity):
    186 				b = append(b, " IDENTITY"...)
    187 			}
    188 		}
    189 		if field.Identity {
    190 			if fmter.Dialect().Features().Has(feature.GeneratedIdentity) {
    191 				b = append(b, " GENERATED BY DEFAULT AS IDENTITY"...)
    192 			}
    193 		}
    194 		if field.SQLDefault != "" {
    195 			b = append(b, " DEFAULT "...)
    196 			b = append(b, field.SQLDefault...)
    197 		}
    198 	}
    199 
    200 	for i, col := range q.columns {
    201 		// Only pre-pend the comma if we are on subsequent iterations, or if there were fields/columns appended before
    202 		// this. This way if we are only appending custom column expressions we will not produce a syntax error with a
    203 		// leading comma.
    204 		if i > 0 || len(q.table.Fields) > 0 {
    205 			b = append(b, ", "...)
    206 		}
    207 		b, err = col.AppendQuery(fmter, b)
    208 		if err != nil {
    209 			return nil, err
    210 		}
    211 	}
    212 
    213 	b = q.appendPKConstraint(b, q.table.PKs)
    214 	b = q.appendUniqueConstraints(fmter, b)
    215 	b, err = q.appendFKConstraints(fmter, b)
    216 	if err != nil {
    217 		return nil, err
    218 	}
    219 
    220 	b = append(b, ")"...)
    221 
    222 	if !q.partitionBy.IsZero() {
    223 		b = append(b, " PARTITION BY "...)
    224 		b, err = q.partitionBy.AppendQuery(fmter, b)
    225 		if err != nil {
    226 			return nil, err
    227 		}
    228 	}
    229 
    230 	if !q.tablespace.IsZero() {
    231 		b = append(b, " TABLESPACE "...)
    232 		b, err = q.tablespace.AppendQuery(fmter, b)
    233 		if err != nil {
    234 			return nil, err
    235 		}
    236 	}
    237 
    238 	return b, nil
    239 }
    240 
    241 func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte {
    242 	// Most of the time these two will match, but for the cases where DiscoveredSQLType is dialect-specific,
    243 	// e.g. pgdialect would change sqltype.SmallInt to pgTypeSmallSerial for columns that have `bun:",autoincrement"`
    244 	if !strings.EqualFold(field.CreateTableSQLType, field.DiscoveredSQLType) {
    245 		return append(b, field.CreateTableSQLType...)
    246 	}
    247 
    248 	// For all common SQL types except VARCHAR, both UserDefinedSQLType and DiscoveredSQLType specify the correct type,
    249 	// and we needn't modify it. For VARCHAR columns, we will stop to check if a valid length has been set in .Varchar(int).
    250 	if !strings.EqualFold(field.CreateTableSQLType, sqltype.VarChar) || q.varchar <= 0 {
    251 		return append(b, field.CreateTableSQLType...)
    252 	}
    253 
    254 	b = append(b, sqltype.VarChar...)
    255 	b = append(b, "("...)
    256 	b = strconv.AppendInt(b, int64(q.varchar), 10)
    257 	b = append(b, ")"...)
    258 	return b
    259 }
    260 
    261 func (q *CreateTableQuery) appendUniqueConstraints(fmter schema.Formatter, b []byte) []byte {
    262 	unique := q.table.Unique
    263 
    264 	keys := make([]string, 0, len(unique))
    265 	for key := range unique {
    266 		keys = append(keys, key)
    267 	}
    268 	sort.Strings(keys)
    269 
    270 	for _, key := range keys {
    271 		if key == "" {
    272 			for _, field := range unique[key] {
    273 				b = q.appendUniqueConstraint(fmter, b, key, field)
    274 			}
    275 			continue
    276 		}
    277 		b = q.appendUniqueConstraint(fmter, b, key, unique[key]...)
    278 	}
    279 
    280 	return b
    281 }
    282 
    283 func (q *CreateTableQuery) appendUniqueConstraint(
    284 	fmter schema.Formatter, b []byte, name string, fields ...*schema.Field,
    285 ) []byte {
    286 	if name != "" {
    287 		b = append(b, ", CONSTRAINT "...)
    288 		b = fmter.AppendIdent(b, name)
    289 	} else {
    290 		b = append(b, ","...)
    291 	}
    292 	b = append(b, " UNIQUE ("...)
    293 	b = appendColumns(b, "", fields)
    294 	b = append(b, ")"...)
    295 	return b
    296 }
    297 
    298 func (q *CreateTableQuery) appendFKConstraints(
    299 	fmter schema.Formatter, b []byte,
    300 ) (_ []byte, err error) {
    301 	for _, fk := range q.fks {
    302 		b = append(b, ", FOREIGN KEY "...)
    303 		b, err = fk.AppendQuery(fmter, b)
    304 		if err != nil {
    305 			return nil, err
    306 		}
    307 	}
    308 	return b, nil
    309 }
    310 
    311 func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []byte {
    312 	if len(pks) == 0 {
    313 		return b
    314 	}
    315 
    316 	b = append(b, ", PRIMARY KEY ("...)
    317 	b = appendColumns(b, "", pks)
    318 	b = append(b, ")"...)
    319 	return b
    320 }
    321 
    322 // ------------------------------------------------------------------------------
    323 
    324 func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
    325 	if err := q.beforeCreateTableHook(ctx); err != nil {
    326 		return nil, err
    327 	}
    328 
    329 	queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
    330 	if err != nil {
    331 		return nil, err
    332 	}
    333 
    334 	query := internal.String(queryBytes)
    335 
    336 	res, err := q.exec(ctx, q, query)
    337 	if err != nil {
    338 		return nil, err
    339 	}
    340 
    341 	if q.table != nil {
    342 		if err := q.afterCreateTableHook(ctx); err != nil {
    343 			return nil, err
    344 		}
    345 	}
    346 
    347 	return res, nil
    348 }
    349 
    350 func (q *CreateTableQuery) beforeCreateTableHook(ctx context.Context) error {
    351 	if hook, ok := q.table.ZeroIface.(BeforeCreateTableHook); ok {
    352 		if err := hook.BeforeCreateTable(ctx, q); err != nil {
    353 			return err
    354 		}
    355 	}
    356 	return nil
    357 }
    358 
    359 func (q *CreateTableQuery) afterCreateTableHook(ctx context.Context) error {
    360 	if hook, ok := q.table.ZeroIface.(AfterCreateTableHook); ok {
    361 		if err := hook.AfterCreateTable(ctx, q); err != nil {
    362 			return err
    363 		}
    364 	}
    365 	return nil
    366 }