gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

query.go (10645B)


      1 package runtime
      2 
      3 import (
      4 	"encoding/base64"
      5 	"errors"
      6 	"fmt"
      7 	"net/url"
      8 	"regexp"
      9 	"strconv"
     10 	"strings"
     11 	"time"
     12 
     13 	"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
     14 	"google.golang.org/genproto/protobuf/field_mask"
     15 	"google.golang.org/grpc/grpclog"
     16 	"google.golang.org/protobuf/proto"
     17 	"google.golang.org/protobuf/reflect/protoreflect"
     18 	"google.golang.org/protobuf/reflect/protoregistry"
     19 	"google.golang.org/protobuf/types/known/durationpb"
     20 	"google.golang.org/protobuf/types/known/timestamppb"
     21 	"google.golang.org/protobuf/types/known/wrapperspb"
     22 )
     23 
     24 var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`)
     25 
     26 var currentQueryParser QueryParameterParser = &defaultQueryParser{}
     27 
     28 // QueryParameterParser defines interface for all query parameter parsers
     29 type QueryParameterParser interface {
     30 	Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error
     31 }
     32 
     33 // PopulateQueryParameters parses query parameters
     34 // into "msg" using current query parser
     35 func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
     36 	return currentQueryParser.Parse(msg, values, filter)
     37 }
     38 
     39 type defaultQueryParser struct{}
     40 
     41 // Parse populates "values" into "msg".
     42 // A value is ignored if its key starts with one of the elements in "filter".
     43 func (*defaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
     44 	for key, values := range values {
     45 		match := valuesKeyRegexp.FindStringSubmatch(key)
     46 		if len(match) == 3 {
     47 			key = match[1]
     48 			values = append([]string{match[2]}, values...)
     49 		}
     50 		fieldPath := strings.Split(key, ".")
     51 		if filter.HasCommonPrefix(fieldPath) {
     52 			continue
     53 		}
     54 		if err := populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, values); err != nil {
     55 			return err
     56 		}
     57 	}
     58 	return nil
     59 }
     60 
     61 // PopulateFieldFromPath sets a value in a nested Protobuf structure.
     62 func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
     63 	fieldPath := strings.Split(fieldPathString, ".")
     64 	return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value})
     65 }
     66 
     67 func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error {
     68 	if len(fieldPath) < 1 {
     69 		return errors.New("no field path")
     70 	}
     71 	if len(values) < 1 {
     72 		return errors.New("no value provided")
     73 	}
     74 
     75 	var fieldDescriptor protoreflect.FieldDescriptor
     76 	for i, fieldName := range fieldPath {
     77 		fields := msgValue.Descriptor().Fields()
     78 
     79 		// Get field by name
     80 		fieldDescriptor = fields.ByName(protoreflect.Name(fieldName))
     81 		if fieldDescriptor == nil {
     82 			fieldDescriptor = fields.ByJSONName(fieldName)
     83 			if fieldDescriptor == nil {
     84 				// We're not returning an error here because this could just be
     85 				// an extra query parameter that isn't part of the request.
     86 				grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, "."))
     87 				return nil
     88 			}
     89 		}
     90 
     91 		// If this is the last element, we're done
     92 		if i == len(fieldPath)-1 {
     93 			break
     94 		}
     95 
     96 		// Only singular message fields are allowed
     97 		if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated {
     98 			return fmt.Errorf("invalid path: %q is not a message", fieldName)
     99 		}
    100 
    101 		// Get the nested message
    102 		msgValue = msgValue.Mutable(fieldDescriptor).Message()
    103 	}
    104 
    105 	// Check if oneof already set
    106 	if of := fieldDescriptor.ContainingOneof(); of != nil {
    107 		if f := msgValue.WhichOneof(of); f != nil {
    108 			return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
    109 		}
    110 	}
    111 
    112 	switch {
    113 	case fieldDescriptor.IsList():
    114 		return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values)
    115 	case fieldDescriptor.IsMap():
    116 		return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values)
    117 	}
    118 
    119 	if len(values) > 1 {
    120 		return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", "))
    121 	}
    122 
    123 	return populateField(fieldDescriptor, msgValue, values[0])
    124 }
    125 
    126 func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error {
    127 	v, err := parseField(fieldDescriptor, value)
    128 	if err != nil {
    129 		return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err)
    130 	}
    131 
    132 	msgValue.Set(fieldDescriptor, v)
    133 	return nil
    134 }
    135 
    136 func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
    137 	for _, value := range values {
    138 		v, err := parseField(fieldDescriptor, value)
    139 		if err != nil {
    140 			return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err)
    141 		}
    142 		list.Append(v)
    143 	}
    144 
    145 	return nil
    146 }
    147 
    148 func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {
    149 	if len(values) != 2 {
    150 		return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName())
    151 	}
    152 
    153 	key, err := parseField(fieldDescriptor.MapKey(), values[0])
    154 	if err != nil {
    155 		return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err)
    156 	}
    157 
    158 	value, err := parseField(fieldDescriptor.MapValue(), values[1])
    159 	if err != nil {
    160 		return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err)
    161 	}
    162 
    163 	mp.Set(key.MapKey(), value)
    164 
    165 	return nil
    166 }
    167 
    168 func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
    169 	switch fieldDescriptor.Kind() {
    170 	case protoreflect.BoolKind:
    171 		v, err := strconv.ParseBool(value)
    172 		if err != nil {
    173 			return protoreflect.Value{}, err
    174 		}
    175 		return protoreflect.ValueOfBool(v), nil
    176 	case protoreflect.EnumKind:
    177 		enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName())
    178 		switch {
    179 		case errors.Is(err, protoregistry.NotFound):
    180 			return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName())
    181 		case err != nil:
    182 			return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
    183 		}
    184 		// Look for enum by name
    185 		v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
    186 		if v == nil {
    187 			i, err := strconv.Atoi(value)
    188 			if err != nil {
    189 				return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
    190 			}
    191 			// Look for enum by number
    192 			v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i))
    193 			if v == nil {
    194 				return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
    195 			}
    196 		}
    197 		return protoreflect.ValueOfEnum(v.Number()), nil
    198 	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
    199 		v, err := strconv.ParseInt(value, 10, 32)
    200 		if err != nil {
    201 			return protoreflect.Value{}, err
    202 		}
    203 		return protoreflect.ValueOfInt32(int32(v)), nil
    204 	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
    205 		v, err := strconv.ParseInt(value, 10, 64)
    206 		if err != nil {
    207 			return protoreflect.Value{}, err
    208 		}
    209 		return protoreflect.ValueOfInt64(v), nil
    210 	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
    211 		v, err := strconv.ParseUint(value, 10, 32)
    212 		if err != nil {
    213 			return protoreflect.Value{}, err
    214 		}
    215 		return protoreflect.ValueOfUint32(uint32(v)), nil
    216 	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
    217 		v, err := strconv.ParseUint(value, 10, 64)
    218 		if err != nil {
    219 			return protoreflect.Value{}, err
    220 		}
    221 		return protoreflect.ValueOfUint64(v), nil
    222 	case protoreflect.FloatKind:
    223 		v, err := strconv.ParseFloat(value, 32)
    224 		if err != nil {
    225 			return protoreflect.Value{}, err
    226 		}
    227 		return protoreflect.ValueOfFloat32(float32(v)), nil
    228 	case protoreflect.DoubleKind:
    229 		v, err := strconv.ParseFloat(value, 64)
    230 		if err != nil {
    231 			return protoreflect.Value{}, err
    232 		}
    233 		return protoreflect.ValueOfFloat64(v), nil
    234 	case protoreflect.StringKind:
    235 		return protoreflect.ValueOfString(value), nil
    236 	case protoreflect.BytesKind:
    237 		v, err := base64.URLEncoding.DecodeString(value)
    238 		if err != nil {
    239 			return protoreflect.Value{}, err
    240 		}
    241 		return protoreflect.ValueOfBytes(v), nil
    242 	case protoreflect.MessageKind, protoreflect.GroupKind:
    243 		return parseMessage(fieldDescriptor.Message(), value)
    244 	default:
    245 		panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind()))
    246 	}
    247 }
    248 
    249 func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
    250 	var msg proto.Message
    251 	switch msgDescriptor.FullName() {
    252 	case "google.protobuf.Timestamp":
    253 		if value == "null" {
    254 			break
    255 		}
    256 		t, err := time.Parse(time.RFC3339Nano, value)
    257 		if err != nil {
    258 			return protoreflect.Value{}, err
    259 		}
    260 		msg = timestamppb.New(t)
    261 	case "google.protobuf.Duration":
    262 		if value == "null" {
    263 			break
    264 		}
    265 		d, err := time.ParseDuration(value)
    266 		if err != nil {
    267 			return protoreflect.Value{}, err
    268 		}
    269 		msg = durationpb.New(d)
    270 	case "google.protobuf.DoubleValue":
    271 		v, err := strconv.ParseFloat(value, 64)
    272 		if err != nil {
    273 			return protoreflect.Value{}, err
    274 		}
    275 		msg = &wrapperspb.DoubleValue{Value: v}
    276 	case "google.protobuf.FloatValue":
    277 		v, err := strconv.ParseFloat(value, 32)
    278 		if err != nil {
    279 			return protoreflect.Value{}, err
    280 		}
    281 		msg = &wrapperspb.FloatValue{Value: float32(v)}
    282 	case "google.protobuf.Int64Value":
    283 		v, err := strconv.ParseInt(value, 10, 64)
    284 		if err != nil {
    285 			return protoreflect.Value{}, err
    286 		}
    287 		msg = &wrapperspb.Int64Value{Value: v}
    288 	case "google.protobuf.Int32Value":
    289 		v, err := strconv.ParseInt(value, 10, 32)
    290 		if err != nil {
    291 			return protoreflect.Value{}, err
    292 		}
    293 		msg = &wrapperspb.Int32Value{Value: int32(v)}
    294 	case "google.protobuf.UInt64Value":
    295 		v, err := strconv.ParseUint(value, 10, 64)
    296 		if err != nil {
    297 			return protoreflect.Value{}, err
    298 		}
    299 		msg = &wrapperspb.UInt64Value{Value: v}
    300 	case "google.protobuf.UInt32Value":
    301 		v, err := strconv.ParseUint(value, 10, 32)
    302 		if err != nil {
    303 			return protoreflect.Value{}, err
    304 		}
    305 		msg = &wrapperspb.UInt32Value{Value: uint32(v)}
    306 	case "google.protobuf.BoolValue":
    307 		v, err := strconv.ParseBool(value)
    308 		if err != nil {
    309 			return protoreflect.Value{}, err
    310 		}
    311 		msg = &wrapperspb.BoolValue{Value: v}
    312 	case "google.protobuf.StringValue":
    313 		msg = &wrapperspb.StringValue{Value: value}
    314 	case "google.protobuf.BytesValue":
    315 		v, err := base64.URLEncoding.DecodeString(value)
    316 		if err != nil {
    317 			return protoreflect.Value{}, err
    318 		}
    319 		msg = &wrapperspb.BytesValue{Value: v}
    320 	case "google.protobuf.FieldMask":
    321 		fm := &field_mask.FieldMask{}
    322 		fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
    323 		msg = fm
    324 	default:
    325 		return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
    326 	}
    327 
    328 	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
    329 }