query.go (10645B)
1 package runtime 2 3 import ( 4 "encoding/base64" 5 "errors" 6 "fmt" 7 "net/url" 8 "regexp" 9 "strconv" 10 "strings" 11 "time" 12 13 "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" 14 "google.golang.org/genproto/protobuf/field_mask" 15 "google.golang.org/grpc/grpclog" 16 "google.golang.org/protobuf/proto" 17 "google.golang.org/protobuf/reflect/protoreflect" 18 "google.golang.org/protobuf/reflect/protoregistry" 19 "google.golang.org/protobuf/types/known/durationpb" 20 "google.golang.org/protobuf/types/known/timestamppb" 21 "google.golang.org/protobuf/types/known/wrapperspb" 22 ) 23 24 var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`) 25 26 var currentQueryParser QueryParameterParser = &defaultQueryParser{} 27 28 // QueryParameterParser defines interface for all query parameter parsers 29 type QueryParameterParser interface { 30 Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error 31 } 32 33 // PopulateQueryParameters parses query parameters 34 // into "msg" using current query parser 35 func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error { 36 return currentQueryParser.Parse(msg, values, filter) 37 } 38 39 type defaultQueryParser struct{} 40 41 // Parse populates "values" into "msg". 42 // A value is ignored if its key starts with one of the elements in "filter". 43 func (*defaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error { 44 for key, values := range values { 45 match := valuesKeyRegexp.FindStringSubmatch(key) 46 if len(match) == 3 { 47 key = match[1] 48 values = append([]string{match[2]}, values...) 49 } 50 fieldPath := strings.Split(key, ".") 51 if filter.HasCommonPrefix(fieldPath) { 52 continue 53 } 54 if err := populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, values); err != nil { 55 return err 56 } 57 } 58 return nil 59 } 60 61 // PopulateFieldFromPath sets a value in a nested Protobuf structure. 62 func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error { 63 fieldPath := strings.Split(fieldPathString, ".") 64 return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value}) 65 } 66 67 func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error { 68 if len(fieldPath) < 1 { 69 return errors.New("no field path") 70 } 71 if len(values) < 1 { 72 return errors.New("no value provided") 73 } 74 75 var fieldDescriptor protoreflect.FieldDescriptor 76 for i, fieldName := range fieldPath { 77 fields := msgValue.Descriptor().Fields() 78 79 // Get field by name 80 fieldDescriptor = fields.ByName(protoreflect.Name(fieldName)) 81 if fieldDescriptor == nil { 82 fieldDescriptor = fields.ByJSONName(fieldName) 83 if fieldDescriptor == nil { 84 // We're not returning an error here because this could just be 85 // an extra query parameter that isn't part of the request. 86 grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, ".")) 87 return nil 88 } 89 } 90 91 // If this is the last element, we're done 92 if i == len(fieldPath)-1 { 93 break 94 } 95 96 // Only singular message fields are allowed 97 if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated { 98 return fmt.Errorf("invalid path: %q is not a message", fieldName) 99 } 100 101 // Get the nested message 102 msgValue = msgValue.Mutable(fieldDescriptor).Message() 103 } 104 105 // Check if oneof already set 106 if of := fieldDescriptor.ContainingOneof(); of != nil { 107 if f := msgValue.WhichOneof(of); f != nil { 108 return fmt.Errorf("field already set for oneof %q", of.FullName().Name()) 109 } 110 } 111 112 switch { 113 case fieldDescriptor.IsList(): 114 return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values) 115 case fieldDescriptor.IsMap(): 116 return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values) 117 } 118 119 if len(values) > 1 { 120 return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", ")) 121 } 122 123 return populateField(fieldDescriptor, msgValue, values[0]) 124 } 125 126 func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error { 127 v, err := parseField(fieldDescriptor, value) 128 if err != nil { 129 return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err) 130 } 131 132 msgValue.Set(fieldDescriptor, v) 133 return nil 134 } 135 136 func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error { 137 for _, value := range values { 138 v, err := parseField(fieldDescriptor, value) 139 if err != nil { 140 return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err) 141 } 142 list.Append(v) 143 } 144 145 return nil 146 } 147 148 func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error { 149 if len(values) != 2 { 150 return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName()) 151 } 152 153 key, err := parseField(fieldDescriptor.MapKey(), values[0]) 154 if err != nil { 155 return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err) 156 } 157 158 value, err := parseField(fieldDescriptor.MapValue(), values[1]) 159 if err != nil { 160 return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err) 161 } 162 163 mp.Set(key.MapKey(), value) 164 165 return nil 166 } 167 168 func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) { 169 switch fieldDescriptor.Kind() { 170 case protoreflect.BoolKind: 171 v, err := strconv.ParseBool(value) 172 if err != nil { 173 return protoreflect.Value{}, err 174 } 175 return protoreflect.ValueOfBool(v), nil 176 case protoreflect.EnumKind: 177 enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName()) 178 switch { 179 case errors.Is(err, protoregistry.NotFound): 180 return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName()) 181 case err != nil: 182 return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err) 183 } 184 // Look for enum by name 185 v := enum.Descriptor().Values().ByName(protoreflect.Name(value)) 186 if v == nil { 187 i, err := strconv.Atoi(value) 188 if err != nil { 189 return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value) 190 } 191 // Look for enum by number 192 v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i)) 193 if v == nil { 194 return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value) 195 } 196 } 197 return protoreflect.ValueOfEnum(v.Number()), nil 198 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 199 v, err := strconv.ParseInt(value, 10, 32) 200 if err != nil { 201 return protoreflect.Value{}, err 202 } 203 return protoreflect.ValueOfInt32(int32(v)), nil 204 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 205 v, err := strconv.ParseInt(value, 10, 64) 206 if err != nil { 207 return protoreflect.Value{}, err 208 } 209 return protoreflect.ValueOfInt64(v), nil 210 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 211 v, err := strconv.ParseUint(value, 10, 32) 212 if err != nil { 213 return protoreflect.Value{}, err 214 } 215 return protoreflect.ValueOfUint32(uint32(v)), nil 216 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 217 v, err := strconv.ParseUint(value, 10, 64) 218 if err != nil { 219 return protoreflect.Value{}, err 220 } 221 return protoreflect.ValueOfUint64(v), nil 222 case protoreflect.FloatKind: 223 v, err := strconv.ParseFloat(value, 32) 224 if err != nil { 225 return protoreflect.Value{}, err 226 } 227 return protoreflect.ValueOfFloat32(float32(v)), nil 228 case protoreflect.DoubleKind: 229 v, err := strconv.ParseFloat(value, 64) 230 if err != nil { 231 return protoreflect.Value{}, err 232 } 233 return protoreflect.ValueOfFloat64(v), nil 234 case protoreflect.StringKind: 235 return protoreflect.ValueOfString(value), nil 236 case protoreflect.BytesKind: 237 v, err := base64.URLEncoding.DecodeString(value) 238 if err != nil { 239 return protoreflect.Value{}, err 240 } 241 return protoreflect.ValueOfBytes(v), nil 242 case protoreflect.MessageKind, protoreflect.GroupKind: 243 return parseMessage(fieldDescriptor.Message(), value) 244 default: 245 panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind())) 246 } 247 } 248 249 func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) { 250 var msg proto.Message 251 switch msgDescriptor.FullName() { 252 case "google.protobuf.Timestamp": 253 if value == "null" { 254 break 255 } 256 t, err := time.Parse(time.RFC3339Nano, value) 257 if err != nil { 258 return protoreflect.Value{}, err 259 } 260 msg = timestamppb.New(t) 261 case "google.protobuf.Duration": 262 if value == "null" { 263 break 264 } 265 d, err := time.ParseDuration(value) 266 if err != nil { 267 return protoreflect.Value{}, err 268 } 269 msg = durationpb.New(d) 270 case "google.protobuf.DoubleValue": 271 v, err := strconv.ParseFloat(value, 64) 272 if err != nil { 273 return protoreflect.Value{}, err 274 } 275 msg = &wrapperspb.DoubleValue{Value: v} 276 case "google.protobuf.FloatValue": 277 v, err := strconv.ParseFloat(value, 32) 278 if err != nil { 279 return protoreflect.Value{}, err 280 } 281 msg = &wrapperspb.FloatValue{Value: float32(v)} 282 case "google.protobuf.Int64Value": 283 v, err := strconv.ParseInt(value, 10, 64) 284 if err != nil { 285 return protoreflect.Value{}, err 286 } 287 msg = &wrapperspb.Int64Value{Value: v} 288 case "google.protobuf.Int32Value": 289 v, err := strconv.ParseInt(value, 10, 32) 290 if err != nil { 291 return protoreflect.Value{}, err 292 } 293 msg = &wrapperspb.Int32Value{Value: int32(v)} 294 case "google.protobuf.UInt64Value": 295 v, err := strconv.ParseUint(value, 10, 64) 296 if err != nil { 297 return protoreflect.Value{}, err 298 } 299 msg = &wrapperspb.UInt64Value{Value: v} 300 case "google.protobuf.UInt32Value": 301 v, err := strconv.ParseUint(value, 10, 32) 302 if err != nil { 303 return protoreflect.Value{}, err 304 } 305 msg = &wrapperspb.UInt32Value{Value: uint32(v)} 306 case "google.protobuf.BoolValue": 307 v, err := strconv.ParseBool(value) 308 if err != nil { 309 return protoreflect.Value{}, err 310 } 311 msg = &wrapperspb.BoolValue{Value: v} 312 case "google.protobuf.StringValue": 313 msg = &wrapperspb.StringValue{Value: value} 314 case "google.protobuf.BytesValue": 315 v, err := base64.URLEncoding.DecodeString(value) 316 if err != nil { 317 return protoreflect.Value{}, err 318 } 319 msg = &wrapperspb.BytesValue{Value: v} 320 case "google.protobuf.FieldMask": 321 fm := &field_mask.FieldMask{} 322 fm.Paths = append(fm.Paths, strings.Split(value, ",")...) 323 msg = fm 324 default: 325 return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName())) 326 } 327 328 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 329 }