gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

codec_map.go (10423B)


      1 // Copyright 2019 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package impl
      6 
      7 import (
      8 	"reflect"
      9 	"sort"
     10 
     11 	"google.golang.org/protobuf/encoding/protowire"
     12 	"google.golang.org/protobuf/internal/genid"
     13 	"google.golang.org/protobuf/reflect/protoreflect"
     14 )
     15 
     16 type mapInfo struct {
     17 	goType     reflect.Type
     18 	keyWiretag uint64
     19 	valWiretag uint64
     20 	keyFuncs   valueCoderFuncs
     21 	valFuncs   valueCoderFuncs
     22 	keyZero    protoreflect.Value
     23 	keyKind    protoreflect.Kind
     24 	conv       *mapConverter
     25 }
     26 
     27 func encoderFuncsForMap(fd protoreflect.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
     28 	// TODO: Consider generating specialized map coders.
     29 	keyField := fd.MapKey()
     30 	valField := fd.MapValue()
     31 	keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
     32 	valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
     33 	keyFuncs := encoderFuncsForValue(keyField)
     34 	valFuncs := encoderFuncsForValue(valField)
     35 	conv := newMapConverter(ft, fd)
     36 
     37 	mapi := &mapInfo{
     38 		goType:     ft,
     39 		keyWiretag: keyWiretag,
     40 		valWiretag: valWiretag,
     41 		keyFuncs:   keyFuncs,
     42 		valFuncs:   valFuncs,
     43 		keyZero:    keyField.Default(),
     44 		keyKind:    keyField.Kind(),
     45 		conv:       conv,
     46 	}
     47 	if valField.Kind() == protoreflect.MessageKind {
     48 		valueMessage = getMessageInfo(ft.Elem())
     49 	}
     50 
     51 	funcs = pointerCoderFuncs{
     52 		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
     53 			return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
     54 		},
     55 		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
     56 			return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
     57 		},
     58 		unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
     59 			mp := p.AsValueOf(ft)
     60 			if mp.Elem().IsNil() {
     61 				mp.Elem().Set(reflect.MakeMap(mapi.goType))
     62 			}
     63 			if f.mi == nil {
     64 				return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
     65 			} else {
     66 				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
     67 			}
     68 		},
     69 	}
     70 	switch valField.Kind() {
     71 	case protoreflect.MessageKind:
     72 		funcs.merge = mergeMapOfMessage
     73 	case protoreflect.BytesKind:
     74 		funcs.merge = mergeMapOfBytes
     75 	default:
     76 		funcs.merge = mergeMap
     77 	}
     78 	if valFuncs.isInit != nil {
     79 		funcs.isInit = func(p pointer, f *coderFieldInfo) error {
     80 			return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
     81 		}
     82 	}
     83 	return valueMessage, funcs
     84 }
     85 
     86 const (
     87 	mapKeyTagSize = 1 // field 1, tag size 1.
     88 	mapValTagSize = 1 // field 2, tag size 2.
     89 )
     90 
     91 func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
     92 	if mapv.Len() == 0 {
     93 		return 0
     94 	}
     95 	n := 0
     96 	iter := mapRange(mapv)
     97 	for iter.Next() {
     98 		key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
     99 		keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
    100 		var valSize int
    101 		value := mapi.conv.valConv.PBValueOf(iter.Value())
    102 		if f.mi == nil {
    103 			valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
    104 		} else {
    105 			p := pointerOfValue(iter.Value())
    106 			valSize += mapValTagSize
    107 			valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
    108 		}
    109 		n += f.tagsize + protowire.SizeBytes(keySize+valSize)
    110 	}
    111 	return n
    112 }
    113 
    114 func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
    115 	if wtyp != protowire.BytesType {
    116 		return out, errUnknown
    117 	}
    118 	b, n := protowire.ConsumeBytes(b)
    119 	if n < 0 {
    120 		return out, errDecode
    121 	}
    122 	var (
    123 		key = mapi.keyZero
    124 		val = mapi.conv.valConv.New()
    125 	)
    126 	for len(b) > 0 {
    127 		num, wtyp, n := protowire.ConsumeTag(b)
    128 		if n < 0 {
    129 			return out, errDecode
    130 		}
    131 		if num > protowire.MaxValidNumber {
    132 			return out, errDecode
    133 		}
    134 		b = b[n:]
    135 		err := errUnknown
    136 		switch num {
    137 		case genid.MapEntry_Key_field_number:
    138 			var v protoreflect.Value
    139 			var o unmarshalOutput
    140 			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
    141 			if err != nil {
    142 				break
    143 			}
    144 			key = v
    145 			n = o.n
    146 		case genid.MapEntry_Value_field_number:
    147 			var v protoreflect.Value
    148 			var o unmarshalOutput
    149 			v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
    150 			if err != nil {
    151 				break
    152 			}
    153 			val = v
    154 			n = o.n
    155 		}
    156 		if err == errUnknown {
    157 			n = protowire.ConsumeFieldValue(num, wtyp, b)
    158 			if n < 0 {
    159 				return out, errDecode
    160 			}
    161 		} else if err != nil {
    162 			return out, err
    163 		}
    164 		b = b[n:]
    165 	}
    166 	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
    167 	out.n = n
    168 	return out, nil
    169 }
    170 
    171 func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
    172 	if wtyp != protowire.BytesType {
    173 		return out, errUnknown
    174 	}
    175 	b, n := protowire.ConsumeBytes(b)
    176 	if n < 0 {
    177 		return out, errDecode
    178 	}
    179 	var (
    180 		key = mapi.keyZero
    181 		val = reflect.New(f.mi.GoReflectType.Elem())
    182 	)
    183 	for len(b) > 0 {
    184 		num, wtyp, n := protowire.ConsumeTag(b)
    185 		if n < 0 {
    186 			return out, errDecode
    187 		}
    188 		if num > protowire.MaxValidNumber {
    189 			return out, errDecode
    190 		}
    191 		b = b[n:]
    192 		err := errUnknown
    193 		switch num {
    194 		case 1:
    195 			var v protoreflect.Value
    196 			var o unmarshalOutput
    197 			v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
    198 			if err != nil {
    199 				break
    200 			}
    201 			key = v
    202 			n = o.n
    203 		case 2:
    204 			if wtyp != protowire.BytesType {
    205 				break
    206 			}
    207 			var v []byte
    208 			v, n = protowire.ConsumeBytes(b)
    209 			if n < 0 {
    210 				return out, errDecode
    211 			}
    212 			var o unmarshalOutput
    213 			o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
    214 			if o.initialized {
    215 				// Consider this map item initialized so long as we see
    216 				// an initialized value.
    217 				out.initialized = true
    218 			}
    219 		}
    220 		if err == errUnknown {
    221 			n = protowire.ConsumeFieldValue(num, wtyp, b)
    222 			if n < 0 {
    223 				return out, errDecode
    224 			}
    225 		} else if err != nil {
    226 			return out, err
    227 		}
    228 		b = b[n:]
    229 	}
    230 	mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
    231 	out.n = n
    232 	return out, nil
    233 }
    234 
    235 func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
    236 	if f.mi == nil {
    237 		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
    238 		val := mapi.conv.valConv.PBValueOf(valrv)
    239 		size := 0
    240 		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
    241 		size += mapi.valFuncs.size(val, mapValTagSize, opts)
    242 		b = protowire.AppendVarint(b, uint64(size))
    243 		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
    244 		if err != nil {
    245 			return nil, err
    246 		}
    247 		return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
    248 	} else {
    249 		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
    250 		val := pointerOfValue(valrv)
    251 		valSize := f.mi.sizePointer(val, opts)
    252 		size := 0
    253 		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
    254 		size += mapValTagSize + protowire.SizeBytes(valSize)
    255 		b = protowire.AppendVarint(b, uint64(size))
    256 		b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
    257 		if err != nil {
    258 			return nil, err
    259 		}
    260 		b = protowire.AppendVarint(b, mapi.valWiretag)
    261 		b = protowire.AppendVarint(b, uint64(valSize))
    262 		return f.mi.marshalAppendPointer(b, val, opts)
    263 	}
    264 }
    265 
    266 func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
    267 	if mapv.Len() == 0 {
    268 		return b, nil
    269 	}
    270 	if opts.Deterministic() {
    271 		return appendMapDeterministic(b, mapv, mapi, f, opts)
    272 	}
    273 	iter := mapRange(mapv)
    274 	for iter.Next() {
    275 		var err error
    276 		b = protowire.AppendVarint(b, f.wiretag)
    277 		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
    278 		if err != nil {
    279 			return b, err
    280 		}
    281 	}
    282 	return b, nil
    283 }
    284 
    285 func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
    286 	keys := mapv.MapKeys()
    287 	sort.Slice(keys, func(i, j int) bool {
    288 		switch keys[i].Kind() {
    289 		case reflect.Bool:
    290 			return !keys[i].Bool() && keys[j].Bool()
    291 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    292 			return keys[i].Int() < keys[j].Int()
    293 		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
    294 			return keys[i].Uint() < keys[j].Uint()
    295 		case reflect.Float32, reflect.Float64:
    296 			return keys[i].Float() < keys[j].Float()
    297 		case reflect.String:
    298 			return keys[i].String() < keys[j].String()
    299 		default:
    300 			panic("invalid kind: " + keys[i].Kind().String())
    301 		}
    302 	})
    303 	for _, key := range keys {
    304 		var err error
    305 		b = protowire.AppendVarint(b, f.wiretag)
    306 		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
    307 		if err != nil {
    308 			return b, err
    309 		}
    310 	}
    311 	return b, nil
    312 }
    313 
    314 func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
    315 	if mi := f.mi; mi != nil {
    316 		mi.init()
    317 		if !mi.needsInitCheck {
    318 			return nil
    319 		}
    320 		iter := mapRange(mapv)
    321 		for iter.Next() {
    322 			val := pointerOfValue(iter.Value())
    323 			if err := mi.checkInitializedPointer(val); err != nil {
    324 				return err
    325 			}
    326 		}
    327 	} else {
    328 		iter := mapRange(mapv)
    329 		for iter.Next() {
    330 			val := mapi.conv.valConv.PBValueOf(iter.Value())
    331 			if err := mapi.valFuncs.isInit(val); err != nil {
    332 				return err
    333 			}
    334 		}
    335 	}
    336 	return nil
    337 }
    338 
    339 func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
    340 	dstm := dst.AsValueOf(f.ft).Elem()
    341 	srcm := src.AsValueOf(f.ft).Elem()
    342 	if srcm.Len() == 0 {
    343 		return
    344 	}
    345 	if dstm.IsNil() {
    346 		dstm.Set(reflect.MakeMap(f.ft))
    347 	}
    348 	iter := mapRange(srcm)
    349 	for iter.Next() {
    350 		dstm.SetMapIndex(iter.Key(), iter.Value())
    351 	}
    352 }
    353 
    354 func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
    355 	dstm := dst.AsValueOf(f.ft).Elem()
    356 	srcm := src.AsValueOf(f.ft).Elem()
    357 	if srcm.Len() == 0 {
    358 		return
    359 	}
    360 	if dstm.IsNil() {
    361 		dstm.Set(reflect.MakeMap(f.ft))
    362 	}
    363 	iter := mapRange(srcm)
    364 	for iter.Next() {
    365 		dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
    366 	}
    367 }
    368 
    369 func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
    370 	dstm := dst.AsValueOf(f.ft).Elem()
    371 	srcm := src.AsValueOf(f.ft).Elem()
    372 	if srcm.Len() == 0 {
    373 		return
    374 	}
    375 	if dstm.IsNil() {
    376 		dstm.Set(reflect.MakeMap(f.ft))
    377 	}
    378 	iter := mapRange(srcm)
    379 	for iter.Next() {
    380 		val := reflect.New(f.ft.Elem().Elem())
    381 		if f.mi != nil {
    382 			f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
    383 		} else {
    384 			opts.Merge(asMessage(val), asMessage(iter.Value()))
    385 		}
    386 		dstm.SetMapIndex(iter.Key(), val)
    387 	}
    388 }