validate.go (15488B)
1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package impl 6 7 import ( 8 "fmt" 9 "math" 10 "math/bits" 11 "reflect" 12 "unicode/utf8" 13 14 "google.golang.org/protobuf/encoding/protowire" 15 "google.golang.org/protobuf/internal/encoding/messageset" 16 "google.golang.org/protobuf/internal/flags" 17 "google.golang.org/protobuf/internal/genid" 18 "google.golang.org/protobuf/internal/strs" 19 "google.golang.org/protobuf/reflect/protoreflect" 20 "google.golang.org/protobuf/reflect/protoregistry" 21 "google.golang.org/protobuf/runtime/protoiface" 22 ) 23 24 // ValidationStatus is the result of validating the wire-format encoding of a message. 25 type ValidationStatus int 26 27 const ( 28 // ValidationUnknown indicates that unmarshaling the message might succeed or fail. 29 // The validator was unable to render a judgement. 30 // 31 // The only causes of this status are an aberrant message type appearing somewhere 32 // in the message or a failure in the extension resolver. 33 ValidationUnknown ValidationStatus = iota + 1 34 35 // ValidationInvalid indicates that unmarshaling the message will fail. 36 ValidationInvalid 37 38 // ValidationValid indicates that unmarshaling the message will succeed. 39 ValidationValid 40 ) 41 42 func (v ValidationStatus) String() string { 43 switch v { 44 case ValidationUnknown: 45 return "ValidationUnknown" 46 case ValidationInvalid: 47 return "ValidationInvalid" 48 case ValidationValid: 49 return "ValidationValid" 50 default: 51 return fmt.Sprintf("ValidationStatus(%d)", int(v)) 52 } 53 } 54 55 // Validate determines whether the contents of the buffer are a valid wire encoding 56 // of the message type. 57 // 58 // This function is exposed for testing. 59 func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) { 60 mi, ok := mt.(*MessageInfo) 61 if !ok { 62 return out, ValidationUnknown 63 } 64 if in.Resolver == nil { 65 in.Resolver = protoregistry.GlobalTypes 66 } 67 o, st := mi.validate(in.Buf, 0, unmarshalOptions{ 68 flags: in.Flags, 69 resolver: in.Resolver, 70 }) 71 if o.initialized { 72 out.Flags |= protoiface.UnmarshalInitialized 73 } 74 return out, st 75 } 76 77 type validationInfo struct { 78 mi *MessageInfo 79 typ validationType 80 keyType, valType validationType 81 82 // For non-required fields, requiredBit is 0. 83 // 84 // For required fields, requiredBit's nth bit is set, where n is a 85 // unique index in the range [0, MessageInfo.numRequiredFields). 86 // 87 // If there are more than 64 required fields, requiredBit is 0. 88 requiredBit uint64 89 } 90 91 type validationType uint8 92 93 const ( 94 validationTypeOther validationType = iota 95 validationTypeMessage 96 validationTypeGroup 97 validationTypeMap 98 validationTypeRepeatedVarint 99 validationTypeRepeatedFixed32 100 validationTypeRepeatedFixed64 101 validationTypeVarint 102 validationTypeFixed32 103 validationTypeFixed64 104 validationTypeBytes 105 validationTypeUTF8String 106 validationTypeMessageSetItem 107 ) 108 109 func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo { 110 var vi validationInfo 111 switch { 112 case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic(): 113 switch fd.Kind() { 114 case protoreflect.MessageKind: 115 vi.typ = validationTypeMessage 116 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { 117 vi.mi = getMessageInfo(ot.Field(0).Type) 118 } 119 case protoreflect.GroupKind: 120 vi.typ = validationTypeGroup 121 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok { 122 vi.mi = getMessageInfo(ot.Field(0).Type) 123 } 124 case protoreflect.StringKind: 125 if strs.EnforceUTF8(fd) { 126 vi.typ = validationTypeUTF8String 127 } 128 } 129 default: 130 vi = newValidationInfo(fd, ft) 131 } 132 if fd.Cardinality() == protoreflect.Required { 133 // Avoid overflow. The required field check is done with a 64-bit mask, with 134 // any message containing more than 64 required fields always reported as 135 // potentially uninitialized, so it is not important to get a precise count 136 // of the required fields past 64. 137 if mi.numRequiredFields < math.MaxUint8 { 138 mi.numRequiredFields++ 139 vi.requiredBit = 1 << (mi.numRequiredFields - 1) 140 } 141 } 142 return vi 143 } 144 145 func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo { 146 var vi validationInfo 147 switch { 148 case fd.IsList(): 149 switch fd.Kind() { 150 case protoreflect.MessageKind: 151 vi.typ = validationTypeMessage 152 if ft.Kind() == reflect.Slice { 153 vi.mi = getMessageInfo(ft.Elem()) 154 } 155 case protoreflect.GroupKind: 156 vi.typ = validationTypeGroup 157 if ft.Kind() == reflect.Slice { 158 vi.mi = getMessageInfo(ft.Elem()) 159 } 160 case protoreflect.StringKind: 161 vi.typ = validationTypeBytes 162 if strs.EnforceUTF8(fd) { 163 vi.typ = validationTypeUTF8String 164 } 165 default: 166 switch wireTypes[fd.Kind()] { 167 case protowire.VarintType: 168 vi.typ = validationTypeRepeatedVarint 169 case protowire.Fixed32Type: 170 vi.typ = validationTypeRepeatedFixed32 171 case protowire.Fixed64Type: 172 vi.typ = validationTypeRepeatedFixed64 173 } 174 } 175 case fd.IsMap(): 176 vi.typ = validationTypeMap 177 switch fd.MapKey().Kind() { 178 case protoreflect.StringKind: 179 if strs.EnforceUTF8(fd) { 180 vi.keyType = validationTypeUTF8String 181 } 182 } 183 switch fd.MapValue().Kind() { 184 case protoreflect.MessageKind: 185 vi.valType = validationTypeMessage 186 if ft.Kind() == reflect.Map { 187 vi.mi = getMessageInfo(ft.Elem()) 188 } 189 case protoreflect.StringKind: 190 if strs.EnforceUTF8(fd) { 191 vi.valType = validationTypeUTF8String 192 } 193 } 194 default: 195 switch fd.Kind() { 196 case protoreflect.MessageKind: 197 vi.typ = validationTypeMessage 198 if !fd.IsWeak() { 199 vi.mi = getMessageInfo(ft) 200 } 201 case protoreflect.GroupKind: 202 vi.typ = validationTypeGroup 203 vi.mi = getMessageInfo(ft) 204 case protoreflect.StringKind: 205 vi.typ = validationTypeBytes 206 if strs.EnforceUTF8(fd) { 207 vi.typ = validationTypeUTF8String 208 } 209 default: 210 switch wireTypes[fd.Kind()] { 211 case protowire.VarintType: 212 vi.typ = validationTypeVarint 213 case protowire.Fixed32Type: 214 vi.typ = validationTypeFixed32 215 case protowire.Fixed64Type: 216 vi.typ = validationTypeFixed64 217 case protowire.BytesType: 218 vi.typ = validationTypeBytes 219 } 220 } 221 } 222 return vi 223 } 224 225 func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) { 226 mi.init() 227 type validationState struct { 228 typ validationType 229 keyType, valType validationType 230 endGroup protowire.Number 231 mi *MessageInfo 232 tail []byte 233 requiredMask uint64 234 } 235 236 // Pre-allocate some slots to avoid repeated slice reallocation. 237 states := make([]validationState, 0, 16) 238 states = append(states, validationState{ 239 typ: validationTypeMessage, 240 mi: mi, 241 }) 242 if groupTag > 0 { 243 states[0].typ = validationTypeGroup 244 states[0].endGroup = groupTag 245 } 246 initialized := true 247 start := len(b) 248 State: 249 for len(states) > 0 { 250 st := &states[len(states)-1] 251 for len(b) > 0 { 252 // Parse the tag (field number and wire type). 253 var tag uint64 254 if b[0] < 0x80 { 255 tag = uint64(b[0]) 256 b = b[1:] 257 } else if len(b) >= 2 && b[1] < 128 { 258 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 259 b = b[2:] 260 } else { 261 var n int 262 tag, n = protowire.ConsumeVarint(b) 263 if n < 0 { 264 return out, ValidationInvalid 265 } 266 b = b[n:] 267 } 268 var num protowire.Number 269 if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) { 270 return out, ValidationInvalid 271 } else { 272 num = protowire.Number(n) 273 } 274 wtyp := protowire.Type(tag & 7) 275 276 if wtyp == protowire.EndGroupType { 277 if st.endGroup == num { 278 goto PopState 279 } 280 return out, ValidationInvalid 281 } 282 var vi validationInfo 283 switch { 284 case st.typ == validationTypeMap: 285 switch num { 286 case genid.MapEntry_Key_field_number: 287 vi.typ = st.keyType 288 case genid.MapEntry_Value_field_number: 289 vi.typ = st.valType 290 vi.mi = st.mi 291 vi.requiredBit = 1 292 } 293 case flags.ProtoLegacy && st.mi.isMessageSet: 294 switch num { 295 case messageset.FieldItem: 296 vi.typ = validationTypeMessageSetItem 297 } 298 default: 299 var f *coderFieldInfo 300 if int(num) < len(st.mi.denseCoderFields) { 301 f = st.mi.denseCoderFields[num] 302 } else { 303 f = st.mi.coderFields[num] 304 } 305 if f != nil { 306 vi = f.validation 307 if vi.typ == validationTypeMessage && vi.mi == nil { 308 // Probable weak field. 309 // 310 // TODO: Consider storing the results of this lookup somewhere 311 // rather than recomputing it on every validation. 312 fd := st.mi.Desc.Fields().ByNumber(num) 313 if fd == nil || !fd.IsWeak() { 314 break 315 } 316 messageName := fd.Message().FullName() 317 messageType, err := protoregistry.GlobalTypes.FindMessageByName(messageName) 318 switch err { 319 case nil: 320 vi.mi, _ = messageType.(*MessageInfo) 321 case protoregistry.NotFound: 322 vi.typ = validationTypeBytes 323 default: 324 return out, ValidationUnknown 325 } 326 } 327 break 328 } 329 // Possible extension field. 330 // 331 // TODO: We should return ValidationUnknown when: 332 // 1. The resolver is not frozen. (More extensions may be added to it.) 333 // 2. The resolver returns preg.NotFound. 334 // In this case, a type added to the resolver in the future could cause 335 // unmarshaling to begin failing. Supporting this requires some way to 336 // determine if the resolver is frozen. 337 xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num) 338 if err != nil && err != protoregistry.NotFound { 339 return out, ValidationUnknown 340 } 341 if err == nil { 342 vi = getExtensionFieldInfo(xt).validation 343 } 344 } 345 if vi.requiredBit != 0 { 346 // Check that the field has a compatible wire type. 347 // We only need to consider non-repeated field types, 348 // since repeated fields (and maps) can never be required. 349 ok := false 350 switch vi.typ { 351 case validationTypeVarint: 352 ok = wtyp == protowire.VarintType 353 case validationTypeFixed32: 354 ok = wtyp == protowire.Fixed32Type 355 case validationTypeFixed64: 356 ok = wtyp == protowire.Fixed64Type 357 case validationTypeBytes, validationTypeUTF8String, validationTypeMessage: 358 ok = wtyp == protowire.BytesType 359 case validationTypeGroup: 360 ok = wtyp == protowire.StartGroupType 361 } 362 if ok { 363 st.requiredMask |= vi.requiredBit 364 } 365 } 366 367 switch wtyp { 368 case protowire.VarintType: 369 if len(b) >= 10 { 370 switch { 371 case b[0] < 0x80: 372 b = b[1:] 373 case b[1] < 0x80: 374 b = b[2:] 375 case b[2] < 0x80: 376 b = b[3:] 377 case b[3] < 0x80: 378 b = b[4:] 379 case b[4] < 0x80: 380 b = b[5:] 381 case b[5] < 0x80: 382 b = b[6:] 383 case b[6] < 0x80: 384 b = b[7:] 385 case b[7] < 0x80: 386 b = b[8:] 387 case b[8] < 0x80: 388 b = b[9:] 389 case b[9] < 0x80 && b[9] < 2: 390 b = b[10:] 391 default: 392 return out, ValidationInvalid 393 } 394 } else { 395 switch { 396 case len(b) > 0 && b[0] < 0x80: 397 b = b[1:] 398 case len(b) > 1 && b[1] < 0x80: 399 b = b[2:] 400 case len(b) > 2 && b[2] < 0x80: 401 b = b[3:] 402 case len(b) > 3 && b[3] < 0x80: 403 b = b[4:] 404 case len(b) > 4 && b[4] < 0x80: 405 b = b[5:] 406 case len(b) > 5 && b[5] < 0x80: 407 b = b[6:] 408 case len(b) > 6 && b[6] < 0x80: 409 b = b[7:] 410 case len(b) > 7 && b[7] < 0x80: 411 b = b[8:] 412 case len(b) > 8 && b[8] < 0x80: 413 b = b[9:] 414 case len(b) > 9 && b[9] < 2: 415 b = b[10:] 416 default: 417 return out, ValidationInvalid 418 } 419 } 420 continue State 421 case protowire.BytesType: 422 var size uint64 423 if len(b) >= 1 && b[0] < 0x80 { 424 size = uint64(b[0]) 425 b = b[1:] 426 } else if len(b) >= 2 && b[1] < 128 { 427 size = uint64(b[0]&0x7f) + uint64(b[1])<<7 428 b = b[2:] 429 } else { 430 var n int 431 size, n = protowire.ConsumeVarint(b) 432 if n < 0 { 433 return out, ValidationInvalid 434 } 435 b = b[n:] 436 } 437 if size > uint64(len(b)) { 438 return out, ValidationInvalid 439 } 440 v := b[:size] 441 b = b[size:] 442 switch vi.typ { 443 case validationTypeMessage: 444 if vi.mi == nil { 445 return out, ValidationUnknown 446 } 447 vi.mi.init() 448 fallthrough 449 case validationTypeMap: 450 if vi.mi != nil { 451 vi.mi.init() 452 } 453 states = append(states, validationState{ 454 typ: vi.typ, 455 keyType: vi.keyType, 456 valType: vi.valType, 457 mi: vi.mi, 458 tail: b, 459 }) 460 b = v 461 continue State 462 case validationTypeRepeatedVarint: 463 // Packed field. 464 for len(v) > 0 { 465 _, n := protowire.ConsumeVarint(v) 466 if n < 0 { 467 return out, ValidationInvalid 468 } 469 v = v[n:] 470 } 471 case validationTypeRepeatedFixed32: 472 // Packed field. 473 if len(v)%4 != 0 { 474 return out, ValidationInvalid 475 } 476 case validationTypeRepeatedFixed64: 477 // Packed field. 478 if len(v)%8 != 0 { 479 return out, ValidationInvalid 480 } 481 case validationTypeUTF8String: 482 if !utf8.Valid(v) { 483 return out, ValidationInvalid 484 } 485 } 486 case protowire.Fixed32Type: 487 if len(b) < 4 { 488 return out, ValidationInvalid 489 } 490 b = b[4:] 491 case protowire.Fixed64Type: 492 if len(b) < 8 { 493 return out, ValidationInvalid 494 } 495 b = b[8:] 496 case protowire.StartGroupType: 497 switch { 498 case vi.typ == validationTypeGroup: 499 if vi.mi == nil { 500 return out, ValidationUnknown 501 } 502 vi.mi.init() 503 states = append(states, validationState{ 504 typ: validationTypeGroup, 505 mi: vi.mi, 506 endGroup: num, 507 }) 508 continue State 509 case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem: 510 typeid, v, n, err := messageset.ConsumeFieldValue(b, false) 511 if err != nil { 512 return out, ValidationInvalid 513 } 514 xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid) 515 switch { 516 case err == protoregistry.NotFound: 517 b = b[n:] 518 case err != nil: 519 return out, ValidationUnknown 520 default: 521 xvi := getExtensionFieldInfo(xt).validation 522 if xvi.mi != nil { 523 xvi.mi.init() 524 } 525 states = append(states, validationState{ 526 typ: xvi.typ, 527 mi: xvi.mi, 528 tail: b[n:], 529 }) 530 b = v 531 continue State 532 } 533 default: 534 n := protowire.ConsumeFieldValue(num, wtyp, b) 535 if n < 0 { 536 return out, ValidationInvalid 537 } 538 b = b[n:] 539 } 540 default: 541 return out, ValidationInvalid 542 } 543 } 544 if st.endGroup != 0 { 545 return out, ValidationInvalid 546 } 547 if len(b) != 0 { 548 return out, ValidationInvalid 549 } 550 b = st.tail 551 PopState: 552 numRequiredFields := 0 553 switch st.typ { 554 case validationTypeMessage, validationTypeGroup: 555 numRequiredFields = int(st.mi.numRequiredFields) 556 case validationTypeMap: 557 // If this is a map field with a message value that contains 558 // required fields, require that the value be present. 559 if st.mi != nil && st.mi.numRequiredFields > 0 { 560 numRequiredFields = 1 561 } 562 } 563 // If there are more than 64 required fields, this check will 564 // always fail and we will report that the message is potentially 565 // uninitialized. 566 if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields { 567 initialized = false 568 } 569 states = states[:len(states)-1] 570 } 571 out.n = start - len(b) 572 if initialized { 573 out.initialized = true 574 } 575 return out, ValidationValid 576 }