compiler.go (26860B)
1 package encoder 2 3 import ( 4 "context" 5 "encoding" 6 "encoding/json" 7 "reflect" 8 "sync/atomic" 9 "unsafe" 10 11 "github.com/goccy/go-json/internal/errors" 12 "github.com/goccy/go-json/internal/runtime" 13 ) 14 15 type marshalerContext interface { 16 MarshalJSON(context.Context) ([]byte, error) 17 } 18 19 var ( 20 marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() 21 marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem() 22 marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() 23 jsonNumberType = reflect.TypeOf(json.Number("")) 24 cachedOpcodeSets []*OpcodeSet 25 cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet 26 typeAddr *runtime.TypeAddr 27 ) 28 29 func init() { 30 typeAddr = runtime.AnalyzeTypeAddr() 31 if typeAddr == nil { 32 typeAddr = &runtime.TypeAddr{} 33 } 34 cachedOpcodeSets = make([]*OpcodeSet, typeAddr.AddrRange>>typeAddr.AddrShift+1) 35 } 36 37 func loadOpcodeMap() map[uintptr]*OpcodeSet { 38 p := atomic.LoadPointer(&cachedOpcodeMap) 39 return *(*map[uintptr]*OpcodeSet)(unsafe.Pointer(&p)) 40 } 41 42 func storeOpcodeSet(typ uintptr, set *OpcodeSet, m map[uintptr]*OpcodeSet) { 43 newOpcodeMap := make(map[uintptr]*OpcodeSet, len(m)+1) 44 newOpcodeMap[typ] = set 45 46 for k, v := range m { 47 newOpcodeMap[k] = v 48 } 49 50 atomic.StorePointer(&cachedOpcodeMap, *(*unsafe.Pointer)(unsafe.Pointer(&newOpcodeMap))) 51 } 52 53 func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) { 54 opcodeMap := loadOpcodeMap() 55 if codeSet, exists := opcodeMap[typeptr]; exists { 56 return codeSet, nil 57 } 58 codeSet, err := newCompiler().compile(typeptr) 59 if err != nil { 60 return nil, err 61 } 62 storeOpcodeSet(typeptr, codeSet, opcodeMap) 63 return codeSet, nil 64 } 65 66 func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*OpcodeSet, error) { 67 if (ctx.Option.Flag & ContextOption) == 0 { 68 return codeSet, nil 69 } 70 query := FieldQueryFromContext(ctx.Option.Context) 71 if query == nil { 72 return codeSet, nil 73 } 74 ctx.Option.Flag |= FieldQueryOption 75 cacheCodeSet := codeSet.getQueryCache(query.Hash()) 76 if cacheCodeSet != nil { 77 return cacheCodeSet, nil 78 } 79 queryCodeSet, err := newCompiler().codeToOpcodeSet(codeSet.Type, codeSet.Code.Filter(query)) 80 if err != nil { 81 return nil, err 82 } 83 codeSet.setQueryCache(query.Hash(), queryCodeSet) 84 return queryCodeSet, nil 85 } 86 87 type Compiler struct { 88 structTypeToCode map[uintptr]*StructCode 89 } 90 91 func newCompiler() *Compiler { 92 return &Compiler{ 93 structTypeToCode: map[uintptr]*StructCode{}, 94 } 95 } 96 97 func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) { 98 // noescape trick for header.typ ( reflect.*rtype ) 99 typ := *(**runtime.Type)(unsafe.Pointer(&typeptr)) 100 code, err := c.typeToCode(typ) 101 if err != nil { 102 return nil, err 103 } 104 return c.codeToOpcodeSet(typ, code) 105 } 106 107 func (c *Compiler) codeToOpcodeSet(typ *runtime.Type, code Code) (*OpcodeSet, error) { 108 noescapeKeyCode := c.codeToOpcode(&compileContext{ 109 structTypeToCodes: map[uintptr]Opcodes{}, 110 recursiveCodes: &Opcodes{}, 111 }, typ, code) 112 if err := noescapeKeyCode.Validate(); err != nil { 113 return nil, err 114 } 115 escapeKeyCode := c.codeToOpcode(&compileContext{ 116 structTypeToCodes: map[uintptr]Opcodes{}, 117 recursiveCodes: &Opcodes{}, 118 escapeKey: true, 119 }, typ, code) 120 noescapeKeyCode = copyOpcode(noescapeKeyCode) 121 escapeKeyCode = copyOpcode(escapeKeyCode) 122 setTotalLengthToInterfaceOp(noescapeKeyCode) 123 setTotalLengthToInterfaceOp(escapeKeyCode) 124 interfaceNoescapeKeyCode := copyToInterfaceOpcode(noescapeKeyCode) 125 interfaceEscapeKeyCode := copyToInterfaceOpcode(escapeKeyCode) 126 codeLength := noescapeKeyCode.TotalLength() 127 return &OpcodeSet{ 128 Type: typ, 129 NoescapeKeyCode: noescapeKeyCode, 130 EscapeKeyCode: escapeKeyCode, 131 InterfaceNoescapeKeyCode: interfaceNoescapeKeyCode, 132 InterfaceEscapeKeyCode: interfaceEscapeKeyCode, 133 CodeLength: codeLength, 134 EndCode: ToEndCode(interfaceNoescapeKeyCode), 135 Code: code, 136 QueryCache: map[string]*OpcodeSet{}, 137 }, nil 138 } 139 140 func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) { 141 switch { 142 case c.implementsMarshalJSON(typ): 143 return c.marshalJSONCode(typ) 144 case c.implementsMarshalText(typ): 145 return c.marshalTextCode(typ) 146 } 147 148 isPtr := false 149 orgType := typ 150 if typ.Kind() == reflect.Ptr { 151 typ = typ.Elem() 152 isPtr = true 153 } 154 switch { 155 case c.implementsMarshalJSON(typ): 156 return c.marshalJSONCode(orgType) 157 case c.implementsMarshalText(typ): 158 return c.marshalTextCode(orgType) 159 } 160 switch typ.Kind() { 161 case reflect.Slice: 162 elem := typ.Elem() 163 if elem.Kind() == reflect.Uint8 { 164 p := runtime.PtrTo(elem) 165 if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { 166 return c.bytesCode(typ, isPtr) 167 } 168 } 169 return c.sliceCode(typ) 170 case reflect.Map: 171 if isPtr { 172 return c.ptrCode(runtime.PtrTo(typ)) 173 } 174 return c.mapCode(typ) 175 case reflect.Struct: 176 return c.structCode(typ, isPtr) 177 case reflect.Int: 178 return c.intCode(typ, isPtr) 179 case reflect.Int8: 180 return c.int8Code(typ, isPtr) 181 case reflect.Int16: 182 return c.int16Code(typ, isPtr) 183 case reflect.Int32: 184 return c.int32Code(typ, isPtr) 185 case reflect.Int64: 186 return c.int64Code(typ, isPtr) 187 case reflect.Uint, reflect.Uintptr: 188 return c.uintCode(typ, isPtr) 189 case reflect.Uint8: 190 return c.uint8Code(typ, isPtr) 191 case reflect.Uint16: 192 return c.uint16Code(typ, isPtr) 193 case reflect.Uint32: 194 return c.uint32Code(typ, isPtr) 195 case reflect.Uint64: 196 return c.uint64Code(typ, isPtr) 197 case reflect.Float32: 198 return c.float32Code(typ, isPtr) 199 case reflect.Float64: 200 return c.float64Code(typ, isPtr) 201 case reflect.String: 202 return c.stringCode(typ, isPtr) 203 case reflect.Bool: 204 return c.boolCode(typ, isPtr) 205 case reflect.Interface: 206 return c.interfaceCode(typ, isPtr) 207 default: 208 if isPtr && typ.Implements(marshalTextType) { 209 typ = orgType 210 } 211 return c.typeToCodeWithPtr(typ, isPtr) 212 } 213 } 214 215 func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error) { 216 switch { 217 case c.implementsMarshalJSON(typ): 218 return c.marshalJSONCode(typ) 219 case c.implementsMarshalText(typ): 220 return c.marshalTextCode(typ) 221 } 222 switch typ.Kind() { 223 case reflect.Ptr: 224 return c.ptrCode(typ) 225 case reflect.Slice: 226 elem := typ.Elem() 227 if elem.Kind() == reflect.Uint8 { 228 p := runtime.PtrTo(elem) 229 if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) { 230 return c.bytesCode(typ, false) 231 } 232 } 233 return c.sliceCode(typ) 234 case reflect.Array: 235 return c.arrayCode(typ) 236 case reflect.Map: 237 return c.mapCode(typ) 238 case reflect.Struct: 239 return c.structCode(typ, isPtr) 240 case reflect.Interface: 241 return c.interfaceCode(typ, false) 242 case reflect.Int: 243 return c.intCode(typ, false) 244 case reflect.Int8: 245 return c.int8Code(typ, false) 246 case reflect.Int16: 247 return c.int16Code(typ, false) 248 case reflect.Int32: 249 return c.int32Code(typ, false) 250 case reflect.Int64: 251 return c.int64Code(typ, false) 252 case reflect.Uint: 253 return c.uintCode(typ, false) 254 case reflect.Uint8: 255 return c.uint8Code(typ, false) 256 case reflect.Uint16: 257 return c.uint16Code(typ, false) 258 case reflect.Uint32: 259 return c.uint32Code(typ, false) 260 case reflect.Uint64: 261 return c.uint64Code(typ, false) 262 case reflect.Uintptr: 263 return c.uintCode(typ, false) 264 case reflect.Float32: 265 return c.float32Code(typ, false) 266 case reflect.Float64: 267 return c.float64Code(typ, false) 268 case reflect.String: 269 return c.stringCode(typ, false) 270 case reflect.Bool: 271 return c.boolCode(typ, false) 272 } 273 return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)} 274 } 275 276 const intSize = 32 << (^uint(0) >> 63) 277 278 //nolint:unparam 279 func (c *Compiler) intCode(typ *runtime.Type, isPtr bool) (*IntCode, error) { 280 return &IntCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil 281 } 282 283 //nolint:unparam 284 func (c *Compiler) int8Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 285 return &IntCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil 286 } 287 288 //nolint:unparam 289 func (c *Compiler) int16Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 290 return &IntCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil 291 } 292 293 //nolint:unparam 294 func (c *Compiler) int32Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 295 return &IntCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil 296 } 297 298 //nolint:unparam 299 func (c *Compiler) int64Code(typ *runtime.Type, isPtr bool) (*IntCode, error) { 300 return &IntCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil 301 } 302 303 //nolint:unparam 304 func (c *Compiler) uintCode(typ *runtime.Type, isPtr bool) (*UintCode, error) { 305 return &UintCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil 306 } 307 308 //nolint:unparam 309 func (c *Compiler) uint8Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 310 return &UintCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil 311 } 312 313 //nolint:unparam 314 func (c *Compiler) uint16Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 315 return &UintCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil 316 } 317 318 //nolint:unparam 319 func (c *Compiler) uint32Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 320 return &UintCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil 321 } 322 323 //nolint:unparam 324 func (c *Compiler) uint64Code(typ *runtime.Type, isPtr bool) (*UintCode, error) { 325 return &UintCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil 326 } 327 328 //nolint:unparam 329 func (c *Compiler) float32Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) { 330 return &FloatCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil 331 } 332 333 //nolint:unparam 334 func (c *Compiler) float64Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) { 335 return &FloatCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil 336 } 337 338 //nolint:unparam 339 func (c *Compiler) stringCode(typ *runtime.Type, isPtr bool) (*StringCode, error) { 340 return &StringCode{typ: typ, isPtr: isPtr}, nil 341 } 342 343 //nolint:unparam 344 func (c *Compiler) boolCode(typ *runtime.Type, isPtr bool) (*BoolCode, error) { 345 return &BoolCode{typ: typ, isPtr: isPtr}, nil 346 } 347 348 //nolint:unparam 349 func (c *Compiler) intStringCode(typ *runtime.Type) (*IntCode, error) { 350 return &IntCode{typ: typ, bitSize: intSize, isString: true}, nil 351 } 352 353 //nolint:unparam 354 func (c *Compiler) int8StringCode(typ *runtime.Type) (*IntCode, error) { 355 return &IntCode{typ: typ, bitSize: 8, isString: true}, nil 356 } 357 358 //nolint:unparam 359 func (c *Compiler) int16StringCode(typ *runtime.Type) (*IntCode, error) { 360 return &IntCode{typ: typ, bitSize: 16, isString: true}, nil 361 } 362 363 //nolint:unparam 364 func (c *Compiler) int32StringCode(typ *runtime.Type) (*IntCode, error) { 365 return &IntCode{typ: typ, bitSize: 32, isString: true}, nil 366 } 367 368 //nolint:unparam 369 func (c *Compiler) int64StringCode(typ *runtime.Type) (*IntCode, error) { 370 return &IntCode{typ: typ, bitSize: 64, isString: true}, nil 371 } 372 373 //nolint:unparam 374 func (c *Compiler) uintStringCode(typ *runtime.Type) (*UintCode, error) { 375 return &UintCode{typ: typ, bitSize: intSize, isString: true}, nil 376 } 377 378 //nolint:unparam 379 func (c *Compiler) uint8StringCode(typ *runtime.Type) (*UintCode, error) { 380 return &UintCode{typ: typ, bitSize: 8, isString: true}, nil 381 } 382 383 //nolint:unparam 384 func (c *Compiler) uint16StringCode(typ *runtime.Type) (*UintCode, error) { 385 return &UintCode{typ: typ, bitSize: 16, isString: true}, nil 386 } 387 388 //nolint:unparam 389 func (c *Compiler) uint32StringCode(typ *runtime.Type) (*UintCode, error) { 390 return &UintCode{typ: typ, bitSize: 32, isString: true}, nil 391 } 392 393 //nolint:unparam 394 func (c *Compiler) uint64StringCode(typ *runtime.Type) (*UintCode, error) { 395 return &UintCode{typ: typ, bitSize: 64, isString: true}, nil 396 } 397 398 //nolint:unparam 399 func (c *Compiler) bytesCode(typ *runtime.Type, isPtr bool) (*BytesCode, error) { 400 return &BytesCode{typ: typ, isPtr: isPtr}, nil 401 } 402 403 //nolint:unparam 404 func (c *Compiler) interfaceCode(typ *runtime.Type, isPtr bool) (*InterfaceCode, error) { 405 return &InterfaceCode{typ: typ, isPtr: isPtr}, nil 406 } 407 408 //nolint:unparam 409 func (c *Compiler) marshalJSONCode(typ *runtime.Type) (*MarshalJSONCode, error) { 410 return &MarshalJSONCode{ 411 typ: typ, 412 isAddrForMarshaler: c.isPtrMarshalJSONType(typ), 413 isNilableType: c.isNilableType(typ), 414 isMarshalerContext: typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType), 415 }, nil 416 } 417 418 //nolint:unparam 419 func (c *Compiler) marshalTextCode(typ *runtime.Type) (*MarshalTextCode, error) { 420 return &MarshalTextCode{ 421 typ: typ, 422 isAddrForMarshaler: c.isPtrMarshalTextType(typ), 423 isNilableType: c.isNilableType(typ), 424 }, nil 425 } 426 427 func (c *Compiler) ptrCode(typ *runtime.Type) (*PtrCode, error) { 428 code, err := c.typeToCodeWithPtr(typ.Elem(), true) 429 if err != nil { 430 return nil, err 431 } 432 ptr, ok := code.(*PtrCode) 433 if ok { 434 return &PtrCode{typ: typ, value: ptr.value, ptrNum: ptr.ptrNum + 1}, nil 435 } 436 return &PtrCode{typ: typ, value: code, ptrNum: 1}, nil 437 } 438 439 func (c *Compiler) sliceCode(typ *runtime.Type) (*SliceCode, error) { 440 elem := typ.Elem() 441 code, err := c.listElemCode(elem) 442 if err != nil { 443 return nil, err 444 } 445 if code.Kind() == CodeKindStruct { 446 structCode := code.(*StructCode) 447 structCode.enableIndirect() 448 } 449 return &SliceCode{typ: typ, value: code}, nil 450 } 451 452 func (c *Compiler) arrayCode(typ *runtime.Type) (*ArrayCode, error) { 453 elem := typ.Elem() 454 code, err := c.listElemCode(elem) 455 if err != nil { 456 return nil, err 457 } 458 if code.Kind() == CodeKindStruct { 459 structCode := code.(*StructCode) 460 structCode.enableIndirect() 461 } 462 return &ArrayCode{typ: typ, value: code}, nil 463 } 464 465 func (c *Compiler) mapCode(typ *runtime.Type) (*MapCode, error) { 466 keyCode, err := c.mapKeyCode(typ.Key()) 467 if err != nil { 468 return nil, err 469 } 470 valueCode, err := c.mapValueCode(typ.Elem()) 471 if err != nil { 472 return nil, err 473 } 474 if valueCode.Kind() == CodeKindStruct { 475 structCode := valueCode.(*StructCode) 476 structCode.enableIndirect() 477 } 478 return &MapCode{typ: typ, key: keyCode, value: valueCode}, nil 479 } 480 481 func (c *Compiler) listElemCode(typ *runtime.Type) (Code, error) { 482 switch { 483 case c.isPtrMarshalJSONType(typ): 484 return c.marshalJSONCode(typ) 485 case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType): 486 return c.marshalTextCode(typ) 487 case typ.Kind() == reflect.Map: 488 return c.ptrCode(runtime.PtrTo(typ)) 489 default: 490 // isPtr was originally used to indicate whether the type of top level is pointer. 491 // However, since the slice/array element is a specification that can get the pointer address, explicitly set isPtr to true. 492 // See here for related issues: https://github.com/goccy/go-json/issues/370 493 code, err := c.typeToCodeWithPtr(typ, true) 494 if err != nil { 495 return nil, err 496 } 497 ptr, ok := code.(*PtrCode) 498 if ok { 499 if ptr.value.Kind() == CodeKindMap { 500 ptr.ptrNum++ 501 } 502 } 503 return code, nil 504 } 505 } 506 507 func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) { 508 switch { 509 case c.implementsMarshalText(typ): 510 return c.marshalTextCode(typ) 511 } 512 switch typ.Kind() { 513 case reflect.Ptr: 514 return c.ptrCode(typ) 515 case reflect.String: 516 return c.stringCode(typ, false) 517 case reflect.Int: 518 return c.intStringCode(typ) 519 case reflect.Int8: 520 return c.int8StringCode(typ) 521 case reflect.Int16: 522 return c.int16StringCode(typ) 523 case reflect.Int32: 524 return c.int32StringCode(typ) 525 case reflect.Int64: 526 return c.int64StringCode(typ) 527 case reflect.Uint: 528 return c.uintStringCode(typ) 529 case reflect.Uint8: 530 return c.uint8StringCode(typ) 531 case reflect.Uint16: 532 return c.uint16StringCode(typ) 533 case reflect.Uint32: 534 return c.uint32StringCode(typ) 535 case reflect.Uint64: 536 return c.uint64StringCode(typ) 537 case reflect.Uintptr: 538 return c.uintStringCode(typ) 539 } 540 return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)} 541 } 542 543 func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) { 544 switch typ.Kind() { 545 case reflect.Map: 546 return c.ptrCode(runtime.PtrTo(typ)) 547 default: 548 code, err := c.typeToCodeWithPtr(typ, false) 549 if err != nil { 550 return nil, err 551 } 552 ptr, ok := code.(*PtrCode) 553 if ok { 554 if ptr.value.Kind() == CodeKindMap { 555 ptr.ptrNum++ 556 } 557 } 558 return code, nil 559 } 560 } 561 562 func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error) { 563 typeptr := uintptr(unsafe.Pointer(typ)) 564 if code, exists := c.structTypeToCode[typeptr]; exists { 565 derefCode := *code 566 derefCode.isRecursive = true 567 return &derefCode, nil 568 } 569 indirect := runtime.IfaceIndir(typ) 570 code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect} 571 c.structTypeToCode[typeptr] = code 572 573 fieldNum := typ.NumField() 574 tags := c.typeToStructTags(typ) 575 fields := []*StructFieldCode{} 576 for i, tag := range tags { 577 isOnlyOneFirstField := i == 0 && fieldNum == 1 578 field, err := c.structFieldCode(code, tag, isPtr, isOnlyOneFirstField) 579 if err != nil { 580 return nil, err 581 } 582 if field.isAnonymous { 583 structCode := field.getAnonymousStruct() 584 if structCode != nil { 585 structCode.removeFieldsByTags(tags) 586 if c.isAssignableIndirect(field, isPtr) { 587 if indirect { 588 structCode.isIndirect = true 589 } else { 590 structCode.isIndirect = false 591 } 592 } 593 } 594 } else { 595 structCode := field.getStruct() 596 if structCode != nil { 597 if indirect { 598 // if parent is indirect type, set child indirect property to true 599 structCode.isIndirect = true 600 } else { 601 // if parent is not indirect type, set child indirect property to false. 602 // but if parent's indirect is false and isPtr is true, then indirect must be true. 603 // Do this only if indirectConversion is enabled at the end of compileStruct. 604 structCode.isIndirect = false 605 } 606 } 607 } 608 fields = append(fields, field) 609 } 610 fieldMap := c.getFieldMap(fields) 611 duplicatedFieldMap := c.getDuplicatedFieldMap(fieldMap) 612 code.fields = c.filteredDuplicatedFields(fields, duplicatedFieldMap) 613 if !code.disableIndirectConversion && !indirect && isPtr { 614 code.enableIndirect() 615 } 616 delete(c.structTypeToCode, typeptr) 617 return code, nil 618 } 619 620 func toElemType(t *runtime.Type) *runtime.Type { 621 for t.Kind() == reflect.Ptr { 622 t = t.Elem() 623 } 624 return t 625 } 626 627 func (c *Compiler) structFieldCode(structCode *StructCode, tag *runtime.StructTag, isPtr, isOnlyOneFirstField bool) (*StructFieldCode, error) { 628 field := tag.Field 629 fieldType := runtime.Type2RType(field.Type) 630 isIndirectSpecialCase := isPtr && isOnlyOneFirstField 631 fieldCode := &StructFieldCode{ 632 typ: fieldType, 633 key: tag.Key, 634 tag: tag, 635 offset: field.Offset, 636 isAnonymous: field.Anonymous && !tag.IsTaggedKey && toElemType(fieldType).Kind() == reflect.Struct, 637 isTaggedKey: tag.IsTaggedKey, 638 isNilableType: c.isNilableType(fieldType), 639 isNilCheck: true, 640 } 641 switch { 642 case c.isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(fieldType, isIndirectSpecialCase): 643 code, err := c.marshalJSONCode(fieldType) 644 if err != nil { 645 return nil, err 646 } 647 fieldCode.value = code 648 fieldCode.isAddrForMarshaler = true 649 fieldCode.isNilCheck = false 650 structCode.isIndirect = false 651 structCode.disableIndirectConversion = true 652 case c.isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(fieldType, isIndirectSpecialCase): 653 code, err := c.marshalTextCode(fieldType) 654 if err != nil { 655 return nil, err 656 } 657 fieldCode.value = code 658 fieldCode.isAddrForMarshaler = true 659 fieldCode.isNilCheck = false 660 structCode.isIndirect = false 661 structCode.disableIndirectConversion = true 662 case isPtr && c.isPtrMarshalJSONType(fieldType): 663 // *struct{ field T } 664 // func (*T) MarshalJSON() ([]byte, error) 665 code, err := c.marshalJSONCode(fieldType) 666 if err != nil { 667 return nil, err 668 } 669 fieldCode.value = code 670 fieldCode.isAddrForMarshaler = true 671 fieldCode.isNilCheck = false 672 case isPtr && c.isPtrMarshalTextType(fieldType): 673 // *struct{ field T } 674 // func (*T) MarshalText() ([]byte, error) 675 code, err := c.marshalTextCode(fieldType) 676 if err != nil { 677 return nil, err 678 } 679 fieldCode.value = code 680 fieldCode.isAddrForMarshaler = true 681 fieldCode.isNilCheck = false 682 default: 683 code, err := c.typeToCodeWithPtr(fieldType, isPtr) 684 if err != nil { 685 return nil, err 686 } 687 switch code.Kind() { 688 case CodeKindPtr, CodeKindInterface: 689 fieldCode.isNextOpPtrType = true 690 } 691 fieldCode.value = code 692 } 693 return fieldCode, nil 694 } 695 696 func (c *Compiler) isAssignableIndirect(fieldCode *StructFieldCode, isPtr bool) bool { 697 if isPtr { 698 return false 699 } 700 codeType := fieldCode.value.Kind() 701 if codeType == CodeKindMarshalJSON { 702 return false 703 } 704 if codeType == CodeKindMarshalText { 705 return false 706 } 707 return true 708 } 709 710 func (c *Compiler) getFieldMap(fields []*StructFieldCode) map[string][]*StructFieldCode { 711 fieldMap := map[string][]*StructFieldCode{} 712 for _, field := range fields { 713 if field.isAnonymous { 714 for k, v := range c.getAnonymousFieldMap(field) { 715 fieldMap[k] = append(fieldMap[k], v...) 716 } 717 continue 718 } 719 fieldMap[field.key] = append(fieldMap[field.key], field) 720 } 721 return fieldMap 722 } 723 724 func (c *Compiler) getAnonymousFieldMap(field *StructFieldCode) map[string][]*StructFieldCode { 725 fieldMap := map[string][]*StructFieldCode{} 726 structCode := field.getAnonymousStruct() 727 if structCode == nil || structCode.isRecursive { 728 fieldMap[field.key] = append(fieldMap[field.key], field) 729 return fieldMap 730 } 731 for k, v := range c.getFieldMapFromAnonymousParent(structCode.fields) { 732 fieldMap[k] = append(fieldMap[k], v...) 733 } 734 return fieldMap 735 } 736 737 func (c *Compiler) getFieldMapFromAnonymousParent(fields []*StructFieldCode) map[string][]*StructFieldCode { 738 fieldMap := map[string][]*StructFieldCode{} 739 for _, field := range fields { 740 if field.isAnonymous { 741 for k, v := range c.getAnonymousFieldMap(field) { 742 // Do not handle tagged key when embedding more than once 743 for _, vv := range v { 744 vv.isTaggedKey = false 745 } 746 fieldMap[k] = append(fieldMap[k], v...) 747 } 748 continue 749 } 750 fieldMap[field.key] = append(fieldMap[field.key], field) 751 } 752 return fieldMap 753 } 754 755 func (c *Compiler) getDuplicatedFieldMap(fieldMap map[string][]*StructFieldCode) map[*StructFieldCode]struct{} { 756 duplicatedFieldMap := map[*StructFieldCode]struct{}{} 757 for _, fields := range fieldMap { 758 if len(fields) == 1 { 759 continue 760 } 761 if c.isTaggedKeyOnly(fields) { 762 for _, field := range fields { 763 if field.isTaggedKey { 764 continue 765 } 766 duplicatedFieldMap[field] = struct{}{} 767 } 768 } else { 769 for _, field := range fields { 770 duplicatedFieldMap[field] = struct{}{} 771 } 772 } 773 } 774 return duplicatedFieldMap 775 } 776 777 func (c *Compiler) filteredDuplicatedFields(fields []*StructFieldCode, duplicatedFieldMap map[*StructFieldCode]struct{}) []*StructFieldCode { 778 filteredFields := make([]*StructFieldCode, 0, len(fields)) 779 for _, field := range fields { 780 if field.isAnonymous { 781 structCode := field.getAnonymousStruct() 782 if structCode != nil && !structCode.isRecursive { 783 structCode.fields = c.filteredDuplicatedFields(structCode.fields, duplicatedFieldMap) 784 if len(structCode.fields) > 0 { 785 filteredFields = append(filteredFields, field) 786 } 787 continue 788 } 789 } 790 if _, exists := duplicatedFieldMap[field]; exists { 791 continue 792 } 793 filteredFields = append(filteredFields, field) 794 } 795 return filteredFields 796 } 797 798 func (c *Compiler) isTaggedKeyOnly(fields []*StructFieldCode) bool { 799 var taggedKeyFieldCount int 800 for _, field := range fields { 801 if field.isTaggedKey { 802 taggedKeyFieldCount++ 803 } 804 } 805 return taggedKeyFieldCount == 1 806 } 807 808 func (c *Compiler) typeToStructTags(typ *runtime.Type) runtime.StructTags { 809 tags := runtime.StructTags{} 810 fieldNum := typ.NumField() 811 for i := 0; i < fieldNum; i++ { 812 field := typ.Field(i) 813 if runtime.IsIgnoredStructField(field) { 814 continue 815 } 816 tags = append(tags, runtime.StructTagFromField(field)) 817 } 818 return tags 819 } 820 821 // *struct{ field T } => struct { field *T } 822 // func (*T) MarshalJSON() ([]byte, error) 823 func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool { 824 return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalJSONType(typ) 825 } 826 827 // *struct{ field T } => struct { field *T } 828 // func (*T) MarshalText() ([]byte, error) 829 func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool { 830 return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalTextType(typ) 831 } 832 833 func (c *Compiler) implementsMarshalJSON(typ *runtime.Type) bool { 834 if !c.implementsMarshalJSONType(typ) { 835 return false 836 } 837 if typ.Kind() != reflect.Ptr { 838 return true 839 } 840 // type kind is reflect.Ptr 841 if !c.implementsMarshalJSONType(typ.Elem()) { 842 return true 843 } 844 // needs to dereference 845 return false 846 } 847 848 func (c *Compiler) implementsMarshalText(typ *runtime.Type) bool { 849 if !typ.Implements(marshalTextType) { 850 return false 851 } 852 if typ.Kind() != reflect.Ptr { 853 return true 854 } 855 // type kind is reflect.Ptr 856 if !typ.Elem().Implements(marshalTextType) { 857 return true 858 } 859 // needs to dereference 860 return false 861 } 862 863 func (c *Compiler) isNilableType(typ *runtime.Type) bool { 864 if !runtime.IfaceIndir(typ) { 865 return true 866 } 867 switch typ.Kind() { 868 case reflect.Ptr: 869 return true 870 case reflect.Map: 871 return true 872 case reflect.Func: 873 return true 874 default: 875 return false 876 } 877 } 878 879 func (c *Compiler) implementsMarshalJSONType(typ *runtime.Type) bool { 880 return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType) 881 } 882 883 func (c *Compiler) isPtrMarshalJSONType(typ *runtime.Type) bool { 884 return !c.implementsMarshalJSONType(typ) && c.implementsMarshalJSONType(runtime.PtrTo(typ)) 885 } 886 887 func (c *Compiler) isPtrMarshalTextType(typ *runtime.Type) bool { 888 return !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType) 889 } 890 891 func (c *Compiler) codeToOpcode(ctx *compileContext, typ *runtime.Type, code Code) *Opcode { 892 codes := code.ToOpcode(ctx) 893 codes.Last().Next = newEndOp(ctx, typ) 894 c.linkRecursiveCode(ctx) 895 return codes.First() 896 } 897 898 func (c *Compiler) linkRecursiveCode(ctx *compileContext) { 899 recursiveCodes := map[uintptr]*CompiledCode{} 900 for _, recursive := range *ctx.recursiveCodes { 901 typeptr := uintptr(unsafe.Pointer(recursive.Type)) 902 codes := ctx.structTypeToCodes[typeptr] 903 if recursiveCode, ok := recursiveCodes[typeptr]; ok { 904 *recursive.Jmp = *recursiveCode 905 continue 906 } 907 908 code := copyOpcode(codes.First()) 909 code.Op = code.Op.PtrHeadToHead() 910 lastCode := newEndOp(&compileContext{}, recursive.Type) 911 lastCode.Op = OpRecursiveEnd 912 913 // OpRecursiveEnd must set before call TotalLength 914 code.End.Next = lastCode 915 916 totalLength := code.TotalLength() 917 918 // Idx, ElemIdx, Length must set after call TotalLength 919 lastCode.Idx = uint32((totalLength + 1) * uintptrSize) 920 lastCode.ElemIdx = lastCode.Idx + uintptrSize 921 lastCode.Length = lastCode.Idx + 2*uintptrSize 922 923 // extend length to alloc slot for elemIdx + length 924 curTotalLength := uintptr(recursive.TotalLength()) + 3 925 nextTotalLength := uintptr(totalLength) + 3 926 927 compiled := recursive.Jmp 928 compiled.Code = code 929 compiled.CurLen = curTotalLength 930 compiled.NextLen = nextTotalLength 931 compiled.Linked = true 932 933 recursiveCodes[typeptr] = compiled 934 } 935 }