gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

decode.go (7615B)


      1 // Copyright 2019 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package impl
      6 
      7 import (
      8 	"math/bits"
      9 
     10 	"google.golang.org/protobuf/encoding/protowire"
     11 	"google.golang.org/protobuf/internal/errors"
     12 	"google.golang.org/protobuf/internal/flags"
     13 	"google.golang.org/protobuf/proto"
     14 	"google.golang.org/protobuf/reflect/protoreflect"
     15 	"google.golang.org/protobuf/reflect/protoregistry"
     16 	"google.golang.org/protobuf/runtime/protoiface"
     17 )
     18 
     19 var errDecode = errors.New("cannot parse invalid wire-format data")
     20 var errRecursionDepth = errors.New("exceeded maximum recursion depth")
     21 
     22 type unmarshalOptions struct {
     23 	flags    protoiface.UnmarshalInputFlags
     24 	resolver interface {
     25 		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
     26 		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
     27 	}
     28 	depth int
     29 }
     30 
     31 func (o unmarshalOptions) Options() proto.UnmarshalOptions {
     32 	return proto.UnmarshalOptions{
     33 		Merge:          true,
     34 		AllowPartial:   true,
     35 		DiscardUnknown: o.DiscardUnknown(),
     36 		Resolver:       o.resolver,
     37 	}
     38 }
     39 
     40 func (o unmarshalOptions) DiscardUnknown() bool {
     41 	return o.flags&protoiface.UnmarshalDiscardUnknown != 0
     42 }
     43 
     44 func (o unmarshalOptions) IsDefault() bool {
     45 	return o.flags == 0 && o.resolver == protoregistry.GlobalTypes
     46 }
     47 
     48 var lazyUnmarshalOptions = unmarshalOptions{
     49 	resolver: protoregistry.GlobalTypes,
     50 	depth:    protowire.DefaultRecursionLimit,
     51 }
     52 
     53 type unmarshalOutput struct {
     54 	n           int // number of bytes consumed
     55 	initialized bool
     56 }
     57 
     58 // unmarshal is protoreflect.Methods.Unmarshal.
     59 func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
     60 	var p pointer
     61 	if ms, ok := in.Message.(*messageState); ok {
     62 		p = ms.pointer()
     63 	} else {
     64 		p = in.Message.(*messageReflectWrapper).pointer()
     65 	}
     66 	out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
     67 		flags:    in.Flags,
     68 		resolver: in.Resolver,
     69 		depth:    in.Depth,
     70 	})
     71 	var flags protoiface.UnmarshalOutputFlags
     72 	if out.initialized {
     73 		flags |= protoiface.UnmarshalInitialized
     74 	}
     75 	return protoiface.UnmarshalOutput{
     76 		Flags: flags,
     77 	}, err
     78 }
     79 
     80 // errUnknown is returned during unmarshaling to indicate a parse error that
     81 // should result in a field being placed in the unknown fields section (for example,
     82 // when the wire type doesn't match) as opposed to the entire unmarshal operation
     83 // failing (for example, when a field extends past the available input).
     84 //
     85 // This is a sentinel error which should never be visible to the user.
     86 var errUnknown = errors.New("unknown")
     87 
     88 func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
     89 	mi.init()
     90 	opts.depth--
     91 	if opts.depth < 0 {
     92 		return out, errRecursionDepth
     93 	}
     94 	if flags.ProtoLegacy && mi.isMessageSet {
     95 		return unmarshalMessageSet(mi, b, p, opts)
     96 	}
     97 	initialized := true
     98 	var requiredMask uint64
     99 	var exts *map[int32]ExtensionField
    100 	start := len(b)
    101 	for len(b) > 0 {
    102 		// Parse the tag (field number and wire type).
    103 		var tag uint64
    104 		if b[0] < 0x80 {
    105 			tag = uint64(b[0])
    106 			b = b[1:]
    107 		} else if len(b) >= 2 && b[1] < 128 {
    108 			tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
    109 			b = b[2:]
    110 		} else {
    111 			var n int
    112 			tag, n = protowire.ConsumeVarint(b)
    113 			if n < 0 {
    114 				return out, errDecode
    115 			}
    116 			b = b[n:]
    117 		}
    118 		var num protowire.Number
    119 		if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
    120 			return out, errDecode
    121 		} else {
    122 			num = protowire.Number(n)
    123 		}
    124 		wtyp := protowire.Type(tag & 7)
    125 
    126 		if wtyp == protowire.EndGroupType {
    127 			if num != groupTag {
    128 				return out, errDecode
    129 			}
    130 			groupTag = 0
    131 			break
    132 		}
    133 
    134 		var f *coderFieldInfo
    135 		if int(num) < len(mi.denseCoderFields) {
    136 			f = mi.denseCoderFields[num]
    137 		} else {
    138 			f = mi.coderFields[num]
    139 		}
    140 		var n int
    141 		err := errUnknown
    142 		switch {
    143 		case f != nil:
    144 			if f.funcs.unmarshal == nil {
    145 				break
    146 			}
    147 			var o unmarshalOutput
    148 			o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
    149 			n = o.n
    150 			if err != nil {
    151 				break
    152 			}
    153 			requiredMask |= f.validation.requiredBit
    154 			if f.funcs.isInit != nil && !o.initialized {
    155 				initialized = false
    156 			}
    157 		default:
    158 			// Possible extension.
    159 			if exts == nil && mi.extensionOffset.IsValid() {
    160 				exts = p.Apply(mi.extensionOffset).Extensions()
    161 				if *exts == nil {
    162 					*exts = make(map[int32]ExtensionField)
    163 				}
    164 			}
    165 			if exts == nil {
    166 				break
    167 			}
    168 			var o unmarshalOutput
    169 			o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
    170 			if err != nil {
    171 				break
    172 			}
    173 			n = o.n
    174 			if !o.initialized {
    175 				initialized = false
    176 			}
    177 		}
    178 		if err != nil {
    179 			if err != errUnknown {
    180 				return out, err
    181 			}
    182 			n = protowire.ConsumeFieldValue(num, wtyp, b)
    183 			if n < 0 {
    184 				return out, errDecode
    185 			}
    186 			if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
    187 				u := mi.mutableUnknownBytes(p)
    188 				*u = protowire.AppendTag(*u, num, wtyp)
    189 				*u = append(*u, b[:n]...)
    190 			}
    191 		}
    192 		b = b[n:]
    193 	}
    194 	if groupTag != 0 {
    195 		return out, errDecode
    196 	}
    197 	if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
    198 		initialized = false
    199 	}
    200 	if initialized {
    201 		out.initialized = true
    202 	}
    203 	out.n = start - len(b)
    204 	return out, nil
    205 }
    206 
    207 func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
    208 	x := exts[int32(num)]
    209 	xt := x.Type()
    210 	if xt == nil {
    211 		var err error
    212 		xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
    213 		if err != nil {
    214 			if err == protoregistry.NotFound {
    215 				return out, errUnknown
    216 			}
    217 			return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
    218 		}
    219 	}
    220 	xi := getExtensionFieldInfo(xt)
    221 	if xi.funcs.unmarshal == nil {
    222 		return out, errUnknown
    223 	}
    224 	if flags.LazyUnmarshalExtensions {
    225 		if opts.IsDefault() && x.canLazy(xt) {
    226 			out, valid := skipExtension(b, xi, num, wtyp, opts)
    227 			switch valid {
    228 			case ValidationValid:
    229 				if out.initialized {
    230 					x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
    231 					exts[int32(num)] = x
    232 					return out, nil
    233 				}
    234 			case ValidationInvalid:
    235 				return out, errDecode
    236 			case ValidationUnknown:
    237 			}
    238 		}
    239 	}
    240 	ival := x.Value()
    241 	if !ival.IsValid() && xi.unmarshalNeedsValue {
    242 		// Create a new message, list, or map value to fill in.
    243 		// For enums, create a prototype value to let the unmarshal func know the
    244 		// concrete type.
    245 		ival = xt.New()
    246 	}
    247 	v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
    248 	if err != nil {
    249 		return out, err
    250 	}
    251 	if xi.funcs.isInit == nil {
    252 		out.initialized = true
    253 	}
    254 	x.Set(xt, v)
    255 	exts[int32(num)] = x
    256 	return out, nil
    257 }
    258 
    259 func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
    260 	if xi.validation.mi == nil {
    261 		return out, ValidationUnknown
    262 	}
    263 	xi.validation.mi.init()
    264 	switch xi.validation.typ {
    265 	case validationTypeMessage:
    266 		if wtyp != protowire.BytesType {
    267 			return out, ValidationUnknown
    268 		}
    269 		v, n := protowire.ConsumeBytes(b)
    270 		if n < 0 {
    271 			return out, ValidationUnknown
    272 		}
    273 		out, st := xi.validation.mi.validate(v, 0, opts)
    274 		out.n = n
    275 		return out, st
    276 	case validationTypeGroup:
    277 		if wtyp != protowire.StartGroupType {
    278 			return out, ValidationUnknown
    279 		}
    280 		out, st := xi.validation.mi.validate(b, num, opts)
    281 		return out, st
    282 	default:
    283 		return out, ValidationUnknown
    284 	}
    285 }