decode_map.go (6337B)
1 package msgpack 2 3 import ( 4 "errors" 5 "fmt" 6 "reflect" 7 8 "github.com/vmihailenco/msgpack/v5/msgpcode" 9 ) 10 11 var errArrayStruct = errors.New("msgpack: number of fields in array-encoded struct has changed") 12 13 var ( 14 mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil)) 15 mapStringStringType = mapStringStringPtrType.Elem() 16 ) 17 18 var ( 19 mapStringInterfacePtrType = reflect.TypeOf((*map[string]interface{})(nil)) 20 mapStringInterfaceType = mapStringInterfacePtrType.Elem() 21 ) 22 23 func decodeMapValue(d *Decoder, v reflect.Value) error { 24 n, err := d.DecodeMapLen() 25 if err != nil { 26 return err 27 } 28 29 typ := v.Type() 30 if n == -1 { 31 v.Set(reflect.Zero(typ)) 32 return nil 33 } 34 35 if v.IsNil() { 36 v.Set(reflect.MakeMap(typ)) 37 } 38 if n == 0 { 39 return nil 40 } 41 42 return d.decodeTypedMapValue(v, n) 43 } 44 45 func (d *Decoder) decodeMapDefault() (interface{}, error) { 46 if d.mapDecoder != nil { 47 return d.mapDecoder(d) 48 } 49 return d.DecodeMap() 50 } 51 52 // DecodeMapLen decodes map length. Length is -1 when map is nil. 53 func (d *Decoder) DecodeMapLen() (int, error) { 54 c, err := d.readCode() 55 if err != nil { 56 return 0, err 57 } 58 59 if msgpcode.IsExt(c) { 60 if err = d.skipExtHeader(c); err != nil { 61 return 0, err 62 } 63 64 c, err = d.readCode() 65 if err != nil { 66 return 0, err 67 } 68 } 69 return d.mapLen(c) 70 } 71 72 func (d *Decoder) mapLen(c byte) (int, error) { 73 if c == msgpcode.Nil { 74 return -1, nil 75 } 76 if c >= msgpcode.FixedMapLow && c <= msgpcode.FixedMapHigh { 77 return int(c & msgpcode.FixedMapMask), nil 78 } 79 if c == msgpcode.Map16 { 80 size, err := d.uint16() 81 return int(size), err 82 } 83 if c == msgpcode.Map32 { 84 size, err := d.uint32() 85 return int(size), err 86 } 87 return 0, unexpectedCodeError{code: c, hint: "map length"} 88 } 89 90 func decodeMapStringStringValue(d *Decoder, v reflect.Value) error { 91 mptr := v.Addr().Convert(mapStringStringPtrType).Interface().(*map[string]string) 92 return d.decodeMapStringStringPtr(mptr) 93 } 94 95 func (d *Decoder) decodeMapStringStringPtr(ptr *map[string]string) error { 96 size, err := d.DecodeMapLen() 97 if err != nil { 98 return err 99 } 100 if size == -1 { 101 *ptr = nil 102 return nil 103 } 104 105 m := *ptr 106 if m == nil { 107 *ptr = make(map[string]string, min(size, maxMapSize)) 108 m = *ptr 109 } 110 111 for i := 0; i < size; i++ { 112 mk, err := d.DecodeString() 113 if err != nil { 114 return err 115 } 116 mv, err := d.DecodeString() 117 if err != nil { 118 return err 119 } 120 m[mk] = mv 121 } 122 123 return nil 124 } 125 126 func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error { 127 ptr := v.Addr().Convert(mapStringInterfacePtrType).Interface().(*map[string]interface{}) 128 return d.decodeMapStringInterfacePtr(ptr) 129 } 130 131 func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error { 132 m, err := d.DecodeMap() 133 if err != nil { 134 return err 135 } 136 *ptr = m 137 return nil 138 } 139 140 func (d *Decoder) DecodeMap() (map[string]interface{}, error) { 141 n, err := d.DecodeMapLen() 142 if err != nil { 143 return nil, err 144 } 145 146 if n == -1 { 147 return nil, nil 148 } 149 150 m := make(map[string]interface{}, min(n, maxMapSize)) 151 152 for i := 0; i < n; i++ { 153 mk, err := d.DecodeString() 154 if err != nil { 155 return nil, err 156 } 157 mv, err := d.decodeInterfaceCond() 158 if err != nil { 159 return nil, err 160 } 161 m[mk] = mv 162 } 163 164 return m, nil 165 } 166 167 func (d *Decoder) DecodeUntypedMap() (map[interface{}]interface{}, error) { 168 n, err := d.DecodeMapLen() 169 if err != nil { 170 return nil, err 171 } 172 173 if n == -1 { 174 return nil, nil 175 } 176 177 m := make(map[interface{}]interface{}, min(n, maxMapSize)) 178 179 for i := 0; i < n; i++ { 180 mk, err := d.decodeInterfaceCond() 181 if err != nil { 182 return nil, err 183 } 184 185 mv, err := d.decodeInterfaceCond() 186 if err != nil { 187 return nil, err 188 } 189 190 m[mk] = mv 191 } 192 193 return m, nil 194 } 195 196 // DecodeTypedMap decodes a typed map. Typed map is a map that has a fixed type for keys and values. 197 // Key and value types may be different. 198 func (d *Decoder) DecodeTypedMap() (interface{}, error) { 199 n, err := d.DecodeMapLen() 200 if err != nil { 201 return nil, err 202 } 203 if n <= 0 { 204 return nil, nil 205 } 206 207 key, err := d.decodeInterfaceCond() 208 if err != nil { 209 return nil, err 210 } 211 212 value, err := d.decodeInterfaceCond() 213 if err != nil { 214 return nil, err 215 } 216 217 keyType := reflect.TypeOf(key) 218 valueType := reflect.TypeOf(value) 219 220 if !keyType.Comparable() { 221 return nil, fmt.Errorf("msgpack: unsupported map key: %s", keyType.String()) 222 } 223 224 mapType := reflect.MapOf(keyType, valueType) 225 mapValue := reflect.MakeMap(mapType) 226 mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value)) 227 228 n-- 229 if err := d.decodeTypedMapValue(mapValue, n); err != nil { 230 return nil, err 231 } 232 233 return mapValue.Interface(), nil 234 } 235 236 func (d *Decoder) decodeTypedMapValue(v reflect.Value, n int) error { 237 typ := v.Type() 238 keyType := typ.Key() 239 valueType := typ.Elem() 240 241 for i := 0; i < n; i++ { 242 mk := reflect.New(keyType).Elem() 243 if err := d.DecodeValue(mk); err != nil { 244 return err 245 } 246 247 mv := reflect.New(valueType).Elem() 248 if err := d.DecodeValue(mv); err != nil { 249 return err 250 } 251 252 v.SetMapIndex(mk, mv) 253 } 254 255 return nil 256 } 257 258 func (d *Decoder) skipMap(c byte) error { 259 n, err := d.mapLen(c) 260 if err != nil { 261 return err 262 } 263 for i := 0; i < n; i++ { 264 if err := d.Skip(); err != nil { 265 return err 266 } 267 if err := d.Skip(); err != nil { 268 return err 269 } 270 } 271 return nil 272 } 273 274 func decodeStructValue(d *Decoder, v reflect.Value) error { 275 c, err := d.readCode() 276 if err != nil { 277 return err 278 } 279 280 n, err := d.mapLen(c) 281 if err == nil { 282 return d.decodeStruct(v, n) 283 } 284 285 var err2 error 286 n, err2 = d.arrayLen(c) 287 if err2 != nil { 288 return err 289 } 290 291 if n <= 0 { 292 v.Set(reflect.Zero(v.Type())) 293 return nil 294 } 295 296 fields := structs.Fields(v.Type(), d.structTag) 297 if n != len(fields.List) { 298 return errArrayStruct 299 } 300 301 for _, f := range fields.List { 302 if err := f.DecodeValue(d, v); err != nil { 303 return err 304 } 305 } 306 307 return nil 308 } 309 310 func (d *Decoder) decodeStruct(v reflect.Value, n int) error { 311 if n == -1 { 312 v.Set(reflect.Zero(v.Type())) 313 return nil 314 } 315 316 fields := structs.Fields(v.Type(), d.structTag) 317 for i := 0; i < n; i++ { 318 name, err := d.decodeStringTemp() 319 if err != nil { 320 return err 321 } 322 323 if f := fields.Map[name]; f != nil { 324 if err := f.DecodeValue(d, v); err != nil { 325 return err 326 } 327 continue 328 } 329 330 if d.flags&disallowUnknownFieldsFlag != 0 { 331 return fmt.Errorf("msgpack: unknown field %q", name) 332 } 333 if err := d.Skip(); err != nil { 334 return err 335 } 336 } 337 338 return nil 339 }