codec_extension.go (5981B)
1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package impl 6 7 import ( 8 "sync" 9 "sync/atomic" 10 11 "google.golang.org/protobuf/encoding/protowire" 12 "google.golang.org/protobuf/internal/errors" 13 "google.golang.org/protobuf/reflect/protoreflect" 14 ) 15 16 type extensionFieldInfo struct { 17 wiretag uint64 18 tagsize int 19 unmarshalNeedsValue bool 20 funcs valueCoderFuncs 21 validation validationInfo 22 } 23 24 var legacyExtensionFieldInfoCache sync.Map // map[protoreflect.ExtensionType]*extensionFieldInfo 25 26 func getExtensionFieldInfo(xt protoreflect.ExtensionType) *extensionFieldInfo { 27 if xi, ok := xt.(*ExtensionInfo); ok { 28 xi.lazyInit() 29 return xi.info 30 } 31 return legacyLoadExtensionFieldInfo(xt) 32 } 33 34 // legacyLoadExtensionFieldInfo dynamically loads a *ExtensionInfo for xt. 35 func legacyLoadExtensionFieldInfo(xt protoreflect.ExtensionType) *extensionFieldInfo { 36 if xi, ok := legacyExtensionFieldInfoCache.Load(xt); ok { 37 return xi.(*extensionFieldInfo) 38 } 39 e := makeExtensionFieldInfo(xt.TypeDescriptor()) 40 if e, ok := legacyMessageTypeCache.LoadOrStore(xt, e); ok { 41 return e.(*extensionFieldInfo) 42 } 43 return e 44 } 45 46 func makeExtensionFieldInfo(xd protoreflect.ExtensionDescriptor) *extensionFieldInfo { 47 var wiretag uint64 48 if !xd.IsPacked() { 49 wiretag = protowire.EncodeTag(xd.Number(), wireTypes[xd.Kind()]) 50 } else { 51 wiretag = protowire.EncodeTag(xd.Number(), protowire.BytesType) 52 } 53 e := &extensionFieldInfo{ 54 wiretag: wiretag, 55 tagsize: protowire.SizeVarint(wiretag), 56 funcs: encoderFuncsForValue(xd), 57 } 58 // Does the unmarshal function need a value passed to it? 59 // This is true for composite types, where we pass in a message, list, or map to fill in, 60 // and for enums, where we pass in a prototype value to specify the concrete enum type. 61 switch xd.Kind() { 62 case protoreflect.MessageKind, protoreflect.GroupKind, protoreflect.EnumKind: 63 e.unmarshalNeedsValue = true 64 default: 65 if xd.Cardinality() == protoreflect.Repeated { 66 e.unmarshalNeedsValue = true 67 } 68 } 69 return e 70 } 71 72 type lazyExtensionValue struct { 73 atomicOnce uint32 // atomically set if value is valid 74 mu sync.Mutex 75 xi *extensionFieldInfo 76 value protoreflect.Value 77 b []byte 78 fn func() protoreflect.Value 79 } 80 81 type ExtensionField struct { 82 typ protoreflect.ExtensionType 83 84 // value is either the value of GetValue, 85 // or a *lazyExtensionValue that then returns the value of GetValue. 86 value protoreflect.Value 87 lazy *lazyExtensionValue 88 } 89 90 func (f *ExtensionField) appendLazyBytes(xt protoreflect.ExtensionType, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, b []byte) { 91 if f.lazy == nil { 92 f.lazy = &lazyExtensionValue{xi: xi} 93 } 94 f.typ = xt 95 f.lazy.xi = xi 96 f.lazy.b = protowire.AppendTag(f.lazy.b, num, wtyp) 97 f.lazy.b = append(f.lazy.b, b...) 98 } 99 100 func (f *ExtensionField) canLazy(xt protoreflect.ExtensionType) bool { 101 if f.typ == nil { 102 return true 103 } 104 if f.typ == xt && f.lazy != nil && atomic.LoadUint32(&f.lazy.atomicOnce) == 0 { 105 return true 106 } 107 return false 108 } 109 110 func (f *ExtensionField) lazyInit() { 111 f.lazy.mu.Lock() 112 defer f.lazy.mu.Unlock() 113 if atomic.LoadUint32(&f.lazy.atomicOnce) == 1 { 114 return 115 } 116 if f.lazy.xi != nil { 117 b := f.lazy.b 118 val := f.typ.New() 119 for len(b) > 0 { 120 var tag uint64 121 if b[0] < 0x80 { 122 tag = uint64(b[0]) 123 b = b[1:] 124 } else if len(b) >= 2 && b[1] < 128 { 125 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7 126 b = b[2:] 127 } else { 128 var n int 129 tag, n = protowire.ConsumeVarint(b) 130 if n < 0 { 131 panic(errors.New("bad tag in lazy extension decoding")) 132 } 133 b = b[n:] 134 } 135 num := protowire.Number(tag >> 3) 136 wtyp := protowire.Type(tag & 7) 137 var out unmarshalOutput 138 var err error 139 val, out, err = f.lazy.xi.funcs.unmarshal(b, val, num, wtyp, lazyUnmarshalOptions) 140 if err != nil { 141 panic(errors.New("decode failure in lazy extension decoding: %v", err)) 142 } 143 b = b[out.n:] 144 } 145 f.lazy.value = val 146 } else { 147 f.lazy.value = f.lazy.fn() 148 } 149 f.lazy.xi = nil 150 f.lazy.fn = nil 151 f.lazy.b = nil 152 atomic.StoreUint32(&f.lazy.atomicOnce, 1) 153 } 154 155 // Set sets the type and value of the extension field. 156 // This must not be called concurrently. 157 func (f *ExtensionField) Set(t protoreflect.ExtensionType, v protoreflect.Value) { 158 f.typ = t 159 f.value = v 160 f.lazy = nil 161 } 162 163 // SetLazy sets the type and a value that is to be lazily evaluated upon first use. 164 // This must not be called concurrently. 165 func (f *ExtensionField) SetLazy(t protoreflect.ExtensionType, fn func() protoreflect.Value) { 166 f.typ = t 167 f.lazy = &lazyExtensionValue{fn: fn} 168 } 169 170 // Value returns the value of the extension field. 171 // This may be called concurrently. 172 func (f *ExtensionField) Value() protoreflect.Value { 173 if f.lazy != nil { 174 if atomic.LoadUint32(&f.lazy.atomicOnce) == 0 { 175 f.lazyInit() 176 } 177 return f.lazy.value 178 } 179 return f.value 180 } 181 182 // Type returns the type of the extension field. 183 // This may be called concurrently. 184 func (f ExtensionField) Type() protoreflect.ExtensionType { 185 return f.typ 186 } 187 188 // IsSet returns whether the extension field is set. 189 // This may be called concurrently. 190 func (f ExtensionField) IsSet() bool { 191 return f.typ != nil 192 } 193 194 // IsLazy reports whether a field is lazily encoded. 195 // It is exported for testing. 196 func IsLazy(m protoreflect.Message, fd protoreflect.FieldDescriptor) bool { 197 var mi *MessageInfo 198 var p pointer 199 switch m := m.(type) { 200 case *messageState: 201 mi = m.messageInfo() 202 p = m.pointer() 203 case *messageReflectWrapper: 204 mi = m.messageInfo() 205 p = m.pointer() 206 default: 207 return false 208 } 209 xd, ok := fd.(protoreflect.ExtensionTypeDescriptor) 210 if !ok { 211 return false 212 } 213 xt := xd.Type() 214 ext := mi.extensionMap(p) 215 if ext == nil { 216 return false 217 } 218 f, ok := (*ext)[int32(fd.Number())] 219 if !ok { 220 return false 221 } 222 return f.typ == xt && f.lazy != nil && atomic.LoadUint32(&f.lazy.atomicOnce) == 0 223 }