decode.go (8732B)
1 // Copyright 2018 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package proto 6 7 import ( 8 "google.golang.org/protobuf/encoding/protowire" 9 "google.golang.org/protobuf/internal/encoding/messageset" 10 "google.golang.org/protobuf/internal/errors" 11 "google.golang.org/protobuf/internal/flags" 12 "google.golang.org/protobuf/internal/genid" 13 "google.golang.org/protobuf/internal/pragma" 14 "google.golang.org/protobuf/reflect/protoreflect" 15 "google.golang.org/protobuf/reflect/protoregistry" 16 "google.golang.org/protobuf/runtime/protoiface" 17 ) 18 19 // UnmarshalOptions configures the unmarshaler. 20 // 21 // Example usage: 22 // 23 // err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m) 24 type UnmarshalOptions struct { 25 pragma.NoUnkeyedLiterals 26 27 // Merge merges the input into the destination message. 28 // The default behavior is to always reset the message before unmarshaling, 29 // unless Merge is specified. 30 Merge bool 31 32 // AllowPartial accepts input for messages that will result in missing 33 // required fields. If AllowPartial is false (the default), Unmarshal will 34 // return an error if there are any missing required fields. 35 AllowPartial bool 36 37 // If DiscardUnknown is set, unknown fields are ignored. 38 DiscardUnknown bool 39 40 // Resolver is used for looking up types when unmarshaling extension fields. 41 // If nil, this defaults to using protoregistry.GlobalTypes. 42 Resolver interface { 43 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) 44 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) 45 } 46 47 // RecursionLimit limits how deeply messages may be nested. 48 // If zero, a default limit is applied. 49 RecursionLimit int 50 } 51 52 // Unmarshal parses the wire-format message in b and places the result in m. 53 // The provided message must be mutable (e.g., a non-nil pointer to a message). 54 func Unmarshal(b []byte, m Message) error { 55 _, err := UnmarshalOptions{RecursionLimit: protowire.DefaultRecursionLimit}.unmarshal(b, m.ProtoReflect()) 56 return err 57 } 58 59 // Unmarshal parses the wire-format message in b and places the result in m. 60 // The provided message must be mutable (e.g., a non-nil pointer to a message). 61 func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error { 62 if o.RecursionLimit == 0 { 63 o.RecursionLimit = protowire.DefaultRecursionLimit 64 } 65 _, err := o.unmarshal(b, m.ProtoReflect()) 66 return err 67 } 68 69 // UnmarshalState parses a wire-format message and places the result in m. 70 // 71 // This method permits fine-grained control over the unmarshaler. 72 // Most users should use Unmarshal instead. 73 func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) { 74 if o.RecursionLimit == 0 { 75 o.RecursionLimit = protowire.DefaultRecursionLimit 76 } 77 return o.unmarshal(in.Buf, in.Message) 78 } 79 80 // unmarshal is a centralized function that all unmarshal operations go through. 81 // For profiling purposes, avoid changing the name of this function or 82 // introducing other code paths for unmarshal that do not go through this. 83 func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) { 84 if o.Resolver == nil { 85 o.Resolver = protoregistry.GlobalTypes 86 } 87 if !o.Merge { 88 Reset(m.Interface()) 89 } 90 allowPartial := o.AllowPartial 91 o.Merge = true 92 o.AllowPartial = true 93 methods := protoMethods(m) 94 if methods != nil && methods.Unmarshal != nil && 95 !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) { 96 in := protoiface.UnmarshalInput{ 97 Message: m, 98 Buf: b, 99 Resolver: o.Resolver, 100 Depth: o.RecursionLimit, 101 } 102 if o.DiscardUnknown { 103 in.Flags |= protoiface.UnmarshalDiscardUnknown 104 } 105 out, err = methods.Unmarshal(in) 106 } else { 107 o.RecursionLimit-- 108 if o.RecursionLimit < 0 { 109 return out, errors.New("exceeded max recursion depth") 110 } 111 err = o.unmarshalMessageSlow(b, m) 112 } 113 if err != nil { 114 return out, err 115 } 116 if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) { 117 return out, nil 118 } 119 return out, checkInitialized(m) 120 } 121 122 func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error { 123 _, err := o.unmarshal(b, m) 124 return err 125 } 126 127 func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error { 128 md := m.Descriptor() 129 if messageset.IsMessageSet(md) { 130 return o.unmarshalMessageSet(b, m) 131 } 132 fields := md.Fields() 133 for len(b) > 0 { 134 // Parse the tag (field number and wire type). 135 num, wtyp, tagLen := protowire.ConsumeTag(b) 136 if tagLen < 0 { 137 return errDecode 138 } 139 if num > protowire.MaxValidNumber { 140 return errDecode 141 } 142 143 // Find the field descriptor for this field number. 144 fd := fields.ByNumber(num) 145 if fd == nil && md.ExtensionRanges().Has(num) { 146 extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num) 147 if err != nil && err != protoregistry.NotFound { 148 return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err) 149 } 150 if extType != nil { 151 fd = extType.TypeDescriptor() 152 } 153 } 154 var err error 155 if fd == nil { 156 err = errUnknown 157 } else if flags.ProtoLegacy { 158 if fd.IsWeak() && fd.Message().IsPlaceholder() { 159 err = errUnknown // weak referent is not linked in 160 } 161 } 162 163 // Parse the field value. 164 var valLen int 165 switch { 166 case err != nil: 167 case fd.IsList(): 168 valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd) 169 case fd.IsMap(): 170 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd) 171 default: 172 valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd) 173 } 174 if err != nil { 175 if err != errUnknown { 176 return err 177 } 178 valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:]) 179 if valLen < 0 { 180 return errDecode 181 } 182 if !o.DiscardUnknown { 183 m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...)) 184 } 185 } 186 b = b[tagLen+valLen:] 187 } 188 return nil 189 } 190 191 func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) { 192 v, n, err := o.unmarshalScalar(b, wtyp, fd) 193 if err != nil { 194 return 0, err 195 } 196 switch fd.Kind() { 197 case protoreflect.GroupKind, protoreflect.MessageKind: 198 m2 := m.Mutable(fd).Message() 199 if err := o.unmarshalMessage(v.Bytes(), m2); err != nil { 200 return n, err 201 } 202 default: 203 // Non-message scalars replace the previous value. 204 m.Set(fd, v) 205 } 206 return n, nil 207 } 208 209 func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) { 210 if wtyp != protowire.BytesType { 211 return 0, errUnknown 212 } 213 b, n = protowire.ConsumeBytes(b) 214 if n < 0 { 215 return 0, errDecode 216 } 217 var ( 218 keyField = fd.MapKey() 219 valField = fd.MapValue() 220 key protoreflect.Value 221 val protoreflect.Value 222 haveKey bool 223 haveVal bool 224 ) 225 switch valField.Kind() { 226 case protoreflect.GroupKind, protoreflect.MessageKind: 227 val = mapv.NewValue() 228 } 229 // Map entries are represented as a two-element message with fields 230 // containing the key and value. 231 for len(b) > 0 { 232 num, wtyp, n := protowire.ConsumeTag(b) 233 if n < 0 { 234 return 0, errDecode 235 } 236 if num > protowire.MaxValidNumber { 237 return 0, errDecode 238 } 239 b = b[n:] 240 err = errUnknown 241 switch num { 242 case genid.MapEntry_Key_field_number: 243 key, n, err = o.unmarshalScalar(b, wtyp, keyField) 244 if err != nil { 245 break 246 } 247 haveKey = true 248 case genid.MapEntry_Value_field_number: 249 var v protoreflect.Value 250 v, n, err = o.unmarshalScalar(b, wtyp, valField) 251 if err != nil { 252 break 253 } 254 switch valField.Kind() { 255 case protoreflect.GroupKind, protoreflect.MessageKind: 256 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil { 257 return 0, err 258 } 259 default: 260 val = v 261 } 262 haveVal = true 263 } 264 if err == errUnknown { 265 n = protowire.ConsumeFieldValue(num, wtyp, b) 266 if n < 0 { 267 return 0, errDecode 268 } 269 } else if err != nil { 270 return 0, err 271 } 272 b = b[n:] 273 } 274 // Every map entry should have entries for key and value, but this is not strictly required. 275 if !haveKey { 276 key = keyField.Default() 277 } 278 if !haveVal { 279 switch valField.Kind() { 280 case protoreflect.GroupKind, protoreflect.MessageKind: 281 default: 282 val = valField.Default() 283 } 284 } 285 mapv.Set(key.MapKey(), val) 286 return n, nil 287 } 288 289 // errUnknown is used internally to indicate fields which should be added 290 // to the unknown field set of a message. It is never returned from an exported 291 // function. 292 var errUnknown = errors.New("BUG: internal error (unknown)") 293 294 var errDecode = errors.New("cannot parse invalid wire-format data")