query_select.go (25908B)
1 package bun 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql" 7 "errors" 8 "fmt" 9 "strconv" 10 "strings" 11 "sync" 12 13 "github.com/uptrace/bun/dialect" 14 15 "github.com/uptrace/bun/dialect/feature" 16 "github.com/uptrace/bun/internal" 17 "github.com/uptrace/bun/schema" 18 ) 19 20 type union struct { 21 expr string 22 query *SelectQuery 23 } 24 25 type SelectQuery struct { 26 whereBaseQuery 27 idxHintsQuery 28 29 distinctOn []schema.QueryWithArgs 30 joins []joinQuery 31 group []schema.QueryWithArgs 32 having []schema.QueryWithArgs 33 order []schema.QueryWithArgs 34 limit int32 35 offset int32 36 selFor schema.QueryWithArgs 37 38 union []union 39 } 40 41 var _ Query = (*SelectQuery)(nil) 42 43 func NewSelectQuery(db *DB) *SelectQuery { 44 return &SelectQuery{ 45 whereBaseQuery: whereBaseQuery{ 46 baseQuery: baseQuery{ 47 db: db, 48 conn: db.DB, 49 }, 50 }, 51 } 52 } 53 54 func (q *SelectQuery) Conn(db IConn) *SelectQuery { 55 q.setConn(db) 56 return q 57 } 58 59 func (q *SelectQuery) Model(model interface{}) *SelectQuery { 60 q.setModel(model) 61 return q 62 } 63 64 func (q *SelectQuery) Err(err error) *SelectQuery { 65 q.setErr(err) 66 return q 67 } 68 69 // Apply calls the fn passing the SelectQuery as an argument. 70 func (q *SelectQuery) Apply(fn func(*SelectQuery) *SelectQuery) *SelectQuery { 71 if fn != nil { 72 return fn(q) 73 } 74 return q 75 } 76 77 func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery { 78 q.addWith(name, query, false) 79 return q 80 } 81 82 func (q *SelectQuery) WithRecursive(name string, query schema.QueryAppender) *SelectQuery { 83 q.addWith(name, query, true) 84 return q 85 } 86 87 func (q *SelectQuery) Distinct() *SelectQuery { 88 q.distinctOn = make([]schema.QueryWithArgs, 0) 89 return q 90 } 91 92 func (q *SelectQuery) DistinctOn(query string, args ...interface{}) *SelectQuery { 93 q.distinctOn = append(q.distinctOn, schema.SafeQuery(query, args)) 94 return q 95 } 96 97 //------------------------------------------------------------------------------ 98 99 func (q *SelectQuery) Table(tables ...string) *SelectQuery { 100 for _, table := range tables { 101 q.addTable(schema.UnsafeIdent(table)) 102 } 103 return q 104 } 105 106 func (q *SelectQuery) TableExpr(query string, args ...interface{}) *SelectQuery { 107 q.addTable(schema.SafeQuery(query, args)) 108 return q 109 } 110 111 func (q *SelectQuery) ModelTableExpr(query string, args ...interface{}) *SelectQuery { 112 q.modelTableName = schema.SafeQuery(query, args) 113 return q 114 } 115 116 //------------------------------------------------------------------------------ 117 118 func (q *SelectQuery) Column(columns ...string) *SelectQuery { 119 for _, column := range columns { 120 q.addColumn(schema.UnsafeIdent(column)) 121 } 122 return q 123 } 124 125 func (q *SelectQuery) ColumnExpr(query string, args ...interface{}) *SelectQuery { 126 q.addColumn(schema.SafeQuery(query, args)) 127 return q 128 } 129 130 func (q *SelectQuery) ExcludeColumn(columns ...string) *SelectQuery { 131 q.excludeColumn(columns) 132 return q 133 } 134 135 //------------------------------------------------------------------------------ 136 137 func (q *SelectQuery) WherePK(cols ...string) *SelectQuery { 138 q.addWhereCols(cols) 139 return q 140 } 141 142 func (q *SelectQuery) Where(query string, args ...interface{}) *SelectQuery { 143 q.addWhere(schema.SafeQueryWithSep(query, args, " AND ")) 144 return q 145 } 146 147 func (q *SelectQuery) WhereOr(query string, args ...interface{}) *SelectQuery { 148 q.addWhere(schema.SafeQueryWithSep(query, args, " OR ")) 149 return q 150 } 151 152 func (q *SelectQuery) WhereGroup(sep string, fn func(*SelectQuery) *SelectQuery) *SelectQuery { 153 saved := q.where 154 q.where = nil 155 156 q = fn(q) 157 158 where := q.where 159 q.where = saved 160 161 q.addWhereGroup(sep, where) 162 163 return q 164 } 165 166 func (q *SelectQuery) WhereDeleted() *SelectQuery { 167 q.whereDeleted() 168 return q 169 } 170 171 func (q *SelectQuery) WhereAllWithDeleted() *SelectQuery { 172 q.whereAllWithDeleted() 173 return q 174 } 175 176 //------------------------------------------------------------------------------ 177 178 func (q *SelectQuery) UseIndex(indexes ...string) *SelectQuery { 179 if q.db.dialect.Name() == dialect.MySQL { 180 q.addUseIndex(indexes...) 181 } 182 return q 183 } 184 185 func (q *SelectQuery) UseIndexForJoin(indexes ...string) *SelectQuery { 186 if q.db.dialect.Name() == dialect.MySQL { 187 q.addUseIndexForJoin(indexes...) 188 } 189 return q 190 } 191 192 func (q *SelectQuery) UseIndexForOrderBy(indexes ...string) *SelectQuery { 193 if q.db.dialect.Name() == dialect.MySQL { 194 q.addUseIndexForOrderBy(indexes...) 195 } 196 return q 197 } 198 199 func (q *SelectQuery) UseIndexForGroupBy(indexes ...string) *SelectQuery { 200 if q.db.dialect.Name() == dialect.MySQL { 201 q.addUseIndexForGroupBy(indexes...) 202 } 203 return q 204 } 205 206 func (q *SelectQuery) IgnoreIndex(indexes ...string) *SelectQuery { 207 if q.db.dialect.Name() == dialect.MySQL { 208 q.addIgnoreIndex(indexes...) 209 } 210 return q 211 } 212 213 func (q *SelectQuery) IgnoreIndexForJoin(indexes ...string) *SelectQuery { 214 if q.db.dialect.Name() == dialect.MySQL { 215 q.addIgnoreIndexForJoin(indexes...) 216 } 217 return q 218 } 219 220 func (q *SelectQuery) IgnoreIndexForOrderBy(indexes ...string) *SelectQuery { 221 if q.db.dialect.Name() == dialect.MySQL { 222 q.addIgnoreIndexForOrderBy(indexes...) 223 } 224 return q 225 } 226 227 func (q *SelectQuery) IgnoreIndexForGroupBy(indexes ...string) *SelectQuery { 228 if q.db.dialect.Name() == dialect.MySQL { 229 q.addIgnoreIndexForGroupBy(indexes...) 230 } 231 return q 232 } 233 234 func (q *SelectQuery) ForceIndex(indexes ...string) *SelectQuery { 235 if q.db.dialect.Name() == dialect.MySQL { 236 q.addForceIndex(indexes...) 237 } 238 return q 239 } 240 241 func (q *SelectQuery) ForceIndexForJoin(indexes ...string) *SelectQuery { 242 if q.db.dialect.Name() == dialect.MySQL { 243 q.addForceIndexForJoin(indexes...) 244 } 245 return q 246 } 247 248 func (q *SelectQuery) ForceIndexForOrderBy(indexes ...string) *SelectQuery { 249 if q.db.dialect.Name() == dialect.MySQL { 250 q.addForceIndexForOrderBy(indexes...) 251 } 252 return q 253 } 254 255 func (q *SelectQuery) ForceIndexForGroupBy(indexes ...string) *SelectQuery { 256 if q.db.dialect.Name() == dialect.MySQL { 257 q.addForceIndexForGroupBy(indexes...) 258 } 259 return q 260 } 261 262 //------------------------------------------------------------------------------ 263 264 func (q *SelectQuery) Group(columns ...string) *SelectQuery { 265 for _, column := range columns { 266 q.group = append(q.group, schema.UnsafeIdent(column)) 267 } 268 return q 269 } 270 271 func (q *SelectQuery) GroupExpr(group string, args ...interface{}) *SelectQuery { 272 q.group = append(q.group, schema.SafeQuery(group, args)) 273 return q 274 } 275 276 func (q *SelectQuery) Having(having string, args ...interface{}) *SelectQuery { 277 q.having = append(q.having, schema.SafeQuery(having, args)) 278 return q 279 } 280 281 func (q *SelectQuery) Order(orders ...string) *SelectQuery { 282 for _, order := range orders { 283 if order == "" { 284 continue 285 } 286 287 index := strings.IndexByte(order, ' ') 288 if index == -1 { 289 q.order = append(q.order, schema.UnsafeIdent(order)) 290 continue 291 } 292 293 field := order[:index] 294 sort := order[index+1:] 295 296 switch strings.ToUpper(sort) { 297 case "ASC", "DESC", "ASC NULLS FIRST", "DESC NULLS FIRST", 298 "ASC NULLS LAST", "DESC NULLS LAST": 299 q.order = append(q.order, schema.SafeQuery("? ?", []interface{}{ 300 Ident(field), 301 Safe(sort), 302 })) 303 default: 304 q.order = append(q.order, schema.UnsafeIdent(order)) 305 } 306 } 307 return q 308 } 309 310 func (q *SelectQuery) OrderExpr(query string, args ...interface{}) *SelectQuery { 311 q.order = append(q.order, schema.SafeQuery(query, args)) 312 return q 313 } 314 315 func (q *SelectQuery) Limit(n int) *SelectQuery { 316 q.limit = int32(n) 317 return q 318 } 319 320 func (q *SelectQuery) Offset(n int) *SelectQuery { 321 q.offset = int32(n) 322 return q 323 } 324 325 func (q *SelectQuery) For(s string, args ...interface{}) *SelectQuery { 326 q.selFor = schema.SafeQuery(s, args) 327 return q 328 } 329 330 //------------------------------------------------------------------------------ 331 332 func (q *SelectQuery) Union(other *SelectQuery) *SelectQuery { 333 return q.addUnion(" UNION ", other) 334 } 335 336 func (q *SelectQuery) UnionAll(other *SelectQuery) *SelectQuery { 337 return q.addUnion(" UNION ALL ", other) 338 } 339 340 func (q *SelectQuery) Intersect(other *SelectQuery) *SelectQuery { 341 return q.addUnion(" INTERSECT ", other) 342 } 343 344 func (q *SelectQuery) IntersectAll(other *SelectQuery) *SelectQuery { 345 return q.addUnion(" INTERSECT ALL ", other) 346 } 347 348 func (q *SelectQuery) Except(other *SelectQuery) *SelectQuery { 349 return q.addUnion(" EXCEPT ", other) 350 } 351 352 func (q *SelectQuery) ExceptAll(other *SelectQuery) *SelectQuery { 353 return q.addUnion(" EXCEPT ALL ", other) 354 } 355 356 func (q *SelectQuery) addUnion(expr string, other *SelectQuery) *SelectQuery { 357 q.union = append(q.union, union{ 358 expr: expr, 359 query: other, 360 }) 361 return q 362 } 363 364 //------------------------------------------------------------------------------ 365 366 func (q *SelectQuery) Join(join string, args ...interface{}) *SelectQuery { 367 q.joins = append(q.joins, joinQuery{ 368 join: schema.SafeQuery(join, args), 369 }) 370 return q 371 } 372 373 func (q *SelectQuery) JoinOn(cond string, args ...interface{}) *SelectQuery { 374 return q.joinOn(cond, args, " AND ") 375 } 376 377 func (q *SelectQuery) JoinOnOr(cond string, args ...interface{}) *SelectQuery { 378 return q.joinOn(cond, args, " OR ") 379 } 380 381 func (q *SelectQuery) joinOn(cond string, args []interface{}, sep string) *SelectQuery { 382 if len(q.joins) == 0 { 383 q.err = errors.New("bun: query has no joins") 384 return q 385 } 386 j := &q.joins[len(q.joins)-1] 387 j.on = append(j.on, schema.SafeQueryWithSep(cond, args, sep)) 388 return q 389 } 390 391 //------------------------------------------------------------------------------ 392 393 // Relation adds a relation to the query. 394 func (q *SelectQuery) Relation(name string, apply ...func(*SelectQuery) *SelectQuery) *SelectQuery { 395 if len(apply) > 1 { 396 panic("only one apply function is supported") 397 } 398 399 if q.tableModel == nil { 400 q.setErr(errNilModel) 401 return q 402 } 403 404 join := q.tableModel.join(name) 405 if join == nil { 406 q.setErr(fmt.Errorf("%s does not have relation=%q", q.table, name)) 407 return q 408 } 409 410 var apply1, apply2 func(*SelectQuery) *SelectQuery 411 412 if len(join.Relation.Condition) > 0 { 413 apply1 = func(q *SelectQuery) *SelectQuery { 414 for _, opt := range join.Relation.Condition { 415 q.addWhere(schema.SafeQueryWithSep(opt, nil, " AND ")) 416 } 417 418 return q 419 } 420 } 421 422 if len(apply) == 1 { 423 apply2 = apply[0] 424 } 425 426 join.apply = func(q *SelectQuery) *SelectQuery { 427 if apply1 != nil { 428 q = apply1(q) 429 } 430 if apply2 != nil { 431 q = apply2(q) 432 } 433 434 return q 435 } 436 437 return q 438 } 439 440 func (q *SelectQuery) forEachInlineRelJoin(fn func(*relationJoin) error) error { 441 if q.tableModel == nil { 442 return nil 443 } 444 return q._forEachInlineRelJoin(fn, q.tableModel.getJoins()) 445 } 446 447 func (q *SelectQuery) _forEachInlineRelJoin(fn func(*relationJoin) error, joins []relationJoin) error { 448 for i := range joins { 449 j := &joins[i] 450 switch j.Relation.Type { 451 case schema.HasOneRelation, schema.BelongsToRelation: 452 if err := fn(j); err != nil { 453 return err 454 } 455 if err := q._forEachInlineRelJoin(fn, j.JoinModel.getJoins()); err != nil { 456 return err 457 } 458 } 459 } 460 return nil 461 } 462 463 func (q *SelectQuery) selectJoins(ctx context.Context, joins []relationJoin) error { 464 for i := range joins { 465 j := &joins[i] 466 467 var err error 468 469 switch j.Relation.Type { 470 case schema.HasOneRelation, schema.BelongsToRelation: 471 err = q.selectJoins(ctx, j.JoinModel.getJoins()) 472 case schema.HasManyRelation: 473 err = j.selectMany(ctx, q.db.NewSelect().Conn(q.conn)) 474 case schema.ManyToManyRelation: 475 err = j.selectM2M(ctx, q.db.NewSelect().Conn(q.conn)) 476 default: 477 panic("not reached") 478 } 479 480 if err != nil { 481 return err 482 } 483 } 484 return nil 485 } 486 487 //------------------------------------------------------------------------------ 488 489 func (q *SelectQuery) Operation() string { 490 return "SELECT" 491 } 492 493 func (q *SelectQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 494 return q.appendQuery(fmter, b, false) 495 } 496 497 func (q *SelectQuery) appendQuery( 498 fmter schema.Formatter, b []byte, count bool, 499 ) (_ []byte, err error) { 500 if q.err != nil { 501 return nil, q.err 502 } 503 504 fmter = formatterWithModel(fmter, q) 505 506 cteCount := count && (len(q.group) > 0 || q.distinctOn != nil) 507 if cteCount { 508 b = append(b, "WITH _count_wrapper AS ("...) 509 } 510 511 if len(q.union) > 0 { 512 b = append(b, '(') 513 } 514 515 b, err = q.appendWith(fmter, b) 516 if err != nil { 517 return nil, err 518 } 519 520 b = append(b, "SELECT "...) 521 522 if len(q.distinctOn) > 0 { 523 b = append(b, "DISTINCT ON ("...) 524 for i, app := range q.distinctOn { 525 if i > 0 { 526 b = append(b, ", "...) 527 } 528 b, err = app.AppendQuery(fmter, b) 529 if err != nil { 530 return nil, err 531 } 532 } 533 b = append(b, ") "...) 534 } else if q.distinctOn != nil { 535 b = append(b, "DISTINCT "...) 536 } 537 538 if count && !cteCount { 539 b = append(b, "count(*)"...) 540 } else { 541 b, err = q.appendColumns(fmter, b) 542 if err != nil { 543 return nil, err 544 } 545 } 546 547 if q.hasTables() { 548 b, err = q.appendTables(fmter, b) 549 if err != nil { 550 return nil, err 551 } 552 } 553 554 if err := q.forEachInlineRelJoin(func(j *relationJoin) error { 555 b = append(b, ' ') 556 b, err = j.appendHasOneJoin(fmter, b, q) 557 return err 558 }); err != nil { 559 return nil, err 560 } 561 562 for _, j := range q.joins { 563 b, err = j.AppendQuery(fmter, b) 564 if err != nil { 565 return nil, err 566 } 567 } 568 569 b, err = q.appendIndexHints(fmter, b) 570 if err != nil { 571 return nil, err 572 } 573 574 b, err = q.appendWhere(fmter, b, true) 575 if err != nil { 576 return nil, err 577 } 578 579 if len(q.group) > 0 { 580 b = append(b, " GROUP BY "...) 581 for i, f := range q.group { 582 if i > 0 { 583 b = append(b, ", "...) 584 } 585 b, err = f.AppendQuery(fmter, b) 586 if err != nil { 587 return nil, err 588 } 589 } 590 } 591 592 if len(q.having) > 0 { 593 b = append(b, " HAVING "...) 594 for i, f := range q.having { 595 if i > 0 { 596 b = append(b, " AND "...) 597 } 598 b = append(b, '(') 599 b, err = f.AppendQuery(fmter, b) 600 if err != nil { 601 return nil, err 602 } 603 b = append(b, ')') 604 } 605 } 606 607 if !count { 608 b, err = q.appendOrder(fmter, b) 609 if err != nil { 610 return nil, err 611 } 612 613 if fmter.Dialect().Features().Has(feature.OffsetFetch) { 614 if q.limit > 0 && q.offset > 0 { 615 b = append(b, " OFFSET "...) 616 b = strconv.AppendInt(b, int64(q.offset), 10) 617 b = append(b, " ROWS"...) 618 619 b = append(b, " FETCH NEXT "...) 620 b = strconv.AppendInt(b, int64(q.limit), 10) 621 b = append(b, " ROWS ONLY"...) 622 } else if q.limit > 0 { 623 b = append(b, " OFFSET 0 ROWS"...) 624 625 b = append(b, " FETCH NEXT "...) 626 b = strconv.AppendInt(b, int64(q.limit), 10) 627 b = append(b, " ROWS ONLY"...) 628 } else if q.offset > 0 { 629 b = append(b, " OFFSET "...) 630 b = strconv.AppendInt(b, int64(q.offset), 10) 631 b = append(b, " ROWS"...) 632 } 633 } else { 634 if q.limit > 0 { 635 b = append(b, " LIMIT "...) 636 b = strconv.AppendInt(b, int64(q.limit), 10) 637 } 638 if q.offset > 0 { 639 b = append(b, " OFFSET "...) 640 b = strconv.AppendInt(b, int64(q.offset), 10) 641 } 642 } 643 644 if !q.selFor.IsZero() { 645 b = append(b, " FOR "...) 646 b, err = q.selFor.AppendQuery(fmter, b) 647 if err != nil { 648 return nil, err 649 } 650 } 651 } 652 653 if len(q.union) > 0 { 654 b = append(b, ')') 655 656 for _, u := range q.union { 657 b = append(b, u.expr...) 658 b = append(b, '(') 659 b, err = u.query.AppendQuery(fmter, b) 660 if err != nil { 661 return nil, err 662 } 663 b = append(b, ')') 664 } 665 } 666 667 if cteCount { 668 b = append(b, ") SELECT count(*) FROM _count_wrapper"...) 669 } 670 671 return b, nil 672 } 673 674 func (q *SelectQuery) appendColumns(fmter schema.Formatter, b []byte) (_ []byte, err error) { 675 start := len(b) 676 677 switch { 678 case q.columns != nil: 679 for i, col := range q.columns { 680 if i > 0 { 681 b = append(b, ", "...) 682 } 683 684 if col.Args == nil && q.table != nil { 685 if field, ok := q.table.FieldMap[col.Query]; ok { 686 b = append(b, q.table.SQLAlias...) 687 b = append(b, '.') 688 b = append(b, field.SQLName...) 689 continue 690 } 691 } 692 693 b, err = col.AppendQuery(fmter, b) 694 if err != nil { 695 return nil, err 696 } 697 } 698 case q.table != nil: 699 if len(q.table.Fields) > 10 && fmter.IsNop() { 700 b = append(b, q.table.SQLAlias...) 701 b = append(b, '.') 702 b = fmter.Dialect().AppendString(b, fmt.Sprintf("%d columns", len(q.table.Fields))) 703 } else { 704 b = appendColumns(b, q.table.SQLAlias, q.table.Fields) 705 } 706 default: 707 b = append(b, '*') 708 } 709 710 if err := q.forEachInlineRelJoin(func(join *relationJoin) error { 711 if len(b) != start { 712 b = append(b, ", "...) 713 start = len(b) 714 } 715 716 b, err = q.appendInlineRelColumns(fmter, b, join) 717 if err != nil { 718 return err 719 } 720 721 return nil 722 }); err != nil { 723 return nil, err 724 } 725 726 b = bytes.TrimSuffix(b, []byte(", ")) 727 728 return b, nil 729 } 730 731 func (q *SelectQuery) appendInlineRelColumns( 732 fmter schema.Formatter, b []byte, join *relationJoin, 733 ) (_ []byte, err error) { 734 join.applyTo(q) 735 736 if join.columns != nil { 737 table := join.JoinModel.Table() 738 for i, col := range join.columns { 739 if i > 0 { 740 b = append(b, ", "...) 741 } 742 743 if col.Args == nil { 744 if field, ok := table.FieldMap[col.Query]; ok { 745 b = join.appendAlias(fmter, b) 746 b = append(b, '.') 747 b = append(b, field.SQLName...) 748 b = append(b, " AS "...) 749 b = join.appendAliasColumn(fmter, b, field.Name) 750 continue 751 } 752 } 753 754 b, err = col.AppendQuery(fmter, b) 755 if err != nil { 756 return nil, err 757 } 758 } 759 return b, nil 760 } 761 762 for i, field := range join.JoinModel.Table().Fields { 763 if i > 0 { 764 b = append(b, ", "...) 765 } 766 b = join.appendAlias(fmter, b) 767 b = append(b, '.') 768 b = append(b, field.SQLName...) 769 b = append(b, " AS "...) 770 b = join.appendAliasColumn(fmter, b, field.Name) 771 } 772 return b, nil 773 } 774 775 func (q *SelectQuery) appendTables(fmter schema.Formatter, b []byte) (_ []byte, err error) { 776 b = append(b, " FROM "...) 777 return q.appendTablesWithAlias(fmter, b) 778 } 779 780 func (q *SelectQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, err error) { 781 if len(q.order) > 0 { 782 b = append(b, " ORDER BY "...) 783 784 for i, f := range q.order { 785 if i > 0 { 786 b = append(b, ", "...) 787 } 788 b, err = f.AppendQuery(fmter, b) 789 if err != nil { 790 return nil, err 791 } 792 } 793 794 return b, nil 795 } 796 return b, nil 797 } 798 799 //------------------------------------------------------------------------------ 800 801 func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) { 802 if q.err != nil { 803 return nil, q.err 804 } 805 806 if err := q.beforeAppendModel(ctx, q); err != nil { 807 return nil, err 808 } 809 810 queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) 811 if err != nil { 812 return nil, err 813 } 814 815 query := internal.String(queryBytes) 816 return q.conn.QueryContext(ctx, query) 817 } 818 819 func (q *SelectQuery) Exec(ctx context.Context, dest ...interface{}) (res sql.Result, err error) { 820 if q.err != nil { 821 return nil, q.err 822 } 823 if err := q.beforeAppendModel(ctx, q); err != nil { 824 return nil, err 825 } 826 827 queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) 828 if err != nil { 829 return nil, err 830 } 831 832 query := internal.String(queryBytes) 833 834 if len(dest) > 0 { 835 model, err := q.getModel(dest) 836 if err != nil { 837 return nil, err 838 } 839 840 res, err = q.scan(ctx, q, query, model, true) 841 if err != nil { 842 return nil, err 843 } 844 } else { 845 res, err = q.exec(ctx, q, query) 846 if err != nil { 847 return nil, err 848 } 849 } 850 851 return res, nil 852 } 853 854 func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error { 855 if q.err != nil { 856 return q.err 857 } 858 859 model, err := q.getModel(dest) 860 if err != nil { 861 return err 862 } 863 864 if q.table != nil { 865 if err := q.beforeSelectHook(ctx); err != nil { 866 return err 867 } 868 } 869 870 if err := q.beforeAppendModel(ctx, q); err != nil { 871 return err 872 } 873 874 queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) 875 if err != nil { 876 return err 877 } 878 879 query := internal.String(queryBytes) 880 881 res, err := q.scan(ctx, q, query, model, true) 882 if err != nil { 883 return err 884 } 885 886 if n, _ := res.RowsAffected(); n > 0 { 887 if tableModel, ok := model.(TableModel); ok { 888 if err := q.selectJoins(ctx, tableModel.getJoins()); err != nil { 889 return err 890 } 891 } 892 } 893 894 if q.table != nil { 895 if err := q.afterSelectHook(ctx); err != nil { 896 return err 897 } 898 } 899 900 return nil 901 } 902 903 func (q *SelectQuery) beforeSelectHook(ctx context.Context) error { 904 if hook, ok := q.table.ZeroIface.(BeforeSelectHook); ok { 905 if err := hook.BeforeSelect(ctx, q); err != nil { 906 return err 907 } 908 } 909 return nil 910 } 911 912 func (q *SelectQuery) afterSelectHook(ctx context.Context) error { 913 if hook, ok := q.table.ZeroIface.(AfterSelectHook); ok { 914 if err := hook.AfterSelect(ctx, q); err != nil { 915 return err 916 } 917 } 918 return nil 919 } 920 921 func (q *SelectQuery) Count(ctx context.Context) (int, error) { 922 if q.err != nil { 923 return 0, q.err 924 } 925 926 qq := countQuery{q} 927 928 queryBytes, err := qq.AppendQuery(q.db.fmter, nil) 929 if err != nil { 930 return 0, err 931 } 932 933 query := internal.String(queryBytes) 934 ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) 935 936 var num int 937 err = q.conn.QueryRowContext(ctx, query).Scan(&num) 938 939 q.db.afterQuery(ctx, event, nil, err) 940 941 return num, err 942 } 943 944 func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) { 945 if _, ok := q.conn.(*DB); ok { 946 return q.scanAndCountConc(ctx, dest...) 947 } 948 return q.scanAndCountSeq(ctx, dest...) 949 } 950 951 func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) { 952 var count int 953 var wg sync.WaitGroup 954 var mu sync.Mutex 955 var firstErr error 956 957 if q.limit >= 0 { 958 wg.Add(1) 959 go func() { 960 defer wg.Done() 961 962 if err := q.Scan(ctx, dest...); err != nil { 963 mu.Lock() 964 if firstErr == nil { 965 firstErr = err 966 } 967 mu.Unlock() 968 } 969 }() 970 } 971 972 wg.Add(1) 973 go func() { 974 defer wg.Done() 975 976 var err error 977 count, err = q.Count(ctx) 978 if err != nil { 979 mu.Lock() 980 if firstErr == nil { 981 firstErr = err 982 } 983 mu.Unlock() 984 } 985 }() 986 987 wg.Wait() 988 return count, firstErr 989 } 990 991 func (q *SelectQuery) scanAndCountSeq(ctx context.Context, dest ...interface{}) (int, error) { 992 var firstErr error 993 994 if q.limit >= 0 { 995 firstErr = q.Scan(ctx, dest...) 996 } 997 998 count, err := q.Count(ctx) 999 if err != nil && firstErr == nil { 1000 firstErr = err 1001 } 1002 1003 return count, firstErr 1004 } 1005 1006 func (q *SelectQuery) Exists(ctx context.Context) (bool, error) { 1007 if q.err != nil { 1008 return false, q.err 1009 } 1010 1011 if q.hasFeature(feature.SelectExists) { 1012 return q.selectExists(ctx) 1013 } 1014 return q.whereExists(ctx) 1015 } 1016 1017 func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) { 1018 qq := selectExistsQuery{q} 1019 1020 queryBytes, err := qq.AppendQuery(q.db.fmter, nil) 1021 if err != nil { 1022 return false, err 1023 } 1024 1025 query := internal.String(queryBytes) 1026 ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) 1027 1028 var exists bool 1029 err = q.conn.QueryRowContext(ctx, query).Scan(&exists) 1030 1031 q.db.afterQuery(ctx, event, nil, err) 1032 1033 return exists, err 1034 } 1035 1036 func (q *SelectQuery) whereExists(ctx context.Context) (bool, error) { 1037 qq := whereExistsQuery{q} 1038 1039 queryBytes, err := qq.AppendQuery(q.db.fmter, nil) 1040 if err != nil { 1041 return false, err 1042 } 1043 1044 query := internal.String(queryBytes) 1045 res, err := q.exec(ctx, qq, query) 1046 if err != nil { 1047 return false, err 1048 } 1049 1050 n, err := res.RowsAffected() 1051 if err != nil { 1052 return false, err 1053 } 1054 1055 return n == 1, nil 1056 } 1057 1058 func (q *SelectQuery) String() string { 1059 buf, err := q.AppendQuery(q.db.Formatter(), nil) 1060 if err != nil { 1061 panic(err) 1062 } 1063 1064 return string(buf) 1065 } 1066 1067 //------------------------------------------------------------------------------ 1068 1069 func (q *SelectQuery) QueryBuilder() QueryBuilder { 1070 return &selectQueryBuilder{q} 1071 } 1072 1073 func (q *SelectQuery) ApplyQueryBuilder(fn func(QueryBuilder) QueryBuilder) *SelectQuery { 1074 return fn(q.QueryBuilder()).Unwrap().(*SelectQuery) 1075 } 1076 1077 type selectQueryBuilder struct { 1078 *SelectQuery 1079 } 1080 1081 func (q *selectQueryBuilder) WhereGroup( 1082 sep string, fn func(QueryBuilder) QueryBuilder, 1083 ) QueryBuilder { 1084 q.SelectQuery = q.SelectQuery.WhereGroup(sep, func(qs *SelectQuery) *SelectQuery { 1085 return fn(q).(*selectQueryBuilder).SelectQuery 1086 }) 1087 return q 1088 } 1089 1090 func (q *selectQueryBuilder) Where(query string, args ...interface{}) QueryBuilder { 1091 q.SelectQuery.Where(query, args...) 1092 return q 1093 } 1094 1095 func (q *selectQueryBuilder) WhereOr(query string, args ...interface{}) QueryBuilder { 1096 q.SelectQuery.WhereOr(query, args...) 1097 return q 1098 } 1099 1100 func (q *selectQueryBuilder) WhereDeleted() QueryBuilder { 1101 q.SelectQuery.WhereDeleted() 1102 return q 1103 } 1104 1105 func (q *selectQueryBuilder) WhereAllWithDeleted() QueryBuilder { 1106 q.SelectQuery.WhereAllWithDeleted() 1107 return q 1108 } 1109 1110 func (q *selectQueryBuilder) WherePK(cols ...string) QueryBuilder { 1111 q.SelectQuery.WherePK(cols...) 1112 return q 1113 } 1114 1115 func (q *selectQueryBuilder) Unwrap() interface{} { 1116 return q.SelectQuery 1117 } 1118 1119 //------------------------------------------------------------------------------ 1120 1121 type joinQuery struct { 1122 join schema.QueryWithArgs 1123 on []schema.QueryWithSep 1124 } 1125 1126 func (j *joinQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 1127 b = append(b, ' ') 1128 1129 b, err = j.join.AppendQuery(fmter, b) 1130 if err != nil { 1131 return nil, err 1132 } 1133 1134 if len(j.on) > 0 { 1135 b = append(b, " ON "...) 1136 for i, on := range j.on { 1137 if i > 0 { 1138 b = append(b, on.Sep...) 1139 } 1140 1141 b = append(b, '(') 1142 b, err = on.AppendQuery(fmter, b) 1143 if err != nil { 1144 return nil, err 1145 } 1146 b = append(b, ')') 1147 } 1148 } 1149 1150 return b, nil 1151 } 1152 1153 //------------------------------------------------------------------------------ 1154 1155 type countQuery struct { 1156 *SelectQuery 1157 } 1158 1159 func (q countQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 1160 if q.err != nil { 1161 return nil, q.err 1162 } 1163 return q.appendQuery(fmter, b, true) 1164 } 1165 1166 //------------------------------------------------------------------------------ 1167 1168 type selectExistsQuery struct { 1169 *SelectQuery 1170 } 1171 1172 func (q selectExistsQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 1173 if q.err != nil { 1174 return nil, q.err 1175 } 1176 1177 b = append(b, "SELECT EXISTS ("...) 1178 1179 b, err = q.appendQuery(fmter, b, false) 1180 if err != nil { 1181 return nil, err 1182 } 1183 1184 b = append(b, ")"...) 1185 1186 return b, nil 1187 } 1188 1189 //------------------------------------------------------------------------------ 1190 1191 type whereExistsQuery struct { 1192 *SelectQuery 1193 } 1194 1195 func (q whereExistsQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, err error) { 1196 if q.err != nil { 1197 return nil, q.err 1198 } 1199 1200 b = append(b, "SELECT 1 WHERE EXISTS ("...) 1201 1202 b, err = q.appendQuery(fmter, b, false) 1203 if err != nil { 1204 return nil, err 1205 } 1206 1207 b = append(b, ")"...) 1208 1209 return b, nil 1210 }