gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

decode.go (8732B)


      1 // Copyright 2018 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package proto
      6 
      7 import (
      8 	"google.golang.org/protobuf/encoding/protowire"
      9 	"google.golang.org/protobuf/internal/encoding/messageset"
     10 	"google.golang.org/protobuf/internal/errors"
     11 	"google.golang.org/protobuf/internal/flags"
     12 	"google.golang.org/protobuf/internal/genid"
     13 	"google.golang.org/protobuf/internal/pragma"
     14 	"google.golang.org/protobuf/reflect/protoreflect"
     15 	"google.golang.org/protobuf/reflect/protoregistry"
     16 	"google.golang.org/protobuf/runtime/protoiface"
     17 )
     18 
     19 // UnmarshalOptions configures the unmarshaler.
     20 //
     21 // Example usage:
     22 //
     23 //	err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
     24 type UnmarshalOptions struct {
     25 	pragma.NoUnkeyedLiterals
     26 
     27 	// Merge merges the input into the destination message.
     28 	// The default behavior is to always reset the message before unmarshaling,
     29 	// unless Merge is specified.
     30 	Merge bool
     31 
     32 	// AllowPartial accepts input for messages that will result in missing
     33 	// required fields. If AllowPartial is false (the default), Unmarshal will
     34 	// return an error if there are any missing required fields.
     35 	AllowPartial bool
     36 
     37 	// If DiscardUnknown is set, unknown fields are ignored.
     38 	DiscardUnknown bool
     39 
     40 	// Resolver is used for looking up types when unmarshaling extension fields.
     41 	// If nil, this defaults to using protoregistry.GlobalTypes.
     42 	Resolver interface {
     43 		FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
     44 		FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
     45 	}
     46 
     47 	// RecursionLimit limits how deeply messages may be nested.
     48 	// If zero, a default limit is applied.
     49 	RecursionLimit int
     50 }
     51 
     52 // Unmarshal parses the wire-format message in b and places the result in m.
     53 // The provided message must be mutable (e.g., a non-nil pointer to a message).
     54 func Unmarshal(b []byte, m Message) error {
     55 	_, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect())
     56 	return err
     57 }
     58 
     59 // Unmarshal parses the wire-format message in b and places the result in m.
     60 // The provided message must be mutable (e.g., a non-nil pointer to a message).
     61 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
     62 	if o.RecursionLimit == 0 {
     63 		o.RecursionLimit = protowire.DefaultRecursionLimit
     64 	}
     65 	_, err := o.unmarshal(b, m.ProtoReflect())
     66 	return err
     67 }
     68 
     69 // UnmarshalState parses a wire-format message and places the result in m.
     70 //
     71 // This method permits fine-grained control over the unmarshaler.
     72 // Most users should use Unmarshal instead.
     73 func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
     74 	if o.RecursionLimit == 0 {
     75 		o.RecursionLimit = protowire.DefaultRecursionLimit
     76 	}
     77 	return o.unmarshal(in.Buf, in.Message)
     78 }
     79 
     80 // unmarshal is a centralized function that all unmarshal operations go through.
     81 // For profiling purposes, avoid changing the name of this function or
     82 // introducing other code paths for unmarshal that do not go through this.
     83 func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
     84 	if o.Resolver == nil {
     85 		o.Resolver = protoregistry.GlobalTypes
     86 	}
     87 	if !o.Merge {
     88 		Reset(m.Interface())
     89 	}
     90 	allowPartial := o.AllowPartial
     91 	o.Merge = true
     92 	o.AllowPartial = true
     93 	methods := protoMethods(m)
     94 	if methods != nil && methods.Unmarshal != nil &&
     95 		!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
     96 		in := protoiface.UnmarshalInput{
     97 			Message:  m,
     98 			Buf:      b,
     99 			Resolver: o.Resolver,
    100 			Depth:    o.RecursionLimit,
    101 		}
    102 		if o.DiscardUnknown {
    103 			in.Flags |= protoiface.UnmarshalDiscardUnknown
    104 		}
    105 		out, err = methods.Unmarshal(in)
    106 	} else {
    107 		o.RecursionLimit--
    108 		if o.RecursionLimit < 0 {
    109 			return out, errors.New("exceeded max recursion depth")
    110 		}
    111 		err = o.unmarshalMessageSlow(b, m)
    112 	}
    113 	if err != nil {
    114 		return out, err
    115 	}
    116 	if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
    117 		return out, nil
    118 	}
    119 	return out, checkInitialized(m)
    120 }
    121 
    122 func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
    123 	_, err := o.unmarshal(b, m)
    124 	return err
    125 }
    126 
    127 func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
    128 	md := m.Descriptor()
    129 	if messageset.IsMessageSet(md) {
    130 		return o.unmarshalMessageSet(b, m)
    131 	}
    132 	fields := md.Fields()
    133 	for len(b) > 0 {
    134 		// Parse the tag (field number and wire type).
    135 		num, wtyp, tagLen := protowire.ConsumeTag(b)
    136 		if tagLen < 0 {
    137 			return errDecode
    138 		}
    139 		if num > protowire.MaxValidNumber {
    140 			return errDecode
    141 		}
    142 
    143 		// Find the field descriptor for this field number.
    144 		fd := fields.ByNumber(num)
    145 		if fd == nil && md.ExtensionRanges().Has(num) {
    146 			extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
    147 			if err != nil && err != protoregistry.NotFound {
    148 				return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
    149 			}
    150 			if extType != nil {
    151 				fd = extType.TypeDescriptor()
    152 			}
    153 		}
    154 		var err error
    155 		if fd == nil {
    156 			err = errUnknown
    157 		} else if flags.ProtoLegacy {
    158 			if fd.IsWeak() && fd.Message().IsPlaceholder() {
    159 				err = errUnknown // weak referent is not linked in
    160 			}
    161 		}
    162 
    163 		// Parse the field value.
    164 		var valLen int
    165 		switch {
    166 		case err != nil:
    167 		case fd.IsList():
    168 			valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
    169 		case fd.IsMap():
    170 			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
    171 		default:
    172 			valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
    173 		}
    174 		if err != nil {
    175 			if err != errUnknown {
    176 				return err
    177 			}
    178 			valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
    179 			if valLen < 0 {
    180 				return errDecode
    181 			}
    182 			if !o.DiscardUnknown {
    183 				m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
    184 			}
    185 		}
    186 		b = b[tagLen+valLen:]
    187 	}
    188 	return nil
    189 }
    190 
    191 func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
    192 	v, n, err := o.unmarshalScalar(b, wtyp, fd)
    193 	if err != nil {
    194 		return 0, err
    195 	}
    196 	switch fd.Kind() {
    197 	case protoreflect.GroupKind, protoreflect.MessageKind:
    198 		m2 := m.Mutable(fd).Message()
    199 		if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
    200 			return n, err
    201 		}
    202 	default:
    203 		// Non-message scalars replace the previous value.
    204 		m.Set(fd, v)
    205 	}
    206 	return n, nil
    207 }
    208 
    209 func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
    210 	if wtyp != protowire.BytesType {
    211 		return 0, errUnknown
    212 	}
    213 	b, n = protowire.ConsumeBytes(b)
    214 	if n < 0 {
    215 		return 0, errDecode
    216 	}
    217 	var (
    218 		keyField = fd.MapKey()
    219 		valField = fd.MapValue()
    220 		key      protoreflect.Value
    221 		val      protoreflect.Value
    222 		haveKey  bool
    223 		haveVal  bool
    224 	)
    225 	switch valField.Kind() {
    226 	case protoreflect.GroupKind, protoreflect.MessageKind:
    227 		val = mapv.NewValue()
    228 	}
    229 	// Map entries are represented as a two-element message with fields
    230 	// containing the key and value.
    231 	for len(b) > 0 {
    232 		num, wtyp, n := protowire.ConsumeTag(b)
    233 		if n < 0 {
    234 			return 0, errDecode
    235 		}
    236 		if num > protowire.MaxValidNumber {
    237 			return 0, errDecode
    238 		}
    239 		b = b[n:]
    240 		err = errUnknown
    241 		switch num {
    242 		case genid.MapEntry_Key_field_number:
    243 			key, n, err = o.unmarshalScalar(b, wtyp, keyField)
    244 			if err != nil {
    245 				break
    246 			}
    247 			haveKey = true
    248 		case genid.MapEntry_Value_field_number:
    249 			var v protoreflect.Value
    250 			v, n, err = o.unmarshalScalar(b, wtyp, valField)
    251 			if err != nil {
    252 				break
    253 			}
    254 			switch valField.Kind() {
    255 			case protoreflect.GroupKind, protoreflect.MessageKind:
    256 				if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
    257 					return 0, err
    258 				}
    259 			default:
    260 				val = v
    261 			}
    262 			haveVal = true
    263 		}
    264 		if err == errUnknown {
    265 			n = protowire.ConsumeFieldValue(num, wtyp, b)
    266 			if n < 0 {
    267 				return 0, errDecode
    268 			}
    269 		} else if err != nil {
    270 			return 0, err
    271 		}
    272 		b = b[n:]
    273 	}
    274 	// Every map entry should have entries for key and value, but this is not strictly required.
    275 	if !haveKey {
    276 		key = keyField.Default()
    277 	}
    278 	if !haveVal {
    279 		switch valField.Kind() {
    280 		case protoreflect.GroupKind, protoreflect.MessageKind:
    281 		default:
    282 			val = valField.Default()
    283 		}
    284 	}
    285 	mapv.Set(key.MapKey(), val)
    286 	return n, nil
    287 }
    288 
    289 // errUnknown is used internally to indicate fields which should be added
    290 // to the unknown field set of a message. It is never returned from an exported
    291 // function.
    292 var errUnknown = errors.New("BUG: internal error (unknown)")
    293 
    294 var errDecode = errors.New("cannot parse invalid wire-format data")