gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

decode.go (14197B)


      1 package msgpack
      2 
      3 import (
      4 	"bufio"
      5 	"bytes"
      6 	"errors"
      7 	"fmt"
      8 	"io"
      9 	"reflect"
     10 	"sync"
     11 	"time"
     12 
     13 	"github.com/vmihailenco/msgpack/v5/msgpcode"
     14 )
     15 
     16 const (
     17 	looseInterfaceDecodingFlag uint32 = 1 << iota
     18 	disallowUnknownFieldsFlag
     19 )
     20 
     21 const (
     22 	bytesAllocLimit = 1e6 // 1mb
     23 	sliceAllocLimit = 1e4
     24 	maxMapSize      = 1e6
     25 )
     26 
     27 type bufReader interface {
     28 	io.Reader
     29 	io.ByteScanner
     30 }
     31 
     32 //------------------------------------------------------------------------------
     33 
     34 var decPool = sync.Pool{
     35 	New: func() interface{} {
     36 		return NewDecoder(nil)
     37 	},
     38 }
     39 
     40 func GetDecoder() *Decoder {
     41 	return decPool.Get().(*Decoder)
     42 }
     43 
     44 func PutDecoder(dec *Decoder) {
     45 	dec.r = nil
     46 	dec.s = nil
     47 	decPool.Put(dec)
     48 }
     49 
     50 //------------------------------------------------------------------------------
     51 
     52 // Unmarshal decodes the MessagePack-encoded data and stores the result
     53 // in the value pointed to by v.
     54 func Unmarshal(data []byte, v interface{}) error {
     55 	dec := GetDecoder()
     56 
     57 	dec.Reset(bytes.NewReader(data))
     58 	err := dec.Decode(v)
     59 
     60 	PutDecoder(dec)
     61 
     62 	return err
     63 }
     64 
     65 // A Decoder reads and decodes MessagePack values from an input stream.
     66 type Decoder struct {
     67 	r   io.Reader
     68 	s   io.ByteScanner
     69 	buf []byte
     70 
     71 	rec []byte // accumulates read data if not nil
     72 
     73 	dict       []string
     74 	flags      uint32
     75 	structTag  string
     76 	mapDecoder func(*Decoder) (interface{}, error)
     77 }
     78 
     79 // NewDecoder returns a new decoder that reads from r.
     80 //
     81 // The decoder introduces its own buffering and may read data from r
     82 // beyond the requested msgpack values. Buffering can be disabled
     83 // by passing a reader that implements io.ByteScanner interface.
     84 func NewDecoder(r io.Reader) *Decoder {
     85 	d := new(Decoder)
     86 	d.Reset(r)
     87 	return d
     88 }
     89 
     90 // Reset discards any buffered data, resets all state, and switches the buffered
     91 // reader to read from r.
     92 func (d *Decoder) Reset(r io.Reader) {
     93 	d.ResetDict(r, nil)
     94 }
     95 
     96 // ResetDict is like Reset, but also resets the dict.
     97 func (d *Decoder) ResetDict(r io.Reader, dict []string) {
     98 	d.resetReader(r)
     99 	d.flags = 0
    100 	d.structTag = ""
    101 	d.mapDecoder = nil
    102 	d.dict = dict
    103 }
    104 
    105 func (d *Decoder) WithDict(dict []string, fn func(*Decoder) error) error {
    106 	oldDict := d.dict
    107 	d.dict = dict
    108 	err := fn(d)
    109 	d.dict = oldDict
    110 	return err
    111 }
    112 
    113 func (d *Decoder) resetReader(r io.Reader) {
    114 	if br, ok := r.(bufReader); ok {
    115 		d.r = br
    116 		d.s = br
    117 	} else {
    118 		br := bufio.NewReader(r)
    119 		d.r = br
    120 		d.s = br
    121 	}
    122 }
    123 
    124 func (d *Decoder) SetMapDecoder(fn func(*Decoder) (interface{}, error)) {
    125 	d.mapDecoder = fn
    126 }
    127 
    128 // UseLooseInterfaceDecoding causes decoder to use DecodeInterfaceLoose
    129 // to decode msgpack value into Go interface{}.
    130 func (d *Decoder) UseLooseInterfaceDecoding(on bool) {
    131 	if on {
    132 		d.flags |= looseInterfaceDecodingFlag
    133 	} else {
    134 		d.flags &= ^looseInterfaceDecodingFlag
    135 	}
    136 }
    137 
    138 // SetCustomStructTag causes the decoder to use the supplied tag as a fallback option
    139 // if there is no msgpack tag.
    140 func (d *Decoder) SetCustomStructTag(tag string) {
    141 	d.structTag = tag
    142 }
    143 
    144 // DisallowUnknownFields causes the Decoder to return an error when the destination
    145 // is a struct and the input contains object keys which do not match any
    146 // non-ignored, exported fields in the destination.
    147 func (d *Decoder) DisallowUnknownFields(on bool) {
    148 	if on {
    149 		d.flags |= disallowUnknownFieldsFlag
    150 	} else {
    151 		d.flags &= ^disallowUnknownFieldsFlag
    152 	}
    153 }
    154 
    155 // UseInternedStrings enables support for decoding interned strings.
    156 func (d *Decoder) UseInternedStrings(on bool) {
    157 	if on {
    158 		d.flags |= useInternedStringsFlag
    159 	} else {
    160 		d.flags &= ^useInternedStringsFlag
    161 	}
    162 }
    163 
    164 // Buffered returns a reader of the data remaining in the Decoder's buffer.
    165 // The reader is valid until the next call to Decode.
    166 func (d *Decoder) Buffered() io.Reader {
    167 	return d.r
    168 }
    169 
    170 //nolint:gocyclo
    171 func (d *Decoder) Decode(v interface{}) error {
    172 	var err error
    173 	switch v := v.(type) {
    174 	case *string:
    175 		if v != nil {
    176 			*v, err = d.DecodeString()
    177 			return err
    178 		}
    179 	case *[]byte:
    180 		if v != nil {
    181 			return d.decodeBytesPtr(v)
    182 		}
    183 	case *int:
    184 		if v != nil {
    185 			*v, err = d.DecodeInt()
    186 			return err
    187 		}
    188 	case *int8:
    189 		if v != nil {
    190 			*v, err = d.DecodeInt8()
    191 			return err
    192 		}
    193 	case *int16:
    194 		if v != nil {
    195 			*v, err = d.DecodeInt16()
    196 			return err
    197 		}
    198 	case *int32:
    199 		if v != nil {
    200 			*v, err = d.DecodeInt32()
    201 			return err
    202 		}
    203 	case *int64:
    204 		if v != nil {
    205 			*v, err = d.DecodeInt64()
    206 			return err
    207 		}
    208 	case *uint:
    209 		if v != nil {
    210 			*v, err = d.DecodeUint()
    211 			return err
    212 		}
    213 	case *uint8:
    214 		if v != nil {
    215 			*v, err = d.DecodeUint8()
    216 			return err
    217 		}
    218 	case *uint16:
    219 		if v != nil {
    220 			*v, err = d.DecodeUint16()
    221 			return err
    222 		}
    223 	case *uint32:
    224 		if v != nil {
    225 			*v, err = d.DecodeUint32()
    226 			return err
    227 		}
    228 	case *uint64:
    229 		if v != nil {
    230 			*v, err = d.DecodeUint64()
    231 			return err
    232 		}
    233 	case *bool:
    234 		if v != nil {
    235 			*v, err = d.DecodeBool()
    236 			return err
    237 		}
    238 	case *float32:
    239 		if v != nil {
    240 			*v, err = d.DecodeFloat32()
    241 			return err
    242 		}
    243 	case *float64:
    244 		if v != nil {
    245 			*v, err = d.DecodeFloat64()
    246 			return err
    247 		}
    248 	case *[]string:
    249 		return d.decodeStringSlicePtr(v)
    250 	case *map[string]string:
    251 		return d.decodeMapStringStringPtr(v)
    252 	case *map[string]interface{}:
    253 		return d.decodeMapStringInterfacePtr(v)
    254 	case *time.Duration:
    255 		if v != nil {
    256 			vv, err := d.DecodeInt64()
    257 			*v = time.Duration(vv)
    258 			return err
    259 		}
    260 	case *time.Time:
    261 		if v != nil {
    262 			*v, err = d.DecodeTime()
    263 			return err
    264 		}
    265 	}
    266 
    267 	vv := reflect.ValueOf(v)
    268 	if !vv.IsValid() {
    269 		return errors.New("msgpack: Decode(nil)")
    270 	}
    271 	if vv.Kind() != reflect.Ptr {
    272 		return fmt.Errorf("msgpack: Decode(non-pointer %T)", v)
    273 	}
    274 	if vv.IsNil() {
    275 		return fmt.Errorf("msgpack: Decode(non-settable %T)", v)
    276 	}
    277 
    278 	vv = vv.Elem()
    279 	if vv.Kind() == reflect.Interface {
    280 		if !vv.IsNil() {
    281 			vv = vv.Elem()
    282 			if vv.Kind() != reflect.Ptr {
    283 				return fmt.Errorf("msgpack: Decode(non-pointer %s)", vv.Type().String())
    284 			}
    285 		}
    286 	}
    287 
    288 	return d.DecodeValue(vv)
    289 }
    290 
    291 func (d *Decoder) DecodeMulti(v ...interface{}) error {
    292 	for _, vv := range v {
    293 		if err := d.Decode(vv); err != nil {
    294 			return err
    295 		}
    296 	}
    297 	return nil
    298 }
    299 
    300 func (d *Decoder) decodeInterfaceCond() (interface{}, error) {
    301 	if d.flags&looseInterfaceDecodingFlag != 0 {
    302 		return d.DecodeInterfaceLoose()
    303 	}
    304 	return d.DecodeInterface()
    305 }
    306 
    307 func (d *Decoder) DecodeValue(v reflect.Value) error {
    308 	decode := getDecoder(v.Type())
    309 	return decode(d, v)
    310 }
    311 
    312 func (d *Decoder) DecodeNil() error {
    313 	c, err := d.readCode()
    314 	if err != nil {
    315 		return err
    316 	}
    317 	if c != msgpcode.Nil {
    318 		return fmt.Errorf("msgpack: invalid code=%x decoding nil", c)
    319 	}
    320 	return nil
    321 }
    322 
    323 func (d *Decoder) decodeNilValue(v reflect.Value) error {
    324 	err := d.DecodeNil()
    325 	if v.IsNil() {
    326 		return err
    327 	}
    328 	if v.Kind() == reflect.Ptr {
    329 		v = v.Elem()
    330 	}
    331 	v.Set(reflect.Zero(v.Type()))
    332 	return err
    333 }
    334 
    335 func (d *Decoder) DecodeBool() (bool, error) {
    336 	c, err := d.readCode()
    337 	if err != nil {
    338 		return false, err
    339 	}
    340 	return d.bool(c)
    341 }
    342 
    343 func (d *Decoder) bool(c byte) (bool, error) {
    344 	if c == msgpcode.Nil {
    345 		return false, nil
    346 	}
    347 	if c == msgpcode.False {
    348 		return false, nil
    349 	}
    350 	if c == msgpcode.True {
    351 		return true, nil
    352 	}
    353 	return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c)
    354 }
    355 
    356 func (d *Decoder) DecodeDuration() (time.Duration, error) {
    357 	n, err := d.DecodeInt64()
    358 	if err != nil {
    359 		return 0, err
    360 	}
    361 	return time.Duration(n), nil
    362 }
    363 
    364 // DecodeInterface decodes value into interface. It returns following types:
    365 //   - nil,
    366 //   - bool,
    367 //   - int8, int16, int32, int64,
    368 //   - uint8, uint16, uint32, uint64,
    369 //   - float32 and float64,
    370 //   - string,
    371 //   - []byte,
    372 //   - slices of any of the above,
    373 //   - maps of any of the above.
    374 //
    375 // DecodeInterface should be used only when you don't know the type of value
    376 // you are decoding. For example, if you are decoding number it is better to use
    377 // DecodeInt64 for negative numbers and DecodeUint64 for positive numbers.
    378 func (d *Decoder) DecodeInterface() (interface{}, error) {
    379 	c, err := d.readCode()
    380 	if err != nil {
    381 		return nil, err
    382 	}
    383 
    384 	if msgpcode.IsFixedNum(c) {
    385 		return int8(c), nil
    386 	}
    387 	if msgpcode.IsFixedMap(c) {
    388 		err = d.s.UnreadByte()
    389 		if err != nil {
    390 			return nil, err
    391 		}
    392 		return d.decodeMapDefault()
    393 	}
    394 	if msgpcode.IsFixedArray(c) {
    395 		return d.decodeSlice(c)
    396 	}
    397 	if msgpcode.IsFixedString(c) {
    398 		return d.string(c)
    399 	}
    400 
    401 	switch c {
    402 	case msgpcode.Nil:
    403 		return nil, nil
    404 	case msgpcode.False, msgpcode.True:
    405 		return d.bool(c)
    406 	case msgpcode.Float:
    407 		return d.float32(c)
    408 	case msgpcode.Double:
    409 		return d.float64(c)
    410 	case msgpcode.Uint8:
    411 		return d.uint8()
    412 	case msgpcode.Uint16:
    413 		return d.uint16()
    414 	case msgpcode.Uint32:
    415 		return d.uint32()
    416 	case msgpcode.Uint64:
    417 		return d.uint64()
    418 	case msgpcode.Int8:
    419 		return d.int8()
    420 	case msgpcode.Int16:
    421 		return d.int16()
    422 	case msgpcode.Int32:
    423 		return d.int32()
    424 	case msgpcode.Int64:
    425 		return d.int64()
    426 	case msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
    427 		return d.bytes(c, nil)
    428 	case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32:
    429 		return d.string(c)
    430 	case msgpcode.Array16, msgpcode.Array32:
    431 		return d.decodeSlice(c)
    432 	case msgpcode.Map16, msgpcode.Map32:
    433 		err = d.s.UnreadByte()
    434 		if err != nil {
    435 			return nil, err
    436 		}
    437 		return d.decodeMapDefault()
    438 	case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
    439 		msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
    440 		return d.decodeInterfaceExt(c)
    441 	}
    442 
    443 	return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
    444 }
    445 
    446 // DecodeInterfaceLoose is like DecodeInterface except that:
    447 //   - int8, int16, and int32 are converted to int64,
    448 //   - uint8, uint16, and uint32 are converted to uint64,
    449 //   - float32 is converted to float64.
    450 //   - []byte is converted to string.
    451 func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) {
    452 	c, err := d.readCode()
    453 	if err != nil {
    454 		return nil, err
    455 	}
    456 
    457 	if msgpcode.IsFixedNum(c) {
    458 		return int64(int8(c)), nil
    459 	}
    460 	if msgpcode.IsFixedMap(c) {
    461 		err = d.s.UnreadByte()
    462 		if err != nil {
    463 			return nil, err
    464 		}
    465 		return d.decodeMapDefault()
    466 	}
    467 	if msgpcode.IsFixedArray(c) {
    468 		return d.decodeSlice(c)
    469 	}
    470 	if msgpcode.IsFixedString(c) {
    471 		return d.string(c)
    472 	}
    473 
    474 	switch c {
    475 	case msgpcode.Nil:
    476 		return nil, nil
    477 	case msgpcode.False, msgpcode.True:
    478 		return d.bool(c)
    479 	case msgpcode.Float, msgpcode.Double:
    480 		return d.float64(c)
    481 	case msgpcode.Uint8, msgpcode.Uint16, msgpcode.Uint32, msgpcode.Uint64:
    482 		return d.uint(c)
    483 	case msgpcode.Int8, msgpcode.Int16, msgpcode.Int32, msgpcode.Int64:
    484 		return d.int(c)
    485 	case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32,
    486 		msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
    487 		return d.string(c)
    488 	case msgpcode.Array16, msgpcode.Array32:
    489 		return d.decodeSlice(c)
    490 	case msgpcode.Map16, msgpcode.Map32:
    491 		err = d.s.UnreadByte()
    492 		if err != nil {
    493 			return nil, err
    494 		}
    495 		return d.decodeMapDefault()
    496 	case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
    497 		msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
    498 		return d.decodeInterfaceExt(c)
    499 	}
    500 
    501 	return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
    502 }
    503 
    504 // Skip skips next value.
    505 func (d *Decoder) Skip() error {
    506 	c, err := d.readCode()
    507 	if err != nil {
    508 		return err
    509 	}
    510 
    511 	if msgpcode.IsFixedNum(c) {
    512 		return nil
    513 	}
    514 	if msgpcode.IsFixedMap(c) {
    515 		return d.skipMap(c)
    516 	}
    517 	if msgpcode.IsFixedArray(c) {
    518 		return d.skipSlice(c)
    519 	}
    520 	if msgpcode.IsFixedString(c) {
    521 		return d.skipBytes(c)
    522 	}
    523 
    524 	switch c {
    525 	case msgpcode.Nil, msgpcode.False, msgpcode.True:
    526 		return nil
    527 	case msgpcode.Uint8, msgpcode.Int8:
    528 		return d.skipN(1)
    529 	case msgpcode.Uint16, msgpcode.Int16:
    530 		return d.skipN(2)
    531 	case msgpcode.Uint32, msgpcode.Int32, msgpcode.Float:
    532 		return d.skipN(4)
    533 	case msgpcode.Uint64, msgpcode.Int64, msgpcode.Double:
    534 		return d.skipN(8)
    535 	case msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
    536 		return d.skipBytes(c)
    537 	case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32:
    538 		return d.skipBytes(c)
    539 	case msgpcode.Array16, msgpcode.Array32:
    540 		return d.skipSlice(c)
    541 	case msgpcode.Map16, msgpcode.Map32:
    542 		return d.skipMap(c)
    543 	case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
    544 		msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
    545 		return d.skipExt(c)
    546 	}
    547 
    548 	return fmt.Errorf("msgpack: unknown code %x", c)
    549 }
    550 
    551 func (d *Decoder) DecodeRaw() (RawMessage, error) {
    552 	d.rec = make([]byte, 0)
    553 	if err := d.Skip(); err != nil {
    554 		return nil, err
    555 	}
    556 	msg := RawMessage(d.rec)
    557 	d.rec = nil
    558 	return msg, nil
    559 }
    560 
    561 // PeekCode returns the next MessagePack code without advancing the reader.
    562 // Subpackage msgpack/codes defines the list of available msgpcode.
    563 func (d *Decoder) PeekCode() (byte, error) {
    564 	c, err := d.s.ReadByte()
    565 	if err != nil {
    566 		return 0, err
    567 	}
    568 	return c, d.s.UnreadByte()
    569 }
    570 
    571 // ReadFull reads exactly len(buf) bytes into the buf.
    572 func (d *Decoder) ReadFull(buf []byte) error {
    573 	_, err := readN(d.r, buf, len(buf))
    574 	return err
    575 }
    576 
    577 func (d *Decoder) hasNilCode() bool {
    578 	code, err := d.PeekCode()
    579 	return err == nil && code == msgpcode.Nil
    580 }
    581 
    582 func (d *Decoder) readCode() (byte, error) {
    583 	c, err := d.s.ReadByte()
    584 	if err != nil {
    585 		return 0, err
    586 	}
    587 	if d.rec != nil {
    588 		d.rec = append(d.rec, c)
    589 	}
    590 	return c, nil
    591 }
    592 
    593 func (d *Decoder) readFull(b []byte) error {
    594 	_, err := io.ReadFull(d.r, b)
    595 	if err != nil {
    596 		return err
    597 	}
    598 	if d.rec != nil {
    599 		d.rec = append(d.rec, b...)
    600 	}
    601 	return nil
    602 }
    603 
    604 func (d *Decoder) readN(n int) ([]byte, error) {
    605 	var err error
    606 	d.buf, err = readN(d.r, d.buf, n)
    607 	if err != nil {
    608 		return nil, err
    609 	}
    610 	if d.rec != nil {
    611 		// TODO: read directly into d.rec?
    612 		d.rec = append(d.rec, d.buf...)
    613 	}
    614 	return d.buf, nil
    615 }
    616 
    617 func readN(r io.Reader, b []byte, n int) ([]byte, error) {
    618 	if b == nil {
    619 		if n == 0 {
    620 			return make([]byte, 0), nil
    621 		}
    622 		switch {
    623 		case n < 64:
    624 			b = make([]byte, 0, 64)
    625 		case n <= bytesAllocLimit:
    626 			b = make([]byte, 0, n)
    627 		default:
    628 			b = make([]byte, 0, bytesAllocLimit)
    629 		}
    630 	}
    631 
    632 	if n <= cap(b) {
    633 		b = b[:n]
    634 		_, err := io.ReadFull(r, b)
    635 		return b, err
    636 	}
    637 	b = b[:cap(b)]
    638 
    639 	var pos int
    640 	for {
    641 		alloc := min(n-len(b), bytesAllocLimit)
    642 		b = append(b, make([]byte, alloc)...)
    643 
    644 		_, err := io.ReadFull(r, b[pos:])
    645 		if err != nil {
    646 			return b, err
    647 		}
    648 
    649 		if len(b) == n {
    650 			break
    651 		}
    652 		pos = len(b)
    653 	}
    654 
    655 	return b, nil
    656 }
    657 
    658 func min(a, b int) int { //nolint:unparam
    659 	if a <= b {
    660 		return a
    661 	}
    662 	return b
    663 }