decode.go (14197B)
1 package msgpack 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "fmt" 8 "io" 9 "reflect" 10 "sync" 11 "time" 12 13 "github.com/vmihailenco/msgpack/v5/msgpcode" 14 ) 15 16 const ( 17 looseInterfaceDecodingFlag uint32 = 1 << iota 18 disallowUnknownFieldsFlag 19 ) 20 21 const ( 22 bytesAllocLimit = 1e6 // 1mb 23 sliceAllocLimit = 1e4 24 maxMapSize = 1e6 25 ) 26 27 type bufReader interface { 28 io.Reader 29 io.ByteScanner 30 } 31 32 //------------------------------------------------------------------------------ 33 34 var decPool = sync.Pool{ 35 New: func() interface{} { 36 return NewDecoder(nil) 37 }, 38 } 39 40 func GetDecoder() *Decoder { 41 return decPool.Get().(*Decoder) 42 } 43 44 func PutDecoder(dec *Decoder) { 45 dec.r = nil 46 dec.s = nil 47 decPool.Put(dec) 48 } 49 50 //------------------------------------------------------------------------------ 51 52 // Unmarshal decodes the MessagePack-encoded data and stores the result 53 // in the value pointed to by v. 54 func Unmarshal(data []byte, v interface{}) error { 55 dec := GetDecoder() 56 57 dec.Reset(bytes.NewReader(data)) 58 err := dec.Decode(v) 59 60 PutDecoder(dec) 61 62 return err 63 } 64 65 // A Decoder reads and decodes MessagePack values from an input stream. 66 type Decoder struct { 67 r io.Reader 68 s io.ByteScanner 69 buf []byte 70 71 rec []byte // accumulates read data if not nil 72 73 dict []string 74 flags uint32 75 structTag string 76 mapDecoder func(*Decoder) (interface{}, error) 77 } 78 79 // NewDecoder returns a new decoder that reads from r. 80 // 81 // The decoder introduces its own buffering and may read data from r 82 // beyond the requested msgpack values. Buffering can be disabled 83 // by passing a reader that implements io.ByteScanner interface. 84 func NewDecoder(r io.Reader) *Decoder { 85 d := new(Decoder) 86 d.Reset(r) 87 return d 88 } 89 90 // Reset discards any buffered data, resets all state, and switches the buffered 91 // reader to read from r. 92 func (d *Decoder) Reset(r io.Reader) { 93 d.ResetDict(r, nil) 94 } 95 96 // ResetDict is like Reset, but also resets the dict. 97 func (d *Decoder) ResetDict(r io.Reader, dict []string) { 98 d.resetReader(r) 99 d.flags = 0 100 d.structTag = "" 101 d.mapDecoder = nil 102 d.dict = dict 103 } 104 105 func (d *Decoder) WithDict(dict []string, fn func(*Decoder) error) error { 106 oldDict := d.dict 107 d.dict = dict 108 err := fn(d) 109 d.dict = oldDict 110 return err 111 } 112 113 func (d *Decoder) resetReader(r io.Reader) { 114 if br, ok := r.(bufReader); ok { 115 d.r = br 116 d.s = br 117 } else { 118 br := bufio.NewReader(r) 119 d.r = br 120 d.s = br 121 } 122 } 123 124 func (d *Decoder) SetMapDecoder(fn func(*Decoder) (interface{}, error)) { 125 d.mapDecoder = fn 126 } 127 128 // UseLooseInterfaceDecoding causes decoder to use DecodeInterfaceLoose 129 // to decode msgpack value into Go interface{}. 130 func (d *Decoder) UseLooseInterfaceDecoding(on bool) { 131 if on { 132 d.flags |= looseInterfaceDecodingFlag 133 } else { 134 d.flags &= ^looseInterfaceDecodingFlag 135 } 136 } 137 138 // SetCustomStructTag causes the decoder to use the supplied tag as a fallback option 139 // if there is no msgpack tag. 140 func (d *Decoder) SetCustomStructTag(tag string) { 141 d.structTag = tag 142 } 143 144 // DisallowUnknownFields causes the Decoder to return an error when the destination 145 // is a struct and the input contains object keys which do not match any 146 // non-ignored, exported fields in the destination. 147 func (d *Decoder) DisallowUnknownFields(on bool) { 148 if on { 149 d.flags |= disallowUnknownFieldsFlag 150 } else { 151 d.flags &= ^disallowUnknownFieldsFlag 152 } 153 } 154 155 // UseInternedStrings enables support for decoding interned strings. 156 func (d *Decoder) UseInternedStrings(on bool) { 157 if on { 158 d.flags |= useInternedStringsFlag 159 } else { 160 d.flags &= ^useInternedStringsFlag 161 } 162 } 163 164 // Buffered returns a reader of the data remaining in the Decoder's buffer. 165 // The reader is valid until the next call to Decode. 166 func (d *Decoder) Buffered() io.Reader { 167 return d.r 168 } 169 170 //nolint:gocyclo 171 func (d *Decoder) Decode(v interface{}) error { 172 var err error 173 switch v := v.(type) { 174 case *string: 175 if v != nil { 176 *v, err = d.DecodeString() 177 return err 178 } 179 case *[]byte: 180 if v != nil { 181 return d.decodeBytesPtr(v) 182 } 183 case *int: 184 if v != nil { 185 *v, err = d.DecodeInt() 186 return err 187 } 188 case *int8: 189 if v != nil { 190 *v, err = d.DecodeInt8() 191 return err 192 } 193 case *int16: 194 if v != nil { 195 *v, err = d.DecodeInt16() 196 return err 197 } 198 case *int32: 199 if v != nil { 200 *v, err = d.DecodeInt32() 201 return err 202 } 203 case *int64: 204 if v != nil { 205 *v, err = d.DecodeInt64() 206 return err 207 } 208 case *uint: 209 if v != nil { 210 *v, err = d.DecodeUint() 211 return err 212 } 213 case *uint8: 214 if v != nil { 215 *v, err = d.DecodeUint8() 216 return err 217 } 218 case *uint16: 219 if v != nil { 220 *v, err = d.DecodeUint16() 221 return err 222 } 223 case *uint32: 224 if v != nil { 225 *v, err = d.DecodeUint32() 226 return err 227 } 228 case *uint64: 229 if v != nil { 230 *v, err = d.DecodeUint64() 231 return err 232 } 233 case *bool: 234 if v != nil { 235 *v, err = d.DecodeBool() 236 return err 237 } 238 case *float32: 239 if v != nil { 240 *v, err = d.DecodeFloat32() 241 return err 242 } 243 case *float64: 244 if v != nil { 245 *v, err = d.DecodeFloat64() 246 return err 247 } 248 case *[]string: 249 return d.decodeStringSlicePtr(v) 250 case *map[string]string: 251 return d.decodeMapStringStringPtr(v) 252 case *map[string]interface{}: 253 return d.decodeMapStringInterfacePtr(v) 254 case *time.Duration: 255 if v != nil { 256 vv, err := d.DecodeInt64() 257 *v = time.Duration(vv) 258 return err 259 } 260 case *time.Time: 261 if v != nil { 262 *v, err = d.DecodeTime() 263 return err 264 } 265 } 266 267 vv := reflect.ValueOf(v) 268 if !vv.IsValid() { 269 return errors.New("msgpack: Decode(nil)") 270 } 271 if vv.Kind() != reflect.Ptr { 272 return fmt.Errorf("msgpack: Decode(non-pointer %T)", v) 273 } 274 if vv.IsNil() { 275 return fmt.Errorf("msgpack: Decode(non-settable %T)", v) 276 } 277 278 vv = vv.Elem() 279 if vv.Kind() == reflect.Interface { 280 if !vv.IsNil() { 281 vv = vv.Elem() 282 if vv.Kind() != reflect.Ptr { 283 return fmt.Errorf("msgpack: Decode(non-pointer %s)", vv.Type().String()) 284 } 285 } 286 } 287 288 return d.DecodeValue(vv) 289 } 290 291 func (d *Decoder) DecodeMulti(v ...interface{}) error { 292 for _, vv := range v { 293 if err := d.Decode(vv); err != nil { 294 return err 295 } 296 } 297 return nil 298 } 299 300 func (d *Decoder) decodeInterfaceCond() (interface{}, error) { 301 if d.flags&looseInterfaceDecodingFlag != 0 { 302 return d.DecodeInterfaceLoose() 303 } 304 return d.DecodeInterface() 305 } 306 307 func (d *Decoder) DecodeValue(v reflect.Value) error { 308 decode := getDecoder(v.Type()) 309 return decode(d, v) 310 } 311 312 func (d *Decoder) DecodeNil() error { 313 c, err := d.readCode() 314 if err != nil { 315 return err 316 } 317 if c != msgpcode.Nil { 318 return fmt.Errorf("msgpack: invalid code=%x decoding nil", c) 319 } 320 return nil 321 } 322 323 func (d *Decoder) decodeNilValue(v reflect.Value) error { 324 err := d.DecodeNil() 325 if v.IsNil() { 326 return err 327 } 328 if v.Kind() == reflect.Ptr { 329 v = v.Elem() 330 } 331 v.Set(reflect.Zero(v.Type())) 332 return err 333 } 334 335 func (d *Decoder) DecodeBool() (bool, error) { 336 c, err := d.readCode() 337 if err != nil { 338 return false, err 339 } 340 return d.bool(c) 341 } 342 343 func (d *Decoder) bool(c byte) (bool, error) { 344 if c == msgpcode.Nil { 345 return false, nil 346 } 347 if c == msgpcode.False { 348 return false, nil 349 } 350 if c == msgpcode.True { 351 return true, nil 352 } 353 return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c) 354 } 355 356 func (d *Decoder) DecodeDuration() (time.Duration, error) { 357 n, err := d.DecodeInt64() 358 if err != nil { 359 return 0, err 360 } 361 return time.Duration(n), nil 362 } 363 364 // DecodeInterface decodes value into interface. It returns following types: 365 // - nil, 366 // - bool, 367 // - int8, int16, int32, int64, 368 // - uint8, uint16, uint32, uint64, 369 // - float32 and float64, 370 // - string, 371 // - []byte, 372 // - slices of any of the above, 373 // - maps of any of the above. 374 // 375 // DecodeInterface should be used only when you don't know the type of value 376 // you are decoding. For example, if you are decoding number it is better to use 377 // DecodeInt64 for negative numbers and DecodeUint64 for positive numbers. 378 func (d *Decoder) DecodeInterface() (interface{}, error) { 379 c, err := d.readCode() 380 if err != nil { 381 return nil, err 382 } 383 384 if msgpcode.IsFixedNum(c) { 385 return int8(c), nil 386 } 387 if msgpcode.IsFixedMap(c) { 388 err = d.s.UnreadByte() 389 if err != nil { 390 return nil, err 391 } 392 return d.decodeMapDefault() 393 } 394 if msgpcode.IsFixedArray(c) { 395 return d.decodeSlice(c) 396 } 397 if msgpcode.IsFixedString(c) { 398 return d.string(c) 399 } 400 401 switch c { 402 case msgpcode.Nil: 403 return nil, nil 404 case msgpcode.False, msgpcode.True: 405 return d.bool(c) 406 case msgpcode.Float: 407 return d.float32(c) 408 case msgpcode.Double: 409 return d.float64(c) 410 case msgpcode.Uint8: 411 return d.uint8() 412 case msgpcode.Uint16: 413 return d.uint16() 414 case msgpcode.Uint32: 415 return d.uint32() 416 case msgpcode.Uint64: 417 return d.uint64() 418 case msgpcode.Int8: 419 return d.int8() 420 case msgpcode.Int16: 421 return d.int16() 422 case msgpcode.Int32: 423 return d.int32() 424 case msgpcode.Int64: 425 return d.int64() 426 case msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32: 427 return d.bytes(c, nil) 428 case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32: 429 return d.string(c) 430 case msgpcode.Array16, msgpcode.Array32: 431 return d.decodeSlice(c) 432 case msgpcode.Map16, msgpcode.Map32: 433 err = d.s.UnreadByte() 434 if err != nil { 435 return nil, err 436 } 437 return d.decodeMapDefault() 438 case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16, 439 msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32: 440 return d.decodeInterfaceExt(c) 441 } 442 443 return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c) 444 } 445 446 // DecodeInterfaceLoose is like DecodeInterface except that: 447 // - int8, int16, and int32 are converted to int64, 448 // - uint8, uint16, and uint32 are converted to uint64, 449 // - float32 is converted to float64. 450 // - []byte is converted to string. 451 func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) { 452 c, err := d.readCode() 453 if err != nil { 454 return nil, err 455 } 456 457 if msgpcode.IsFixedNum(c) { 458 return int64(int8(c)), nil 459 } 460 if msgpcode.IsFixedMap(c) { 461 err = d.s.UnreadByte() 462 if err != nil { 463 return nil, err 464 } 465 return d.decodeMapDefault() 466 } 467 if msgpcode.IsFixedArray(c) { 468 return d.decodeSlice(c) 469 } 470 if msgpcode.IsFixedString(c) { 471 return d.string(c) 472 } 473 474 switch c { 475 case msgpcode.Nil: 476 return nil, nil 477 case msgpcode.False, msgpcode.True: 478 return d.bool(c) 479 case msgpcode.Float, msgpcode.Double: 480 return d.float64(c) 481 case msgpcode.Uint8, msgpcode.Uint16, msgpcode.Uint32, msgpcode.Uint64: 482 return d.uint(c) 483 case msgpcode.Int8, msgpcode.Int16, msgpcode.Int32, msgpcode.Int64: 484 return d.int(c) 485 case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32, 486 msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32: 487 return d.string(c) 488 case msgpcode.Array16, msgpcode.Array32: 489 return d.decodeSlice(c) 490 case msgpcode.Map16, msgpcode.Map32: 491 err = d.s.UnreadByte() 492 if err != nil { 493 return nil, err 494 } 495 return d.decodeMapDefault() 496 case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16, 497 msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32: 498 return d.decodeInterfaceExt(c) 499 } 500 501 return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c) 502 } 503 504 // Skip skips next value. 505 func (d *Decoder) Skip() error { 506 c, err := d.readCode() 507 if err != nil { 508 return err 509 } 510 511 if msgpcode.IsFixedNum(c) { 512 return nil 513 } 514 if msgpcode.IsFixedMap(c) { 515 return d.skipMap(c) 516 } 517 if msgpcode.IsFixedArray(c) { 518 return d.skipSlice(c) 519 } 520 if msgpcode.IsFixedString(c) { 521 return d.skipBytes(c) 522 } 523 524 switch c { 525 case msgpcode.Nil, msgpcode.False, msgpcode.True: 526 return nil 527 case msgpcode.Uint8, msgpcode.Int8: 528 return d.skipN(1) 529 case msgpcode.Uint16, msgpcode.Int16: 530 return d.skipN(2) 531 case msgpcode.Uint32, msgpcode.Int32, msgpcode.Float: 532 return d.skipN(4) 533 case msgpcode.Uint64, msgpcode.Int64, msgpcode.Double: 534 return d.skipN(8) 535 case msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32: 536 return d.skipBytes(c) 537 case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32: 538 return d.skipBytes(c) 539 case msgpcode.Array16, msgpcode.Array32: 540 return d.skipSlice(c) 541 case msgpcode.Map16, msgpcode.Map32: 542 return d.skipMap(c) 543 case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16, 544 msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32: 545 return d.skipExt(c) 546 } 547 548 return fmt.Errorf("msgpack: unknown code %x", c) 549 } 550 551 func (d *Decoder) DecodeRaw() (RawMessage, error) { 552 d.rec = make([]byte, 0) 553 if err := d.Skip(); err != nil { 554 return nil, err 555 } 556 msg := RawMessage(d.rec) 557 d.rec = nil 558 return msg, nil 559 } 560 561 // PeekCode returns the next MessagePack code without advancing the reader. 562 // Subpackage msgpack/codes defines the list of available msgpcode. 563 func (d *Decoder) PeekCode() (byte, error) { 564 c, err := d.s.ReadByte() 565 if err != nil { 566 return 0, err 567 } 568 return c, d.s.UnreadByte() 569 } 570 571 // ReadFull reads exactly len(buf) bytes into the buf. 572 func (d *Decoder) ReadFull(buf []byte) error { 573 _, err := readN(d.r, buf, len(buf)) 574 return err 575 } 576 577 func (d *Decoder) hasNilCode() bool { 578 code, err := d.PeekCode() 579 return err == nil && code == msgpcode.Nil 580 } 581 582 func (d *Decoder) readCode() (byte, error) { 583 c, err := d.s.ReadByte() 584 if err != nil { 585 return 0, err 586 } 587 if d.rec != nil { 588 d.rec = append(d.rec, c) 589 } 590 return c, nil 591 } 592 593 func (d *Decoder) readFull(b []byte) error { 594 _, err := io.ReadFull(d.r, b) 595 if err != nil { 596 return err 597 } 598 if d.rec != nil { 599 d.rec = append(d.rec, b...) 600 } 601 return nil 602 } 603 604 func (d *Decoder) readN(n int) ([]byte, error) { 605 var err error 606 d.buf, err = readN(d.r, d.buf, n) 607 if err != nil { 608 return nil, err 609 } 610 if d.rec != nil { 611 // TODO: read directly into d.rec? 612 d.rec = append(d.rec, d.buf...) 613 } 614 return d.buf, nil 615 } 616 617 func readN(r io.Reader, b []byte, n int) ([]byte, error) { 618 if b == nil { 619 if n == 0 { 620 return make([]byte, 0), nil 621 } 622 switch { 623 case n < 64: 624 b = make([]byte, 0, 64) 625 case n <= bytesAllocLimit: 626 b = make([]byte, 0, n) 627 default: 628 b = make([]byte, 0, bytesAllocLimit) 629 } 630 } 631 632 if n <= cap(b) { 633 b = b[:n] 634 _, err := io.ReadFull(r, b) 635 return b, err 636 } 637 b = b[:cap(b)] 638 639 var pos int 640 for { 641 alloc := min(n-len(b), bytesAllocLimit) 642 b = append(b, make([]byte, alloc)...) 643 644 _, err := io.ReadFull(r, b[pos:]) 645 if err != nil { 646 return b, err 647 } 648 649 if len(b) == n { 650 break 651 } 652 pos = len(b) 653 } 654 655 return b, nil 656 } 657 658 func min(a, b int) int { //nolint:unparam 659 if a <= b { 660 return a 661 } 662 return b 663 }