checkinit.go (3573B)
1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package impl 6 7 import ( 8 "sync" 9 10 "google.golang.org/protobuf/internal/errors" 11 "google.golang.org/protobuf/reflect/protoreflect" 12 "google.golang.org/protobuf/runtime/protoiface" 13 ) 14 15 func (mi *MessageInfo) checkInitialized(in protoiface.CheckInitializedInput) (protoiface.CheckInitializedOutput, error) { 16 var p pointer 17 if ms, ok := in.Message.(*messageState); ok { 18 p = ms.pointer() 19 } else { 20 p = in.Message.(*messageReflectWrapper).pointer() 21 } 22 return protoiface.CheckInitializedOutput{}, mi.checkInitializedPointer(p) 23 } 24 25 func (mi *MessageInfo) checkInitializedPointer(p pointer) error { 26 mi.init() 27 if !mi.needsInitCheck { 28 return nil 29 } 30 if p.IsNil() { 31 for _, f := range mi.orderedCoderFields { 32 if f.isRequired { 33 return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName())) 34 } 35 } 36 return nil 37 } 38 if mi.extensionOffset.IsValid() { 39 e := p.Apply(mi.extensionOffset).Extensions() 40 if err := mi.isInitExtensions(e); err != nil { 41 return err 42 } 43 } 44 for _, f := range mi.orderedCoderFields { 45 if !f.isRequired && f.funcs.isInit == nil { 46 continue 47 } 48 fptr := p.Apply(f.offset) 49 if f.isPointer && fptr.Elem().IsNil() { 50 if f.isRequired { 51 return errors.RequiredNotSet(string(mi.Desc.Fields().ByNumber(f.num).FullName())) 52 } 53 continue 54 } 55 if f.funcs.isInit == nil { 56 continue 57 } 58 if err := f.funcs.isInit(fptr, f); err != nil { 59 return err 60 } 61 } 62 return nil 63 } 64 65 func (mi *MessageInfo) isInitExtensions(ext *map[int32]ExtensionField) error { 66 if ext == nil { 67 return nil 68 } 69 for _, x := range *ext { 70 ei := getExtensionFieldInfo(x.Type()) 71 if ei.funcs.isInit == nil { 72 continue 73 } 74 v := x.Value() 75 if !v.IsValid() { 76 continue 77 } 78 if err := ei.funcs.isInit(v); err != nil { 79 return err 80 } 81 } 82 return nil 83 } 84 85 var ( 86 needsInitCheckMu sync.Mutex 87 needsInitCheckMap sync.Map 88 ) 89 90 // needsInitCheck reports whether a message needs to be checked for partial initialization. 91 // 92 // It returns true if the message transitively includes any required or extension fields. 93 func needsInitCheck(md protoreflect.MessageDescriptor) bool { 94 if v, ok := needsInitCheckMap.Load(md); ok { 95 if has, ok := v.(bool); ok { 96 return has 97 } 98 } 99 needsInitCheckMu.Lock() 100 defer needsInitCheckMu.Unlock() 101 return needsInitCheckLocked(md) 102 } 103 104 func needsInitCheckLocked(md protoreflect.MessageDescriptor) (has bool) { 105 if v, ok := needsInitCheckMap.Load(md); ok { 106 // If has is true, we've previously determined that this message 107 // needs init checks. 108 // 109 // If has is false, we've previously determined that it can never 110 // be uninitialized. 111 // 112 // If has is not a bool, we've just encountered a cycle in the 113 // message graph. In this case, it is safe to return false: If 114 // the message does have required fields, we'll detect them later 115 // in the graph traversal. 116 has, ok := v.(bool) 117 return ok && has 118 } 119 needsInitCheckMap.Store(md, struct{}{}) // avoid cycles while descending into this message 120 defer func() { 121 needsInitCheckMap.Store(md, has) 122 }() 123 if md.RequiredNumbers().Len() > 0 { 124 return true 125 } 126 if md.ExtensionRanges().Len() > 0 { 127 return true 128 } 129 for i := 0; i < md.Fields().Len(); i++ { 130 fd := md.Fields().Get(i) 131 // Map keys are never messages, so just consider the map value. 132 if fd.IsMap() { 133 fd = fd.MapValue() 134 } 135 fmd := fd.Message() 136 if fmd != nil && needsInitCheckLocked(fmd) { 137 return true 138 } 139 } 140 return false 141 }