gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

validate.go (15488B)


      1 // Copyright 2019 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package impl
      6 
      7 import (
      8 	"fmt"
      9 	"math"
     10 	"math/bits"
     11 	"reflect"
     12 	"unicode/utf8"
     13 
     14 	"google.golang.org/protobuf/encoding/protowire"
     15 	"google.golang.org/protobuf/internal/encoding/messageset"
     16 	"google.golang.org/protobuf/internal/flags"
     17 	"google.golang.org/protobuf/internal/genid"
     18 	"google.golang.org/protobuf/internal/strs"
     19 	"google.golang.org/protobuf/reflect/protoreflect"
     20 	"google.golang.org/protobuf/reflect/protoregistry"
     21 	"google.golang.org/protobuf/runtime/protoiface"
     22 )
     23 
     24 // ValidationStatus is the result of validating the wire-format encoding of a message.
     25 type ValidationStatus int
     26 
     27 const (
     28 	// ValidationUnknown indicates that unmarshaling the message might succeed or fail.
     29 	// The validator was unable to render a judgement.
     30 	//
     31 	// The only causes of this status are an aberrant message type appearing somewhere
     32 	// in the message or a failure in the extension resolver.
     33 	ValidationUnknown ValidationStatus = iota + 1
     34 
     35 	// ValidationInvalid indicates that unmarshaling the message will fail.
     36 	ValidationInvalid
     37 
     38 	// ValidationValid indicates that unmarshaling the message will succeed.
     39 	ValidationValid
     40 )
     41 
     42 func (v ValidationStatus) String() string {
     43 	switch v {
     44 	case ValidationUnknown:
     45 		return "ValidationUnknown"
     46 	case ValidationInvalid:
     47 		return "ValidationInvalid"
     48 	case ValidationValid:
     49 		return "ValidationValid"
     50 	default:
     51 		return fmt.Sprintf("ValidationStatus(%d)", int(v))
     52 	}
     53 }
     54 
     55 // Validate determines whether the contents of the buffer are a valid wire encoding
     56 // of the message type.
     57 //
     58 // This function is exposed for testing.
     59 func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
     60 	mi, ok := mt.(*MessageInfo)
     61 	if !ok {
     62 		return out, ValidationUnknown
     63 	}
     64 	if in.Resolver == nil {
     65 		in.Resolver = protoregistry.GlobalTypes
     66 	}
     67 	o, st := mi.validate(in.Buf, 0, unmarshalOptions{
     68 		flags:    in.Flags,
     69 		resolver: in.Resolver,
     70 	})
     71 	if o.initialized {
     72 		out.Flags |= protoiface.UnmarshalInitialized
     73 	}
     74 	return out, st
     75 }
     76 
     77 type validationInfo struct {
     78 	mi               *MessageInfo
     79 	typ              validationType
     80 	keyType, valType validationType
     81 
     82 	// For non-required fields, requiredBit is 0.
     83 	//
     84 	// For required fields, requiredBit's nth bit is set, where n is a
     85 	// unique index in the range [0, MessageInfo.numRequiredFields).
     86 	//
     87 	// If there are more than 64 required fields, requiredBit is 0.
     88 	requiredBit uint64
     89 }
     90 
     91 type validationType uint8
     92 
     93 const (
     94 	validationTypeOther validationType = iota
     95 	validationTypeMessage
     96 	validationTypeGroup
     97 	validationTypeMap
     98 	validationTypeRepeatedVarint
     99 	validationTypeRepeatedFixed32
    100 	validationTypeRepeatedFixed64
    101 	validationTypeVarint
    102 	validationTypeFixed32
    103 	validationTypeFixed64
    104 	validationTypeBytes
    105 	validationTypeUTF8String
    106 	validationTypeMessageSetItem
    107 )
    108 
    109 func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
    110 	var vi validationInfo
    111 	switch {
    112 	case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
    113 		switch fd.Kind() {
    114 		case protoreflect.MessageKind:
    115 			vi.typ = validationTypeMessage
    116 			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
    117 				vi.mi = getMessageInfo(ot.Field(0).Type)
    118 			}
    119 		case protoreflect.GroupKind:
    120 			vi.typ = validationTypeGroup
    121 			if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
    122 				vi.mi = getMessageInfo(ot.Field(0).Type)
    123 			}
    124 		case protoreflect.StringKind:
    125 			if strs.EnforceUTF8(fd) {
    126 				vi.typ = validationTypeUTF8String
    127 			}
    128 		}
    129 	default:
    130 		vi = newValidationInfo(fd, ft)
    131 	}
    132 	if fd.Cardinality() == protoreflect.Required {
    133 		// Avoid overflow. The required field check is done with a 64-bit mask, with
    134 		// any message containing more than 64 required fields always reported as
    135 		// potentially uninitialized, so it is not important to get a precise count
    136 		// of the required fields past 64.
    137 		if mi.numRequiredFields < math.MaxUint8 {
    138 			mi.numRequiredFields++
    139 			vi.requiredBit = 1 << (mi.numRequiredFields - 1)
    140 		}
    141 	}
    142 	return vi
    143 }
    144 
    145 func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
    146 	var vi validationInfo
    147 	switch {
    148 	case fd.IsList():
    149 		switch fd.Kind() {
    150 		case protoreflect.MessageKind:
    151 			vi.typ = validationTypeMessage
    152 			if ft.Kind() == reflect.Slice {
    153 				vi.mi = getMessageInfo(ft.Elem())
    154 			}
    155 		case protoreflect.GroupKind:
    156 			vi.typ = validationTypeGroup
    157 			if ft.Kind() == reflect.Slice {
    158 				vi.mi = getMessageInfo(ft.Elem())
    159 			}
    160 		case protoreflect.StringKind:
    161 			vi.typ = validationTypeBytes
    162 			if strs.EnforceUTF8(fd) {
    163 				vi.typ = validationTypeUTF8String
    164 			}
    165 		default:
    166 			switch wireTypes[fd.Kind()] {
    167 			case protowire.VarintType:
    168 				vi.typ = validationTypeRepeatedVarint
    169 			case protowire.Fixed32Type:
    170 				vi.typ = validationTypeRepeatedFixed32
    171 			case protowire.Fixed64Type:
    172 				vi.typ = validationTypeRepeatedFixed64
    173 			}
    174 		}
    175 	case fd.IsMap():
    176 		vi.typ = validationTypeMap
    177 		switch fd.MapKey().Kind() {
    178 		case protoreflect.StringKind:
    179 			if strs.EnforceUTF8(fd) {
    180 				vi.keyType = validationTypeUTF8String
    181 			}
    182 		}
    183 		switch fd.MapValue().Kind() {
    184 		case protoreflect.MessageKind:
    185 			vi.valType = validationTypeMessage
    186 			if ft.Kind() == reflect.Map {
    187 				vi.mi = getMessageInfo(ft.Elem())
    188 			}
    189 		case protoreflect.StringKind:
    190 			if strs.EnforceUTF8(fd) {
    191 				vi.valType = validationTypeUTF8String
    192 			}
    193 		}
    194 	default:
    195 		switch fd.Kind() {
    196 		case protoreflect.MessageKind:
    197 			vi.typ = validationTypeMessage
    198 			if !fd.IsWeak() {
    199 				vi.mi = getMessageInfo(ft)
    200 			}
    201 		case protoreflect.GroupKind:
    202 			vi.typ = validationTypeGroup
    203 			vi.mi = getMessageInfo(ft)
    204 		case protoreflect.StringKind:
    205 			vi.typ = validationTypeBytes
    206 			if strs.EnforceUTF8(fd) {
    207 				vi.typ = validationTypeUTF8String
    208 			}
    209 		default:
    210 			switch wireTypes[fd.Kind()] {
    211 			case protowire.VarintType:
    212 				vi.typ = validationTypeVarint
    213 			case protowire.Fixed32Type:
    214 				vi.typ = validationTypeFixed32
    215 			case protowire.Fixed64Type:
    216 				vi.typ = validationTypeFixed64
    217 			case protowire.BytesType:
    218 				vi.typ = validationTypeBytes
    219 			}
    220 		}
    221 	}
    222 	return vi
    223 }
    224 
    225 func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
    226 	mi.init()
    227 	type validationState struct {
    228 		typ              validationType
    229 		keyType, valType validationType
    230 		endGroup         protowire.Number
    231 		mi               *MessageInfo
    232 		tail             []byte
    233 		requiredMask     uint64
    234 	}
    235 
    236 	// Pre-allocate some slots to avoid repeated slice reallocation.
    237 	states := make([]validationState, 0, 16)
    238 	states = append(states, validationState{
    239 		typ: validationTypeMessage,
    240 		mi:  mi,
    241 	})
    242 	if groupTag > 0 {
    243 		states[0].typ = validationTypeGroup
    244 		states[0].endGroup = groupTag
    245 	}
    246 	initialized := true
    247 	start := len(b)
    248 State:
    249 	for len(states) > 0 {
    250 		st := &states[len(states)-1]
    251 		for len(b) > 0 {
    252 			// Parse the tag (field number and wire type).
    253 			var tag uint64
    254 			if b[0] < 0x80 {
    255 				tag = uint64(b[0])
    256 				b = b[1:]
    257 			} else if len(b) >= 2 && b[1] < 128 {
    258 				tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
    259 				b = b[2:]
    260 			} else {
    261 				var n int
    262 				tag, n = protowire.ConsumeVarint(b)
    263 				if n < 0 {
    264 					return out, ValidationInvalid
    265 				}
    266 				b = b[n:]
    267 			}
    268 			var num protowire.Number
    269 			if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
    270 				return out, ValidationInvalid
    271 			} else {
    272 				num = protowire.Number(n)
    273 			}
    274 			wtyp := protowire.Type(tag & 7)
    275 
    276 			if wtyp == protowire.EndGroupType {
    277 				if st.endGroup == num {
    278 					goto PopState
    279 				}
    280 				return out, ValidationInvalid
    281 			}
    282 			var vi validationInfo
    283 			switch {
    284 			case st.typ == validationTypeMap:
    285 				switch num {
    286 				case genid.MapEntry_Key_field_number:
    287 					vi.typ = st.keyType
    288 				case genid.MapEntry_Value_field_number:
    289 					vi.typ = st.valType
    290 					vi.mi = st.mi
    291 					vi.requiredBit = 1
    292 				}
    293 			case flags.ProtoLegacy && st.mi.isMessageSet:
    294 				switch num {
    295 				case messageset.FieldItem:
    296 					vi.typ = validationTypeMessageSetItem
    297 				}
    298 			default:
    299 				var f *coderFieldInfo
    300 				if int(num) < len(st.mi.denseCoderFields) {
    301 					f = st.mi.denseCoderFields[num]
    302 				} else {
    303 					f = st.mi.coderFields[num]
    304 				}
    305 				if f != nil {
    306 					vi = f.validation
    307 					if vi.typ == validationTypeMessage && vi.mi == nil {
    308 						// Probable weak field.
    309 						//
    310 						// TODO: Consider storing the results of this lookup somewhere
    311 						// rather than recomputing it on every validation.
    312 						fd := st.mi.Desc.Fields().ByNumber(num)
    313 						if fd == nil || !fd.IsWeak() {
    314 							break
    315 						}
    316 						messageName := fd.Message().FullName()
    317 						messageType, err := protoregistry.GlobalTypes.FindMessageByName(messageName)
    318 						switch err {
    319 						case nil:
    320 							vi.mi, _ = messageType.(*MessageInfo)
    321 						case protoregistry.NotFound:
    322 							vi.typ = validationTypeBytes
    323 						default:
    324 							return out, ValidationUnknown
    325 						}
    326 					}
    327 					break
    328 				}
    329 				// Possible extension field.
    330 				//
    331 				// TODO: We should return ValidationUnknown when:
    332 				//   1. The resolver is not frozen. (More extensions may be added to it.)
    333 				//   2. The resolver returns preg.NotFound.
    334 				// In this case, a type added to the resolver in the future could cause
    335 				// unmarshaling to begin failing. Supporting this requires some way to
    336 				// determine if the resolver is frozen.
    337 				xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
    338 				if err != nil && err != protoregistry.NotFound {
    339 					return out, ValidationUnknown
    340 				}
    341 				if err == nil {
    342 					vi = getExtensionFieldInfo(xt).validation
    343 				}
    344 			}
    345 			if vi.requiredBit != 0 {
    346 				// Check that the field has a compatible wire type.
    347 				// We only need to consider non-repeated field types,
    348 				// since repeated fields (and maps) can never be required.
    349 				ok := false
    350 				switch vi.typ {
    351 				case validationTypeVarint:
    352 					ok = wtyp == protowire.VarintType
    353 				case validationTypeFixed32:
    354 					ok = wtyp == protowire.Fixed32Type
    355 				case validationTypeFixed64:
    356 					ok = wtyp == protowire.Fixed64Type
    357 				case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
    358 					ok = wtyp == protowire.BytesType
    359 				case validationTypeGroup:
    360 					ok = wtyp == protowire.StartGroupType
    361 				}
    362 				if ok {
    363 					st.requiredMask |= vi.requiredBit
    364 				}
    365 			}
    366 
    367 			switch wtyp {
    368 			case protowire.VarintType:
    369 				if len(b) >= 10 {
    370 					switch {
    371 					case b[0] < 0x80:
    372 						b = b[1:]
    373 					case b[1] < 0x80:
    374 						b = b[2:]
    375 					case b[2] < 0x80:
    376 						b = b[3:]
    377 					case b[3] < 0x80:
    378 						b = b[4:]
    379 					case b[4] < 0x80:
    380 						b = b[5:]
    381 					case b[5] < 0x80:
    382 						b = b[6:]
    383 					case b[6] < 0x80:
    384 						b = b[7:]
    385 					case b[7] < 0x80:
    386 						b = b[8:]
    387 					case b[8] < 0x80:
    388 						b = b[9:]
    389 					case b[9] < 0x80 && b[9] < 2:
    390 						b = b[10:]
    391 					default:
    392 						return out, ValidationInvalid
    393 					}
    394 				} else {
    395 					switch {
    396 					case len(b) > 0 && b[0] < 0x80:
    397 						b = b[1:]
    398 					case len(b) > 1 && b[1] < 0x80:
    399 						b = b[2:]
    400 					case len(b) > 2 && b[2] < 0x80:
    401 						b = b[3:]
    402 					case len(b) > 3 && b[3] < 0x80:
    403 						b = b[4:]
    404 					case len(b) > 4 && b[4] < 0x80:
    405 						b = b[5:]
    406 					case len(b) > 5 && b[5] < 0x80:
    407 						b = b[6:]
    408 					case len(b) > 6 && b[6] < 0x80:
    409 						b = b[7:]
    410 					case len(b) > 7 && b[7] < 0x80:
    411 						b = b[8:]
    412 					case len(b) > 8 && b[8] < 0x80:
    413 						b = b[9:]
    414 					case len(b) > 9 && b[9] < 2:
    415 						b = b[10:]
    416 					default:
    417 						return out, ValidationInvalid
    418 					}
    419 				}
    420 				continue State
    421 			case protowire.BytesType:
    422 				var size uint64
    423 				if len(b) >= 1 && b[0] < 0x80 {
    424 					size = uint64(b[0])
    425 					b = b[1:]
    426 				} else if len(b) >= 2 && b[1] < 128 {
    427 					size = uint64(b[0]&0x7f) + uint64(b[1])<<7
    428 					b = b[2:]
    429 				} else {
    430 					var n int
    431 					size, n = protowire.ConsumeVarint(b)
    432 					if n < 0 {
    433 						return out, ValidationInvalid
    434 					}
    435 					b = b[n:]
    436 				}
    437 				if size > uint64(len(b)) {
    438 					return out, ValidationInvalid
    439 				}
    440 				v := b[:size]
    441 				b = b[size:]
    442 				switch vi.typ {
    443 				case validationTypeMessage:
    444 					if vi.mi == nil {
    445 						return out, ValidationUnknown
    446 					}
    447 					vi.mi.init()
    448 					fallthrough
    449 				case validationTypeMap:
    450 					if vi.mi != nil {
    451 						vi.mi.init()
    452 					}
    453 					states = append(states, validationState{
    454 						typ:     vi.typ,
    455 						keyType: vi.keyType,
    456 						valType: vi.valType,
    457 						mi:      vi.mi,
    458 						tail:    b,
    459 					})
    460 					b = v
    461 					continue State
    462 				case validationTypeRepeatedVarint:
    463 					// Packed field.
    464 					for len(v) > 0 {
    465 						_, n := protowire.ConsumeVarint(v)
    466 						if n < 0 {
    467 							return out, ValidationInvalid
    468 						}
    469 						v = v[n:]
    470 					}
    471 				case validationTypeRepeatedFixed32:
    472 					// Packed field.
    473 					if len(v)%4 != 0 {
    474 						return out, ValidationInvalid
    475 					}
    476 				case validationTypeRepeatedFixed64:
    477 					// Packed field.
    478 					if len(v)%8 != 0 {
    479 						return out, ValidationInvalid
    480 					}
    481 				case validationTypeUTF8String:
    482 					if !utf8.Valid(v) {
    483 						return out, ValidationInvalid
    484 					}
    485 				}
    486 			case protowire.Fixed32Type:
    487 				if len(b) < 4 {
    488 					return out, ValidationInvalid
    489 				}
    490 				b = b[4:]
    491 			case protowire.Fixed64Type:
    492 				if len(b) < 8 {
    493 					return out, ValidationInvalid
    494 				}
    495 				b = b[8:]
    496 			case protowire.StartGroupType:
    497 				switch {
    498 				case vi.typ == validationTypeGroup:
    499 					if vi.mi == nil {
    500 						return out, ValidationUnknown
    501 					}
    502 					vi.mi.init()
    503 					states = append(states, validationState{
    504 						typ:      validationTypeGroup,
    505 						mi:       vi.mi,
    506 						endGroup: num,
    507 					})
    508 					continue State
    509 				case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
    510 					typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
    511 					if err != nil {
    512 						return out, ValidationInvalid
    513 					}
    514 					xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
    515 					switch {
    516 					case err == protoregistry.NotFound:
    517 						b = b[n:]
    518 					case err != nil:
    519 						return out, ValidationUnknown
    520 					default:
    521 						xvi := getExtensionFieldInfo(xt).validation
    522 						if xvi.mi != nil {
    523 							xvi.mi.init()
    524 						}
    525 						states = append(states, validationState{
    526 							typ:  xvi.typ,
    527 							mi:   xvi.mi,
    528 							tail: b[n:],
    529 						})
    530 						b = v
    531 						continue State
    532 					}
    533 				default:
    534 					n := protowire.ConsumeFieldValue(num, wtyp, b)
    535 					if n < 0 {
    536 						return out, ValidationInvalid
    537 					}
    538 					b = b[n:]
    539 				}
    540 			default:
    541 				return out, ValidationInvalid
    542 			}
    543 		}
    544 		if st.endGroup != 0 {
    545 			return out, ValidationInvalid
    546 		}
    547 		if len(b) != 0 {
    548 			return out, ValidationInvalid
    549 		}
    550 		b = st.tail
    551 	PopState:
    552 		numRequiredFields := 0
    553 		switch st.typ {
    554 		case validationTypeMessage, validationTypeGroup:
    555 			numRequiredFields = int(st.mi.numRequiredFields)
    556 		case validationTypeMap:
    557 			// If this is a map field with a message value that contains
    558 			// required fields, require that the value be present.
    559 			if st.mi != nil && st.mi.numRequiredFields > 0 {
    560 				numRequiredFields = 1
    561 			}
    562 		}
    563 		// If there are more than 64 required fields, this check will
    564 		// always fail and we will report that the message is potentially
    565 		// uninitialized.
    566 		if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
    567 			initialized = false
    568 		}
    569 		states = states[:len(states)-1]
    570 	}
    571 	out.n = start - len(b)
    572 	if initialized {
    573 		out.initialized = true
    574 	}
    575 	return out, ValidationValid
    576 }