gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

mux.go (13397B)


      1 package runtime
      2 
      3 import (
      4 	"context"
      5 	"errors"
      6 	"fmt"
      7 	"net/http"
      8 	"net/textproto"
      9 	"strings"
     10 
     11 	"github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule"
     12 	"google.golang.org/grpc/codes"
     13 	"google.golang.org/grpc/metadata"
     14 	"google.golang.org/grpc/status"
     15 	"google.golang.org/protobuf/proto"
     16 )
     17 
     18 // UnescapingMode defines the behavior of ServeMux when unescaping path parameters.
     19 type UnescapingMode int
     20 
     21 const (
     22 	// UnescapingModeLegacy is the default V2 behavior, which escapes the entire
     23 	// path string before doing any routing.
     24 	UnescapingModeLegacy UnescapingMode = iota
     25 
     26 	// EscapingTypeExceptReserved unescapes all path parameters except RFC 6570
     27 	// reserved characters.
     28 	UnescapingModeAllExceptReserved
     29 
     30 	// EscapingTypeExceptSlash unescapes URL path parameters except path
     31 	// seperators, which will be left as "%2F".
     32 	UnescapingModeAllExceptSlash
     33 
     34 	// URL path parameters will be fully decoded.
     35 	UnescapingModeAllCharacters
     36 
     37 	// UnescapingModeDefault is the default escaping type.
     38 	// TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's
     39 	// reference implementation
     40 	UnescapingModeDefault = UnescapingModeLegacy
     41 )
     42 
     43 // A HandlerFunc handles a specific pair of path pattern and HTTP method.
     44 type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
     45 
     46 // ServeMux is a request multiplexer for grpc-gateway.
     47 // It matches http requests to patterns and invokes the corresponding handler.
     48 type ServeMux struct {
     49 	// handlers maps HTTP method to a list of handlers.
     50 	handlers                  map[string][]handler
     51 	forwardResponseOptions    []func(context.Context, http.ResponseWriter, proto.Message) error
     52 	marshalers                marshalerRegistry
     53 	incomingHeaderMatcher     HeaderMatcherFunc
     54 	outgoingHeaderMatcher     HeaderMatcherFunc
     55 	metadataAnnotators        []func(context.Context, *http.Request) metadata.MD
     56 	errorHandler              ErrorHandlerFunc
     57 	streamErrorHandler        StreamErrorHandlerFunc
     58 	routingErrorHandler       RoutingErrorHandlerFunc
     59 	disablePathLengthFallback bool
     60 	unescapingMode            UnescapingMode
     61 }
     62 
     63 // ServeMuxOption is an option that can be given to a ServeMux on construction.
     64 type ServeMuxOption func(*ServeMux)
     65 
     66 // WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
     67 //
     68 // forwardResponseOption is an option that will be called on the relevant context.Context,
     69 // http.ResponseWriter, and proto.Message before every forwarded response.
     70 //
     71 // The message may be nil in the case where just a header is being sent.
     72 func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
     73 	return func(serveMux *ServeMux) {
     74 		serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
     75 	}
     76 }
     77 
     78 // WithEscapingType sets the escaping type. See the definitions of UnescapingMode
     79 // for more information.
     80 func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
     81 	return func(serveMux *ServeMux) {
     82 		serveMux.unescapingMode = mode
     83 	}
     84 }
     85 
     86 // SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
     87 // Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
     88 // done with careful consideration.
     89 func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption {
     90 	return func(serveMux *ServeMux) {
     91 		currentQueryParser = queryParameterParser
     92 	}
     93 }
     94 
     95 // HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
     96 type HeaderMatcherFunc func(string) (string, bool)
     97 
     98 // DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
     99 // keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with
    100 // 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'.
    101 func DefaultHeaderMatcher(key string) (string, bool) {
    102 	key = textproto.CanonicalMIMEHeaderKey(key)
    103 	if isPermanentHTTPHeader(key) {
    104 		return MetadataPrefix + key, true
    105 	} else if strings.HasPrefix(key, MetadataHeaderPrefix) {
    106 		return key[len(MetadataHeaderPrefix):], true
    107 	}
    108 	return "", false
    109 }
    110 
    111 // WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
    112 //
    113 // This matcher will be called with each header in http.Request. If matcher returns true, that header will be
    114 // passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
    115 func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
    116 	return func(mux *ServeMux) {
    117 		mux.incomingHeaderMatcher = fn
    118 	}
    119 }
    120 
    121 // WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
    122 //
    123 // This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
    124 // passed to http response returned from gateway. To transform the header before passing to response,
    125 // matcher should return modified header.
    126 func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
    127 	return func(mux *ServeMux) {
    128 		mux.outgoingHeaderMatcher = fn
    129 	}
    130 }
    131 
    132 // WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
    133 //
    134 // This can be used by services that need to read from http.Request and modify gRPC context. A common use case
    135 // is reading token from cookie and adding it in gRPC context.
    136 func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
    137 	return func(serveMux *ServeMux) {
    138 		serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
    139 	}
    140 }
    141 
    142 // WithErrorHandler returns a ServeMuxOption for configuring a custom error handler.
    143 //
    144 // This can be used to configure a custom error response.
    145 func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption {
    146 	return func(serveMux *ServeMux) {
    147 		serveMux.errorHandler = fn
    148 	}
    149 }
    150 
    151 // WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
    152 // error handler, which allows for customizing the error trailer for server-streaming
    153 // calls.
    154 //
    155 // For stream errors that occur before any response has been written, the mux's
    156 // ErrorHandler will be invoked. However, once data has been written, the errors must
    157 // be handled differently: they must be included in the response body. The response body's
    158 // final message will include the error details returned by the stream error handler.
    159 func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
    160 	return func(serveMux *ServeMux) {
    161 		serveMux.streamErrorHandler = fn
    162 	}
    163 }
    164 
    165 // WithRoutingErrorHandler returns a ServeMuxOption for configuring a custom error handler to  handle http routing errors.
    166 //
    167 // Method called for errors which can happen before gRPC route selected or executed.
    168 // The following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest
    169 func WithRoutingErrorHandler(fn RoutingErrorHandlerFunc) ServeMuxOption {
    170 	return func(serveMux *ServeMux) {
    171 		serveMux.routingErrorHandler = fn
    172 	}
    173 }
    174 
    175 // WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
    176 func WithDisablePathLengthFallback() ServeMuxOption {
    177 	return func(serveMux *ServeMux) {
    178 		serveMux.disablePathLengthFallback = true
    179 	}
    180 }
    181 
    182 // NewServeMux returns a new ServeMux whose internal mapping is empty.
    183 func NewServeMux(opts ...ServeMuxOption) *ServeMux {
    184 	serveMux := &ServeMux{
    185 		handlers:               make(map[string][]handler),
    186 		forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
    187 		marshalers:             makeMarshalerMIMERegistry(),
    188 		errorHandler:           DefaultHTTPErrorHandler,
    189 		streamErrorHandler:     DefaultStreamErrorHandler,
    190 		routingErrorHandler:    DefaultRoutingErrorHandler,
    191 		unescapingMode:         UnescapingModeDefault,
    192 	}
    193 
    194 	for _, opt := range opts {
    195 		opt(serveMux)
    196 	}
    197 
    198 	if serveMux.incomingHeaderMatcher == nil {
    199 		serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
    200 	}
    201 
    202 	if serveMux.outgoingHeaderMatcher == nil {
    203 		serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
    204 			return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
    205 		}
    206 	}
    207 
    208 	return serveMux
    209 }
    210 
    211 // Handle associates "h" to the pair of HTTP method and path pattern.
    212 func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
    213 	s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...)
    214 }
    215 
    216 // HandlePath allows users to configure custom path handlers.
    217 // refer: https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/inject_router/
    218 func (s *ServeMux) HandlePath(meth string, pathPattern string, h HandlerFunc) error {
    219 	compiler, err := httprule.Parse(pathPattern)
    220 	if err != nil {
    221 		return fmt.Errorf("parsing path pattern: %w", err)
    222 	}
    223 	tp := compiler.Compile()
    224 	pattern, err := NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb)
    225 	if err != nil {
    226 		return fmt.Errorf("creating new pattern: %w", err)
    227 	}
    228 	s.Handle(meth, pattern, h)
    229 	return nil
    230 }
    231 
    232 // ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path.
    233 func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    234 	ctx := r.Context()
    235 
    236 	path := r.URL.Path
    237 	if !strings.HasPrefix(path, "/") {
    238 		_, outboundMarshaler := MarshalerForRequest(s, r)
    239 		s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest)
    240 		return
    241 	}
    242 
    243 	// TODO(v3): remove UnescapingModeLegacy
    244 	if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" {
    245 		path = r.URL.RawPath
    246 	}
    247 
    248 	components := strings.Split(path[1:], "/")
    249 
    250 	if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
    251 		r.Method = strings.ToUpper(override)
    252 		if err := r.ParseForm(); err != nil {
    253 			_, outboundMarshaler := MarshalerForRequest(s, r)
    254 			sterr := status.Error(codes.InvalidArgument, err.Error())
    255 			s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
    256 			return
    257 		}
    258 	}
    259 
    260 	// Verb out here is to memoize for the fallback case below
    261 	var verb string
    262 
    263 	for _, h := range s.handlers[r.Method] {
    264 		// If the pattern has a verb, explicitly look for a suffix in the last
    265 		// component that matches a colon plus the verb. This allows us to
    266 		// handle some cases that otherwise can't be correctly handled by the
    267 		// former LastIndex case, such as when the verb literal itself contains
    268 		// a colon. This should work for all cases that have run through the
    269 		// parser because we know what verb we're looking for, however, there
    270 		// are still some cases that the parser itself cannot disambiguate. See
    271 		// the comment there if interested.
    272 		patVerb := h.pat.Verb()
    273 		l := len(components)
    274 		lastComponent := components[l-1]
    275 		var idx int = -1
    276 		if patVerb != "" && strings.HasSuffix(lastComponent, ":"+patVerb) {
    277 			idx = len(lastComponent) - len(patVerb) - 1
    278 		}
    279 		if idx == 0 {
    280 			_, outboundMarshaler := MarshalerForRequest(s, r)
    281 			s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
    282 			return
    283 		}
    284 		if idx > 0 {
    285 			components[l-1], verb = lastComponent[:idx], lastComponent[idx+1:]
    286 		}
    287 
    288 		pathParams, err := h.pat.MatchAndEscape(components, verb, s.unescapingMode)
    289 		if err != nil {
    290 			var mse MalformedSequenceError
    291 			if ok := errors.As(err, &mse); ok {
    292 				_, outboundMarshaler := MarshalerForRequest(s, r)
    293 				s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
    294 					HTTPStatus: http.StatusBadRequest,
    295 					Err:        mse,
    296 				})
    297 			}
    298 			continue
    299 		}
    300 		h.h(w, r, pathParams)
    301 		return
    302 	}
    303 
    304 	// lookup other methods to handle fallback from GET to POST and
    305 	// to determine if it is NotImplemented or NotFound.
    306 	for m, handlers := range s.handlers {
    307 		if m == r.Method {
    308 			continue
    309 		}
    310 		for _, h := range handlers {
    311 			pathParams, err := h.pat.MatchAndEscape(components, verb, s.unescapingMode)
    312 			if err != nil {
    313 				var mse MalformedSequenceError
    314 				if ok := errors.As(err, &mse); ok {
    315 					_, outboundMarshaler := MarshalerForRequest(s, r)
    316 					s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
    317 						HTTPStatus: http.StatusBadRequest,
    318 						Err:        mse,
    319 					})
    320 				}
    321 				continue
    322 			}
    323 			// X-HTTP-Method-Override is optional. Always allow fallback to POST.
    324 			if s.isPathLengthFallback(r) {
    325 				if err := r.ParseForm(); err != nil {
    326 					_, outboundMarshaler := MarshalerForRequest(s, r)
    327 					sterr := status.Error(codes.InvalidArgument, err.Error())
    328 					s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
    329 					return
    330 				}
    331 				h.h(w, r, pathParams)
    332 				return
    333 			}
    334 			_, outboundMarshaler := MarshalerForRequest(s, r)
    335 			s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusMethodNotAllowed)
    336 			return
    337 		}
    338 	}
    339 
    340 	_, outboundMarshaler := MarshalerForRequest(s, r)
    341 	s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
    342 }
    343 
    344 // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
    345 func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
    346 	return s.forwardResponseOptions
    347 }
    348 
    349 func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
    350 	return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
    351 }
    352 
    353 type handler struct {
    354 	pat Pattern
    355 	h   HandlerFunc
    356 }