decode.go (7615B)
1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package impl 6 7 import ( 8 "math/bits" 9 10 "google.golang.org/protobuf/encoding/protowire" 11 "google.golang.org/protobuf/internal/errors" 12 "google.golang.org/protobuf/internal/flags" 13 "google.golang.org/protobuf/proto" 14 "google.golang.org/protobuf/reflect/protoreflect" 15 "google.golang.org/protobuf/reflect/protoregistry" 16 "google.golang.org/protobuf/runtime/protoiface" 17 ) 18 19 var errDecode = errors.New("cannot parse invalid wire-format data") 20 var errRecursionDepth = errors.New("exceeded maximum recursion depth") 21 22 type unmarshalOptions struct { 23 flags protoiface.UnmarshalInputFlags 24 resolver interface { 25 FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) 26 FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) 27 } 28 depth int 29 } 30 31 func (o unmarshalOptions) Options() proto.UnmarshalOptions { 32 return proto.UnmarshalOptions{ 33 Merge: true, 34 AllowPartial: true, 35 DiscardUnknown: o.DiscardUnknown(), 36 Resolver: o.resolver, 37 } 38 } 39 40 func (o unmarshalOptions) DiscardUnknown() bool { 41 return o.flags&protoiface.UnmarshalDiscardUnknown != 0 42 } 43 44 func (o unmarshalOptions) IsDefault() bool { 45 return o.flags == 0 && o.resolver == protoregistry.GlobalTypes 46 } 47 48 var lazyUnmarshalOptions = unmarshalOptions{ 49 resolver: protoregistry.GlobalTypes, 50 depth: protowire.DefaultRecursionLimit, 51 } 52 53 type unmarshalOutput struct { 54 n int // number of bytes consumed 55 initialized bool 56 } 57 58 // unmarshal is protoreflect.Methods.Unmarshal. 59 func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) { 60 var p pointer 61 if ms, ok := in.Message.(*messageState); ok { 62 p = ms.pointer() 63 } else { 64 p = in.Message.(*messageReflectWrapper).pointer() 65 } 66 out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{ 67 flags: in.Flags, 68 resolver: in.Resolver, 69 depth: in.Depth, 70 }) 71 var flags protoiface.UnmarshalOutputFlags 72 if out.initialized { 73 flags |= protoiface.UnmarshalInitialized 74 } 75 return protoiface.UnmarshalOutput{ 76 Flags: flags, 77 }, err 78 } 79 80 // errUnknown is returned during unmarshaling to indicate a parse error that 81 // should result in a field being placed in the unknown fields section (for example, 82 // when the wire type doesn't match) as opposed to the entire unmarshal operation 83 // failing (for example, when a field extends past the available input). 84 // 85 // This is a sentinel error which should never be visible to the user. 86 var errUnknown = errors.New("unknown") 87 88 func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) { 89 mi.init() 90 opts.depth-- 91 if opts.depth < 0 { 92 return out, errRecursionDepth 93 } 94 if flags.ProtoLegacy && mi.isMessageSet { 95 return unmarshalMessageSet(mi, b, p, opts) 96 } 97 initialized := true 98 var requiredMask uint64 99 var exts *map[int32]ExtensionField 100 start := len(b) 101 for len(b) > 0 { 102 // Parse the tag (field number and wire type). 103 var tag uint64 104 if b[0] < 0x80 { 105 tag = uint64(b[0]) 106 b = b[1:] 107 } else if len(b) >= 2 && b[1] < 128 { 108 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 109 b = b[2:] 110 } else { 111 var n int 112 tag, n = protowire.ConsumeVarint(b) 113 if n < 0 { 114 return out, errDecode 115 } 116 b = b[n:] 117 } 118 var num protowire.Number 119 if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { 120 return out, errDecode 121 } else { 122 num = protowire.Number(n) 123 } 124 wtyp := protowire.Type(tag & 7) 125 126 if wtyp == protowire.EndGroupType { 127 if num != groupTag { 128 return out, errDecode 129 } 130 groupTag = 0 131 break 132 } 133 134 var f *coderFieldInfo 135 if int(num) < len(mi.denseCoderFields) { 136 f = mi.denseCoderFields[num] 137 } else { 138 f = mi.coderFields[num] 139 } 140 var n int 141 err := errUnknown 142 switch { 143 case f != nil: 144 if f.funcs.unmarshal == nil { 145 break 146 } 147 var o unmarshalOutput 148 o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts) 149 n = o.n 150 if err != nil { 151 break 152 } 153 requiredMask |= f.validation.requiredBit 154 if f.funcs.isInit != nil && !o.initialized { 155 initialized = false 156 } 157 default: 158 // Possible extension. 159 if exts == nil && mi.extensionOffset.IsValid() { 160 exts = p.Apply(mi.extensionOffset).Extensions() 161 if *exts == nil { 162 *exts = make(map[int32]ExtensionField) 163 } 164 } 165 if exts == nil { 166 break 167 } 168 var o unmarshalOutput 169 o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts) 170 if err != nil { 171 break 172 } 173 n = o.n 174 if !o.initialized { 175 initialized = false 176 } 177 } 178 if err != nil { 179 if err != errUnknown { 180 return out, err 181 } 182 n = protowire.ConsumeFieldValue(num, wtyp, b) 183 if n < 0 { 184 return out, errDecode 185 } 186 if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() { 187 u := mi.mutableUnknownBytes(p) 188 *u = protowire.AppendTag(*u, num, wtyp) 189 *u = append(*u, b[:n]...) 190 } 191 } 192 b = b[n:] 193 } 194 if groupTag != 0 { 195 return out, errDecode 196 } 197 if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) { 198 initialized = false 199 } 200 if initialized { 201 out.initialized = true 202 } 203 out.n = start - len(b) 204 return out, nil 205 } 206 207 func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) { 208 x := exts[int32(num)] 209 xt := x.Type() 210 if xt == nil { 211 var err error 212 xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num) 213 if err != nil { 214 if err == protoregistry.NotFound { 215 return out, errUnknown 216 } 217 return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err) 218 } 219 } 220 xi := getExtensionFieldInfo(xt) 221 if xi.funcs.unmarshal == nil { 222 return out, errUnknown 223 } 224 if flags.LazyUnmarshalExtensions { 225 if opts.IsDefault() && x.canLazy(xt) { 226 out, valid := skipExtension(b, xi, num, wtyp, opts) 227 switch valid { 228 case ValidationValid: 229 if out.initialized { 230 x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n]) 231 exts[int32(num)] = x 232 return out, nil 233 } 234 case ValidationInvalid: 235 return out, errDecode 236 case ValidationUnknown: 237 } 238 } 239 } 240 ival := x.Value() 241 if !ival.IsValid() && xi.unmarshalNeedsValue { 242 // Create a new message, list, or map value to fill in. 243 // For enums, create a prototype value to let the unmarshal func know the 244 // concrete type. 245 ival = xt.New() 246 } 247 v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts) 248 if err != nil { 249 return out, err 250 } 251 if xi.funcs.isInit == nil { 252 out.initialized = true 253 } 254 x.Set(xt, v) 255 exts[int32(num)] = x 256 return out, nil 257 } 258 259 func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) { 260 if xi.validation.mi == nil { 261 return out, ValidationUnknown 262 } 263 xi.validation.mi.init() 264 switch xi.validation.typ { 265 case validationTypeMessage: 266 if wtyp != protowire.BytesType { 267 return out, ValidationUnknown 268 } 269 v, n := protowire.ConsumeBytes(b) 270 if n < 0 { 271 return out, ValidationUnknown 272 } 273 out, st := xi.validation.mi.validate(v, 0, opts) 274 out.n = n 275 return out, st 276 case validationTypeGroup: 277 if wtyp != protowire.StartGroupType { 278 return out, ValidationUnknown 279 } 280 out, st := xi.validation.mi.validate(b, num, opts) 281 return out, st 282 default: 283 return out, ValidationUnknown 284 } 285 }