gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

extensions.go (10945B)


      1 // Copyright 2010 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package proto
      6 
      7 import (
      8 	"errors"
      9 	"fmt"
     10 	"reflect"
     11 
     12 	"google.golang.org/protobuf/encoding/protowire"
     13 	"google.golang.org/protobuf/proto"
     14 	"google.golang.org/protobuf/reflect/protoreflect"
     15 	"google.golang.org/protobuf/reflect/protoregistry"
     16 	"google.golang.org/protobuf/runtime/protoiface"
     17 	"google.golang.org/protobuf/runtime/protoimpl"
     18 )
     19 
     20 type (
     21 	// ExtensionDesc represents an extension descriptor and
     22 	// is used to interact with an extension field in a message.
     23 	//
     24 	// Variables of this type are generated in code by protoc-gen-go.
     25 	ExtensionDesc = protoimpl.ExtensionInfo
     26 
     27 	// ExtensionRange represents a range of message extensions.
     28 	// Used in code generated by protoc-gen-go.
     29 	ExtensionRange = protoiface.ExtensionRangeV1
     30 
     31 	// Deprecated: Do not use; this is an internal type.
     32 	Extension = protoimpl.ExtensionFieldV1
     33 
     34 	// Deprecated: Do not use; this is an internal type.
     35 	XXX_InternalExtensions = protoimpl.ExtensionFields
     36 )
     37 
     38 // ErrMissingExtension reports whether the extension was not present.
     39 var ErrMissingExtension = errors.New("proto: missing extension")
     40 
     41 var errNotExtendable = errors.New("proto: not an extendable proto.Message")
     42 
     43 // HasExtension reports whether the extension field is present in m
     44 // either as an explicitly populated field or as an unknown field.
     45 func HasExtension(m Message, xt *ExtensionDesc) (has bool) {
     46 	mr := MessageReflect(m)
     47 	if mr == nil || !mr.IsValid() {
     48 		return false
     49 	}
     50 
     51 	// Check whether any populated known field matches the field number.
     52 	xtd := xt.TypeDescriptor()
     53 	if isValidExtension(mr.Descriptor(), xtd) {
     54 		has = mr.Has(xtd)
     55 	} else {
     56 		mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
     57 			has = int32(fd.Number()) == xt.Field
     58 			return !has
     59 		})
     60 	}
     61 
     62 	// Check whether any unknown field matches the field number.
     63 	for b := mr.GetUnknown(); !has && len(b) > 0; {
     64 		num, _, n := protowire.ConsumeField(b)
     65 		has = int32(num) == xt.Field
     66 		b = b[n:]
     67 	}
     68 	return has
     69 }
     70 
     71 // ClearExtension removes the extension field from m
     72 // either as an explicitly populated field or as an unknown field.
     73 func ClearExtension(m Message, xt *ExtensionDesc) {
     74 	mr := MessageReflect(m)
     75 	if mr == nil || !mr.IsValid() {
     76 		return
     77 	}
     78 
     79 	xtd := xt.TypeDescriptor()
     80 	if isValidExtension(mr.Descriptor(), xtd) {
     81 		mr.Clear(xtd)
     82 	} else {
     83 		mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
     84 			if int32(fd.Number()) == xt.Field {
     85 				mr.Clear(fd)
     86 				return false
     87 			}
     88 			return true
     89 		})
     90 	}
     91 	clearUnknown(mr, fieldNum(xt.Field))
     92 }
     93 
     94 // ClearAllExtensions clears all extensions from m.
     95 // This includes populated fields and unknown fields in the extension range.
     96 func ClearAllExtensions(m Message) {
     97 	mr := MessageReflect(m)
     98 	if mr == nil || !mr.IsValid() {
     99 		return
    100 	}
    101 
    102 	mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
    103 		if fd.IsExtension() {
    104 			mr.Clear(fd)
    105 		}
    106 		return true
    107 	})
    108 	clearUnknown(mr, mr.Descriptor().ExtensionRanges())
    109 }
    110 
    111 // GetExtension retrieves a proto2 extended field from m.
    112 //
    113 // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
    114 // then GetExtension parses the encoded field and returns a Go value of the specified type.
    115 // If the field is not present, then the default value is returned (if one is specified),
    116 // otherwise ErrMissingExtension is reported.
    117 //
    118 // If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil),
    119 // then GetExtension returns the raw encoded bytes for the extension field.
    120 func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) {
    121 	mr := MessageReflect(m)
    122 	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
    123 		return nil, errNotExtendable
    124 	}
    125 
    126 	// Retrieve the unknown fields for this extension field.
    127 	var bo protoreflect.RawFields
    128 	for bi := mr.GetUnknown(); len(bi) > 0; {
    129 		num, _, n := protowire.ConsumeField(bi)
    130 		if int32(num) == xt.Field {
    131 			bo = append(bo, bi[:n]...)
    132 		}
    133 		bi = bi[n:]
    134 	}
    135 
    136 	// For type incomplete descriptors, only retrieve the unknown fields.
    137 	if xt.ExtensionType == nil {
    138 		return []byte(bo), nil
    139 	}
    140 
    141 	// If the extension field only exists as unknown fields, unmarshal it.
    142 	// This is rarely done since proto.Unmarshal eagerly unmarshals extensions.
    143 	xtd := xt.TypeDescriptor()
    144 	if !isValidExtension(mr.Descriptor(), xtd) {
    145 		return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
    146 	}
    147 	if !mr.Has(xtd) && len(bo) > 0 {
    148 		m2 := mr.New()
    149 		if err := (proto.UnmarshalOptions{
    150 			Resolver: extensionResolver{xt},
    151 		}.Unmarshal(bo, m2.Interface())); err != nil {
    152 			return nil, err
    153 		}
    154 		if m2.Has(xtd) {
    155 			mr.Set(xtd, m2.Get(xtd))
    156 			clearUnknown(mr, fieldNum(xt.Field))
    157 		}
    158 	}
    159 
    160 	// Check whether the message has the extension field set or a default.
    161 	var pv protoreflect.Value
    162 	switch {
    163 	case mr.Has(xtd):
    164 		pv = mr.Get(xtd)
    165 	case xtd.HasDefault():
    166 		pv = xtd.Default()
    167 	default:
    168 		return nil, ErrMissingExtension
    169 	}
    170 
    171 	v := xt.InterfaceOf(pv)
    172 	rv := reflect.ValueOf(v)
    173 	if isScalarKind(rv.Kind()) {
    174 		rv2 := reflect.New(rv.Type())
    175 		rv2.Elem().Set(rv)
    176 		v = rv2.Interface()
    177 	}
    178 	return v, nil
    179 }
    180 
    181 // extensionResolver is a custom extension resolver that stores a single
    182 // extension type that takes precedence over the global registry.
    183 type extensionResolver struct{ xt protoreflect.ExtensionType }
    184 
    185 func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
    186 	if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field {
    187 		return r.xt, nil
    188 	}
    189 	return protoregistry.GlobalTypes.FindExtensionByName(field)
    190 }
    191 
    192 func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
    193 	if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field {
    194 		return r.xt, nil
    195 	}
    196 	return protoregistry.GlobalTypes.FindExtensionByNumber(message, field)
    197 }
    198 
    199 // GetExtensions returns a list of the extensions values present in m,
    200 // corresponding with the provided list of extension descriptors, xts.
    201 // If an extension is missing in m, the corresponding value is nil.
    202 func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) {
    203 	mr := MessageReflect(m)
    204 	if mr == nil || !mr.IsValid() {
    205 		return nil, errNotExtendable
    206 	}
    207 
    208 	vs := make([]interface{}, len(xts))
    209 	for i, xt := range xts {
    210 		v, err := GetExtension(m, xt)
    211 		if err != nil {
    212 			if err == ErrMissingExtension {
    213 				continue
    214 			}
    215 			return vs, err
    216 		}
    217 		vs[i] = v
    218 	}
    219 	return vs, nil
    220 }
    221 
    222 // SetExtension sets an extension field in m to the provided value.
    223 func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error {
    224 	mr := MessageReflect(m)
    225 	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
    226 		return errNotExtendable
    227 	}
    228 
    229 	rv := reflect.ValueOf(v)
    230 	if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) {
    231 		return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType)
    232 	}
    233 	if rv.Kind() == reflect.Ptr {
    234 		if rv.IsNil() {
    235 			return fmt.Errorf("proto: SetExtension called with nil value of type %T", v)
    236 		}
    237 		if isScalarKind(rv.Elem().Kind()) {
    238 			v = rv.Elem().Interface()
    239 		}
    240 	}
    241 
    242 	xtd := xt.TypeDescriptor()
    243 	if !isValidExtension(mr.Descriptor(), xtd) {
    244 		return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
    245 	}
    246 	mr.Set(xtd, xt.ValueOf(v))
    247 	clearUnknown(mr, fieldNum(xt.Field))
    248 	return nil
    249 }
    250 
    251 // SetRawExtension inserts b into the unknown fields of m.
    252 //
    253 // Deprecated: Use Message.ProtoReflect.SetUnknown instead.
    254 func SetRawExtension(m Message, fnum int32, b []byte) {
    255 	mr := MessageReflect(m)
    256 	if mr == nil || !mr.IsValid() {
    257 		return
    258 	}
    259 
    260 	// Verify that the raw field is valid.
    261 	for b0 := b; len(b0) > 0; {
    262 		num, _, n := protowire.ConsumeField(b0)
    263 		if int32(num) != fnum {
    264 			panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum))
    265 		}
    266 		b0 = b0[n:]
    267 	}
    268 
    269 	ClearExtension(m, &ExtensionDesc{Field: fnum})
    270 	mr.SetUnknown(append(mr.GetUnknown(), b...))
    271 }
    272 
    273 // ExtensionDescs returns a list of extension descriptors found in m,
    274 // containing descriptors for both populated extension fields in m and
    275 // also unknown fields of m that are in the extension range.
    276 // For the later case, an type incomplete descriptor is provided where only
    277 // the ExtensionDesc.Field field is populated.
    278 // The order of the extension descriptors is undefined.
    279 func ExtensionDescs(m Message) ([]*ExtensionDesc, error) {
    280 	mr := MessageReflect(m)
    281 	if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
    282 		return nil, errNotExtendable
    283 	}
    284 
    285 	// Collect a set of known extension descriptors.
    286 	extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc)
    287 	mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
    288 		if fd.IsExtension() {
    289 			xt := fd.(protoreflect.ExtensionTypeDescriptor)
    290 			if xd, ok := xt.Type().(*ExtensionDesc); ok {
    291 				extDescs[fd.Number()] = xd
    292 			}
    293 		}
    294 		return true
    295 	})
    296 
    297 	// Collect a set of unknown extension descriptors.
    298 	extRanges := mr.Descriptor().ExtensionRanges()
    299 	for b := mr.GetUnknown(); len(b) > 0; {
    300 		num, _, n := protowire.ConsumeField(b)
    301 		if extRanges.Has(num) && extDescs[num] == nil {
    302 			extDescs[num] = nil
    303 		}
    304 		b = b[n:]
    305 	}
    306 
    307 	// Transpose the set of descriptors into a list.
    308 	var xts []*ExtensionDesc
    309 	for num, xt := range extDescs {
    310 		if xt == nil {
    311 			xt = &ExtensionDesc{Field: int32(num)}
    312 		}
    313 		xts = append(xts, xt)
    314 	}
    315 	return xts, nil
    316 }
    317 
    318 // isValidExtension reports whether xtd is a valid extension descriptor for md.
    319 func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool {
    320 	return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number())
    321 }
    322 
    323 // isScalarKind reports whether k is a protobuf scalar kind (except bytes).
    324 // This function exists for historical reasons since the representation of
    325 // scalars differs between v1 and v2, where v1 uses *T and v2 uses T.
    326 func isScalarKind(k reflect.Kind) bool {
    327 	switch k {
    328 	case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
    329 		return true
    330 	default:
    331 		return false
    332 	}
    333 }
    334 
    335 // clearUnknown removes unknown fields from m where remover.Has reports true.
    336 func clearUnknown(m protoreflect.Message, remover interface {
    337 	Has(protoreflect.FieldNumber) bool
    338 }) {
    339 	var bo protoreflect.RawFields
    340 	for bi := m.GetUnknown(); len(bi) > 0; {
    341 		num, _, n := protowire.ConsumeField(bi)
    342 		if !remover.Has(num) {
    343 			bo = append(bo, bi[:n]...)
    344 		}
    345 		bi = bi[n:]
    346 	}
    347 	if bi := m.GetUnknown(); len(bi) != len(bo) {
    348 		m.SetUnknown(bo)
    349 	}
    350 }
    351 
    352 type fieldNum protoreflect.FieldNumber
    353 
    354 func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool {
    355 	return protoreflect.FieldNumber(n1) == n2
    356 }