gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

query_insert.go (14778B)


      1 package bun
      2 
      3 import (
      4 	"context"
      5 	"database/sql"
      6 	"fmt"
      7 	"reflect"
      8 	"strings"
      9 
     10 	"github.com/uptrace/bun/dialect/feature"
     11 	"github.com/uptrace/bun/internal"
     12 	"github.com/uptrace/bun/schema"
     13 )
     14 
     15 type InsertQuery struct {
     16 	whereBaseQuery
     17 	returningQuery
     18 	customValueQuery
     19 
     20 	on schema.QueryWithArgs
     21 	setQuery
     22 
     23 	ignore  bool
     24 	replace bool
     25 }
     26 
     27 var _ Query = (*InsertQuery)(nil)
     28 
     29 func NewInsertQuery(db *DB) *InsertQuery {
     30 	q := &InsertQuery{
     31 		whereBaseQuery: whereBaseQuery{
     32 			baseQuery: baseQuery{
     33 				db:   db,
     34 				conn: db.DB,
     35 			},
     36 		},
     37 	}
     38 	return q
     39 }
     40 
     41 func (q *InsertQuery) Conn(db IConn) *InsertQuery {
     42 	q.setConn(db)
     43 	return q
     44 }
     45 
     46 func (q *InsertQuery) Model(model interface{}) *InsertQuery {
     47 	q.setModel(model)
     48 	return q
     49 }
     50 
     51 func (q *InsertQuery) Err(err error) *InsertQuery {
     52 	q.setErr(err)
     53 	return q
     54 }
     55 
     56 // Apply calls the fn passing the SelectQuery as an argument.
     57 func (q *InsertQuery) Apply(fn func(*InsertQuery) *InsertQuery) *InsertQuery {
     58 	if fn != nil {
     59 		return fn(q)
     60 	}
     61 	return q
     62 }
     63 
     64 func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery {
     65 	q.addWith(name, query, false)
     66 	return q
     67 }
     68 
     69 func (q *InsertQuery) WithRecursive(name string, query schema.QueryAppender) *InsertQuery {
     70 	q.addWith(name, query, true)
     71 	return q
     72 }
     73 
     74 //------------------------------------------------------------------------------
     75 
     76 func (q *InsertQuery) Table(tables ...string) *InsertQuery {
     77 	for _, table := range tables {
     78 		q.addTable(schema.UnsafeIdent(table))
     79 	}
     80 	return q
     81 }
     82 
     83 func (q *InsertQuery) TableExpr(query string, args ...interface{}) *InsertQuery {
     84 	q.addTable(schema.SafeQuery(query, args))
     85 	return q
     86 }
     87 
     88 func (q *InsertQuery) ModelTableExpr(query string, args ...interface{}) *InsertQuery {
     89 	q.modelTableName = schema.SafeQuery(query, args)
     90 	return q
     91 }
     92 
     93 //------------------------------------------------------------------------------
     94 
     95 func (q *InsertQuery) Column(columns ...string) *InsertQuery {
     96 	for _, column := range columns {
     97 		q.addColumn(schema.UnsafeIdent(column))
     98 	}
     99 	return q
    100 }
    101 
    102 func (q *InsertQuery) ColumnExpr(query string, args ...interface{}) *InsertQuery {
    103 	q.addColumn(schema.SafeQuery(query, args))
    104 	return q
    105 }
    106 
    107 func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery {
    108 	q.excludeColumn(columns)
    109 	return q
    110 }
    111 
    112 // Value overwrites model value for the column.
    113 func (q *InsertQuery) Value(column string, expr string, args ...interface{}) *InsertQuery {
    114 	if q.table == nil {
    115 		q.err = errNilModel
    116 		return q
    117 	}
    118 	q.addValue(q.table, column, expr, args)
    119 	return q
    120 }
    121 
    122 func (q *InsertQuery) Where(query string, args ...interface{}) *InsertQuery {
    123 	q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
    124 	return q
    125 }
    126 
    127 func (q *InsertQuery) WhereOr(query string, args ...interface{}) *InsertQuery {
    128 	q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
    129 	return q
    130 }
    131 
    132 //------------------------------------------------------------------------------
    133 
    134 // Returning adds a RETURNING clause to the query.
    135 //
    136 // To suppress the auto-generated RETURNING clause, use `Returning("")`.
    137 func (q *InsertQuery) Returning(query string, args ...interface{}) *InsertQuery {
    138 	q.addReturning(schema.SafeQuery(query, args))
    139 	return q
    140 }
    141 
    142 //------------------------------------------------------------------------------
    143 
    144 // Ignore generates different queries depending on the DBMS:
    145 //   - On MySQL, it generates `INSERT IGNORE INTO`.
    146 //   - On PostgreSQL, it generates `ON CONFLICT DO NOTHING`.
    147 func (q *InsertQuery) Ignore() *InsertQuery {
    148 	if q.db.fmter.HasFeature(feature.InsertOnConflict) {
    149 		return q.On("CONFLICT DO NOTHING")
    150 	}
    151 	if q.db.fmter.HasFeature(feature.InsertIgnore) {
    152 		q.ignore = true
    153 	}
    154 	return q
    155 }
    156 
    157 // Replaces generates a `REPLACE INTO` query (MySQL and MariaDB).
    158 func (q *InsertQuery) Replace() *InsertQuery {
    159 	q.replace = true
    160 	return q
    161 }
    162 
    163 //------------------------------------------------------------------------------
    164 
    165 func (q *InsertQuery) Operation() string {
    166 	return "INSERT"
    167 }
    168 
    169 func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
    170 	if q.err != nil {
    171 		return nil, q.err
    172 	}
    173 
    174 	fmter = formatterWithModel(fmter, q)
    175 
    176 	b, err = q.appendWith(fmter, b)
    177 	if err != nil {
    178 		return nil, err
    179 	}
    180 
    181 	if q.replace {
    182 		b = append(b, "REPLACE "...)
    183 	} else {
    184 		b = append(b, "INSERT "...)
    185 		if q.ignore {
    186 			b = append(b, "IGNORE "...)
    187 		}
    188 	}
    189 	b = append(b, "INTO "...)
    190 
    191 	if q.db.features.Has(feature.InsertTableAlias) && !q.on.IsZero() {
    192 		b, err = q.appendFirstTableWithAlias(fmter, b)
    193 	} else {
    194 		b, err = q.appendFirstTable(fmter, b)
    195 	}
    196 	if err != nil {
    197 		return nil, err
    198 	}
    199 
    200 	b, err = q.appendColumnsValues(fmter, b, false)
    201 	if err != nil {
    202 		return nil, err
    203 	}
    204 
    205 	b, err = q.appendOn(fmter, b)
    206 	if err != nil {
    207 		return nil, err
    208 	}
    209 
    210 	if q.hasFeature(feature.InsertReturning) && q.hasReturning() {
    211 		b = append(b, " RETURNING "...)
    212 		b, err = q.appendReturning(fmter, b)
    213 		if err != nil {
    214 			return nil, err
    215 		}
    216 	}
    217 
    218 	return b, nil
    219 }
    220 
    221 func (q *InsertQuery) appendColumnsValues(
    222 	fmter schema.Formatter, b []byte, skipOutput bool,
    223 ) (_ []byte, err error) {
    224 	if q.hasMultiTables() {
    225 		if q.columns != nil {
    226 			b = append(b, " ("...)
    227 			b, err = q.appendColumns(fmter, b)
    228 			if err != nil {
    229 				return nil, err
    230 			}
    231 			b = append(b, ")"...)
    232 		}
    233 
    234 		if q.hasFeature(feature.Output) && q.hasReturning() {
    235 			b = append(b, " OUTPUT "...)
    236 			b, err = q.appendOutput(fmter, b)
    237 			if err != nil {
    238 				return nil, err
    239 			}
    240 		}
    241 
    242 		b = append(b, " SELECT "...)
    243 
    244 		if q.columns != nil {
    245 			b, err = q.appendColumns(fmter, b)
    246 			if err != nil {
    247 				return nil, err
    248 			}
    249 		} else {
    250 			b = append(b, "*"...)
    251 		}
    252 
    253 		b = append(b, " FROM "...)
    254 		b, err = q.appendOtherTables(fmter, b)
    255 		if err != nil {
    256 			return nil, err
    257 		}
    258 
    259 		return b, nil
    260 	}
    261 
    262 	if m, ok := q.model.(*mapModel); ok {
    263 		return m.appendColumnsValues(fmter, b), nil
    264 	}
    265 	if _, ok := q.model.(*mapSliceModel); ok {
    266 		return nil, fmt.Errorf("Insert(*[]map[string]interface{}) is not supported")
    267 	}
    268 
    269 	if q.model == nil {
    270 		return nil, errNilModel
    271 	}
    272 
    273 	// Build fields to populate RETURNING clause.
    274 	fields, err := q.getFields()
    275 	if err != nil {
    276 		return nil, err
    277 	}
    278 
    279 	b = append(b, " ("...)
    280 	b = q.appendFields(fmter, b, fields)
    281 	b = append(b, ")"...)
    282 
    283 	if q.hasFeature(feature.Output) && q.hasReturning() && !skipOutput {
    284 		b = append(b, " OUTPUT "...)
    285 		b, err = q.appendOutput(fmter, b)
    286 		if err != nil {
    287 			return nil, err
    288 		}
    289 	}
    290 
    291 	b = append(b, " VALUES ("...)
    292 
    293 	switch model := q.tableModel.(type) {
    294 	case *structTableModel:
    295 		b, err = q.appendStructValues(fmter, b, fields, model.strct)
    296 		if err != nil {
    297 			return nil, err
    298 		}
    299 	case *sliceTableModel:
    300 		b, err = q.appendSliceValues(fmter, b, fields, model.slice)
    301 		if err != nil {
    302 			return nil, err
    303 		}
    304 	default:
    305 		return nil, fmt.Errorf("bun: Insert does not support %T", q.tableModel)
    306 	}
    307 
    308 	b = append(b, ')')
    309 
    310 	return b, nil
    311 }
    312 
    313 func (q *InsertQuery) appendStructValues(
    314 	fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value,
    315 ) (_ []byte, err error) {
    316 	isTemplate := fmter.IsNop()
    317 	for i, f := range fields {
    318 		if i > 0 {
    319 			b = append(b, ", "...)
    320 		}
    321 
    322 		app, ok := q.modelValues[f.Name]
    323 		if ok {
    324 			b, err = app.AppendQuery(fmter, b)
    325 			if err != nil {
    326 				return nil, err
    327 			}
    328 			q.addReturningField(f)
    329 			continue
    330 		}
    331 
    332 		switch {
    333 		case isTemplate:
    334 			b = append(b, '?')
    335 		case (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)):
    336 			if q.db.features.Has(feature.DefaultPlaceholder) {
    337 				b = append(b, "DEFAULT"...)
    338 			} else if f.SQLDefault != "" {
    339 				b = append(b, f.SQLDefault...)
    340 			} else {
    341 				b = append(b, "NULL"...)
    342 			}
    343 			q.addReturningField(f)
    344 		default:
    345 			b = f.AppendValue(fmter, b, strct)
    346 		}
    347 	}
    348 
    349 	for i, v := range q.extraValues {
    350 		if i > 0 || len(fields) > 0 {
    351 			b = append(b, ", "...)
    352 		}
    353 
    354 		b, err = v.value.AppendQuery(fmter, b)
    355 		if err != nil {
    356 			return nil, err
    357 		}
    358 	}
    359 
    360 	return b, nil
    361 }
    362 
    363 func (q *InsertQuery) appendSliceValues(
    364 	fmter schema.Formatter, b []byte, fields []*schema.Field, slice reflect.Value,
    365 ) (_ []byte, err error) {
    366 	if fmter.IsNop() {
    367 		return q.appendStructValues(fmter, b, fields, reflect.Value{})
    368 	}
    369 
    370 	sliceLen := slice.Len()
    371 	for i := 0; i < sliceLen; i++ {
    372 		if i > 0 {
    373 			b = append(b, "), ("...)
    374 		}
    375 		el := indirect(slice.Index(i))
    376 		b, err = q.appendStructValues(fmter, b, fields, el)
    377 		if err != nil {
    378 			return nil, err
    379 		}
    380 	}
    381 
    382 	return b, nil
    383 }
    384 
    385 func (q *InsertQuery) getFields() ([]*schema.Field, error) {
    386 	hasIdentity := q.db.features.Has(feature.Identity)
    387 
    388 	if len(q.columns) > 0 || q.db.features.Has(feature.DefaultPlaceholder) && !hasIdentity {
    389 		return q.baseQuery.getFields()
    390 	}
    391 
    392 	var strct reflect.Value
    393 
    394 	switch model := q.tableModel.(type) {
    395 	case *structTableModel:
    396 		strct = model.strct
    397 	case *sliceTableModel:
    398 		if model.sliceLen == 0 {
    399 			return nil, fmt.Errorf("bun: Insert(empty %T)", model.slice.Type())
    400 		}
    401 		strct = indirect(model.slice.Index(0))
    402 	default:
    403 		return nil, errNilModel
    404 	}
    405 
    406 	fields := make([]*schema.Field, 0, len(q.table.Fields))
    407 
    408 	for _, f := range q.table.Fields {
    409 		if hasIdentity && f.AutoIncrement {
    410 			q.addReturningField(f)
    411 			continue
    412 		}
    413 		if f.NotNull && f.SQLDefault == "" {
    414 			if (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)) {
    415 				q.addReturningField(f)
    416 				continue
    417 			}
    418 		}
    419 		fields = append(fields, f)
    420 	}
    421 
    422 	return fields, nil
    423 }
    424 
    425 func (q *InsertQuery) appendFields(
    426 	fmter schema.Formatter, b []byte, fields []*schema.Field,
    427 ) []byte {
    428 	b = appendColumns(b, "", fields)
    429 	for i, v := range q.extraValues {
    430 		if i > 0 || len(fields) > 0 {
    431 			b = append(b, ", "...)
    432 		}
    433 		b = fmter.AppendIdent(b, v.column)
    434 	}
    435 	return b
    436 }
    437 
    438 //------------------------------------------------------------------------------
    439 
    440 func (q *InsertQuery) On(s string, args ...interface{}) *InsertQuery {
    441 	q.on = schema.SafeQuery(s, args)
    442 	return q
    443 }
    444 
    445 func (q *InsertQuery) Set(query string, args ...interface{}) *InsertQuery {
    446 	q.addSet(schema.SafeQuery(query, args))
    447 	return q
    448 }
    449 
    450 func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err error) {
    451 	if q.on.IsZero() {
    452 		return b, nil
    453 	}
    454 
    455 	b = append(b, " ON "...)
    456 	b, err = q.on.AppendQuery(fmter, b)
    457 	if err != nil {
    458 		return nil, err
    459 	}
    460 
    461 	if len(q.set) > 0 {
    462 		if fmter.HasFeature(feature.InsertOnDuplicateKey) {
    463 			b = append(b, ' ')
    464 		} else {
    465 			b = append(b, " SET "...)
    466 		}
    467 
    468 		b, err = q.appendSet(fmter, b)
    469 		if err != nil {
    470 			return nil, err
    471 		}
    472 	} else if q.onConflictDoUpdate() {
    473 		fields, err := q.getDataFields()
    474 		if err != nil {
    475 			return nil, err
    476 		}
    477 
    478 		if len(fields) == 0 {
    479 			fields = q.tableModel.Table().DataFields
    480 		}
    481 
    482 		b = q.appendSetExcluded(b, fields)
    483 	} else if q.onDuplicateKeyUpdate() {
    484 		fields, err := q.getDataFields()
    485 		if err != nil {
    486 			return nil, err
    487 		}
    488 
    489 		if len(fields) == 0 {
    490 			fields = q.tableModel.Table().DataFields
    491 		}
    492 
    493 		b = q.appendSetValues(b, fields)
    494 	}
    495 
    496 	if len(q.where) > 0 {
    497 		b = append(b, " WHERE "...)
    498 
    499 		b, err = appendWhere(fmter, b, q.where)
    500 		if err != nil {
    501 			return nil, err
    502 		}
    503 	}
    504 
    505 	return b, nil
    506 }
    507 
    508 func (q *InsertQuery) onConflictDoUpdate() bool {
    509 	return strings.HasSuffix(strings.ToUpper(q.on.Query), " DO UPDATE")
    510 }
    511 
    512 func (q *InsertQuery) onDuplicateKeyUpdate() bool {
    513 	return strings.ToUpper(q.on.Query) == "DUPLICATE KEY UPDATE"
    514 }
    515 
    516 func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte {
    517 	b = append(b, " SET "...)
    518 	for i, f := range fields {
    519 		if i > 0 {
    520 			b = append(b, ", "...)
    521 		}
    522 		b = append(b, f.SQLName...)
    523 		b = append(b, " = EXCLUDED."...)
    524 		b = append(b, f.SQLName...)
    525 	}
    526 	return b
    527 }
    528 
    529 func (q *InsertQuery) appendSetValues(b []byte, fields []*schema.Field) []byte {
    530 	b = append(b, " "...)
    531 	for i, f := range fields {
    532 		if i > 0 {
    533 			b = append(b, ", "...)
    534 		}
    535 		b = append(b, f.SQLName...)
    536 		b = append(b, " = VALUES("...)
    537 		b = append(b, f.SQLName...)
    538 		b = append(b, ")"...)
    539 	}
    540 	return b
    541 }
    542 
    543 //------------------------------------------------------------------------------
    544 
    545 func (q *InsertQuery) Scan(ctx context.Context, dest ...interface{}) error {
    546 	_, err := q.scanOrExec(ctx, dest, true)
    547 	return err
    548 }
    549 
    550 func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
    551 	return q.scanOrExec(ctx, dest, len(dest) > 0)
    552 }
    553 
    554 func (q *InsertQuery) scanOrExec(
    555 	ctx context.Context, dest []interface{}, hasDest bool,
    556 ) (sql.Result, error) {
    557 	if q.err != nil {
    558 		return nil, q.err
    559 	}
    560 
    561 	if q.table != nil {
    562 		if err := q.beforeInsertHook(ctx); err != nil {
    563 			return nil, err
    564 		}
    565 	}
    566 
    567 	// Run append model hooks before generating the query.
    568 	if err := q.beforeAppendModel(ctx, q); err != nil {
    569 		return nil, err
    570 	}
    571 
    572 	// Generate the query before checking hasReturning.
    573 	queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
    574 	if err != nil {
    575 		return nil, err
    576 	}
    577 
    578 	useScan := hasDest || (q.hasReturning() && q.hasFeature(feature.InsertReturning|feature.Output))
    579 	var model Model
    580 
    581 	if useScan {
    582 		var err error
    583 		model, err = q.getModel(dest)
    584 		if err != nil {
    585 			return nil, err
    586 		}
    587 	}
    588 
    589 	query := internal.String(queryBytes)
    590 	var res sql.Result
    591 
    592 	if useScan {
    593 		res, err = q.scan(ctx, q, query, model, hasDest)
    594 		if err != nil {
    595 			return nil, err
    596 		}
    597 	} else {
    598 		res, err = q.exec(ctx, q, query)
    599 		if err != nil {
    600 			return nil, err
    601 		}
    602 
    603 		if err := q.tryLastInsertID(res, dest); err != nil {
    604 			return nil, err
    605 		}
    606 	}
    607 
    608 	if q.table != nil {
    609 		if err := q.afterInsertHook(ctx); err != nil {
    610 			return nil, err
    611 		}
    612 	}
    613 
    614 	return res, nil
    615 }
    616 
    617 func (q *InsertQuery) beforeInsertHook(ctx context.Context) error {
    618 	if hook, ok := q.table.ZeroIface.(BeforeInsertHook); ok {
    619 		if err := hook.BeforeInsert(ctx, q); err != nil {
    620 			return err
    621 		}
    622 	}
    623 	return nil
    624 }
    625 
    626 func (q *InsertQuery) afterInsertHook(ctx context.Context) error {
    627 	if hook, ok := q.table.ZeroIface.(AfterInsertHook); ok {
    628 		if err := hook.AfterInsert(ctx, q); err != nil {
    629 			return err
    630 		}
    631 	}
    632 	return nil
    633 }
    634 
    635 func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error {
    636 	if q.db.features.Has(feature.Returning) ||
    637 		q.db.features.Has(feature.Output) ||
    638 		q.table == nil ||
    639 		len(q.table.PKs) != 1 ||
    640 		!q.table.PKs[0].AutoIncrement {
    641 		return nil
    642 	}
    643 
    644 	id, err := res.LastInsertId()
    645 	if err != nil {
    646 		return err
    647 	}
    648 	if id == 0 {
    649 		return nil
    650 	}
    651 
    652 	model, err := q.getModel(dest)
    653 	if err != nil {
    654 		return err
    655 	}
    656 
    657 	pk := q.table.PKs[0]
    658 	switch model := model.(type) {
    659 	case *structTableModel:
    660 		if err := pk.ScanValue(model.strct, id); err != nil {
    661 			return err
    662 		}
    663 	case *sliceTableModel:
    664 		sliceLen := model.slice.Len()
    665 		for i := 0; i < sliceLen; i++ {
    666 			strct := indirect(model.slice.Index(i))
    667 			if err := pk.ScanValue(strct, id); err != nil {
    668 				return err
    669 			}
    670 			id++
    671 		}
    672 	}
    673 
    674 	return nil
    675 }
    676 
    677 func (q *InsertQuery) String() string {
    678 	buf, err := q.AppendQuery(q.db.Formatter(), nil)
    679 	if err != nil {
    680 		panic(err)
    681 	}
    682 
    683 	return string(buf)
    684 }