gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

decode_map.go (6337B)


      1 package msgpack
      2 
      3 import (
      4 	"errors"
      5 	"fmt"
      6 	"reflect"
      7 
      8 	"github.com/vmihailenco/msgpack/v5/msgpcode"
      9 )
     10 
     11 var errArrayStruct = errors.New("msgpack: number of fields in array-encoded struct has changed")
     12 
     13 var (
     14 	mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil))
     15 	mapStringStringType    = mapStringStringPtrType.Elem()
     16 )
     17 
     18 var (
     19 	mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil))
     20 	mapStringInterfaceType    = mapStringInterfacePtrType.Elem()
     21 )
     22 
     23 func decodeMapValue(d *Decoder, v reflect.Value) error {
     24 	n, err := d.DecodeMapLen()
     25 	if err != nil {
     26 		return err
     27 	}
     28 
     29 	typ := v.Type()
     30 	if n == -1 {
     31 		v.Set(reflect.Zero(typ))
     32 		return nil
     33 	}
     34 
     35 	if v.IsNil() {
     36 		v.Set(reflect.MakeMap(typ))
     37 	}
     38 	if n == 0 {
     39 		return nil
     40 	}
     41 
     42 	return d.decodeTypedMapValue(v, n)
     43 }
     44 
     45 func (d *Decoder) decodeMapDefault() (interface{}, error) {
     46 	if d.mapDecoder != nil {
     47 		return d.mapDecoder(d)
     48 	}
     49 	return d.DecodeMap()
     50 }
     51 
     52 // DecodeMapLen decodes map length. Length is -1 when map is nil.
     53 func (d *Decoder) DecodeMapLen() (int, error) {
     54 	c, err := d.readCode()
     55 	if err != nil {
     56 		return 0, err
     57 	}
     58 
     59 	if msgpcode.IsExt(c) {
     60 		if err = d.skipExtHeader(c); err != nil {
     61 			return 0, err
     62 		}
     63 
     64 		c, err = d.readCode()
     65 		if err != nil {
     66 			return 0, err
     67 		}
     68 	}
     69 	return d.mapLen(c)
     70 }
     71 
     72 func (d *Decoder) mapLen(c byte) (int, error) {
     73 	if c == msgpcode.Nil {
     74 		return -1, nil
     75 	}
     76 	if c >= msgpcode.FixedMapLow && c <= msgpcode.FixedMapHigh {
     77 		return int(c & msgpcode.FixedMapMask), nil
     78 	}
     79 	if c == msgpcode.Map16 {
     80 		size, err := d.uint16()
     81 		return int(size), err
     82 	}
     83 	if c == msgpcode.Map32 {
     84 		size, err := d.uint32()
     85 		return int(size), err
     86 	}
     87 	return 0, unexpectedCodeError{code: c, hint: "map length"}
     88 }
     89 
     90 func decodeMapStringStringValue(d *Decoder, v reflect.Value) error {
     91 	mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string)
     92 	return d.decodeMapStringStringPtr(mptr)
     93 }
     94 
     95 func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error {
     96 	size, err := d.DecodeMapLen()
     97 	if err != nil {
     98 		return err
     99 	}
    100 	if size == -1 {
    101 		*ptr = nil
    102 		return nil
    103 	}
    104 
    105 	m := *ptr
    106 	if m == nil {
    107 		*ptr = make(map[string]string, min(size, maxMapSize))
    108 		m = *ptr
    109 	}
    110 
    111 	for i := 0; i < size; i++ {
    112 		mk, err := d.DecodeString()
    113 		if err != nil {
    114 			return err
    115 		}
    116 		mv, err := d.DecodeString()
    117 		if err != nil {
    118 			return err
    119 		}
    120 		m[mk] = mv
    121 	}
    122 
    123 	return nil
    124 }
    125 
    126 func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error {
    127 	ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{})
    128 	return d.decodeMapStringInterfacePtr(ptr)
    129 }
    130 
    131 func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error {
    132 	m, err := d.DecodeMap()
    133 	if err != nil {
    134 		return err
    135 	}
    136 	*ptr = m
    137 	return nil
    138 }
    139 
    140 func (d *Decoder) DecodeMap() (map[string]interface{}, error) {
    141 	n, err := d.DecodeMapLen()
    142 	if err != nil {
    143 		return nil, err
    144 	}
    145 
    146 	if n == -1 {
    147 		return nil, nil
    148 	}
    149 
    150 	m := make(map[string]interface{}, min(n, maxMapSize))
    151 
    152 	for i := 0; i < n; i++ {
    153 		mk, err := d.DecodeString()
    154 		if err != nil {
    155 			return nil, err
    156 		}
    157 		mv, err := d.decodeInterfaceCond()
    158 		if err != nil {
    159 			return nil, err
    160 		}
    161 		m[mk] = mv
    162 	}
    163 
    164 	return m, nil
    165 }
    166 
    167 func (d *Decoder) DecodeUntypedMap() (map[interface{}]interface{}, error) {
    168 	n, err := d.DecodeMapLen()
    169 	if err != nil {
    170 		return nil, err
    171 	}
    172 
    173 	if n == -1 {
    174 		return nil, nil
    175 	}
    176 
    177 	m := make(map[interface{}]interface{}, min(n, maxMapSize))
    178 
    179 	for i := 0; i < n; i++ {
    180 		mk, err := d.decodeInterfaceCond()
    181 		if err != nil {
    182 			return nil, err
    183 		}
    184 
    185 		mv, err := d.decodeInterfaceCond()
    186 		if err != nil {
    187 			return nil, err
    188 		}
    189 
    190 		m[mk] = mv
    191 	}
    192 
    193 	return m, nil
    194 }
    195 
    196 // DecodeTypedMap decodes a typed map. Typed map is a map that has a fixed type for keys and values.
    197 // Key and value types may be different.
    198 func (d *Decoder) DecodeTypedMap() (interface{}, error) {
    199 	n, err := d.DecodeMapLen()
    200 	if err != nil {
    201 		return nil, err
    202 	}
    203 	if n <= 0 {
    204 		return nil, nil
    205 	}
    206 
    207 	key, err := d.decodeInterfaceCond()
    208 	if err != nil {
    209 		return nil, err
    210 	}
    211 
    212 	value, err := d.decodeInterfaceCond()
    213 	if err != nil {
    214 		return nil, err
    215 	}
    216 
    217 	keyType := reflect.TypeOf(key)
    218 	valueType := reflect.TypeOf(value)
    219 
    220 	if !keyType.Comparable() {
    221 		return nil, fmt.Errorf("msgpack: unsupported map key: %s", keyType.String())
    222 	}
    223 
    224 	mapType := reflect.MapOf(keyType, valueType)
    225 	mapValue := reflect.MakeMap(mapType)
    226 	mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value))
    227 
    228 	n--
    229 	if err := d.decodeTypedMapValue(mapValue, n); err != nil {
    230 		return nil, err
    231 	}
    232 
    233 	return mapValue.Interface(), nil
    234 }
    235 
    236 func (d *Decoder) decodeTypedMapValue(v reflect.Value, n int) error {
    237 	typ := v.Type()
    238 	keyType := typ.Key()
    239 	valueType := typ.Elem()
    240 
    241 	for i := 0; i < n; i++ {
    242 		mk := reflect.New(keyType).Elem()
    243 		if err := d.DecodeValue(mk); err != nil {
    244 			return err
    245 		}
    246 
    247 		mv := reflect.New(valueType).Elem()
    248 		if err := d.DecodeValue(mv); err != nil {
    249 			return err
    250 		}
    251 
    252 		v.SetMapIndex(mk, mv)
    253 	}
    254 
    255 	return nil
    256 }
    257 
    258 func (d *Decoder) skipMap(c byte) error {
    259 	n, err := d.mapLen(c)
    260 	if err != nil {
    261 		return err
    262 	}
    263 	for i := 0; i < n; i++ {
    264 		if err := d.Skip(); err != nil {
    265 			return err
    266 		}
    267 		if err := d.Skip(); err != nil {
    268 			return err
    269 		}
    270 	}
    271 	return nil
    272 }
    273 
    274 func decodeStructValue(d *Decoder, v reflect.Value) error {
    275 	c, err := d.readCode()
    276 	if err != nil {
    277 		return err
    278 	}
    279 
    280 	n, err := d.mapLen(c)
    281 	if err == nil {
    282 		return d.decodeStruct(v, n)
    283 	}
    284 
    285 	var err2 error
    286 	n, err2 = d.arrayLen(c)
    287 	if err2 != nil {
    288 		return err
    289 	}
    290 
    291 	if n <= 0 {
    292 		v.Set(reflect.Zero(v.Type()))
    293 		return nil
    294 	}
    295 
    296 	fields := structs.Fields(v.Type(), d.structTag)
    297 	if n != len(fields.List) {
    298 		return errArrayStruct
    299 	}
    300 
    301 	for _, f := range fields.List {
    302 		if err := f.DecodeValue(d, v); err != nil {
    303 			return err
    304 		}
    305 	}
    306 
    307 	return nil
    308 }
    309 
    310 func (d *Decoder) decodeStruct(v reflect.Value, n int) error {
    311 	if n == -1 {
    312 		v.Set(reflect.Zero(v.Type()))
    313 		return nil
    314 	}
    315 
    316 	fields := structs.Fields(v.Type(), d.structTag)
    317 	for i := 0; i < n; i++ {
    318 		name, err := d.decodeStringTemp()
    319 		if err != nil {
    320 			return err
    321 		}
    322 
    323 		if f := fields.Map[name]; f != nil {
    324 			if err := f.DecodeValue(d, v); err != nil {
    325 				return err
    326 			}
    327 			continue
    328 		}
    329 
    330 		if d.flags&disallowUnknownFieldsFlag != 0 {
    331 			return fmt.Errorf("msgpack: unknown field %q", name)
    332 		}
    333 		if err := d.Skip(); err != nil {
    334 			return err
    335 		}
    336 	}
    337 
    338 	return nil
    339 }