extensions.go (10945B)
1 // Copyright 2010 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package proto 6 7 import ( 8 "errors" 9 "fmt" 10 "reflect" 11 12 "google.golang.org/protobuf/encoding/protowire" 13 "google.golang.org/protobuf/proto" 14 "google.golang.org/protobuf/reflect/protoreflect" 15 "google.golang.org/protobuf/reflect/protoregistry" 16 "google.golang.org/protobuf/runtime/protoiface" 17 "google.golang.org/protobuf/runtime/protoimpl" 18 ) 19 20 type ( 21 // ExtensionDesc represents an extension descriptor and 22 // is used to interact with an extension field in a message. 23 // 24 // Variables of this type are generated in code by protoc-gen-go. 25 ExtensionDesc = protoimpl.ExtensionInfo 26 27 // ExtensionRange represents a range of message extensions. 28 // Used in code generated by protoc-gen-go. 29 ExtensionRange = protoiface.ExtensionRangeV1 30 31 // Deprecated: Do not use; this is an internal type. 32 Extension = protoimpl.ExtensionFieldV1 33 34 // Deprecated: Do not use; this is an internal type. 35 XXX_InternalExtensions = protoimpl.ExtensionFields 36 ) 37 38 // ErrMissingExtension reports whether the extension was not present. 39 var ErrMissingExtension = errors.New("proto: missing extension") 40 41 var errNotExtendable = errors.New("proto: not an extendable proto.Message") 42 43 // HasExtension reports whether the extension field is present in m 44 // either as an explicitly populated field or as an unknown field. 45 func HasExtension(m Message, xt *ExtensionDesc) (has bool) { 46 mr := MessageReflect(m) 47 if mr == nil || !mr.IsValid() { 48 return false 49 } 50 51 // Check whether any populated known field matches the field number. 52 xtd := xt.TypeDescriptor() 53 if isValidExtension(mr.Descriptor(), xtd) { 54 has = mr.Has(xtd) 55 } else { 56 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { 57 has = int32(fd.Number()) == xt.Field 58 return !has 59 }) 60 } 61 62 // Check whether any unknown field matches the field number. 63 for b := mr.GetUnknown(); !has && len(b) > 0; { 64 num, _, n := protowire.ConsumeField(b) 65 has = int32(num) == xt.Field 66 b = b[n:] 67 } 68 return has 69 } 70 71 // ClearExtension removes the extension field from m 72 // either as an explicitly populated field or as an unknown field. 73 func ClearExtension(m Message, xt *ExtensionDesc) { 74 mr := MessageReflect(m) 75 if mr == nil || !mr.IsValid() { 76 return 77 } 78 79 xtd := xt.TypeDescriptor() 80 if isValidExtension(mr.Descriptor(), xtd) { 81 mr.Clear(xtd) 82 } else { 83 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { 84 if int32(fd.Number()) == xt.Field { 85 mr.Clear(fd) 86 return false 87 } 88 return true 89 }) 90 } 91 clearUnknown(mr, fieldNum(xt.Field)) 92 } 93 94 // ClearAllExtensions clears all extensions from m. 95 // This includes populated fields and unknown fields in the extension range. 96 func ClearAllExtensions(m Message) { 97 mr := MessageReflect(m) 98 if mr == nil || !mr.IsValid() { 99 return 100 } 101 102 mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool { 103 if fd.IsExtension() { 104 mr.Clear(fd) 105 } 106 return true 107 }) 108 clearUnknown(mr, mr.Descriptor().ExtensionRanges()) 109 } 110 111 // GetExtension retrieves a proto2 extended field from m. 112 // 113 // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil), 114 // then GetExtension parses the encoded field and returns a Go value of the specified type. 115 // If the field is not present, then the default value is returned (if one is specified), 116 // otherwise ErrMissingExtension is reported. 117 // 118 // If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil), 119 // then GetExtension returns the raw encoded bytes for the extension field. 120 func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) { 121 mr := MessageReflect(m) 122 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 { 123 return nil, errNotExtendable 124 } 125 126 // Retrieve the unknown fields for this extension field. 127 var bo protoreflect.RawFields 128 for bi := mr.GetUnknown(); len(bi) > 0; { 129 num, _, n := protowire.ConsumeField(bi) 130 if int32(num) == xt.Field { 131 bo = append(bo, bi[:n]...) 132 } 133 bi = bi[n:] 134 } 135 136 // For type incomplete descriptors, only retrieve the unknown fields. 137 if xt.ExtensionType == nil { 138 return []byte(bo), nil 139 } 140 141 // If the extension field only exists as unknown fields, unmarshal it. 142 // This is rarely done since proto.Unmarshal eagerly unmarshals extensions. 143 xtd := xt.TypeDescriptor() 144 if !isValidExtension(mr.Descriptor(), xtd) { 145 return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m) 146 } 147 if !mr.Has(xtd) && len(bo) > 0 { 148 m2 := mr.New() 149 if err := (proto.UnmarshalOptions{ 150 Resolver: extensionResolver{xt}, 151 }.Unmarshal(bo, m2.Interface())); err != nil { 152 return nil, err 153 } 154 if m2.Has(xtd) { 155 mr.Set(xtd, m2.Get(xtd)) 156 clearUnknown(mr, fieldNum(xt.Field)) 157 } 158 } 159 160 // Check whether the message has the extension field set or a default. 161 var pv protoreflect.Value 162 switch { 163 case mr.Has(xtd): 164 pv = mr.Get(xtd) 165 case xtd.HasDefault(): 166 pv = xtd.Default() 167 default: 168 return nil, ErrMissingExtension 169 } 170 171 v := xt.InterfaceOf(pv) 172 rv := reflect.ValueOf(v) 173 if isScalarKind(rv.Kind()) { 174 rv2 := reflect.New(rv.Type()) 175 rv2.Elem().Set(rv) 176 v = rv2.Interface() 177 } 178 return v, nil 179 } 180 181 // extensionResolver is a custom extension resolver that stores a single 182 // extension type that takes precedence over the global registry. 183 type extensionResolver struct{ xt protoreflect.ExtensionType } 184 185 func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { 186 if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field { 187 return r.xt, nil 188 } 189 return protoregistry.GlobalTypes.FindExtensionByName(field) 190 } 191 192 func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { 193 if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field { 194 return r.xt, nil 195 } 196 return protoregistry.GlobalTypes.FindExtensionByNumber(message, field) 197 } 198 199 // GetExtensions returns a list of the extensions values present in m, 200 // corresponding with the provided list of extension descriptors, xts. 201 // If an extension is missing in m, the corresponding value is nil. 202 func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) { 203 mr := MessageReflect(m) 204 if mr == nil || !mr.IsValid() { 205 return nil, errNotExtendable 206 } 207 208 vs := make([]interface{}, len(xts)) 209 for i, xt := range xts { 210 v, err := GetExtension(m, xt) 211 if err != nil { 212 if err == ErrMissingExtension { 213 continue 214 } 215 return vs, err 216 } 217 vs[i] = v 218 } 219 return vs, nil 220 } 221 222 // SetExtension sets an extension field in m to the provided value. 223 func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error { 224 mr := MessageReflect(m) 225 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 { 226 return errNotExtendable 227 } 228 229 rv := reflect.ValueOf(v) 230 if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) { 231 return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType) 232 } 233 if rv.Kind() == reflect.Ptr { 234 if rv.IsNil() { 235 return fmt.Errorf("proto: SetExtension called with nil value of type %T", v) 236 } 237 if isScalarKind(rv.Elem().Kind()) { 238 v = rv.Elem().Interface() 239 } 240 } 241 242 xtd := xt.TypeDescriptor() 243 if !isValidExtension(mr.Descriptor(), xtd) { 244 return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m) 245 } 246 mr.Set(xtd, xt.ValueOf(v)) 247 clearUnknown(mr, fieldNum(xt.Field)) 248 return nil 249 } 250 251 // SetRawExtension inserts b into the unknown fields of m. 252 // 253 // Deprecated: Use Message.ProtoReflect.SetUnknown instead. 254 func SetRawExtension(m Message, fnum int32, b []byte) { 255 mr := MessageReflect(m) 256 if mr == nil || !mr.IsValid() { 257 return 258 } 259 260 // Verify that the raw field is valid. 261 for b0 := b; len(b0) > 0; { 262 num, _, n := protowire.ConsumeField(b0) 263 if int32(num) != fnum { 264 panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum)) 265 } 266 b0 = b0[n:] 267 } 268 269 ClearExtension(m, &ExtensionDesc{Field: fnum}) 270 mr.SetUnknown(append(mr.GetUnknown(), b...)) 271 } 272 273 // ExtensionDescs returns a list of extension descriptors found in m, 274 // containing descriptors for both populated extension fields in m and 275 // also unknown fields of m that are in the extension range. 276 // For the later case, an type incomplete descriptor is provided where only 277 // the ExtensionDesc.Field field is populated. 278 // The order of the extension descriptors is undefined. 279 func ExtensionDescs(m Message) ([]*ExtensionDesc, error) { 280 mr := MessageReflect(m) 281 if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 { 282 return nil, errNotExtendable 283 } 284 285 // Collect a set of known extension descriptors. 286 extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc) 287 mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { 288 if fd.IsExtension() { 289 xt := fd.(protoreflect.ExtensionTypeDescriptor) 290 if xd, ok := xt.Type().(*ExtensionDesc); ok { 291 extDescs[fd.Number()] = xd 292 } 293 } 294 return true 295 }) 296 297 // Collect a set of unknown extension descriptors. 298 extRanges := mr.Descriptor().ExtensionRanges() 299 for b := mr.GetUnknown(); len(b) > 0; { 300 num, _, n := protowire.ConsumeField(b) 301 if extRanges.Has(num) && extDescs[num] == nil { 302 extDescs[num] = nil 303 } 304 b = b[n:] 305 } 306 307 // Transpose the set of descriptors into a list. 308 var xts []*ExtensionDesc 309 for num, xt := range extDescs { 310 if xt == nil { 311 xt = &ExtensionDesc{Field: int32(num)} 312 } 313 xts = append(xts, xt) 314 } 315 return xts, nil 316 } 317 318 // isValidExtension reports whether xtd is a valid extension descriptor for md. 319 func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool { 320 return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number()) 321 } 322 323 // isScalarKind reports whether k is a protobuf scalar kind (except bytes). 324 // This function exists for historical reasons since the representation of 325 // scalars differs between v1 and v2, where v1 uses *T and v2 uses T. 326 func isScalarKind(k reflect.Kind) bool { 327 switch k { 328 case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String: 329 return true 330 default: 331 return false 332 } 333 } 334 335 // clearUnknown removes unknown fields from m where remover.Has reports true. 336 func clearUnknown(m protoreflect.Message, remover interface { 337 Has(protoreflect.FieldNumber) bool 338 }) { 339 var bo protoreflect.RawFields 340 for bi := m.GetUnknown(); len(bi) > 0; { 341 num, _, n := protowire.ConsumeField(bi) 342 if !remover.Has(num) { 343 bo = append(bo, bi[:n]...) 344 } 345 bi = bi[n:] 346 } 347 if bi := m.GetUnknown(); len(bi) != len(bo) { 348 m.SetUnknown(bo) 349 } 350 } 351 352 type fieldNum protoreflect.FieldNumber 353 354 func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool { 355 return protoreflect.FieldNumber(n1) == n2 356 }