codec_map.go (10423B)
1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package impl 6 7 import ( 8 "reflect" 9 "sort" 10 11 "google.golang.org/protobuf/encoding/protowire" 12 "google.golang.org/protobuf/internal/genid" 13 "google.golang.org/protobuf/reflect/protoreflect" 14 ) 15 16 type mapInfo struct { 17 goType reflect.Type 18 keyWiretag uint64 19 valWiretag uint64 20 keyFuncs valueCoderFuncs 21 valFuncs valueCoderFuncs 22 keyZero protoreflect.Value 23 keyKind protoreflect.Kind 24 conv *mapConverter 25 } 26 27 func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) { 28 // TODO: Consider generating specialized map coders. 29 keyField := fd.MapKey() 30 valField := fd.MapValue() 31 keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()]) 32 valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()]) 33 keyFuncs := encoderFuncsForValue(keyField) 34 valFuncs := encoderFuncsForValue(valField) 35 conv := newMapConverter(ft, fd) 36 37 mapi := &mapInfo{ 38 goType: ft, 39 keyWiretag: keyWiretag, 40 valWiretag: valWiretag, 41 keyFuncs: keyFuncs, 42 valFuncs: valFuncs, 43 keyZero: keyField.Default(), 44 keyKind: keyField.Kind(), 45 conv: conv, 46 } 47 if valField.Kind() == protoreflect.MessageKind { 48 valueMessage = getMessageInfo(ft.Elem()) 49 } 50 51 funcs = pointerCoderFuncs{ 52 size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int { 53 return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts) 54 }, 55 marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { 56 return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts) 57 }, 58 unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) { 59 mp := p.AsValueOf(ft) 60 if mp.Elem().IsNil() { 61 mp.Elem().Set(reflect.MakeMap(mapi.goType)) 62 } 63 if f.mi == nil { 64 return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts) 65 } else { 66 return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts) 67 } 68 }, 69 } 70 switch valField.Kind() { 71 case protoreflect.MessageKind: 72 funcs.merge = mergeMapOfMessage 73 case protoreflect.BytesKind: 74 funcs.merge = mergeMapOfBytes 75 default: 76 funcs.merge = mergeMap 77 } 78 if valFuncs.isInit != nil { 79 funcs.isInit = func(p pointer, f *coderFieldInfo) error { 80 return isInitMap(p.AsValueOf(ft).Elem(), mapi, f) 81 } 82 } 83 return valueMessage, funcs 84 } 85 86 const ( 87 mapKeyTagSize = 1 // field 1, tag size 1. 88 mapValTagSize = 1 // field 2, tag size 2. 89 ) 90 91 func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int { 92 if mapv.Len() == 0 { 93 return 0 94 } 95 n := 0 96 iter := mapRange(mapv) 97 for iter.Next() { 98 key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey() 99 keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) 100 var valSize int 101 value := mapi.conv.valConv.PBValueOf(iter.Value()) 102 if f.mi == nil { 103 valSize = mapi.valFuncs.size(value, mapValTagSize, opts) 104 } else { 105 p := pointerOfValue(iter.Value()) 106 valSize += mapValTagSize 107 valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts)) 108 } 109 n += f.tagsize + protowire.SizeBytes(keySize+valSize) 110 } 111 return n 112 } 113 114 func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { 115 if wtyp != protowire.BytesType { 116 return out, errUnknown 117 } 118 b, n := protowire.ConsumeBytes(b) 119 if n < 0 { 120 return out, errDecode 121 } 122 var ( 123 key = mapi.keyZero 124 val = mapi.conv.valConv.New() 125 ) 126 for len(b) > 0 { 127 num, wtyp, n := protowire.ConsumeTag(b) 128 if n < 0 { 129 return out, errDecode 130 } 131 if num > protowire.MaxValidNumber { 132 return out, errDecode 133 } 134 b = b[n:] 135 err := errUnknown 136 switch num { 137 case genid.MapEntry_Key_field_number: 138 var v protoreflect.Value 139 var o unmarshalOutput 140 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts) 141 if err != nil { 142 break 143 } 144 key = v 145 n = o.n 146 case genid.MapEntry_Value_field_number: 147 var v protoreflect.Value 148 var o unmarshalOutput 149 v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts) 150 if err != nil { 151 break 152 } 153 val = v 154 n = o.n 155 } 156 if err == errUnknown { 157 n = protowire.ConsumeFieldValue(num, wtyp, b) 158 if n < 0 { 159 return out, errDecode 160 } 161 } else if err != nil { 162 return out, err 163 } 164 b = b[n:] 165 } 166 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val)) 167 out.n = n 168 return out, nil 169 } 170 171 func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) { 172 if wtyp != protowire.BytesType { 173 return out, errUnknown 174 } 175 b, n := protowire.ConsumeBytes(b) 176 if n < 0 { 177 return out, errDecode 178 } 179 var ( 180 key = mapi.keyZero 181 val = reflect.New(f.mi.GoReflectType.Elem()) 182 ) 183 for len(b) > 0 { 184 num, wtyp, n := protowire.ConsumeTag(b) 185 if n < 0 { 186 return out, errDecode 187 } 188 if num > protowire.MaxValidNumber { 189 return out, errDecode 190 } 191 b = b[n:] 192 err := errUnknown 193 switch num { 194 case 1: 195 var v protoreflect.Value 196 var o unmarshalOutput 197 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts) 198 if err != nil { 199 break 200 } 201 key = v 202 n = o.n 203 case 2: 204 if wtyp != protowire.BytesType { 205 break 206 } 207 var v []byte 208 v, n = protowire.ConsumeBytes(b) 209 if n < 0 { 210 return out, errDecode 211 } 212 var o unmarshalOutput 213 o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts) 214 if o.initialized { 215 // Consider this map item initialized so long as we see 216 // an initialized value. 217 out.initialized = true 218 } 219 } 220 if err == errUnknown { 221 n = protowire.ConsumeFieldValue(num, wtyp, b) 222 if n < 0 { 223 return out, errDecode 224 } 225 } else if err != nil { 226 return out, err 227 } 228 b = b[n:] 229 } 230 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val) 231 out.n = n 232 return out, nil 233 } 234 235 func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { 236 if f.mi == nil { 237 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey() 238 val := mapi.conv.valConv.PBValueOf(valrv) 239 size := 0 240 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) 241 size += mapi.valFuncs.size(val, mapValTagSize, opts) 242 b = protowire.AppendVarint(b, uint64(size)) 243 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts) 244 if err != nil { 245 return nil, err 246 } 247 return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts) 248 } else { 249 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey() 250 val := pointerOfValue(valrv) 251 valSize := f.mi.sizePointer(val, opts) 252 size := 0 253 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts) 254 size += mapValTagSize + protowire.SizeBytes(valSize) 255 b = protowire.AppendVarint(b, uint64(size)) 256 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts) 257 if err != nil { 258 return nil, err 259 } 260 b = protowire.AppendVarint(b, mapi.valWiretag) 261 b = protowire.AppendVarint(b, uint64(valSize)) 262 return f.mi.marshalAppendPointer(b, val, opts) 263 } 264 } 265 266 func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { 267 if mapv.Len() == 0 { 268 return b, nil 269 } 270 if opts.Deterministic() { 271 return appendMapDeterministic(b, mapv, mapi, f, opts) 272 } 273 iter := mapRange(mapv) 274 for iter.Next() { 275 var err error 276 b = protowire.AppendVarint(b, f.wiretag) 277 b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts) 278 if err != nil { 279 return b, err 280 } 281 } 282 return b, nil 283 } 284 285 func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) { 286 keys := mapv.MapKeys() 287 sort.Slice(keys, func(i, j int) bool { 288 switch keys[i].Kind() { 289 case reflect.Bool: 290 return !keys[i].Bool() && keys[j].Bool() 291 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 292 return keys[i].Int() < keys[j].Int() 293 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: 294 return keys[i].Uint() < keys[j].Uint() 295 case reflect.Float32, reflect.Float64: 296 return keys[i].Float() < keys[j].Float() 297 case reflect.String: 298 return keys[i].String() < keys[j].String() 299 default: 300 panic("invalid kind: " + keys[i].Kind().String()) 301 } 302 }) 303 for _, key := range keys { 304 var err error 305 b = protowire.AppendVarint(b, f.wiretag) 306 b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts) 307 if err != nil { 308 return b, err 309 } 310 } 311 return b, nil 312 } 313 314 func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error { 315 if mi := f.mi; mi != nil { 316 mi.init() 317 if !mi.needsInitCheck { 318 return nil 319 } 320 iter := mapRange(mapv) 321 for iter.Next() { 322 val := pointerOfValue(iter.Value()) 323 if err := mi.checkInitializedPointer(val); err != nil { 324 return err 325 } 326 } 327 } else { 328 iter := mapRange(mapv) 329 for iter.Next() { 330 val := mapi.conv.valConv.PBValueOf(iter.Value()) 331 if err := mapi.valFuncs.isInit(val); err != nil { 332 return err 333 } 334 } 335 } 336 return nil 337 } 338 339 func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { 340 dstm := dst.AsValueOf(f.ft).Elem() 341 srcm := src.AsValueOf(f.ft).Elem() 342 if srcm.Len() == 0 { 343 return 344 } 345 if dstm.IsNil() { 346 dstm.Set(reflect.MakeMap(f.ft)) 347 } 348 iter := mapRange(srcm) 349 for iter.Next() { 350 dstm.SetMapIndex(iter.Key(), iter.Value()) 351 } 352 } 353 354 func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { 355 dstm := dst.AsValueOf(f.ft).Elem() 356 srcm := src.AsValueOf(f.ft).Elem() 357 if srcm.Len() == 0 { 358 return 359 } 360 if dstm.IsNil() { 361 dstm.Set(reflect.MakeMap(f.ft)) 362 } 363 iter := mapRange(srcm) 364 for iter.Next() { 365 dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...))) 366 } 367 } 368 369 func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) { 370 dstm := dst.AsValueOf(f.ft).Elem() 371 srcm := src.AsValueOf(f.ft).Elem() 372 if srcm.Len() == 0 { 373 return 374 } 375 if dstm.IsNil() { 376 dstm.Set(reflect.MakeMap(f.ft)) 377 } 378 iter := mapRange(srcm) 379 for iter.Next() { 380 val := reflect.New(f.ft.Elem().Elem()) 381 if f.mi != nil { 382 f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts) 383 } else { 384 opts.Merge(asMessage(val), asMessage(iter.Value())) 385 } 386 dstm.SetMapIndex(iter.Key(), val) 387 } 388 }