table.go (25292B)
1 package schema 2 3 import ( 4 "database/sql" 5 "fmt" 6 "reflect" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/jinzhu/inflection" 12 13 "github.com/uptrace/bun/internal" 14 "github.com/uptrace/bun/internal/tagparser" 15 ) 16 17 const ( 18 beforeAppendModelHookFlag internal.Flag = 1 << iota 19 beforeScanHookFlag 20 afterScanHookFlag 21 beforeScanRowHookFlag 22 afterScanRowHookFlag 23 ) 24 25 var ( 26 baseModelType = reflect.TypeOf((*BaseModel)(nil)).Elem() 27 tableNameInflector = inflection.Plural 28 ) 29 30 type BaseModel struct{} 31 32 // SetTableNameInflector overrides the default func that pluralizes 33 // model name to get table name, e.g. my_article becomes my_articles. 34 func SetTableNameInflector(fn func(string) string) { 35 tableNameInflector = fn 36 } 37 38 // Table represents a SQL table created from Go struct. 39 type Table struct { 40 dialect Dialect 41 42 Type reflect.Type 43 ZeroValue reflect.Value // reflect.Struct 44 ZeroIface interface{} // struct pointer 45 46 TypeName string 47 ModelName string 48 49 Name string 50 SQLName Safe 51 SQLNameForSelects Safe 52 Alias string 53 SQLAlias Safe 54 55 Fields []*Field // PKs + DataFields 56 PKs []*Field 57 DataFields []*Field 58 59 fieldsMapMu sync.RWMutex 60 FieldMap map[string]*Field 61 62 Relations map[string]*Relation 63 Unique map[string][]*Field 64 65 SoftDeleteField *Field 66 UpdateSoftDeleteField func(fv reflect.Value, tm time.Time) error 67 68 allFields []*Field // read only 69 70 flags internal.Flag 71 } 72 73 func newTable(dialect Dialect, typ reflect.Type) *Table { 74 t := new(Table) 75 t.dialect = dialect 76 t.Type = typ 77 t.ZeroValue = reflect.New(t.Type).Elem() 78 t.ZeroIface = reflect.New(t.Type).Interface() 79 t.TypeName = internal.ToExported(t.Type.Name()) 80 t.ModelName = internal.Underscore(t.Type.Name()) 81 tableName := tableNameInflector(t.ModelName) 82 t.setName(tableName) 83 t.Alias = t.ModelName 84 t.SQLAlias = t.quoteIdent(t.ModelName) 85 86 hooks := []struct { 87 typ reflect.Type 88 flag internal.Flag 89 }{ 90 {beforeAppendModelHookType, beforeAppendModelHookFlag}, 91 92 {beforeScanHookType, beforeScanHookFlag}, 93 {afterScanHookType, afterScanHookFlag}, 94 95 {beforeScanRowHookType, beforeScanRowHookFlag}, 96 {afterScanRowHookType, afterScanRowHookFlag}, 97 } 98 99 typ = reflect.PtrTo(t.Type) 100 for _, hook := range hooks { 101 if typ.Implements(hook.typ) { 102 t.flags = t.flags.Set(hook.flag) 103 } 104 } 105 106 // Deprecated. 107 deprecatedHooks := []struct { 108 typ reflect.Type 109 flag internal.Flag 110 msg string 111 }{ 112 {beforeScanHookType, beforeScanHookFlag, "rename BeforeScan hook to BeforeScanRow"}, 113 {afterScanHookType, afterScanHookFlag, "rename AfterScan hook to AfterScanRow"}, 114 } 115 for _, hook := range deprecatedHooks { 116 if typ.Implements(hook.typ) { 117 internal.Deprecated.Printf("%s: %s", t.TypeName, hook.msg) 118 t.flags = t.flags.Set(hook.flag) 119 } 120 } 121 122 return t 123 } 124 125 func (t *Table) init1() { 126 t.initFields() 127 } 128 129 func (t *Table) init2() { 130 t.initRelations() 131 } 132 133 func (t *Table) setName(name string) { 134 t.Name = name 135 t.SQLName = t.quoteIdent(name) 136 t.SQLNameForSelects = t.quoteIdent(name) 137 if t.SQLAlias == "" { 138 t.Alias = name 139 t.SQLAlias = t.quoteIdent(name) 140 } 141 } 142 143 func (t *Table) String() string { 144 return "model=" + t.TypeName 145 } 146 147 func (t *Table) CheckPKs() error { 148 if len(t.PKs) == 0 { 149 return fmt.Errorf("bun: %s does not have primary keys", t) 150 } 151 return nil 152 } 153 154 func (t *Table) addField(field *Field) { 155 t.Fields = append(t.Fields, field) 156 if field.IsPK { 157 t.PKs = append(t.PKs, field) 158 } else { 159 t.DataFields = append(t.DataFields, field) 160 } 161 t.FieldMap[field.Name] = field 162 } 163 164 func (t *Table) removeField(field *Field) { 165 t.Fields = removeField(t.Fields, field) 166 if field.IsPK { 167 t.PKs = removeField(t.PKs, field) 168 } else { 169 t.DataFields = removeField(t.DataFields, field) 170 } 171 delete(t.FieldMap, field.Name) 172 } 173 174 func (t *Table) fieldWithLock(name string) *Field { 175 t.fieldsMapMu.RLock() 176 field := t.FieldMap[name] 177 t.fieldsMapMu.RUnlock() 178 return field 179 } 180 181 func (t *Table) HasField(name string) bool { 182 _, ok := t.FieldMap[name] 183 return ok 184 } 185 186 func (t *Table) Field(name string) (*Field, error) { 187 field, ok := t.FieldMap[name] 188 if !ok { 189 return nil, fmt.Errorf("bun: %s does not have column=%s", t, name) 190 } 191 return field, nil 192 } 193 194 func (t *Table) fieldByGoName(name string) *Field { 195 for _, f := range t.allFields { 196 if f.GoName == name { 197 return f 198 } 199 } 200 return nil 201 } 202 203 func (t *Table) initFields() { 204 t.Fields = make([]*Field, 0, t.Type.NumField()) 205 t.FieldMap = make(map[string]*Field, t.Type.NumField()) 206 t.addFields(t.Type, "", nil) 207 } 208 209 func (t *Table) addFields(typ reflect.Type, prefix string, index []int) { 210 for i := 0; i < typ.NumField(); i++ { 211 f := typ.Field(i) 212 unexported := f.PkgPath != "" 213 214 if unexported && !f.Anonymous { // unexported 215 continue 216 } 217 if f.Tag.Get("bun") == "-" { 218 continue 219 } 220 221 if f.Anonymous { 222 if f.Name == "BaseModel" && f.Type == baseModelType { 223 if len(index) == 0 { 224 t.processBaseModelField(f) 225 } 226 continue 227 } 228 229 // If field is an embedded struct, add each field of the embedded struct. 230 fieldType := indirectType(f.Type) 231 if fieldType.Kind() == reflect.Struct { 232 t.addFields(fieldType, "", withIndex(index, f.Index)) 233 234 tag := tagparser.Parse(f.Tag.Get("bun")) 235 if tag.HasOption("inherit") || tag.HasOption("extend") { 236 embeddedTable := t.dialect.Tables().Ref(fieldType) 237 t.TypeName = embeddedTable.TypeName 238 t.SQLName = embeddedTable.SQLName 239 t.SQLNameForSelects = embeddedTable.SQLNameForSelects 240 t.Alias = embeddedTable.Alias 241 t.SQLAlias = embeddedTable.SQLAlias 242 t.ModelName = embeddedTable.ModelName 243 } 244 continue 245 } 246 } 247 248 // If field is not a struct, add it. 249 // This will also add any embedded non-struct type as a field. 250 if field := t.newField(f, prefix, index); field != nil { 251 t.addField(field) 252 } 253 } 254 } 255 256 func (t *Table) processBaseModelField(f reflect.StructField) { 257 tag := tagparser.Parse(f.Tag.Get("bun")) 258 259 if isKnownTableOption(tag.Name) { 260 internal.Warn.Printf( 261 "%s.%s tag name %q is also an option name, is it a mistake? Try table:%s.", 262 t.TypeName, f.Name, tag.Name, tag.Name, 263 ) 264 } 265 266 for name := range tag.Options { 267 if !isKnownTableOption(name) { 268 internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) 269 } 270 } 271 272 if tag.Name != "" { 273 t.setName(tag.Name) 274 } 275 276 if s, ok := tag.Option("table"); ok { 277 t.setName(s) 278 } 279 280 if s, ok := tag.Option("select"); ok { 281 t.SQLNameForSelects = t.quoteTableName(s) 282 } 283 284 if s, ok := tag.Option("alias"); ok { 285 t.Alias = s 286 t.SQLAlias = t.quoteIdent(s) 287 } 288 } 289 290 // nolint 291 func (t *Table) newField(f reflect.StructField, prefix string, index []int) *Field { 292 tag := tagparser.Parse(f.Tag.Get("bun")) 293 294 if nextPrefix, ok := tag.Option("embed"); ok { 295 fieldType := indirectType(f.Type) 296 if fieldType.Kind() != reflect.Struct { 297 panic(fmt.Errorf("bun: embed %s.%s: got %s, wanted reflect.Struct", 298 t.TypeName, f.Name, fieldType.Kind())) 299 } 300 t.addFields(fieldType, prefix+nextPrefix, withIndex(index, f.Index)) 301 return nil 302 } 303 304 sqlName := internal.Underscore(f.Name) 305 if tag.Name != "" && tag.Name != sqlName { 306 if isKnownFieldOption(tag.Name) { 307 internal.Warn.Printf( 308 "%s.%s tag name %q is also an option name, is it a mistake? Try column:%s.", 309 t.TypeName, f.Name, tag.Name, tag.Name, 310 ) 311 } 312 sqlName = tag.Name 313 } 314 if s, ok := tag.Option("column"); ok { 315 sqlName = s 316 } 317 sqlName = prefix + sqlName 318 319 for name := range tag.Options { 320 if !isKnownFieldOption(name) { 321 internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, f.Name, name) 322 } 323 } 324 325 index = withIndex(index, f.Index) 326 if field := t.fieldWithLock(sqlName); field != nil { 327 if indexEqual(field.Index, index) { 328 return field 329 } 330 t.removeField(field) 331 } 332 333 field := &Field{ 334 StructField: f, 335 IsPtr: f.Type.Kind() == reflect.Ptr, 336 337 Tag: tag, 338 IndirectType: indirectType(f.Type), 339 Index: index, 340 341 Name: sqlName, 342 GoName: f.Name, 343 SQLName: t.quoteIdent(sqlName), 344 } 345 346 field.NotNull = tag.HasOption("notnull") 347 field.NullZero = tag.HasOption("nullzero") 348 if tag.HasOption("pk") { 349 field.IsPK = true 350 field.NotNull = true 351 } 352 if tag.HasOption("autoincrement") { 353 field.AutoIncrement = true 354 field.NullZero = true 355 } 356 if tag.HasOption("identity") { 357 field.Identity = true 358 } 359 360 if v, ok := tag.Options["unique"]; ok { 361 var names []string 362 if len(v) == 1 { 363 // Split the value by comma, this will allow multiple names to be specified. 364 // We can use this to create multiple named unique constraints where a single column 365 // might be included in multiple constraints. 366 names = strings.Split(v[0], ",") 367 } else { 368 names = v 369 } 370 371 for _, uniqueName := range names { 372 if t.Unique == nil { 373 t.Unique = make(map[string][]*Field) 374 } 375 t.Unique[uniqueName] = append(t.Unique[uniqueName], field) 376 } 377 } 378 if s, ok := tag.Option("default"); ok { 379 field.SQLDefault = s 380 field.NullZero = true 381 } 382 if s, ok := field.Tag.Option("type"); ok { 383 field.UserSQLType = s 384 } 385 field.DiscoveredSQLType = DiscoverSQLType(field.IndirectType) 386 field.Append = FieldAppender(t.dialect, field) 387 field.Scan = FieldScanner(t.dialect, field) 388 field.IsZero = zeroChecker(field.StructField.Type) 389 390 if v, ok := tag.Option("alt"); ok { 391 t.FieldMap[v] = field 392 } 393 394 t.allFields = append(t.allFields, field) 395 if tag.HasOption("scanonly") { 396 t.FieldMap[field.Name] = field 397 if field.IndirectType.Kind() == reflect.Struct { 398 t.inlineFields(field, nil) 399 } 400 return nil 401 } 402 403 if _, ok := tag.Options["soft_delete"]; ok { 404 t.SoftDeleteField = field 405 t.UpdateSoftDeleteField = softDeleteFieldUpdater(field) 406 } 407 408 return field 409 } 410 411 //--------------------------------------------------------------------------------------- 412 413 func (t *Table) initRelations() { 414 for i := 0; i < len(t.Fields); { 415 f := t.Fields[i] 416 if t.tryRelation(f) { 417 t.Fields = removeField(t.Fields, f) 418 t.DataFields = removeField(t.DataFields, f) 419 } else { 420 i++ 421 } 422 423 if f.IndirectType.Kind() == reflect.Struct { 424 t.inlineFields(f, nil) 425 } 426 } 427 } 428 429 func (t *Table) tryRelation(field *Field) bool { 430 if rel, ok := field.Tag.Option("rel"); ok { 431 t.initRelation(field, rel) 432 return true 433 } 434 if field.Tag.HasOption("m2m") { 435 t.addRelation(t.m2mRelation(field)) 436 return true 437 } 438 439 if field.Tag.HasOption("join") { 440 internal.Warn.Printf( 441 `%s.%s "join" option must come together with "rel" option`, 442 t.TypeName, field.GoName, 443 ) 444 } 445 446 return false 447 } 448 449 func (t *Table) initRelation(field *Field, rel string) { 450 switch rel { 451 case "belongs-to": 452 t.addRelation(t.belongsToRelation(field)) 453 case "has-one": 454 t.addRelation(t.hasOneRelation(field)) 455 case "has-many": 456 t.addRelation(t.hasManyRelation(field)) 457 default: 458 panic(fmt.Errorf("bun: unknown relation=%s on field=%s", rel, field.GoName)) 459 } 460 } 461 462 func (t *Table) addRelation(rel *Relation) { 463 if t.Relations == nil { 464 t.Relations = make(map[string]*Relation) 465 } 466 _, ok := t.Relations[rel.Field.GoName] 467 if ok { 468 panic(fmt.Errorf("%s already has %s", t, rel)) 469 } 470 t.Relations[rel.Field.GoName] = rel 471 } 472 473 func (t *Table) belongsToRelation(field *Field) *Relation { 474 joinTable := t.dialect.Tables().Ref(field.IndirectType) 475 if err := joinTable.CheckPKs(); err != nil { 476 panic(err) 477 } 478 479 rel := &Relation{ 480 Type: HasOneRelation, 481 Field: field, 482 JoinTable: joinTable, 483 } 484 485 if field.Tag.HasOption("join_on") { 486 rel.Condition = field.Tag.Options["join_on"] 487 } 488 489 rel.OnUpdate = "ON UPDATE NO ACTION" 490 if onUpdate, ok := field.Tag.Options["on_update"]; ok { 491 if len(onUpdate) > 1 { 492 panic(fmt.Errorf("bun: %s belongs-to %s: on_update option must be a single field", t.TypeName, field.GoName)) 493 } 494 495 rule := strings.ToUpper(onUpdate[0]) 496 if !isKnownFKRule(rule) { 497 internal.Warn.Printf("bun: %s belongs-to %s: unknown on_update rule %s", t.TypeName, field.GoName, rule) 498 } 499 500 s := fmt.Sprintf("ON UPDATE %s", rule) 501 rel.OnUpdate = s 502 } 503 504 rel.OnDelete = "ON DELETE NO ACTION" 505 if onDelete, ok := field.Tag.Options["on_delete"]; ok { 506 if len(onDelete) > 1 { 507 panic(fmt.Errorf("bun: %s belongs-to %s: on_delete option must be a single field", t.TypeName, field.GoName)) 508 } 509 510 rule := strings.ToUpper(onDelete[0]) 511 if !isKnownFKRule(rule) { 512 internal.Warn.Printf("bun: %s belongs-to %s: unknown on_delete rule %s", t.TypeName, field.GoName, rule) 513 } 514 s := fmt.Sprintf("ON DELETE %s", rule) 515 rel.OnDelete = s 516 } 517 518 if join, ok := field.Tag.Options["join"]; ok { 519 baseColumns, joinColumns := parseRelationJoin(join) 520 for i, baseColumn := range baseColumns { 521 joinColumn := joinColumns[i] 522 523 if f := t.fieldWithLock(baseColumn); f != nil { 524 rel.BaseFields = append(rel.BaseFields, f) 525 } else { 526 panic(fmt.Errorf( 527 "bun: %s belongs-to %s: %s must have column %s", 528 t.TypeName, field.GoName, t.TypeName, baseColumn, 529 )) 530 } 531 532 if f := joinTable.fieldWithLock(joinColumn); f != nil { 533 rel.JoinFields = append(rel.JoinFields, f) 534 } else { 535 panic(fmt.Errorf( 536 "bun: %s belongs-to %s: %s must have column %s", 537 t.TypeName, field.GoName, joinTable.TypeName, joinColumn, 538 )) 539 } 540 } 541 return rel 542 } 543 544 rel.JoinFields = joinTable.PKs 545 fkPrefix := internal.Underscore(field.GoName) + "_" 546 for _, joinPK := range joinTable.PKs { 547 fkName := fkPrefix + joinPK.Name 548 if fk := t.fieldWithLock(fkName); fk != nil { 549 rel.BaseFields = append(rel.BaseFields, fk) 550 continue 551 } 552 553 if fk := t.fieldWithLock(joinPK.Name); fk != nil { 554 rel.BaseFields = append(rel.BaseFields, fk) 555 continue 556 } 557 558 panic(fmt.Errorf( 559 "bun: %s belongs-to %s: %s must have column %s "+ 560 "(to override, use join:base_column=join_column tag on %s field)", 561 t.TypeName, field.GoName, t.TypeName, fkName, field.GoName, 562 )) 563 } 564 return rel 565 } 566 567 func (t *Table) hasOneRelation(field *Field) *Relation { 568 if err := t.CheckPKs(); err != nil { 569 panic(err) 570 } 571 572 joinTable := t.dialect.Tables().Ref(field.IndirectType) 573 rel := &Relation{ 574 Type: BelongsToRelation, 575 Field: field, 576 JoinTable: joinTable, 577 } 578 579 if field.Tag.HasOption("join_on") { 580 rel.Condition = field.Tag.Options["join_on"] 581 } 582 583 if join, ok := field.Tag.Options["join"]; ok { 584 baseColumns, joinColumns := parseRelationJoin(join) 585 for i, baseColumn := range baseColumns { 586 if f := t.fieldWithLock(baseColumn); f != nil { 587 rel.BaseFields = append(rel.BaseFields, f) 588 } else { 589 panic(fmt.Errorf( 590 "bun: %s has-one %s: %s must have column %s", 591 field.GoName, t.TypeName, t.TypeName, baseColumn, 592 )) 593 } 594 595 joinColumn := joinColumns[i] 596 if f := joinTable.fieldWithLock(joinColumn); f != nil { 597 rel.JoinFields = append(rel.JoinFields, f) 598 } else { 599 panic(fmt.Errorf( 600 "bun: %s has-one %s: %s must have column %s", 601 field.GoName, t.TypeName, joinTable.TypeName, joinColumn, 602 )) 603 } 604 } 605 return rel 606 } 607 608 rel.BaseFields = t.PKs 609 fkPrefix := internal.Underscore(t.ModelName) + "_" 610 for _, pk := range t.PKs { 611 fkName := fkPrefix + pk.Name 612 if f := joinTable.fieldWithLock(fkName); f != nil { 613 rel.JoinFields = append(rel.JoinFields, f) 614 continue 615 } 616 617 if f := joinTable.fieldWithLock(pk.Name); f != nil { 618 rel.JoinFields = append(rel.JoinFields, f) 619 continue 620 } 621 622 panic(fmt.Errorf( 623 "bun: %s has-one %s: %s must have column %s "+ 624 "(to override, use join:base_column=join_column tag on %s field)", 625 field.GoName, t.TypeName, joinTable.TypeName, fkName, field.GoName, 626 )) 627 } 628 return rel 629 } 630 631 func (t *Table) hasManyRelation(field *Field) *Relation { 632 if err := t.CheckPKs(); err != nil { 633 panic(err) 634 } 635 if field.IndirectType.Kind() != reflect.Slice { 636 panic(fmt.Errorf( 637 "bun: %s.%s has-many relation requires slice, got %q", 638 t.TypeName, field.GoName, field.IndirectType.Kind(), 639 )) 640 } 641 642 joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) 643 polymorphicValue, isPolymorphic := field.Tag.Option("polymorphic") 644 rel := &Relation{ 645 Type: HasManyRelation, 646 Field: field, 647 JoinTable: joinTable, 648 } 649 650 if field.Tag.HasOption("join_on") { 651 rel.Condition = field.Tag.Options["join_on"] 652 } 653 654 var polymorphicColumn string 655 656 if join, ok := field.Tag.Options["join"]; ok { 657 baseColumns, joinColumns := parseRelationJoin(join) 658 for i, baseColumn := range baseColumns { 659 joinColumn := joinColumns[i] 660 661 if isPolymorphic && baseColumn == "type" { 662 polymorphicColumn = joinColumn 663 continue 664 } 665 666 if f := t.fieldWithLock(baseColumn); f != nil { 667 rel.BaseFields = append(rel.BaseFields, f) 668 } else { 669 panic(fmt.Errorf( 670 "bun: %s has-many %s: %s must have column %s", 671 t.TypeName, field.GoName, t.TypeName, baseColumn, 672 )) 673 } 674 675 if f := joinTable.fieldWithLock(joinColumn); f != nil { 676 rel.JoinFields = append(rel.JoinFields, f) 677 } else { 678 panic(fmt.Errorf( 679 "bun: %s has-many %s: %s must have column %s", 680 t.TypeName, field.GoName, joinTable.TypeName, joinColumn, 681 )) 682 } 683 } 684 } else { 685 rel.BaseFields = t.PKs 686 fkPrefix := internal.Underscore(t.ModelName) + "_" 687 if isPolymorphic { 688 polymorphicColumn = fkPrefix + "type" 689 } 690 691 for _, pk := range t.PKs { 692 joinColumn := fkPrefix + pk.Name 693 if fk := joinTable.fieldWithLock(joinColumn); fk != nil { 694 rel.JoinFields = append(rel.JoinFields, fk) 695 continue 696 } 697 698 if fk := joinTable.fieldWithLock(pk.Name); fk != nil { 699 rel.JoinFields = append(rel.JoinFields, fk) 700 continue 701 } 702 703 panic(fmt.Errorf( 704 "bun: %s has-many %s: %s must have column %s "+ 705 "(to override, use join:base_column=join_column tag on the field %s)", 706 t.TypeName, field.GoName, joinTable.TypeName, joinColumn, field.GoName, 707 )) 708 } 709 } 710 711 if isPolymorphic { 712 rel.PolymorphicField = joinTable.fieldWithLock(polymorphicColumn) 713 if rel.PolymorphicField == nil { 714 panic(fmt.Errorf( 715 "bun: %s has-many %s: %s must have polymorphic column %s", 716 t.TypeName, field.GoName, joinTable.TypeName, polymorphicColumn, 717 )) 718 } 719 720 if polymorphicValue == "" { 721 polymorphicValue = t.ModelName 722 } 723 rel.PolymorphicValue = polymorphicValue 724 } 725 726 return rel 727 } 728 729 func (t *Table) m2mRelation(field *Field) *Relation { 730 if field.IndirectType.Kind() != reflect.Slice { 731 panic(fmt.Errorf( 732 "bun: %s.%s m2m relation requires slice, got %q", 733 t.TypeName, field.GoName, field.IndirectType.Kind(), 734 )) 735 } 736 joinTable := t.dialect.Tables().Ref(indirectType(field.IndirectType.Elem())) 737 738 if err := t.CheckPKs(); err != nil { 739 panic(err) 740 } 741 if err := joinTable.CheckPKs(); err != nil { 742 panic(err) 743 } 744 745 m2mTableName, ok := field.Tag.Option("m2m") 746 if !ok { 747 panic(fmt.Errorf("bun: %s must have m2m tag option", field.GoName)) 748 } 749 750 m2mTable := t.dialect.Tables().ByName(m2mTableName) 751 if m2mTable == nil { 752 panic(fmt.Errorf( 753 "bun: can't find m2m %s table (use db.RegisterModel)", 754 m2mTableName, 755 )) 756 } 757 758 rel := &Relation{ 759 Type: ManyToManyRelation, 760 Field: field, 761 JoinTable: joinTable, 762 M2MTable: m2mTable, 763 } 764 765 if field.Tag.HasOption("join_on") { 766 rel.Condition = field.Tag.Options["join_on"] 767 } 768 769 var leftColumn, rightColumn string 770 771 if join, ok := field.Tag.Options["join"]; ok { 772 left, right := parseRelationJoin(join) 773 leftColumn = left[0] 774 rightColumn = right[0] 775 } else { 776 leftColumn = t.TypeName 777 rightColumn = joinTable.TypeName 778 } 779 780 leftField := m2mTable.fieldByGoName(leftColumn) 781 if leftField == nil { 782 panic(fmt.Errorf( 783 "bun: %s many-to-many %s: %s must have field %s "+ 784 "(to override, use tag join:LeftField=RightField on field %s.%s", 785 t.TypeName, field.GoName, m2mTable.TypeName, leftColumn, t.TypeName, field.GoName, 786 )) 787 } 788 789 rightField := m2mTable.fieldByGoName(rightColumn) 790 if rightField == nil { 791 panic(fmt.Errorf( 792 "bun: %s many-to-many %s: %s must have field %s "+ 793 "(to override, use tag join:LeftField=RightField on field %s.%s", 794 t.TypeName, field.GoName, m2mTable.TypeName, rightColumn, t.TypeName, field.GoName, 795 )) 796 } 797 798 leftRel := m2mTable.belongsToRelation(leftField) 799 rel.BaseFields = leftRel.JoinFields 800 rel.M2MBaseFields = leftRel.BaseFields 801 802 rightRel := m2mTable.belongsToRelation(rightField) 803 rel.JoinFields = rightRel.JoinFields 804 rel.M2MJoinFields = rightRel.BaseFields 805 806 return rel 807 } 808 809 func (t *Table) inlineFields(field *Field, seen map[reflect.Type]struct{}) { 810 if seen == nil { 811 seen = map[reflect.Type]struct{}{t.Type: {}} 812 } 813 814 if _, ok := seen[field.IndirectType]; ok { 815 return 816 } 817 seen[field.IndirectType] = struct{}{} 818 819 joinTable := t.dialect.Tables().Ref(field.IndirectType) 820 for _, f := range joinTable.allFields { 821 f = f.Clone() 822 f.GoName = field.GoName + "_" + f.GoName 823 f.Name = field.Name + "__" + f.Name 824 f.SQLName = t.quoteIdent(f.Name) 825 f.Index = withIndex(field.Index, f.Index) 826 827 t.fieldsMapMu.Lock() 828 if _, ok := t.FieldMap[f.Name]; !ok { 829 t.FieldMap[f.Name] = f 830 } 831 t.fieldsMapMu.Unlock() 832 833 if f.IndirectType.Kind() != reflect.Struct { 834 continue 835 } 836 837 if _, ok := seen[f.IndirectType]; !ok { 838 t.inlineFields(f, seen) 839 } 840 } 841 } 842 843 //------------------------------------------------------------------------------ 844 845 func (t *Table) Dialect() Dialect { return t.dialect } 846 847 func (t *Table) HasBeforeAppendModelHook() bool { return t.flags.Has(beforeAppendModelHookFlag) } 848 849 // DEPRECATED. Use HasBeforeScanRowHook. 850 func (t *Table) HasBeforeScanHook() bool { return t.flags.Has(beforeScanHookFlag) } 851 852 // DEPRECATED. Use HasAfterScanRowHook. 853 func (t *Table) HasAfterScanHook() bool { return t.flags.Has(afterScanHookFlag) } 854 855 func (t *Table) HasBeforeScanRowHook() bool { return t.flags.Has(beforeScanRowHookFlag) } 856 func (t *Table) HasAfterScanRowHook() bool { return t.flags.Has(afterScanRowHookFlag) } 857 858 //------------------------------------------------------------------------------ 859 860 func (t *Table) AppendNamedArg( 861 fmter Formatter, b []byte, name string, strct reflect.Value, 862 ) ([]byte, bool) { 863 if field, ok := t.FieldMap[name]; ok { 864 return field.AppendValue(fmter, b, strct), true 865 } 866 return b, false 867 } 868 869 func (t *Table) quoteTableName(s string) Safe { 870 // Don't quote if table name contains placeholder (?) or parentheses. 871 if strings.IndexByte(s, '?') >= 0 || 872 strings.IndexByte(s, '(') >= 0 || 873 strings.IndexByte(s, ')') >= 0 { 874 return Safe(s) 875 } 876 return t.quoteIdent(s) 877 } 878 879 func (t *Table) quoteIdent(s string) Safe { 880 return Safe(NewFormatter(t.dialect).AppendIdent(nil, s)) 881 } 882 883 func isKnownTableOption(name string) bool { 884 switch name { 885 case "table", "alias", "select": 886 return true 887 } 888 return false 889 } 890 891 func isKnownFieldOption(name string) bool { 892 switch name { 893 case "column", 894 "alias", 895 "type", 896 "array", 897 "hstore", 898 "composite", 899 "json_use_number", 900 "msgpack", 901 "notnull", 902 "nullzero", 903 "default", 904 "unique", 905 "soft_delete", 906 "scanonly", 907 "skipupdate", 908 909 "pk", 910 "autoincrement", 911 "rel", 912 "join", 913 "join_on", 914 "on_update", 915 "on_delete", 916 "m2m", 917 "polymorphic", 918 "identity": 919 return true 920 } 921 return false 922 } 923 924 func isKnownFKRule(name string) bool { 925 switch name { 926 case "CASCADE", 927 "RESTRICT", 928 "SET NULL", 929 "SET DEFAULT": 930 return true 931 } 932 return false 933 } 934 935 func removeField(fields []*Field, field *Field) []*Field { 936 for i, f := range fields { 937 if f == field { 938 return append(fields[:i], fields[i+1:]...) 939 } 940 } 941 return fields 942 } 943 944 func parseRelationJoin(join []string) ([]string, []string) { 945 var ss []string 946 if len(join) == 1 { 947 ss = strings.Split(join[0], ",") 948 } else { 949 ss = join 950 } 951 952 baseColumns := make([]string, len(ss)) 953 joinColumns := make([]string, len(ss)) 954 for i, s := range ss { 955 ss := strings.Split(strings.TrimSpace(s), "=") 956 if len(ss) != 2 { 957 panic(fmt.Errorf("can't parse relation join: %q", join)) 958 } 959 baseColumns[i] = ss[0] 960 joinColumns[i] = ss[1] 961 } 962 return baseColumns, joinColumns 963 } 964 965 //------------------------------------------------------------------------------ 966 967 func softDeleteFieldUpdater(field *Field) func(fv reflect.Value, tm time.Time) error { 968 typ := field.StructField.Type 969 970 switch typ { 971 case timeType: 972 return func(fv reflect.Value, tm time.Time) error { 973 ptr := fv.Addr().Interface().(*time.Time) 974 *ptr = tm 975 return nil 976 } 977 case nullTimeType: 978 return func(fv reflect.Value, tm time.Time) error { 979 ptr := fv.Addr().Interface().(*sql.NullTime) 980 *ptr = sql.NullTime{Time: tm} 981 return nil 982 } 983 case nullIntType: 984 return func(fv reflect.Value, tm time.Time) error { 985 ptr := fv.Addr().Interface().(*sql.NullInt64) 986 *ptr = sql.NullInt64{Int64: tm.UnixNano()} 987 return nil 988 } 989 } 990 991 switch field.IndirectType.Kind() { 992 case reflect.Int64: 993 return func(fv reflect.Value, tm time.Time) error { 994 ptr := fv.Addr().Interface().(*int64) 995 *ptr = tm.UnixNano() 996 return nil 997 } 998 case reflect.Ptr: 999 typ = typ.Elem() 1000 default: 1001 return softDeleteFieldUpdaterFallback(field) 1002 } 1003 1004 switch typ { //nolint:gocritic 1005 case timeType: 1006 return func(fv reflect.Value, tm time.Time) error { 1007 fv.Set(reflect.ValueOf(&tm)) 1008 return nil 1009 } 1010 } 1011 1012 switch typ.Kind() { //nolint:gocritic 1013 case reflect.Int64: 1014 return func(fv reflect.Value, tm time.Time) error { 1015 utime := tm.UnixNano() 1016 fv.Set(reflect.ValueOf(&utime)) 1017 return nil 1018 } 1019 } 1020 1021 return softDeleteFieldUpdaterFallback(field) 1022 } 1023 1024 func softDeleteFieldUpdaterFallback(field *Field) func(fv reflect.Value, tm time.Time) error { 1025 return func(fv reflect.Value, tm time.Time) error { 1026 return field.ScanWithCheck(fv, tm) 1027 } 1028 } 1029 1030 func withIndex(a, b []int) []int { 1031 dest := make([]int, 0, len(a)+len(b)) 1032 dest = append(dest, a...) 1033 dest = append(dest, b...) 1034 return dest 1035 }