gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

query_update.go (13579B)


      1 package bun
      2 
      3 import (
      4 	"context"
      5 	"database/sql"
      6 	"errors"
      7 	"fmt"
      8 
      9 	"github.com/uptrace/bun/dialect"
     10 
     11 	"github.com/uptrace/bun/dialect/feature"
     12 	"github.com/uptrace/bun/internal"
     13 	"github.com/uptrace/bun/schema"
     14 )
     15 
     16 type UpdateQuery struct {
     17 	whereBaseQuery
     18 	returningQuery
     19 	customValueQuery
     20 	setQuery
     21 	idxHintsQuery
     22 
     23 	omitZero bool
     24 }
     25 
     26 var _ Query = (*UpdateQuery)(nil)
     27 
     28 func NewUpdateQuery(db *DB) *UpdateQuery {
     29 	q := &UpdateQuery{
     30 		whereBaseQuery: whereBaseQuery{
     31 			baseQuery: baseQuery{
     32 				db:   db,
     33 				conn: db.DB,
     34 			},
     35 		},
     36 	}
     37 	return q
     38 }
     39 
     40 func (q *UpdateQuery) Conn(db IConn) *UpdateQuery {
     41 	q.setConn(db)
     42 	return q
     43 }
     44 
     45 func (q *UpdateQuery) Model(model interface{}) *UpdateQuery {
     46 	q.setModel(model)
     47 	return q
     48 }
     49 
     50 func (q *UpdateQuery) Err(err error) *UpdateQuery {
     51 	q.setErr(err)
     52 	return q
     53 }
     54 
     55 // Apply calls the fn passing the SelectQuery as an argument.
     56 func (q *UpdateQuery) Apply(fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery {
     57 	if fn != nil {
     58 		return fn(q)
     59 	}
     60 	return q
     61 }
     62 
     63 func (q *UpdateQuery) With(name string, query schema.QueryAppender) *UpdateQuery {
     64 	q.addWith(name, query, false)
     65 	return q
     66 }
     67 
     68 func (q *UpdateQuery) WithRecursive(name string, query schema.QueryAppender) *UpdateQuery {
     69 	q.addWith(name, query, true)
     70 	return q
     71 }
     72 
     73 //------------------------------------------------------------------------------
     74 
     75 func (q *UpdateQuery) Table(tables ...string) *UpdateQuery {
     76 	for _, table := range tables {
     77 		q.addTable(schema.UnsafeIdent(table))
     78 	}
     79 	return q
     80 }
     81 
     82 func (q *UpdateQuery) TableExpr(query string, args ...interface{}) *UpdateQuery {
     83 	q.addTable(schema.SafeQuery(query, args))
     84 	return q
     85 }
     86 
     87 func (q *UpdateQuery) ModelTableExpr(query string, args ...interface{}) *UpdateQuery {
     88 	q.modelTableName = schema.SafeQuery(query, args)
     89 	return q
     90 }
     91 
     92 //------------------------------------------------------------------------------
     93 
     94 func (q *UpdateQuery) Column(columns ...string) *UpdateQuery {
     95 	for _, column := range columns {
     96 		q.addColumn(schema.UnsafeIdent(column))
     97 	}
     98 	return q
     99 }
    100 
    101 func (q *UpdateQuery) ExcludeColumn(columns ...string) *UpdateQuery {
    102 	q.excludeColumn(columns)
    103 	return q
    104 }
    105 
    106 func (q *UpdateQuery) Set(query string, args ...interface{}) *UpdateQuery {
    107 	q.addSet(schema.SafeQuery(query, args))
    108 	return q
    109 }
    110 
    111 func (q *UpdateQuery) SetColumn(column string, query string, args ...interface{}) *UpdateQuery {
    112 	if q.db.HasFeature(feature.UpdateMultiTable) {
    113 		column = q.table.Alias + "." + column
    114 	}
    115 	q.addSet(schema.SafeQuery(column+" = "+query, args))
    116 	return q
    117 }
    118 
    119 // Value overwrites model value for the column.
    120 func (q *UpdateQuery) Value(column string, query string, args ...interface{}) *UpdateQuery {
    121 	if q.table == nil {
    122 		q.err = errNilModel
    123 		return q
    124 	}
    125 	q.addValue(q.table, column, query, args)
    126 	return q
    127 }
    128 
    129 func (q *UpdateQuery) OmitZero() *UpdateQuery {
    130 	q.omitZero = true
    131 	return q
    132 }
    133 
    134 //------------------------------------------------------------------------------
    135 
    136 func (q *UpdateQuery) WherePK(cols ...string) *UpdateQuery {
    137 	q.addWhereCols(cols)
    138 	return q
    139 }
    140 
    141 func (q *UpdateQuery) Where(query string, args ...interface{}) *UpdateQuery {
    142 	q.addWhere(schema.SafeQueryWithSep(query, args, " AND "))
    143 	return q
    144 }
    145 
    146 func (q *UpdateQuery) WhereOr(query string, args ...interface{}) *UpdateQuery {
    147 	q.addWhere(schema.SafeQueryWithSep(query, args, " OR "))
    148 	return q
    149 }
    150 
    151 func (q *UpdateQuery) WhereGroup(sep string, fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery {
    152 	saved := q.where
    153 	q.where = nil
    154 
    155 	q = fn(q)
    156 
    157 	where := q.where
    158 	q.where = saved
    159 
    160 	q.addWhereGroup(sep, where)
    161 
    162 	return q
    163 }
    164 
    165 func (q *UpdateQuery) WhereDeleted() *UpdateQuery {
    166 	q.whereDeleted()
    167 	return q
    168 }
    169 
    170 func (q *UpdateQuery) WhereAllWithDeleted() *UpdateQuery {
    171 	q.whereAllWithDeleted()
    172 	return q
    173 }
    174 
    175 //------------------------------------------------------------------------------
    176 
    177 // Returning adds a RETURNING clause to the query.
    178 //
    179 // To suppress the auto-generated RETURNING clause, use `Returning("NULL")`.
    180 func (q *UpdateQuery) Returning(query string, args ...interface{}) *UpdateQuery {
    181 	q.addReturning(schema.SafeQuery(query, args))
    182 	return q
    183 }
    184 
    185 //------------------------------------------------------------------------------
    186 
    187 func (q *UpdateQuery) Operation() string {
    188 	return "UPDATE"
    189 }
    190 
    191 func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) {
    192 	if q.err != nil {
    193 		return nil, q.err
    194 	}
    195 
    196 	fmter = formatterWithModel(fmter, q)
    197 
    198 	b, err = q.appendWith(fmter, b)
    199 	if err != nil {
    200 		return nil, err
    201 	}
    202 
    203 	b = append(b, "UPDATE "...)
    204 
    205 	if fmter.HasFeature(feature.UpdateMultiTable) {
    206 		b, err = q.appendTablesWithAlias(fmter, b)
    207 	} else if fmter.HasFeature(feature.UpdateTableAlias) {
    208 		b, err = q.appendFirstTableWithAlias(fmter, b)
    209 	} else {
    210 		b, err = q.appendFirstTable(fmter, b)
    211 	}
    212 	if err != nil {
    213 		return nil, err
    214 	}
    215 
    216 	b, err = q.appendIndexHints(fmter, b)
    217 	if err != nil {
    218 		return nil, err
    219 	}
    220 
    221 	b, err = q.mustAppendSet(fmter, b)
    222 	if err != nil {
    223 		return nil, err
    224 	}
    225 
    226 	if !fmter.HasFeature(feature.UpdateMultiTable) {
    227 		b, err = q.appendOtherTables(fmter, b)
    228 		if err != nil {
    229 			return nil, err
    230 		}
    231 	}
    232 
    233 	if q.hasFeature(feature.Output) && q.hasReturning() {
    234 		b = append(b, " OUTPUT "...)
    235 		b, err = q.appendOutput(fmter, b)
    236 		if err != nil {
    237 			return nil, err
    238 		}
    239 	}
    240 
    241 	b, err = q.mustAppendWhere(fmter, b, q.hasTableAlias(fmter))
    242 	if err != nil {
    243 		return nil, err
    244 	}
    245 
    246 	if q.hasFeature(feature.Returning) && q.hasReturning() {
    247 		b = append(b, " RETURNING "...)
    248 		b, err = q.appendReturning(fmter, b)
    249 		if err != nil {
    250 			return nil, err
    251 		}
    252 	}
    253 
    254 	return b, nil
    255 }
    256 
    257 func (q *UpdateQuery) mustAppendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) {
    258 	b = append(b, " SET "...)
    259 
    260 	if len(q.set) > 0 {
    261 		return q.appendSet(fmter, b)
    262 	}
    263 
    264 	if m, ok := q.model.(*mapModel); ok {
    265 		return m.appendSet(fmter, b), nil
    266 	}
    267 
    268 	if q.tableModel == nil {
    269 		return nil, errNilModel
    270 	}
    271 
    272 	switch model := q.tableModel.(type) {
    273 	case *structTableModel:
    274 		b, err = q.appendSetStruct(fmter, b, model)
    275 		if err != nil {
    276 			return nil, err
    277 		}
    278 	case *sliceTableModel:
    279 		return nil, errors.New("bun: to bulk Update, use CTE and VALUES")
    280 	default:
    281 		return nil, fmt.Errorf("bun: Update does not support %T", q.tableModel)
    282 	}
    283 
    284 	return b, nil
    285 }
    286 
    287 func (q *UpdateQuery) appendSetStruct(
    288 	fmter schema.Formatter, b []byte, model *structTableModel,
    289 ) ([]byte, error) {
    290 	fields, err := q.getDataFields()
    291 	if err != nil {
    292 		return nil, err
    293 	}
    294 
    295 	isTemplate := fmter.IsNop()
    296 	pos := len(b)
    297 	for _, f := range fields {
    298 		if f.SkipUpdate() {
    299 			continue
    300 		}
    301 
    302 		app, hasValue := q.modelValues[f.Name]
    303 
    304 		if !hasValue && q.omitZero && f.HasZeroValue(model.strct) {
    305 			continue
    306 		}
    307 
    308 		if len(b) != pos {
    309 			b = append(b, ", "...)
    310 			pos = len(b)
    311 		}
    312 
    313 		b = append(b, f.SQLName...)
    314 		b = append(b, " = "...)
    315 
    316 		if isTemplate {
    317 			b = append(b, '?')
    318 			continue
    319 		}
    320 
    321 		if hasValue {
    322 			b, err = app.AppendQuery(fmter, b)
    323 			if err != nil {
    324 				return nil, err
    325 			}
    326 		} else {
    327 			b = f.AppendValue(fmter, b, model.strct)
    328 		}
    329 	}
    330 
    331 	for i, v := range q.extraValues {
    332 		if i > 0 || len(fields) > 0 {
    333 			b = append(b, ", "...)
    334 		}
    335 
    336 		b = append(b, v.column...)
    337 		b = append(b, " = "...)
    338 
    339 		b, err = v.value.AppendQuery(fmter, b)
    340 		if err != nil {
    341 			return nil, err
    342 		}
    343 	}
    344 
    345 	return b, nil
    346 }
    347 
    348 func (q *UpdateQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) {
    349 	if !q.hasMultiTables() {
    350 		return b, nil
    351 	}
    352 
    353 	b = append(b, " FROM "...)
    354 
    355 	b, err = q.whereBaseQuery.appendOtherTables(fmter, b)
    356 	if err != nil {
    357 		return nil, err
    358 	}
    359 
    360 	return b, nil
    361 }
    362 
    363 //------------------------------------------------------------------------------
    364 
    365 func (q *UpdateQuery) Bulk() *UpdateQuery {
    366 	model, ok := q.model.(*sliceTableModel)
    367 	if !ok {
    368 		q.setErr(fmt.Errorf("bun: Bulk requires a slice, got %T", q.model))
    369 		return q
    370 	}
    371 
    372 	set, err := q.updateSliceSet(q.db.fmter, model)
    373 	if err != nil {
    374 		q.setErr(err)
    375 		return q
    376 	}
    377 
    378 	values := q.db.NewValues(model)
    379 	values.customValueQuery = q.customValueQuery
    380 
    381 	return q.With("_data", values).
    382 		Model(model).
    383 		TableExpr("_data").
    384 		Set(set).
    385 		Where(q.updateSliceWhere(q.db.fmter, model))
    386 }
    387 
    388 func (q *UpdateQuery) updateSliceSet(
    389 	fmter schema.Formatter, model *sliceTableModel,
    390 ) (string, error) {
    391 	fields, err := q.getDataFields()
    392 	if err != nil {
    393 		return "", err
    394 	}
    395 
    396 	var b []byte
    397 	pos := len(b)
    398 	for _, field := range fields {
    399 		if field.SkipUpdate() {
    400 			continue
    401 		}
    402 		if len(b) != pos {
    403 			b = append(b, ", "...)
    404 			pos = len(b)
    405 		}
    406 		if fmter.HasFeature(feature.UpdateMultiTable) {
    407 			b = append(b, model.table.SQLAlias...)
    408 			b = append(b, '.')
    409 		}
    410 		b = append(b, field.SQLName...)
    411 		b = append(b, " = _data."...)
    412 		b = append(b, field.SQLName...)
    413 	}
    414 	return internal.String(b), nil
    415 }
    416 
    417 func (q *UpdateQuery) updateSliceWhere(fmter schema.Formatter, model *sliceTableModel) string {
    418 	var b []byte
    419 	for i, pk := range model.table.PKs {
    420 		if i > 0 {
    421 			b = append(b, " AND "...)
    422 		}
    423 		if q.hasTableAlias(fmter) {
    424 			b = append(b, model.table.SQLAlias...)
    425 		} else {
    426 			b = append(b, model.table.SQLName...)
    427 		}
    428 		b = append(b, '.')
    429 		b = append(b, pk.SQLName...)
    430 		b = append(b, " = _data."...)
    431 		b = append(b, pk.SQLName...)
    432 	}
    433 	return internal.String(b)
    434 }
    435 
    436 //------------------------------------------------------------------------------
    437 
    438 func (q *UpdateQuery) Scan(ctx context.Context, dest ...interface{}) error {
    439 	_, err := q.scanOrExec(ctx, dest, true)
    440 	return err
    441 }
    442 
    443 func (q *UpdateQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
    444 	return q.scanOrExec(ctx, dest, len(dest) > 0)
    445 }
    446 
    447 func (q *UpdateQuery) scanOrExec(
    448 	ctx context.Context, dest []interface{}, hasDest bool,
    449 ) (sql.Result, error) {
    450 	if q.err != nil {
    451 		return nil, q.err
    452 	}
    453 
    454 	if q.table != nil {
    455 		if err := q.beforeUpdateHook(ctx); err != nil {
    456 			return nil, err
    457 		}
    458 	}
    459 
    460 	// Run append model hooks before generating the query.
    461 	if err := q.beforeAppendModel(ctx, q); err != nil {
    462 		return nil, err
    463 	}
    464 
    465 	// Generate the query before checking hasReturning.
    466 	queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
    467 	if err != nil {
    468 		return nil, err
    469 	}
    470 
    471 	useScan := hasDest || (q.hasReturning() && q.hasFeature(feature.Returning|feature.Output))
    472 	var model Model
    473 
    474 	if useScan {
    475 		var err error
    476 		model, err = q.getModel(dest)
    477 		if err != nil {
    478 			return nil, err
    479 		}
    480 	}
    481 
    482 	query := internal.String(queryBytes)
    483 
    484 	var res sql.Result
    485 
    486 	if useScan {
    487 		res, err = q.scan(ctx, q, query, model, hasDest)
    488 		if err != nil {
    489 			return nil, err
    490 		}
    491 	} else {
    492 		res, err = q.exec(ctx, q, query)
    493 		if err != nil {
    494 			return nil, err
    495 		}
    496 	}
    497 
    498 	if q.table != nil {
    499 		if err := q.afterUpdateHook(ctx); err != nil {
    500 			return nil, err
    501 		}
    502 	}
    503 
    504 	return res, nil
    505 }
    506 
    507 func (q *UpdateQuery) beforeUpdateHook(ctx context.Context) error {
    508 	if hook, ok := q.table.ZeroIface.(BeforeUpdateHook); ok {
    509 		if err := hook.BeforeUpdate(ctx, q); err != nil {
    510 			return err
    511 		}
    512 	}
    513 	return nil
    514 }
    515 
    516 func (q *UpdateQuery) afterUpdateHook(ctx context.Context) error {
    517 	if hook, ok := q.table.ZeroIface.(AfterUpdateHook); ok {
    518 		if err := hook.AfterUpdate(ctx, q); err != nil {
    519 			return err
    520 		}
    521 	}
    522 	return nil
    523 }
    524 
    525 // FQN returns a fully qualified column name, for example, table_name.column_name or
    526 // table_alias.column_alias.
    527 func (q *UpdateQuery) FQN(column string) Ident {
    528 	if q.table == nil {
    529 		panic("UpdateQuery.FQN requires a model")
    530 	}
    531 	if q.hasTableAlias(q.db.fmter) {
    532 		return Ident(q.table.Alias + "." + column)
    533 	}
    534 	return Ident(q.table.Name + "." + column)
    535 }
    536 
    537 func (q *UpdateQuery) hasTableAlias(fmter schema.Formatter) bool {
    538 	return fmter.HasFeature(feature.UpdateMultiTable | feature.UpdateTableAlias)
    539 }
    540 
    541 func (q *UpdateQuery) String() string {
    542 	buf, err := q.AppendQuery(q.db.Formatter(), nil)
    543 	if err != nil {
    544 		panic(err)
    545 	}
    546 
    547 	return string(buf)
    548 }
    549 
    550 //------------------------------------------------------------------------------
    551 
    552 func (q *UpdateQuery) QueryBuilder() QueryBuilder {
    553 	return &updateQueryBuilder{q}
    554 }
    555 
    556 func (q *UpdateQuery) ApplyQueryBuilder(fn func(QueryBuilder) QueryBuilder) *UpdateQuery {
    557 	return fn(q.QueryBuilder()).Unwrap().(*UpdateQuery)
    558 }
    559 
    560 type updateQueryBuilder struct {
    561 	*UpdateQuery
    562 }
    563 
    564 func (q *updateQueryBuilder) WhereGroup(
    565 	sep string, fn func(QueryBuilder) QueryBuilder,
    566 ) QueryBuilder {
    567 	q.UpdateQuery = q.UpdateQuery.WhereGroup(sep, func(qs *UpdateQuery) *UpdateQuery {
    568 		return fn(q).(*updateQueryBuilder).UpdateQuery
    569 	})
    570 	return q
    571 }
    572 
    573 func (q *updateQueryBuilder) Where(query string, args ...interface{}) QueryBuilder {
    574 	q.UpdateQuery.Where(query, args...)
    575 	return q
    576 }
    577 
    578 func (q *updateQueryBuilder) WhereOr(query string, args ...interface{}) QueryBuilder {
    579 	q.UpdateQuery.WhereOr(query, args...)
    580 	return q
    581 }
    582 
    583 func (q *updateQueryBuilder) WhereDeleted() QueryBuilder {
    584 	q.UpdateQuery.WhereDeleted()
    585 	return q
    586 }
    587 
    588 func (q *updateQueryBuilder) WhereAllWithDeleted() QueryBuilder {
    589 	q.UpdateQuery.WhereAllWithDeleted()
    590 	return q
    591 }
    592 
    593 func (q *updateQueryBuilder) WherePK(cols ...string) QueryBuilder {
    594 	q.UpdateQuery.WherePK(cols...)
    595 	return q
    596 }
    597 
    598 func (q *updateQueryBuilder) Unwrap() interface{} {
    599 	return q.UpdateQuery
    600 }
    601 
    602 //------------------------------------------------------------------------------
    603 
    604 func (q *UpdateQuery) UseIndex(indexes ...string) *UpdateQuery {
    605 	if q.db.dialect.Name() == dialect.MySQL {
    606 		q.addUseIndex(indexes...)
    607 	}
    608 	return q
    609 }
    610 
    611 func (q *UpdateQuery) IgnoreIndex(indexes ...string) *UpdateQuery {
    612 	if q.db.dialect.Name() == dialect.MySQL {
    613 		q.addIgnoreIndex(indexes...)
    614 	}
    615 	return q
    616 }
    617 
    618 func (q *UpdateQuery) ForceIndex(indexes ...string) *UpdateQuery {
    619 	if q.db.dialect.Name() == dialect.MySQL {
    620 		q.addForceIndex(indexes...)
    621 	}
    622 	return q
    623 }