query_update.go (13579B)
1 package bun 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 9 "github.com/uptrace/bun/dialect" 10 11 "github.com/uptrace/bun/dialect/feature" 12 "github.com/uptrace/bun/internal" 13 "github.com/uptrace/bun/schema" 14 ) 15 16 type UpdateQuery struct { 17 whereBaseQuery 18 returningQuery 19 customValueQuery 20 setQuery 21 idxHintsQuery 22 23 omitZero bool 24 } 25 26 var _ Query = (*UpdateQuery)(nil) 27 28 func NewUpdateQuery(db *DB) *UpdateQuery { 29 q := &UpdateQuery{ 30 whereBaseQuery: whereBaseQuery{ 31 baseQuery: baseQuery{ 32 db: db, 33 conn: db.DB, 34 }, 35 }, 36 } 37 return q 38 } 39 40 func (q *UpdateQuery) Conn(db IConn) *UpdateQuery { 41 q.setConn(db) 42 return q 43 } 44 45 func (q *UpdateQuery) Model(model interface{}) *UpdateQuery { 46 q.setModel(model) 47 return q 48 } 49 50 func (q *UpdateQuery) Err(err error) *UpdateQuery { 51 q.setErr(err) 52 return q 53 } 54 55 // Apply calls the fn passing the SelectQuery as an argument. 56 func (q *UpdateQuery) Apply(fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery { 57 if fn != nil { 58 return fn(q) 59 } 60 return q 61 } 62 63 func (q *UpdateQuery) With(name string, query schema.QueryAppender) *UpdateQuery { 64 q.addWith(name, query, false) 65 return q 66 } 67 68 func (q *UpdateQuery) WithRecursive(name string, query schema.QueryAppender) *UpdateQuery { 69 q.addWith(name, query, true) 70 return q 71 } 72 73 //------------------------------------------------------------------------------ 74 75 func (q *UpdateQuery) Table(tables ...string) *UpdateQuery { 76 for _, table := range tables { 77 q.addTable(schema.UnsafeIdent(table)) 78 } 79 return q 80 } 81 82 func (q *UpdateQuery) TableExpr(query string, args ...interface{}) *UpdateQuery { 83 q.addTable(schema.SafeQuery(query, args)) 84 return q 85 } 86 87 func (q *UpdateQuery) ModelTableExpr(query string, args ...interface{}) *UpdateQuery { 88 q.modelTableName = schema.SafeQuery(query, args) 89 return q 90 } 91 92 //------------------------------------------------------------------------------ 93 94 func (q *UpdateQuery) Column(columns ...string) *UpdateQuery { 95 for _, column := range columns { 96 q.addColumn(schema.UnsafeIdent(column)) 97 } 98 return q 99 } 100 101 func (q *UpdateQuery) ExcludeColumn(columns ...string) *UpdateQuery { 102 q.excludeColumn(columns) 103 return q 104 } 105 106 func (q *UpdateQuery) Set(query string, args ...interface{}) *UpdateQuery { 107 q.addSet(schema.SafeQuery(query, args)) 108 return q 109 } 110 111 func (q *UpdateQuery) SetColumn(column string, query string, args ...interface{}) *UpdateQuery { 112 if q.db.HasFeature(feature.UpdateMultiTable) { 113 column = q.table.Alias + "." + column 114 } 115 q.addSet(schema.SafeQuery(column+" = "+query, args)) 116 return q 117 } 118 119 // Value overwrites model value for the column. 120 func (q *UpdateQuery) Value(column string, query string, args ...interface{}) *UpdateQuery { 121 if q.table == nil { 122 q.err = errNilModel 123 return q 124 } 125 q.addValue(q.table, column, query, args) 126 return q 127 } 128 129 func (q *UpdateQuery) OmitZero() *UpdateQuery { 130 q.omitZero = true 131 return q 132 } 133 134 //------------------------------------------------------------------------------ 135 136 func (q *UpdateQuery) WherePK(cols ...string) *UpdateQuery { 137 q.addWhereCols(cols) 138 return q 139 } 140 141 func (q *UpdateQuery) Where(query string, args ...interface{}) *UpdateQuery { 142 q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) 143 return q 144 } 145 146 func (q *UpdateQuery) WhereOr(query string, args ...interface{}) *UpdateQuery { 147 q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) 148 return q 149 } 150 151 func (q *UpdateQuery) WhereGroup(sep string, fn func(*UpdateQuery) *UpdateQuery) *UpdateQuery { 152 saved := q.where 153 q.where = nil 154 155 q = fn(q) 156 157 where := q.where 158 q.where = saved 159 160 q.addWhereGroup(sep, where) 161 162 return q 163 } 164 165 func (q *UpdateQuery) WhereDeleted() *UpdateQuery { 166 q.whereDeleted() 167 return q 168 } 169 170 func (q *UpdateQuery) WhereAllWithDeleted() *UpdateQuery { 171 q.whereAllWithDeleted() 172 return q 173 } 174 175 //------------------------------------------------------------------------------ 176 177 // Returning adds a RETURNING clause to the query. 178 // 179 // To suppress the auto-generated RETURNING clause, use `Returning("NULL")`. 180 func (q *UpdateQuery) Returning(query string, args ...interface{}) *UpdateQuery { 181 q.addReturning(schema.SafeQuery(query, args)) 182 return q 183 } 184 185 //------------------------------------------------------------------------------ 186 187 func (q *UpdateQuery) Operation() string { 188 return "UPDATE" 189 } 190 191 func (q *UpdateQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 192 if q.err != nil { 193 return nil, q.err 194 } 195 196 fmter = formatterWithModel(fmter, q) 197 198 b, err = q.appendWith(fmter, b) 199 if err != nil { 200 return nil, err 201 } 202 203 b = append(b, "UPDATE "...) 204 205 if fmter.HasFeature(feature.UpdateMultiTable) { 206 b, err = q.appendTablesWithAlias(fmter, b) 207 } else if fmter.HasFeature(feature.UpdateTableAlias) { 208 b, err = q.appendFirstTableWithAlias(fmter, b) 209 } else { 210 b, err = q.appendFirstTable(fmter, b) 211 } 212 if err != nil { 213 return nil, err 214 } 215 216 b, err = q.appendIndexHints(fmter, b) 217 if err != nil { 218 return nil, err 219 } 220 221 b, err = q.mustAppendSet(fmter, b) 222 if err != nil { 223 return nil, err 224 } 225 226 if !fmter.HasFeature(feature.UpdateMultiTable) { 227 b, err = q.appendOtherTables(fmter, b) 228 if err != nil { 229 return nil, err 230 } 231 } 232 233 if q.hasFeature(feature.Output) && q.hasReturning() { 234 b = append(b, " OUTPUT "...) 235 b, err = q.appendOutput(fmter, b) 236 if err != nil { 237 return nil, err 238 } 239 } 240 241 b, err = q.mustAppendWhere(fmter, b, q.hasTableAlias(fmter)) 242 if err != nil { 243 return nil, err 244 } 245 246 if q.hasFeature(feature.Returning) && q.hasReturning() { 247 b = append(b, " RETURNING "...) 248 b, err = q.appendReturning(fmter, b) 249 if err != nil { 250 return nil, err 251 } 252 } 253 254 return b, nil 255 } 256 257 func (q *UpdateQuery) mustAppendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) { 258 b = append(b, " SET "...) 259 260 if len(q.set) > 0 { 261 return q.appendSet(fmter, b) 262 } 263 264 if m, ok := q.model.(*mapModel); ok { 265 return m.appendSet(fmter, b), nil 266 } 267 268 if q.tableModel == nil { 269 return nil, errNilModel 270 } 271 272 switch model := q.tableModel.(type) { 273 case *structTableModel: 274 b, err = q.appendSetStruct(fmter, b, model) 275 if err != nil { 276 return nil, err 277 } 278 case *sliceTableModel: 279 return nil, errors.New("bun: to bulk Update, use CTE and VALUES") 280 default: 281 return nil, fmt.Errorf("bun: Update does not support %T", q.tableModel) 282 } 283 284 return b, nil 285 } 286 287 func (q *UpdateQuery) appendSetStruct( 288 fmter schema.Formatter, b []byte, model *structTableModel, 289 ) ([]byte, error) { 290 fields, err := q.getDataFields() 291 if err != nil { 292 return nil, err 293 } 294 295 isTemplate := fmter.IsNop() 296 pos := len(b) 297 for _, f := range fields { 298 if f.SkipUpdate() { 299 continue 300 } 301 302 app, hasValue := q.modelValues[f.Name] 303 304 if !hasValue && q.omitZero && f.HasZeroValue(model.strct) { 305 continue 306 } 307 308 if len(b) != pos { 309 b = append(b, ", "...) 310 pos = len(b) 311 } 312 313 b = append(b, f.SQLName...) 314 b = append(b, " = "...) 315 316 if isTemplate { 317 b = append(b, '?') 318 continue 319 } 320 321 if hasValue { 322 b, err = app.AppendQuery(fmter, b) 323 if err != nil { 324 return nil, err 325 } 326 } else { 327 b = f.AppendValue(fmter, b, model.strct) 328 } 329 } 330 331 for i, v := range q.extraValues { 332 if i > 0 || len(fields) > 0 { 333 b = append(b, ", "...) 334 } 335 336 b = append(b, v.column...) 337 b = append(b, " = "...) 338 339 b, err = v.value.AppendQuery(fmter, b) 340 if err != nil { 341 return nil, err 342 } 343 } 344 345 return b, nil 346 } 347 348 func (q *UpdateQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { 349 if !q.hasMultiTables() { 350 return b, nil 351 } 352 353 b = append(b, " FROM "...) 354 355 b, err = q.whereBaseQuery.appendOtherTables(fmter, b) 356 if err != nil { 357 return nil, err 358 } 359 360 return b, nil 361 } 362 363 //------------------------------------------------------------------------------ 364 365 func (q *UpdateQuery) Bulk() *UpdateQuery { 366 model, ok := q.model.(*sliceTableModel) 367 if !ok { 368 q.setErr(fmt.Errorf("bun: Bulk requires a slice, got %T", q.model)) 369 return q 370 } 371 372 set, err := q.updateSliceSet(q.db.fmter, model) 373 if err != nil { 374 q.setErr(err) 375 return q 376 } 377 378 values := q.db.NewValues(model) 379 values.customValueQuery = q.customValueQuery 380 381 return q.With("_data", values). 382 Model(model). 383 TableExpr("_data"). 384 Set(set). 385 Where(q.updateSliceWhere(q.db.fmter, model)) 386 } 387 388 func (q *UpdateQuery) updateSliceSet( 389 fmter schema.Formatter, model *sliceTableModel, 390 ) (string, error) { 391 fields, err := q.getDataFields() 392 if err != nil { 393 return "", err 394 } 395 396 var b []byte 397 pos := len(b) 398 for _, field := range fields { 399 if field.SkipUpdate() { 400 continue 401 } 402 if len(b) != pos { 403 b = append(b, ", "...) 404 pos = len(b) 405 } 406 if fmter.HasFeature(feature.UpdateMultiTable) { 407 b = append(b, model.table.SQLAlias...) 408 b = append(b, '.') 409 } 410 b = append(b, field.SQLName...) 411 b = append(b, " = _data."...) 412 b = append(b, field.SQLName...) 413 } 414 return internal.String(b), nil 415 } 416 417 func (q *UpdateQuery) updateSliceWhere(fmter schema.Formatter, model *sliceTableModel) string { 418 var b []byte 419 for i, pk := range model.table.PKs { 420 if i > 0 { 421 b = append(b, " AND "...) 422 } 423 if q.hasTableAlias(fmter) { 424 b = append(b, model.table.SQLAlias...) 425 } else { 426 b = append(b, model.table.SQLName...) 427 } 428 b = append(b, '.') 429 b = append(b, pk.SQLName...) 430 b = append(b, " = _data."...) 431 b = append(b, pk.SQLName...) 432 } 433 return internal.String(b) 434 } 435 436 //------------------------------------------------------------------------------ 437 438 func (q *UpdateQuery) Scan(ctx context.Context, dest ...interface{}) error { 439 _, err := q.scanOrExec(ctx, dest, true) 440 return err 441 } 442 443 func (q *UpdateQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { 444 return q.scanOrExec(ctx, dest, len(dest) > 0) 445 } 446 447 func (q *UpdateQuery) scanOrExec( 448 ctx context.Context, dest []interface{}, hasDest bool, 449 ) (sql.Result, error) { 450 if q.err != nil { 451 return nil, q.err 452 } 453 454 if q.table != nil { 455 if err := q.beforeUpdateHook(ctx); err != nil { 456 return nil, err 457 } 458 } 459 460 // Run append model hooks before generating the query. 461 if err := q.beforeAppendModel(ctx, q); err != nil { 462 return nil, err 463 } 464 465 // Generate the query before checking hasReturning. 466 queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) 467 if err != nil { 468 return nil, err 469 } 470 471 useScan := hasDest || (q.hasReturning() && q.hasFeature(feature.Returning|feature.Output)) 472 var model Model 473 474 if useScan { 475 var err error 476 model, err = q.getModel(dest) 477 if err != nil { 478 return nil, err 479 } 480 } 481 482 query := internal.String(queryBytes) 483 484 var res sql.Result 485 486 if useScan { 487 res, err = q.scan(ctx, q, query, model, hasDest) 488 if err != nil { 489 return nil, err 490 } 491 } else { 492 res, err = q.exec(ctx, q, query) 493 if err != nil { 494 return nil, err 495 } 496 } 497 498 if q.table != nil { 499 if err := q.afterUpdateHook(ctx); err != nil { 500 return nil, err 501 } 502 } 503 504 return res, nil 505 } 506 507 func (q *UpdateQuery) beforeUpdateHook(ctx context.Context) error { 508 if hook, ok := q.table.ZeroIface.(BeforeUpdateHook); ok { 509 if err := hook.BeforeUpdate(ctx, q); err != nil { 510 return err 511 } 512 } 513 return nil 514 } 515 516 func (q *UpdateQuery) afterUpdateHook(ctx context.Context) error { 517 if hook, ok := q.table.ZeroIface.(AfterUpdateHook); ok { 518 if err := hook.AfterUpdate(ctx, q); err != nil { 519 return err 520 } 521 } 522 return nil 523 } 524 525 // FQN returns a fully qualified column name, for example, table_name.column_name or 526 // table_alias.column_alias. 527 func (q *UpdateQuery) FQN(column string) Ident { 528 if q.table == nil { 529 panic("UpdateQuery.FQN requires a model") 530 } 531 if q.hasTableAlias(q.db.fmter) { 532 return Ident(q.table.Alias + "." + column) 533 } 534 return Ident(q.table.Name + "." + column) 535 } 536 537 func (q *UpdateQuery) hasTableAlias(fmter schema.Formatter) bool { 538 return fmter.HasFeature(feature.UpdateMultiTable | feature.UpdateTableAlias) 539 } 540 541 func (q *UpdateQuery) String() string { 542 buf, err := q.AppendQuery(q.db.Formatter(), nil) 543 if err != nil { 544 panic(err) 545 } 546 547 return string(buf) 548 } 549 550 //------------------------------------------------------------------------------ 551 552 func (q *UpdateQuery) QueryBuilder() QueryBuilder { 553 return &updateQueryBuilder{q} 554 } 555 556 func (q *UpdateQuery) ApplyQueryBuilder(fn func(QueryBuilder) QueryBuilder) *UpdateQuery { 557 return fn(q.QueryBuilder()).Unwrap().(*UpdateQuery) 558 } 559 560 type updateQueryBuilder struct { 561 *UpdateQuery 562 } 563 564 func (q *updateQueryBuilder) WhereGroup( 565 sep string, fn func(QueryBuilder) QueryBuilder, 566 ) QueryBuilder { 567 q.UpdateQuery = q.UpdateQuery.WhereGroup(sep, func(qs *UpdateQuery) *UpdateQuery { 568 return fn(q).(*updateQueryBuilder).UpdateQuery 569 }) 570 return q 571 } 572 573 func (q *updateQueryBuilder) Where(query string, args ...interface{}) QueryBuilder { 574 q.UpdateQuery.Where(query, args...) 575 return q 576 } 577 578 func (q *updateQueryBuilder) WhereOr(query string, args ...interface{}) QueryBuilder { 579 q.UpdateQuery.WhereOr(query, args...) 580 return q 581 } 582 583 func (q *updateQueryBuilder) WhereDeleted() QueryBuilder { 584 q.UpdateQuery.WhereDeleted() 585 return q 586 } 587 588 func (q *updateQueryBuilder) WhereAllWithDeleted() QueryBuilder { 589 q.UpdateQuery.WhereAllWithDeleted() 590 return q 591 } 592 593 func (q *updateQueryBuilder) WherePK(cols ...string) QueryBuilder { 594 q.UpdateQuery.WherePK(cols...) 595 return q 596 } 597 598 func (q *updateQueryBuilder) Unwrap() interface{} { 599 return q.UpdateQuery 600 } 601 602 //------------------------------------------------------------------------------ 603 604 func (q *UpdateQuery) UseIndex(indexes ...string) *UpdateQuery { 605 if q.db.dialect.Name() == dialect.MySQL { 606 q.addUseIndex(indexes...) 607 } 608 return q 609 } 610 611 func (q *UpdateQuery) IgnoreIndex(indexes ...string) *UpdateQuery { 612 if q.db.dialect.Name() == dialect.MySQL { 613 q.addIgnoreIndex(indexes...) 614 } 615 return q 616 } 617 618 func (q *UpdateQuery) ForceIndex(indexes ...string) *UpdateQuery { 619 if q.db.dialect.Name() == dialect.MySQL { 620 q.addForceIndex(indexes...) 621 } 622 return q 623 }