query_insert.go (14778B)
1 package bun 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 "reflect" 8 "strings" 9 10 "github.com/uptrace/bun/dialect/feature" 11 "github.com/uptrace/bun/internal" 12 "github.com/uptrace/bun/schema" 13 ) 14 15 type InsertQuery struct { 16 whereBaseQuery 17 returningQuery 18 customValueQuery 19 20 on schema.QueryWithArgs 21 setQuery 22 23 ignore bool 24 replace bool 25 } 26 27 var _ Query = (*InsertQuery)(nil) 28 29 func NewInsertQuery(db *DB) *InsertQuery { 30 q := &InsertQuery{ 31 whereBaseQuery: whereBaseQuery{ 32 baseQuery: baseQuery{ 33 db: db, 34 conn: db.DB, 35 }, 36 }, 37 } 38 return q 39 } 40 41 func (q *InsertQuery) Conn(db IConn) *InsertQuery { 42 q.setConn(db) 43 return q 44 } 45 46 func (q *InsertQuery) Model(model interface{}) *InsertQuery { 47 q.setModel(model) 48 return q 49 } 50 51 func (q *InsertQuery) Err(err error) *InsertQuery { 52 q.setErr(err) 53 return q 54 } 55 56 // Apply calls the fn passing the SelectQuery as an argument. 57 func (q *InsertQuery) Apply(fn func(*InsertQuery) *InsertQuery) *InsertQuery { 58 if fn != nil { 59 return fn(q) 60 } 61 return q 62 } 63 64 func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery { 65 q.addWith(name, query, false) 66 return q 67 } 68 69 func (q *InsertQuery) WithRecursive(name string, query schema.QueryAppender) *InsertQuery { 70 q.addWith(name, query, true) 71 return q 72 } 73 74 //------------------------------------------------------------------------------ 75 76 func (q *InsertQuery) Table(tables ...string) *InsertQuery { 77 for _, table := range tables { 78 q.addTable(schema.UnsafeIdent(table)) 79 } 80 return q 81 } 82 83 func (q *InsertQuery) TableExpr(query string, args ...interface{}) *InsertQuery { 84 q.addTable(schema.SafeQuery(query, args)) 85 return q 86 } 87 88 func (q *InsertQuery) ModelTableExpr(query string, args ...interface{}) *InsertQuery { 89 q.modelTableName = schema.SafeQuery(query, args) 90 return q 91 } 92 93 //------------------------------------------------------------------------------ 94 95 func (q *InsertQuery) Column(columns ...string) *InsertQuery { 96 for _, column := range columns { 97 q.addColumn(schema.UnsafeIdent(column)) 98 } 99 return q 100 } 101 102 func (q *InsertQuery) ColumnExpr(query string, args ...interface{}) *InsertQuery { 103 q.addColumn(schema.SafeQuery(query, args)) 104 return q 105 } 106 107 func (q *InsertQuery) ExcludeColumn(columns ...string) *InsertQuery { 108 q.excludeColumn(columns) 109 return q 110 } 111 112 // Value overwrites model value for the column. 113 func (q *InsertQuery) Value(column string, expr string, args ...interface{}) *InsertQuery { 114 if q.table == nil { 115 q.err = errNilModel 116 return q 117 } 118 q.addValue(q.table, column, expr, args) 119 return q 120 } 121 122 func (q *InsertQuery) Where(query string, args ...interface{}) *InsertQuery { 123 q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) 124 return q 125 } 126 127 func (q *InsertQuery) WhereOr(query string, args ...interface{}) *InsertQuery { 128 q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) 129 return q 130 } 131 132 //------------------------------------------------------------------------------ 133 134 // Returning adds a RETURNING clause to the query. 135 // 136 // To suppress the auto-generated RETURNING clause, use `Returning("")`. 137 func (q *InsertQuery) Returning(query string, args ...interface{}) *InsertQuery { 138 q.addReturning(schema.SafeQuery(query, args)) 139 return q 140 } 141 142 //------------------------------------------------------------------------------ 143 144 // Ignore generates different queries depending on the DBMS: 145 // - On MySQL, it generates `INSERT IGNORE INTO`. 146 // - On PostgreSQL, it generates `ON CONFLICT DO NOTHING`. 147 func (q *InsertQuery) Ignore() *InsertQuery { 148 if q.db.fmter.HasFeature(feature.InsertOnConflict) { 149 return q.On("CONFLICT DO NOTHING") 150 } 151 if q.db.fmter.HasFeature(feature.InsertIgnore) { 152 q.ignore = true 153 } 154 return q 155 } 156 157 // Replaces generates a `REPLACE INTO` query (MySQL and MariaDB). 158 func (q *InsertQuery) Replace() *InsertQuery { 159 q.replace = true 160 return q 161 } 162 163 //------------------------------------------------------------------------------ 164 165 func (q *InsertQuery) Operation() string { 166 return "INSERT" 167 } 168 169 func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 170 if q.err != nil { 171 return nil, q.err 172 } 173 174 fmter = formatterWithModel(fmter, q) 175 176 b, err = q.appendWith(fmter, b) 177 if err != nil { 178 return nil, err 179 } 180 181 if q.replace { 182 b = append(b, "REPLACE "...) 183 } else { 184 b = append(b, "INSERT "...) 185 if q.ignore { 186 b = append(b, "IGNORE "...) 187 } 188 } 189 b = append(b, "INTO "...) 190 191 if q.db.features.Has(feature.InsertTableAlias) && !q.on.IsZero() { 192 b, err = q.appendFirstTableWithAlias(fmter, b) 193 } else { 194 b, err = q.appendFirstTable(fmter, b) 195 } 196 if err != nil { 197 return nil, err 198 } 199 200 b, err = q.appendColumnsValues(fmter, b, false) 201 if err != nil { 202 return nil, err 203 } 204 205 b, err = q.appendOn(fmter, b) 206 if err != nil { 207 return nil, err 208 } 209 210 if q.hasFeature(feature.InsertReturning) && q.hasReturning() { 211 b = append(b, " RETURNING "...) 212 b, err = q.appendReturning(fmter, b) 213 if err != nil { 214 return nil, err 215 } 216 } 217 218 return b, nil 219 } 220 221 func (q *InsertQuery) appendColumnsValues( 222 fmter schema.Formatter, b []byte, skipOutput bool, 223 ) (_ []byte, err error) { 224 if q.hasMultiTables() { 225 if q.columns != nil { 226 b = append(b, " ("...) 227 b, err = q.appendColumns(fmter, b) 228 if err != nil { 229 return nil, err 230 } 231 b = append(b, ")"...) 232 } 233 234 if q.hasFeature(feature.Output) && q.hasReturning() { 235 b = append(b, " OUTPUT "...) 236 b, err = q.appendOutput(fmter, b) 237 if err != nil { 238 return nil, err 239 } 240 } 241 242 b = append(b, " SELECT "...) 243 244 if q.columns != nil { 245 b, err = q.appendColumns(fmter, b) 246 if err != nil { 247 return nil, err 248 } 249 } else { 250 b = append(b, "*"...) 251 } 252 253 b = append(b, " FROM "...) 254 b, err = q.appendOtherTables(fmter, b) 255 if err != nil { 256 return nil, err 257 } 258 259 return b, nil 260 } 261 262 if m, ok := q.model.(*mapModel); ok { 263 return m.appendColumnsValues(fmter, b), nil 264 } 265 if _, ok := q.model.(*mapSliceModel); ok { 266 return nil, fmt.Errorf("Insert(*[]map[string]interface{}) is not supported") 267 } 268 269 if q.model == nil { 270 return nil, errNilModel 271 } 272 273 // Build fields to populate RETURNING clause. 274 fields, err := q.getFields() 275 if err != nil { 276 return nil, err 277 } 278 279 b = append(b, " ("...) 280 b = q.appendFields(fmter, b, fields) 281 b = append(b, ")"...) 282 283 if q.hasFeature(feature.Output) && q.hasReturning() && !skipOutput { 284 b = append(b, " OUTPUT "...) 285 b, err = q.appendOutput(fmter, b) 286 if err != nil { 287 return nil, err 288 } 289 } 290 291 b = append(b, " VALUES ("...) 292 293 switch model := q.tableModel.(type) { 294 case *structTableModel: 295 b, err = q.appendStructValues(fmter, b, fields, model.strct) 296 if err != nil { 297 return nil, err 298 } 299 case *sliceTableModel: 300 b, err = q.appendSliceValues(fmter, b, fields, model.slice) 301 if err != nil { 302 return nil, err 303 } 304 default: 305 return nil, fmt.Errorf("bun: Insert does not support %T", q.tableModel) 306 } 307 308 b = append(b, ')') 309 310 return b, nil 311 } 312 313 func (q *InsertQuery) appendStructValues( 314 fmter schema.Formatter, b []byte, fields []*schema.Field, strct reflect.Value, 315 ) (_ []byte, err error) { 316 isTemplate := fmter.IsNop() 317 for i, f := range fields { 318 if i > 0 { 319 b = append(b, ", "...) 320 } 321 322 app, ok := q.modelValues[f.Name] 323 if ok { 324 b, err = app.AppendQuery(fmter, b) 325 if err != nil { 326 return nil, err 327 } 328 q.addReturningField(f) 329 continue 330 } 331 332 switch { 333 case isTemplate: 334 b = append(b, '?') 335 case (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)): 336 if q.db.features.Has(feature.DefaultPlaceholder) { 337 b = append(b, "DEFAULT"...) 338 } else if f.SQLDefault != "" { 339 b = append(b, f.SQLDefault...) 340 } else { 341 b = append(b, "NULL"...) 342 } 343 q.addReturningField(f) 344 default: 345 b = f.AppendValue(fmter, b, strct) 346 } 347 } 348 349 for i, v := range q.extraValues { 350 if i > 0 || len(fields) > 0 { 351 b = append(b, ", "...) 352 } 353 354 b, err = v.value.AppendQuery(fmter, b) 355 if err != nil { 356 return nil, err 357 } 358 } 359 360 return b, nil 361 } 362 363 func (q *InsertQuery) appendSliceValues( 364 fmter schema.Formatter, b []byte, fields []*schema.Field, slice reflect.Value, 365 ) (_ []byte, err error) { 366 if fmter.IsNop() { 367 return q.appendStructValues(fmter, b, fields, reflect.Value{}) 368 } 369 370 sliceLen := slice.Len() 371 for i := 0; i < sliceLen; i++ { 372 if i > 0 { 373 b = append(b, "), ("...) 374 } 375 el := indirect(slice.Index(i)) 376 b, err = q.appendStructValues(fmter, b, fields, el) 377 if err != nil { 378 return nil, err 379 } 380 } 381 382 return b, nil 383 } 384 385 func (q *InsertQuery) getFields() ([]*schema.Field, error) { 386 hasIdentity := q.db.features.Has(feature.Identity) 387 388 if len(q.columns) > 0 || q.db.features.Has(feature.DefaultPlaceholder) && !hasIdentity { 389 return q.baseQuery.getFields() 390 } 391 392 var strct reflect.Value 393 394 switch model := q.tableModel.(type) { 395 case *structTableModel: 396 strct = model.strct 397 case *sliceTableModel: 398 if model.sliceLen == 0 { 399 return nil, fmt.Errorf("bun: Insert(empty %T)", model.slice.Type()) 400 } 401 strct = indirect(model.slice.Index(0)) 402 default: 403 return nil, errNilModel 404 } 405 406 fields := make([]*schema.Field, 0, len(q.table.Fields)) 407 408 for _, f := range q.table.Fields { 409 if hasIdentity && f.AutoIncrement { 410 q.addReturningField(f) 411 continue 412 } 413 if f.NotNull && f.SQLDefault == "" { 414 if (f.IsPtr && f.HasNilValue(strct)) || (f.NullZero && f.HasZeroValue(strct)) { 415 q.addReturningField(f) 416 continue 417 } 418 } 419 fields = append(fields, f) 420 } 421 422 return fields, nil 423 } 424 425 func (q *InsertQuery) appendFields( 426 fmter schema.Formatter, b []byte, fields []*schema.Field, 427 ) []byte { 428 b = appendColumns(b, "", fields) 429 for i, v := range q.extraValues { 430 if i > 0 || len(fields) > 0 { 431 b = append(b, ", "...) 432 } 433 b = fmter.AppendIdent(b, v.column) 434 } 435 return b 436 } 437 438 //------------------------------------------------------------------------------ 439 440 func (q *InsertQuery) On(s string, args ...interface{}) *InsertQuery { 441 q.on = schema.SafeQuery(s, args) 442 return q 443 } 444 445 func (q *InsertQuery) Set(query string, args ...interface{}) *InsertQuery { 446 q.addSet(schema.SafeQuery(query, args)) 447 return q 448 } 449 450 func (q *InsertQuery) appendOn(fmter schema.Formatter, b []byte) (_ []byte, err error) { 451 if q.on.IsZero() { 452 return b, nil 453 } 454 455 b = append(b, " ON "...) 456 b, err = q.on.AppendQuery(fmter, b) 457 if err != nil { 458 return nil, err 459 } 460 461 if len(q.set) > 0 { 462 if fmter.HasFeature(feature.InsertOnDuplicateKey) { 463 b = append(b, ' ') 464 } else { 465 b = append(b, " SET "...) 466 } 467 468 b, err = q.appendSet(fmter, b) 469 if err != nil { 470 return nil, err 471 } 472 } else if q.onConflictDoUpdate() { 473 fields, err := q.getDataFields() 474 if err != nil { 475 return nil, err 476 } 477 478 if len(fields) == 0 { 479 fields = q.tableModel.Table().DataFields 480 } 481 482 b = q.appendSetExcluded(b, fields) 483 } else if q.onDuplicateKeyUpdate() { 484 fields, err := q.getDataFields() 485 if err != nil { 486 return nil, err 487 } 488 489 if len(fields) == 0 { 490 fields = q.tableModel.Table().DataFields 491 } 492 493 b = q.appendSetValues(b, fields) 494 } 495 496 if len(q.where) > 0 { 497 b = append(b, " WHERE "...) 498 499 b, err = appendWhere(fmter, b, q.where) 500 if err != nil { 501 return nil, err 502 } 503 } 504 505 return b, nil 506 } 507 508 func (q *InsertQuery) onConflictDoUpdate() bool { 509 return strings.HasSuffix(strings.ToUpper(q.on.Query), " DO UPDATE") 510 } 511 512 func (q *InsertQuery) onDuplicateKeyUpdate() bool { 513 return strings.ToUpper(q.on.Query) == "DUPLICATE KEY UPDATE" 514 } 515 516 func (q *InsertQuery) appendSetExcluded(b []byte, fields []*schema.Field) []byte { 517 b = append(b, " SET "...) 518 for i, f := range fields { 519 if i > 0 { 520 b = append(b, ", "...) 521 } 522 b = append(b, f.SQLName...) 523 b = append(b, " = EXCLUDED."...) 524 b = append(b, f.SQLName...) 525 } 526 return b 527 } 528 529 func (q *InsertQuery) appendSetValues(b []byte, fields []*schema.Field) []byte { 530 b = append(b, " "...) 531 for i, f := range fields { 532 if i > 0 { 533 b = append(b, ", "...) 534 } 535 b = append(b, f.SQLName...) 536 b = append(b, " = VALUES("...) 537 b = append(b, f.SQLName...) 538 b = append(b, ")"...) 539 } 540 return b 541 } 542 543 //------------------------------------------------------------------------------ 544 545 func (q *InsertQuery) Scan(ctx context.Context, dest ...interface{}) error { 546 _, err := q.scanOrExec(ctx, dest, true) 547 return err 548 } 549 550 func (q *InsertQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) { 551 return q.scanOrExec(ctx, dest, len(dest) > 0) 552 } 553 554 func (q *InsertQuery) scanOrExec( 555 ctx context.Context, dest []interface{}, hasDest bool, 556 ) (sql.Result, error) { 557 if q.err != nil { 558 return nil, q.err 559 } 560 561 if q.table != nil { 562 if err := q.beforeInsertHook(ctx); err != nil { 563 return nil, err 564 } 565 } 566 567 // Run append model hooks before generating the query. 568 if err := q.beforeAppendModel(ctx, q); err != nil { 569 return nil, err 570 } 571 572 // Generate the query before checking hasReturning. 573 queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) 574 if err != nil { 575 return nil, err 576 } 577 578 useScan := hasDest || (q.hasReturning() && q.hasFeature(feature.InsertReturning|feature.Output)) 579 var model Model 580 581 if useScan { 582 var err error 583 model, err = q.getModel(dest) 584 if err != nil { 585 return nil, err 586 } 587 } 588 589 query := internal.String(queryBytes) 590 var res sql.Result 591 592 if useScan { 593 res, err = q.scan(ctx, q, query, model, hasDest) 594 if err != nil { 595 return nil, err 596 } 597 } else { 598 res, err = q.exec(ctx, q, query) 599 if err != nil { 600 return nil, err 601 } 602 603 if err := q.tryLastInsertID(res, dest); err != nil { 604 return nil, err 605 } 606 } 607 608 if q.table != nil { 609 if err := q.afterInsertHook(ctx); err != nil { 610 return nil, err 611 } 612 } 613 614 return res, nil 615 } 616 617 func (q *InsertQuery) beforeInsertHook(ctx context.Context) error { 618 if hook, ok := q.table.ZeroIface.(BeforeInsertHook); ok { 619 if err := hook.BeforeInsert(ctx, q); err != nil { 620 return err 621 } 622 } 623 return nil 624 } 625 626 func (q *InsertQuery) afterInsertHook(ctx context.Context) error { 627 if hook, ok := q.table.ZeroIface.(AfterInsertHook); ok { 628 if err := hook.AfterInsert(ctx, q); err != nil { 629 return err 630 } 631 } 632 return nil 633 } 634 635 func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error { 636 if q.db.features.Has(feature.Returning) || 637 q.db.features.Has(feature.Output) || 638 q.table == nil || 639 len(q.table.PKs) != 1 || 640 !q.table.PKs[0].AutoIncrement { 641 return nil 642 } 643 644 id, err := res.LastInsertId() 645 if err != nil { 646 return err 647 } 648 if id == 0 { 649 return nil 650 } 651 652 model, err := q.getModel(dest) 653 if err != nil { 654 return err 655 } 656 657 pk := q.table.PKs[0] 658 switch model := model.(type) { 659 case *structTableModel: 660 if err := pk.ScanValue(model.strct, id); err != nil { 661 return err 662 } 663 case *sliceTableModel: 664 sliceLen := model.slice.Len() 665 for i := 0; i < sliceLen; i++ { 666 strct := indirect(model.slice.Index(i)) 667 if err := pk.ScanValue(strct, id); err != nil { 668 return err 669 } 670 id++ 671 } 672 } 673 674 return nil 675 } 676 677 func (q *InsertQuery) String() string { 678 buf, err := q.AppendQuery(q.db.Formatter(), nil) 679 if err != nil { 680 panic(err) 681 } 682 683 return string(buf) 684 }