query_base.go (29372B)
1 package bun 2 3 import ( 4 "context" 5 "database/sql" 6 "database/sql/driver" 7 "errors" 8 "fmt" 9 "time" 10 11 "github.com/uptrace/bun/dialect/feature" 12 "github.com/uptrace/bun/internal" 13 "github.com/uptrace/bun/schema" 14 ) 15 16 const ( 17 forceDeleteFlag internal.Flag = 1 << iota 18 deletedFlag 19 allWithDeletedFlag 20 ) 21 22 type withQuery struct { 23 name string 24 query schema.QueryAppender 25 recursive bool 26 } 27 28 // IConn is a common interface for *sql.DB, *sql.Conn, and *sql.Tx. 29 type IConn interface { 30 QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) 31 ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) 32 QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row 33 } 34 35 var ( 36 _ IConn = (*sql.DB)(nil) 37 _ IConn = (*sql.Conn)(nil) 38 _ IConn = (*sql.Tx)(nil) 39 _ IConn = (*DB)(nil) 40 _ IConn = (*Conn)(nil) 41 _ IConn = (*Tx)(nil) 42 ) 43 44 // IDB is a common interface for *bun.DB, bun.Conn, and bun.Tx. 45 type IDB interface { 46 IConn 47 Dialect() schema.Dialect 48 49 NewValues(model interface{}) *ValuesQuery 50 NewSelect() *SelectQuery 51 NewInsert() *InsertQuery 52 NewUpdate() *UpdateQuery 53 NewDelete() *DeleteQuery 54 NewRaw(query string, args ...interface{}) *RawQuery 55 NewCreateTable() *CreateTableQuery 56 NewDropTable() *DropTableQuery 57 NewCreateIndex() *CreateIndexQuery 58 NewDropIndex() *DropIndexQuery 59 NewTruncateTable() *TruncateTableQuery 60 NewAddColumn() *AddColumnQuery 61 NewDropColumn() *DropColumnQuery 62 63 BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) 64 RunInTx(ctx context.Context, opts *sql.TxOptions, f func(ctx context.Context, tx Tx) error) error 65 } 66 67 var ( 68 _ IDB = (*DB)(nil) 69 _ IDB = (*Conn)(nil) 70 _ IDB = (*Tx)(nil) 71 ) 72 73 // QueryBuilder is used for common query methods 74 type QueryBuilder interface { 75 Query 76 Where(query string, args ...interface{}) QueryBuilder 77 WhereGroup(sep string, fn func(QueryBuilder) QueryBuilder) QueryBuilder 78 WhereOr(query string, args ...interface{}) QueryBuilder 79 WhereDeleted() QueryBuilder 80 WhereAllWithDeleted() QueryBuilder 81 WherePK(cols ...string) QueryBuilder 82 Unwrap() interface{} 83 } 84 85 var ( 86 _ QueryBuilder = (*selectQueryBuilder)(nil) 87 _ QueryBuilder = (*updateQueryBuilder)(nil) 88 _ QueryBuilder = (*deleteQueryBuilder)(nil) 89 ) 90 91 type baseQuery struct { 92 db *DB 93 conn IConn 94 95 model Model 96 err error 97 98 tableModel TableModel 99 table *schema.Table 100 101 with []withQuery 102 modelTableName schema.QueryWithArgs 103 tables []schema.QueryWithArgs 104 columns []schema.QueryWithArgs 105 106 flags internal.Flag 107 } 108 109 func (q *baseQuery) DB() *DB { 110 return q.db 111 } 112 113 func (q *baseQuery) GetConn() IConn { 114 return q.conn 115 } 116 117 func (q *baseQuery) GetModel() Model { 118 return q.model 119 } 120 121 func (q *baseQuery) GetTableName() string { 122 if q.table != nil { 123 return q.table.Name 124 } 125 126 for _, wq := range q.with { 127 if v, ok := wq.query.(Query); ok { 128 if model := v.GetModel(); model != nil { 129 return v.GetTableName() 130 } 131 } 132 } 133 134 if q.modelTableName.Query != "" { 135 return q.modelTableName.Query 136 } 137 138 if len(q.tables) > 0 { 139 b, _ := q.tables[0].AppendQuery(q.db.fmter, nil) 140 if len(b) < 64 { 141 return string(b) 142 } 143 } 144 145 return "" 146 } 147 148 func (q *baseQuery) setConn(db IConn) { 149 // Unwrap Bun wrappers to not call query hooks twice. 150 switch db := db.(type) { 151 case *DB: 152 q.conn = db.DB 153 case Conn: 154 q.conn = db.Conn 155 case Tx: 156 q.conn = db.Tx 157 default: 158 q.conn = db 159 } 160 } 161 162 func (q *baseQuery) setModel(modeli interface{}) { 163 model, err := newSingleModel(q.db, modeli) 164 if err != nil { 165 q.setErr(err) 166 return 167 } 168 169 q.model = model 170 if tm, ok := model.(TableModel); ok { 171 q.tableModel = tm 172 q.table = tm.Table() 173 } 174 } 175 176 func (q *baseQuery) setErr(err error) { 177 if q.err == nil { 178 q.err = err 179 } 180 } 181 182 func (q *baseQuery) getModel(dest []interface{}) (Model, error) { 183 if len(dest) > 0 { 184 return newModel(q.db, dest) 185 } 186 if q.model != nil { 187 return q.model, nil 188 } 189 return nil, errNilModel 190 } 191 192 func (q *baseQuery) beforeAppendModel(ctx context.Context, query Query) error { 193 if q.tableModel != nil { 194 return q.tableModel.BeforeAppendModel(ctx, query) 195 } 196 return nil 197 } 198 199 func (q *baseQuery) hasFeature(feature feature.Feature) bool { 200 return q.db.features.Has(feature) 201 } 202 203 //------------------------------------------------------------------------------ 204 205 func (q *baseQuery) checkSoftDelete() error { 206 if q.table == nil { 207 return errors.New("bun: can't use soft deletes without a table") 208 } 209 if q.table.SoftDeleteField == nil { 210 return fmt.Errorf("%s does not have a soft delete field", q.table) 211 } 212 if q.tableModel == nil { 213 return errors.New("bun: can't use soft deletes without a table model") 214 } 215 return nil 216 } 217 218 // Deleted adds `WHERE deleted_at IS NOT NULL` clause for soft deleted models. 219 func (q *baseQuery) whereDeleted() { 220 if err := q.checkSoftDelete(); err != nil { 221 q.setErr(err) 222 return 223 } 224 q.flags = q.flags.Set(deletedFlag) 225 q.flags = q.flags.Remove(allWithDeletedFlag) 226 } 227 228 // AllWithDeleted changes query to return all rows including soft deleted ones. 229 func (q *baseQuery) whereAllWithDeleted() { 230 if err := q.checkSoftDelete(); err != nil { 231 q.setErr(err) 232 return 233 } 234 q.flags = q.flags.Set(allWithDeletedFlag).Remove(deletedFlag) 235 } 236 237 func (q *baseQuery) isSoftDelete() bool { 238 if q.table != nil { 239 return q.table.SoftDeleteField != nil && 240 !q.flags.Has(allWithDeletedFlag) && 241 (!q.flags.Has(forceDeleteFlag) || q.flags.Has(deletedFlag)) 242 } 243 return false 244 } 245 246 //------------------------------------------------------------------------------ 247 248 func (q *baseQuery) addWith(name string, query schema.QueryAppender, recursive bool) { 249 q.with = append(q.with, withQuery{ 250 name: name, 251 query: query, 252 recursive: recursive, 253 }) 254 } 255 256 func (q *baseQuery) appendWith(fmter schema.Formatter, b []byte) (_ []byte, err error) { 257 if len(q.with) == 0 { 258 return b, nil 259 } 260 261 b = append(b, "WITH "...) 262 for i, with := range q.with { 263 if i > 0 { 264 b = append(b, ", "...) 265 } 266 267 if with.recursive { 268 b = append(b, "RECURSIVE "...) 269 } 270 271 b, err = q.appendCTE(fmter, b, with) 272 if err != nil { 273 return nil, err 274 } 275 } 276 b = append(b, ' ') 277 return b, nil 278 } 279 280 func (q *baseQuery) appendCTE( 281 fmter schema.Formatter, b []byte, cte withQuery, 282 ) (_ []byte, err error) { 283 if !fmter.Dialect().Features().Has(feature.WithValues) { 284 if values, ok := cte.query.(*ValuesQuery); ok { 285 return q.appendSelectFromValues(fmter, b, cte, values) 286 } 287 } 288 289 b = fmter.AppendIdent(b, cte.name) 290 291 if q, ok := cte.query.(schema.ColumnsAppender); ok { 292 b = append(b, " ("...) 293 b, err = q.AppendColumns(fmter, b) 294 if err != nil { 295 return nil, err 296 } 297 b = append(b, ")"...) 298 } 299 300 b = append(b, " AS ("...) 301 302 b, err = cte.query.AppendQuery(fmter, b) 303 if err != nil { 304 return nil, err 305 } 306 307 b = append(b, ")"...) 308 return b, nil 309 } 310 311 func (q *baseQuery) appendSelectFromValues( 312 fmter schema.Formatter, b []byte, cte withQuery, values *ValuesQuery, 313 ) (_ []byte, err error) { 314 b = fmter.AppendIdent(b, cte.name) 315 b = append(b, " AS (SELECT * FROM ("...) 316 317 b, err = cte.query.AppendQuery(fmter, b) 318 if err != nil { 319 return nil, err 320 } 321 322 b = append(b, ") AS t"...) 323 if q, ok := cte.query.(schema.ColumnsAppender); ok { 324 b = append(b, " ("...) 325 b, err = q.AppendColumns(fmter, b) 326 if err != nil { 327 return nil, err 328 } 329 b = append(b, ")"...) 330 } 331 b = append(b, ")"...) 332 333 return b, nil 334 } 335 336 //------------------------------------------------------------------------------ 337 338 func (q *baseQuery) addTable(table schema.QueryWithArgs) { 339 q.tables = append(q.tables, table) 340 } 341 342 func (q *baseQuery) addColumn(column schema.QueryWithArgs) { 343 q.columns = append(q.columns, column) 344 } 345 346 func (q *baseQuery) excludeColumn(columns []string) { 347 if q.table == nil { 348 q.setErr(errNilModel) 349 return 350 } 351 352 if q.columns == nil { 353 for _, f := range q.table.Fields { 354 q.columns = append(q.columns, schema.UnsafeIdent(f.Name)) 355 } 356 } 357 358 if len(columns) == 1 && columns[0] == "*" { 359 q.columns = make([]schema.QueryWithArgs, 0) 360 return 361 } 362 363 for _, column := range columns { 364 if !q._excludeColumn(column) { 365 q.setErr(fmt.Errorf("bun: can't find column=%q", column)) 366 return 367 } 368 } 369 } 370 371 func (q *baseQuery) _excludeColumn(column string) bool { 372 for i, col := range q.columns { 373 if col.Args == nil && col.Query == column { 374 q.columns = append(q.columns[:i], q.columns[i+1:]...) 375 return true 376 } 377 } 378 return false 379 } 380 381 //------------------------------------------------------------------------------ 382 383 func (q *baseQuery) modelHasTableName() bool { 384 if !q.modelTableName.IsZero() { 385 return q.modelTableName.Query != "" 386 } 387 return q.table != nil 388 } 389 390 func (q *baseQuery) hasTables() bool { 391 return q.modelHasTableName() || len(q.tables) > 0 392 } 393 394 func (q *baseQuery) appendTables( 395 fmter schema.Formatter, b []byte, 396 ) (_ []byte, err error) { 397 return q._appendTables(fmter, b, false) 398 } 399 400 func (q *baseQuery) appendTablesWithAlias( 401 fmter schema.Formatter, b []byte, 402 ) (_ []byte, err error) { 403 return q._appendTables(fmter, b, true) 404 } 405 406 func (q *baseQuery) _appendTables( 407 fmter schema.Formatter, b []byte, withAlias bool, 408 ) (_ []byte, err error) { 409 startLen := len(b) 410 411 if q.modelHasTableName() { 412 if !q.modelTableName.IsZero() { 413 b, err = q.modelTableName.AppendQuery(fmter, b) 414 if err != nil { 415 return nil, err 416 } 417 } else { 418 b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects)) 419 if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects { 420 b = append(b, " AS "...) 421 b = append(b, q.table.SQLAlias...) 422 } 423 } 424 } 425 426 for _, table := range q.tables { 427 if len(b) > startLen { 428 b = append(b, ", "...) 429 } 430 b, err = table.AppendQuery(fmter, b) 431 if err != nil { 432 return nil, err 433 } 434 } 435 436 return b, nil 437 } 438 439 func (q *baseQuery) appendFirstTable(fmter schema.Formatter, b []byte) ([]byte, error) { 440 return q._appendFirstTable(fmter, b, false) 441 } 442 443 func (q *baseQuery) appendFirstTableWithAlias( 444 fmter schema.Formatter, b []byte, 445 ) ([]byte, error) { 446 return q._appendFirstTable(fmter, b, true) 447 } 448 449 func (q *baseQuery) _appendFirstTable( 450 fmter schema.Formatter, b []byte, withAlias bool, 451 ) ([]byte, error) { 452 if !q.modelTableName.IsZero() { 453 return q.modelTableName.AppendQuery(fmter, b) 454 } 455 456 if q.table != nil { 457 b = fmter.AppendQuery(b, string(q.table.SQLName)) 458 if withAlias { 459 b = append(b, " AS "...) 460 b = append(b, q.table.SQLAlias...) 461 } 462 return b, nil 463 } 464 465 if len(q.tables) > 0 { 466 return q.tables[0].AppendQuery(fmter, b) 467 } 468 469 return nil, errors.New("bun: query does not have a table") 470 } 471 472 func (q *baseQuery) hasMultiTables() bool { 473 if q.modelHasTableName() { 474 return len(q.tables) >= 1 475 } 476 return len(q.tables) >= 2 477 } 478 479 func (q *baseQuery) appendOtherTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { 480 tables := q.tables 481 if !q.modelHasTableName() { 482 tables = tables[1:] 483 } 484 for i, table := range tables { 485 if i > 0 { 486 b = append(b, ", "...) 487 } 488 b, err = table.AppendQuery(fmter, b) 489 if err != nil { 490 return nil, err 491 } 492 } 493 return b, nil 494 } 495 496 //------------------------------------------------------------------------------ 497 498 func (q *baseQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { 499 for i, f := range q.columns { 500 if i > 0 { 501 b = append(b, ", "...) 502 } 503 b, err = f.AppendQuery(fmter, b) 504 if err != nil { 505 return nil, err 506 } 507 } 508 return b, nil 509 } 510 511 func (q *baseQuery) getFields() ([]*schema.Field, error) { 512 if len(q.columns) == 0 { 513 if q.table == nil { 514 return nil, errNilModel 515 } 516 return q.table.Fields, nil 517 } 518 return q._getFields(false) 519 } 520 521 func (q *baseQuery) getDataFields() ([]*schema.Field, error) { 522 if len(q.columns) == 0 { 523 if q.table == nil { 524 return nil, errNilModel 525 } 526 return q.table.DataFields, nil 527 } 528 return q._getFields(true) 529 } 530 531 func (q *baseQuery) _getFields(omitPK bool) ([]*schema.Field, error) { 532 fields := make([]*schema.Field, 0, len(q.columns)) 533 for _, col := range q.columns { 534 if col.Args != nil { 535 continue 536 } 537 538 field, err := q.table.Field(col.Query) 539 if err != nil { 540 return nil, err 541 } 542 543 if omitPK && field.IsPK { 544 continue 545 } 546 547 fields = append(fields, field) 548 } 549 return fields, nil 550 } 551 552 func (q *baseQuery) scan( 553 ctx context.Context, 554 iquery Query, 555 query string, 556 model Model, 557 hasDest bool, 558 ) (sql.Result, error) { 559 ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) 560 561 rows, err := q.conn.QueryContext(ctx, query) 562 if err != nil { 563 q.db.afterQuery(ctx, event, nil, err) 564 return nil, err 565 } 566 defer rows.Close() 567 568 numRow, err := model.ScanRows(ctx, rows) 569 if err != nil { 570 q.db.afterQuery(ctx, event, nil, err) 571 return nil, err 572 } 573 574 if numRow == 0 && hasDest && isSingleRowModel(model) { 575 err = sql.ErrNoRows 576 } 577 578 res := driver.RowsAffected(numRow) 579 q.db.afterQuery(ctx, event, res, err) 580 581 return res, err 582 } 583 584 func (q *baseQuery) exec( 585 ctx context.Context, 586 iquery Query, 587 query string, 588 ) (sql.Result, error) { 589 ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model) 590 res, err := q.conn.ExecContext(ctx, query) 591 q.db.afterQuery(ctx, event, nil, err) 592 return res, err 593 } 594 595 //------------------------------------------------------------------------------ 596 597 func (q *baseQuery) AppendNamedArg(fmter schema.Formatter, b []byte, name string) ([]byte, bool) { 598 if q.table == nil { 599 return b, false 600 } 601 602 if m, ok := q.tableModel.(*structTableModel); ok { 603 if b, ok := m.AppendNamedArg(fmter, b, name); ok { 604 return b, ok 605 } 606 } 607 608 switch name { 609 case "TableName": 610 b = fmter.AppendQuery(b, string(q.table.SQLName)) 611 return b, true 612 case "TableAlias": 613 b = fmter.AppendQuery(b, string(q.table.SQLAlias)) 614 return b, true 615 case "PKs": 616 b = appendColumns(b, "", q.table.PKs) 617 return b, true 618 case "TablePKs": 619 b = appendColumns(b, q.table.SQLAlias, q.table.PKs) 620 return b, true 621 case "Columns": 622 b = appendColumns(b, "", q.table.Fields) 623 return b, true 624 case "TableColumns": 625 b = appendColumns(b, q.table.SQLAlias, q.table.Fields) 626 return b, true 627 } 628 629 return b, false 630 } 631 632 //------------------------------------------------------------------------------ 633 634 func (q *baseQuery) Dialect() schema.Dialect { 635 return q.db.Dialect() 636 } 637 638 func (q *baseQuery) NewValues(model interface{}) *ValuesQuery { 639 return NewValuesQuery(q.db, model).Conn(q.conn) 640 } 641 642 func (q *baseQuery) NewSelect() *SelectQuery { 643 return NewSelectQuery(q.db).Conn(q.conn) 644 } 645 646 func (q *baseQuery) NewInsert() *InsertQuery { 647 return NewInsertQuery(q.db).Conn(q.conn) 648 } 649 650 func (q *baseQuery) NewUpdate() *UpdateQuery { 651 return NewUpdateQuery(q.db).Conn(q.conn) 652 } 653 654 func (q *baseQuery) NewDelete() *DeleteQuery { 655 return NewDeleteQuery(q.db).Conn(q.conn) 656 } 657 658 func (q *baseQuery) NewRaw(query string, args ...interface{}) *RawQuery { 659 return NewRawQuery(q.db, query, args...).Conn(q.conn) 660 } 661 662 func (q *baseQuery) NewCreateTable() *CreateTableQuery { 663 return NewCreateTableQuery(q.db).Conn(q.conn) 664 } 665 666 func (q *baseQuery) NewDropTable() *DropTableQuery { 667 return NewDropTableQuery(q.db).Conn(q.conn) 668 } 669 670 func (q *baseQuery) NewCreateIndex() *CreateIndexQuery { 671 return NewCreateIndexQuery(q.db).Conn(q.conn) 672 } 673 674 func (q *baseQuery) NewDropIndex() *DropIndexQuery { 675 return NewDropIndexQuery(q.db).Conn(q.conn) 676 } 677 678 func (q *baseQuery) NewTruncateTable() *TruncateTableQuery { 679 return NewTruncateTableQuery(q.db).Conn(q.conn) 680 } 681 682 func (q *baseQuery) NewAddColumn() *AddColumnQuery { 683 return NewAddColumnQuery(q.db).Conn(q.conn) 684 } 685 686 func (q *baseQuery) NewDropColumn() *DropColumnQuery { 687 return NewDropColumnQuery(q.db).Conn(q.conn) 688 } 689 690 //------------------------------------------------------------------------------ 691 692 func appendColumns(b []byte, table schema.Safe, fields []*schema.Field) []byte { 693 for i, f := range fields { 694 if i > 0 { 695 b = append(b, ", "...) 696 } 697 698 if len(table) > 0 { 699 b = append(b, table...) 700 b = append(b, '.') 701 } 702 b = append(b, f.SQLName...) 703 } 704 return b 705 } 706 707 func formatterWithModel(fmter schema.Formatter, model schema.NamedArgAppender) schema.Formatter { 708 if fmter.IsNop() { 709 return fmter 710 } 711 return fmter.WithArg(model) 712 } 713 714 //------------------------------------------------------------------------------ 715 716 type whereBaseQuery struct { 717 baseQuery 718 719 where []schema.QueryWithSep 720 whereFields []*schema.Field 721 } 722 723 func (q *whereBaseQuery) addWhere(where schema.QueryWithSep) { 724 q.where = append(q.where, where) 725 } 726 727 func (q *whereBaseQuery) addWhereGroup(sep string, where []schema.QueryWithSep) { 728 if len(where) == 0 { 729 return 730 } 731 732 q.addWhere(schema.SafeQueryWithSep("", nil, sep)) 733 q.addWhere(schema.SafeQueryWithSep("", nil, "(")) 734 735 where[0].Sep = "" 736 q.where = append(q.where, where...) 737 738 q.addWhere(schema.SafeQueryWithSep("", nil, ")")) 739 } 740 741 func (q *whereBaseQuery) addWhereCols(cols []string) { 742 if q.table == nil { 743 err := fmt.Errorf("bun: got %T, but WherePK requires a struct or slice-based model", q.model) 744 q.setErr(err) 745 return 746 } 747 if q.whereFields != nil { 748 err := errors.New("bun: WherePK can only be called once") 749 q.setErr(err) 750 return 751 } 752 753 if cols == nil { 754 if err := q.table.CheckPKs(); err != nil { 755 q.setErr(err) 756 return 757 } 758 q.whereFields = q.table.PKs 759 return 760 } 761 762 q.whereFields = make([]*schema.Field, len(cols)) 763 for i, col := range cols { 764 field, err := q.table.Field(col) 765 if err != nil { 766 q.setErr(err) 767 return 768 } 769 q.whereFields[i] = field 770 } 771 } 772 773 func (q *whereBaseQuery) mustAppendWhere( 774 fmter schema.Formatter, b []byte, withAlias bool, 775 ) ([]byte, error) { 776 if len(q.where) == 0 && q.whereFields == nil && !q.flags.Has(deletedFlag) { 777 err := errors.New("bun: Update and Delete queries require at least one Where") 778 return nil, err 779 } 780 return q.appendWhere(fmter, b, withAlias) 781 } 782 783 func (q *whereBaseQuery) appendWhere( 784 fmter schema.Formatter, b []byte, withAlias bool, 785 ) (_ []byte, err error) { 786 if len(q.where) == 0 && q.whereFields == nil && !q.isSoftDelete() { 787 return b, nil 788 } 789 790 b = append(b, " WHERE "...) 791 startLen := len(b) 792 793 if len(q.where) > 0 { 794 b, err = appendWhere(fmter, b, q.where) 795 if err != nil { 796 return nil, err 797 } 798 } 799 800 if q.isSoftDelete() { 801 if len(b) > startLen { 802 b = append(b, " AND "...) 803 } 804 805 if withAlias { 806 b = append(b, q.tableModel.Table().SQLAlias...) 807 } else { 808 b = append(b, q.tableModel.Table().SQLName...) 809 } 810 b = append(b, '.') 811 812 field := q.tableModel.Table().SoftDeleteField 813 b = append(b, field.SQLName...) 814 815 if field.IsPtr || field.NullZero { 816 if q.flags.Has(deletedFlag) { 817 b = append(b, " IS NOT NULL"...) 818 } else { 819 b = append(b, " IS NULL"...) 820 } 821 } else { 822 if q.flags.Has(deletedFlag) { 823 b = append(b, " != "...) 824 } else { 825 b = append(b, " = "...) 826 } 827 b = fmter.Dialect().AppendTime(b, time.Time{}) 828 } 829 } 830 831 if q.whereFields != nil { 832 if len(b) > startLen { 833 b = append(b, " AND "...) 834 } 835 b, err = q.appendWhereFields(fmter, b, q.whereFields, withAlias) 836 if err != nil { 837 return nil, err 838 } 839 } 840 841 return b, nil 842 } 843 844 func appendWhere( 845 fmter schema.Formatter, b []byte, where []schema.QueryWithSep, 846 ) (_ []byte, err error) { 847 for i, where := range where { 848 if i > 0 { 849 b = append(b, where.Sep...) 850 } 851 852 if where.Query == "" { 853 continue 854 } 855 856 b = append(b, '(') 857 b, err = where.AppendQuery(fmter, b) 858 if err != nil { 859 return nil, err 860 } 861 b = append(b, ')') 862 } 863 return b, nil 864 } 865 866 func (q *whereBaseQuery) appendWhereFields( 867 fmter schema.Formatter, b []byte, fields []*schema.Field, withAlias bool, 868 ) (_ []byte, err error) { 869 if q.table == nil { 870 err := fmt.Errorf("bun: got %T, but WherePK requires struct or slice-based model", q.model) 871 return nil, err 872 } 873 874 switch model := q.tableModel.(type) { 875 case *structTableModel: 876 return q.appendWhereStructFields(fmter, b, model, fields, withAlias) 877 case *sliceTableModel: 878 return q.appendWhereSliceFields(fmter, b, model, fields, withAlias) 879 default: 880 return nil, fmt.Errorf("bun: WhereColumn does not support %T", q.tableModel) 881 } 882 } 883 884 func (q *whereBaseQuery) appendWhereStructFields( 885 fmter schema.Formatter, 886 b []byte, 887 model *structTableModel, 888 fields []*schema.Field, 889 withAlias bool, 890 ) (_ []byte, err error) { 891 if !model.strct.IsValid() { 892 return nil, errNilModel 893 } 894 895 isTemplate := fmter.IsNop() 896 b = append(b, '(') 897 for i, f := range fields { 898 if i > 0 { 899 b = append(b, " AND "...) 900 } 901 if withAlias { 902 b = append(b, q.table.SQLAlias...) 903 b = append(b, '.') 904 } 905 b = append(b, f.SQLName...) 906 b = append(b, " = "...) 907 if isTemplate { 908 b = append(b, '?') 909 } else { 910 b = f.AppendValue(fmter, b, model.strct) 911 } 912 } 913 b = append(b, ')') 914 return b, nil 915 } 916 917 func (q *whereBaseQuery) appendWhereSliceFields( 918 fmter schema.Formatter, 919 b []byte, 920 model *sliceTableModel, 921 fields []*schema.Field, 922 withAlias bool, 923 ) (_ []byte, err error) { 924 if len(fields) > 1 { 925 b = append(b, '(') 926 } 927 if withAlias { 928 b = appendColumns(b, q.table.SQLAlias, fields) 929 } else { 930 b = appendColumns(b, "", fields) 931 } 932 if len(fields) > 1 { 933 b = append(b, ')') 934 } 935 936 b = append(b, " IN ("...) 937 938 isTemplate := fmter.IsNop() 939 slice := model.slice 940 sliceLen := slice.Len() 941 for i := 0; i < sliceLen; i++ { 942 if i > 0 { 943 if isTemplate { 944 break 945 } 946 b = append(b, ", "...) 947 } 948 949 el := indirect(slice.Index(i)) 950 951 if len(fields) > 1 { 952 b = append(b, '(') 953 } 954 for i, f := range fields { 955 if i > 0 { 956 b = append(b, ", "...) 957 } 958 if isTemplate { 959 b = append(b, '?') 960 } else { 961 b = f.AppendValue(fmter, b, el) 962 } 963 } 964 if len(fields) > 1 { 965 b = append(b, ')') 966 } 967 } 968 969 b = append(b, ')') 970 971 return b, nil 972 } 973 974 //------------------------------------------------------------------------------ 975 976 type returningQuery struct { 977 returning []schema.QueryWithArgs 978 returningFields []*schema.Field 979 } 980 981 func (q *returningQuery) addReturning(ret schema.QueryWithArgs) { 982 q.returning = append(q.returning, ret) 983 } 984 985 func (q *returningQuery) addReturningField(field *schema.Field) { 986 if len(q.returning) > 0 { 987 return 988 } 989 for _, f := range q.returningFields { 990 if f == field { 991 return 992 } 993 } 994 q.returningFields = append(q.returningFields, field) 995 } 996 997 func (q *returningQuery) appendReturning( 998 fmter schema.Formatter, b []byte, 999 ) (_ []byte, err error) { 1000 return q._appendReturning(fmter, b, "") 1001 } 1002 1003 func (q *returningQuery) appendOutput( 1004 fmter schema.Formatter, b []byte, 1005 ) (_ []byte, err error) { 1006 return q._appendReturning(fmter, b, "INSERTED") 1007 } 1008 1009 func (q *returningQuery) _appendReturning( 1010 fmter schema.Formatter, b []byte, table string, 1011 ) (_ []byte, err error) { 1012 for i, f := range q.returning { 1013 if i > 0 { 1014 b = append(b, ", "...) 1015 } 1016 b, err = f.AppendQuery(fmter, b) 1017 if err != nil { 1018 return nil, err 1019 } 1020 } 1021 1022 if len(q.returning) > 0 { 1023 return b, nil 1024 } 1025 1026 b = appendColumns(b, schema.Safe(table), q.returningFields) 1027 return b, nil 1028 } 1029 1030 func (q *returningQuery) hasReturning() bool { 1031 if len(q.returning) == 1 { 1032 if ret := q.returning[0]; len(ret.Args) == 0 { 1033 switch ret.Query { 1034 case "", "null", "NULL": 1035 return false 1036 } 1037 } 1038 } 1039 return len(q.returning) > 0 || len(q.returningFields) > 0 1040 } 1041 1042 //------------------------------------------------------------------------------ 1043 1044 type columnValue struct { 1045 column string 1046 value schema.QueryWithArgs 1047 } 1048 1049 type customValueQuery struct { 1050 modelValues map[string]schema.QueryWithArgs 1051 extraValues []columnValue 1052 } 1053 1054 func (q *customValueQuery) addValue( 1055 table *schema.Table, column string, value string, args []interface{}, 1056 ) { 1057 ok := false 1058 if table != nil { 1059 _, ok = table.FieldMap[column] 1060 } 1061 1062 if ok { 1063 if q.modelValues == nil { 1064 q.modelValues = make(map[string]schema.QueryWithArgs) 1065 } 1066 q.modelValues[column] = schema.SafeQuery(value, args) 1067 } else { 1068 q.extraValues = append(q.extraValues, columnValue{ 1069 column: column, 1070 value: schema.SafeQuery(value, args), 1071 }) 1072 } 1073 } 1074 1075 //------------------------------------------------------------------------------ 1076 1077 type setQuery struct { 1078 set []schema.QueryWithArgs 1079 } 1080 1081 func (q *setQuery) addSet(set schema.QueryWithArgs) { 1082 q.set = append(q.set, set) 1083 } 1084 1085 func (q setQuery) appendSet(fmter schema.Formatter, b []byte) (_ []byte, err error) { 1086 for i, f := range q.set { 1087 if i > 0 { 1088 b = append(b, ", "...) 1089 } 1090 b, err = f.AppendQuery(fmter, b) 1091 if err != nil { 1092 return nil, err 1093 } 1094 } 1095 return b, nil 1096 } 1097 1098 //------------------------------------------------------------------------------ 1099 1100 type cascadeQuery struct { 1101 cascade bool 1102 restrict bool 1103 } 1104 1105 func (q cascadeQuery) appendCascade(fmter schema.Formatter, b []byte) []byte { 1106 if !fmter.HasFeature(feature.TableCascade) { 1107 return b 1108 } 1109 if q.cascade { 1110 b = append(b, " CASCADE"...) 1111 } 1112 if q.restrict { 1113 b = append(b, " RESTRICT"...) 1114 } 1115 return b 1116 } 1117 1118 //------------------------------------------------------------------------------ 1119 1120 type idxHintsQuery struct { 1121 use *indexHints 1122 ignore *indexHints 1123 force *indexHints 1124 } 1125 1126 type indexHints struct { 1127 names []schema.QueryWithArgs 1128 forJoin []schema.QueryWithArgs 1129 forOrderBy []schema.QueryWithArgs 1130 forGroupBy []schema.QueryWithArgs 1131 } 1132 1133 func (ih *idxHintsQuery) lazyUse() *indexHints { 1134 if ih.use == nil { 1135 ih.use = new(indexHints) 1136 } 1137 return ih.use 1138 } 1139 1140 func (ih *idxHintsQuery) lazyIgnore() *indexHints { 1141 if ih.ignore == nil { 1142 ih.ignore = new(indexHints) 1143 } 1144 return ih.ignore 1145 } 1146 1147 func (ih *idxHintsQuery) lazyForce() *indexHints { 1148 if ih.force == nil { 1149 ih.force = new(indexHints) 1150 } 1151 return ih.force 1152 } 1153 1154 func (ih *idxHintsQuery) appendIndexes(hints []schema.QueryWithArgs, indexes ...string) []schema.QueryWithArgs { 1155 for _, idx := range indexes { 1156 hints = append(hints, schema.UnsafeIdent(idx)) 1157 } 1158 return hints 1159 } 1160 1161 func (ih *idxHintsQuery) addUseIndex(indexes ...string) { 1162 if len(indexes) == 0 { 1163 return 1164 } 1165 ih.lazyUse().names = ih.appendIndexes(ih.use.names, indexes...) 1166 } 1167 1168 func (ih *idxHintsQuery) addUseIndexForJoin(indexes ...string) { 1169 if len(indexes) == 0 { 1170 return 1171 } 1172 ih.lazyUse().forJoin = ih.appendIndexes(ih.use.forJoin, indexes...) 1173 } 1174 1175 func (ih *idxHintsQuery) addUseIndexForOrderBy(indexes ...string) { 1176 if len(indexes) == 0 { 1177 return 1178 } 1179 ih.lazyUse().forOrderBy = ih.appendIndexes(ih.use.forOrderBy, indexes...) 1180 } 1181 1182 func (ih *idxHintsQuery) addUseIndexForGroupBy(indexes ...string) { 1183 if len(indexes) == 0 { 1184 return 1185 } 1186 ih.lazyUse().forGroupBy = ih.appendIndexes(ih.use.forGroupBy, indexes...) 1187 } 1188 1189 func (ih *idxHintsQuery) addIgnoreIndex(indexes ...string) { 1190 if len(indexes) == 0 { 1191 return 1192 } 1193 ih.lazyIgnore().names = ih.appendIndexes(ih.ignore.names, indexes...) 1194 } 1195 1196 func (ih *idxHintsQuery) addIgnoreIndexForJoin(indexes ...string) { 1197 if len(indexes) == 0 { 1198 return 1199 } 1200 ih.lazyIgnore().forJoin = ih.appendIndexes(ih.ignore.forJoin, indexes...) 1201 } 1202 1203 func (ih *idxHintsQuery) addIgnoreIndexForOrderBy(indexes ...string) { 1204 if len(indexes) == 0 { 1205 return 1206 } 1207 ih.lazyIgnore().forOrderBy = ih.appendIndexes(ih.ignore.forOrderBy, indexes...) 1208 } 1209 1210 func (ih *idxHintsQuery) addIgnoreIndexForGroupBy(indexes ...string) { 1211 if len(indexes) == 0 { 1212 return 1213 } 1214 ih.lazyIgnore().forGroupBy = ih.appendIndexes(ih.ignore.forGroupBy, indexes...) 1215 } 1216 1217 func (ih *idxHintsQuery) addForceIndex(indexes ...string) { 1218 if len(indexes) == 0 { 1219 return 1220 } 1221 ih.lazyForce().names = ih.appendIndexes(ih.force.names, indexes...) 1222 } 1223 1224 func (ih *idxHintsQuery) addForceIndexForJoin(indexes ...string) { 1225 if len(indexes) == 0 { 1226 return 1227 } 1228 ih.lazyForce().forJoin = ih.appendIndexes(ih.force.forJoin, indexes...) 1229 } 1230 1231 func (ih *idxHintsQuery) addForceIndexForOrderBy(indexes ...string) { 1232 if len(indexes) == 0 { 1233 return 1234 } 1235 ih.lazyForce().forOrderBy = ih.appendIndexes(ih.force.forOrderBy, indexes...) 1236 } 1237 1238 func (ih *idxHintsQuery) addForceIndexForGroupBy(indexes ...string) { 1239 if len(indexes) == 0 { 1240 return 1241 } 1242 ih.lazyForce().forGroupBy = ih.appendIndexes(ih.force.forGroupBy, indexes...) 1243 } 1244 1245 func (ih *idxHintsQuery) appendIndexHints( 1246 fmter schema.Formatter, b []byte, 1247 ) ([]byte, error) { 1248 type IdxHint struct { 1249 Name string 1250 Values []schema.QueryWithArgs 1251 } 1252 1253 var hints []IdxHint 1254 if ih.use != nil { 1255 hints = append(hints, []IdxHint{ 1256 { 1257 Name: "USE INDEX", 1258 Values: ih.use.names, 1259 }, 1260 { 1261 Name: "USE INDEX FOR JOIN", 1262 Values: ih.use.forJoin, 1263 }, 1264 { 1265 Name: "USE INDEX FOR ORDER BY", 1266 Values: ih.use.forOrderBy, 1267 }, 1268 { 1269 Name: "USE INDEX FOR GROUP BY", 1270 Values: ih.use.forGroupBy, 1271 }, 1272 }...) 1273 } 1274 1275 if ih.ignore != nil { 1276 hints = append(hints, []IdxHint{ 1277 { 1278 Name: "IGNORE INDEX", 1279 Values: ih.ignore.names, 1280 }, 1281 { 1282 Name: "IGNORE INDEX FOR JOIN", 1283 Values: ih.ignore.forJoin, 1284 }, 1285 { 1286 Name: "IGNORE INDEX FOR ORDER BY", 1287 Values: ih.ignore.forOrderBy, 1288 }, 1289 { 1290 Name: "IGNORE INDEX FOR GROUP BY", 1291 Values: ih.ignore.forGroupBy, 1292 }, 1293 }...) 1294 } 1295 1296 if ih.force != nil { 1297 hints = append(hints, []IdxHint{ 1298 { 1299 Name: "FORCE INDEX", 1300 Values: ih.force.names, 1301 }, 1302 { 1303 Name: "FORCE INDEX FOR JOIN", 1304 Values: ih.force.forJoin, 1305 }, 1306 { 1307 Name: "FORCE INDEX FOR ORDER BY", 1308 Values: ih.force.forOrderBy, 1309 }, 1310 { 1311 Name: "FORCE INDEX FOR GROUP BY", 1312 Values: ih.force.forGroupBy, 1313 }, 1314 }...) 1315 } 1316 1317 var err error 1318 for _, h := range hints { 1319 b, err = ih.bufIndexHint(h.Name, h.Values, fmter, b) 1320 if err != nil { 1321 return nil, err 1322 } 1323 } 1324 return b, nil 1325 } 1326 1327 func (ih *idxHintsQuery) bufIndexHint( 1328 name string, 1329 hints []schema.QueryWithArgs, 1330 fmter schema.Formatter, b []byte, 1331 ) ([]byte, error) { 1332 var err error 1333 if len(hints) == 0 { 1334 return b, nil 1335 } 1336 b = append(b, fmt.Sprintf(" %s (", name)...) 1337 for i, f := range hints { 1338 if i > 0 { 1339 b = append(b, ", "...) 1340 } 1341 b, err = f.AppendQuery(fmter, b) 1342 if err != nil { 1343 return nil, err 1344 } 1345 } 1346 b = append(b, ")"...) 1347 return b, nil 1348 }