mux.go (13397B)
1 package runtime 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net/http" 8 "net/textproto" 9 "strings" 10 11 "github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule" 12 "google.golang.org/grpc/codes" 13 "google.golang.org/grpc/metadata" 14 "google.golang.org/grpc/status" 15 "google.golang.org/protobuf/proto" 16 ) 17 18 // UnescapingMode defines the behavior of ServeMux when unescaping path parameters. 19 type UnescapingMode int 20 21 const ( 22 // UnescapingModeLegacy is the default V2 behavior, which escapes the entire 23 // path string before doing any routing. 24 UnescapingModeLegacy UnescapingMode = iota 25 26 // EscapingTypeExceptReserved unescapes all path parameters except RFC 6570 27 // reserved characters. 28 UnescapingModeAllExceptReserved 29 30 // EscapingTypeExceptSlash unescapes URL path parameters except path 31 // seperators, which will be left as "%2F". 32 UnescapingModeAllExceptSlash 33 34 // URL path parameters will be fully decoded. 35 UnescapingModeAllCharacters 36 37 // UnescapingModeDefault is the default escaping type. 38 // TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's 39 // reference implementation 40 UnescapingModeDefault = UnescapingModeLegacy 41 ) 42 43 // A HandlerFunc handles a specific pair of path pattern and HTTP method. 44 type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) 45 46 // ServeMux is a request multiplexer for grpc-gateway. 47 // It matches http requests to patterns and invokes the corresponding handler. 48 type ServeMux struct { 49 // handlers maps HTTP method to a list of handlers. 50 handlers map[string][]handler 51 forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error 52 marshalers marshalerRegistry 53 incomingHeaderMatcher HeaderMatcherFunc 54 outgoingHeaderMatcher HeaderMatcherFunc 55 metadataAnnotators []func(context.Context, *http.Request) metadata.MD 56 errorHandler ErrorHandlerFunc 57 streamErrorHandler StreamErrorHandlerFunc 58 routingErrorHandler RoutingErrorHandlerFunc 59 disablePathLengthFallback bool 60 unescapingMode UnescapingMode 61 } 62 63 // ServeMuxOption is an option that can be given to a ServeMux on construction. 64 type ServeMuxOption func(*ServeMux) 65 66 // WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption. 67 // 68 // forwardResponseOption is an option that will be called on the relevant context.Context, 69 // http.ResponseWriter, and proto.Message before every forwarded response. 70 // 71 // The message may be nil in the case where just a header is being sent. 72 func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption { 73 return func(serveMux *ServeMux) { 74 serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption) 75 } 76 } 77 78 // WithEscapingType sets the escaping type. See the definitions of UnescapingMode 79 // for more information. 80 func WithUnescapingMode(mode UnescapingMode) ServeMuxOption { 81 return func(serveMux *ServeMux) { 82 serveMux.unescapingMode = mode 83 } 84 } 85 86 // SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters. 87 // Configuring this will mean the generated OpenAPI output is no longer correct, and it should be 88 // done with careful consideration. 89 func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption { 90 return func(serveMux *ServeMux) { 91 currentQueryParser = queryParameterParser 92 } 93 } 94 95 // HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context. 96 type HeaderMatcherFunc func(string) (string, bool) 97 98 // DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header 99 // keys (as specified by the IANA) to gRPC context with grpcgateway- prefix. HTTP headers that start with 100 // 'Grpc-Metadata-' are mapped to gRPC metadata after removing prefix 'Grpc-Metadata-'. 101 func DefaultHeaderMatcher(key string) (string, bool) { 102 key = textproto.CanonicalMIMEHeaderKey(key) 103 if isPermanentHTTPHeader(key) { 104 return MetadataPrefix + key, true 105 } else if strings.HasPrefix(key, MetadataHeaderPrefix) { 106 return key[len(MetadataHeaderPrefix):], true 107 } 108 return "", false 109 } 110 111 // WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway. 112 // 113 // This matcher will be called with each header in http.Request. If matcher returns true, that header will be 114 // passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header. 115 func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption { 116 return func(mux *ServeMux) { 117 mux.incomingHeaderMatcher = fn 118 } 119 } 120 121 // WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway. 122 // 123 // This matcher will be called with each header in response header metadata. If matcher returns true, that header will be 124 // passed to http response returned from gateway. To transform the header before passing to response, 125 // matcher should return modified header. 126 func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption { 127 return func(mux *ServeMux) { 128 mux.outgoingHeaderMatcher = fn 129 } 130 } 131 132 // WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context. 133 // 134 // This can be used by services that need to read from http.Request and modify gRPC context. A common use case 135 // is reading token from cookie and adding it in gRPC context. 136 func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption { 137 return func(serveMux *ServeMux) { 138 serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator) 139 } 140 } 141 142 // WithErrorHandler returns a ServeMuxOption for configuring a custom error handler. 143 // 144 // This can be used to configure a custom error response. 145 func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption { 146 return func(serveMux *ServeMux) { 147 serveMux.errorHandler = fn 148 } 149 } 150 151 // WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream 152 // error handler, which allows for customizing the error trailer for server-streaming 153 // calls. 154 // 155 // For stream errors that occur before any response has been written, the mux's 156 // ErrorHandler will be invoked. However, once data has been written, the errors must 157 // be handled differently: they must be included in the response body. The response body's 158 // final message will include the error details returned by the stream error handler. 159 func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption { 160 return func(serveMux *ServeMux) { 161 serveMux.streamErrorHandler = fn 162 } 163 } 164 165 // WithRoutingErrorHandler returns a ServeMuxOption for configuring a custom error handler to handle http routing errors. 166 // 167 // Method called for errors which can happen before gRPC route selected or executed. 168 // The following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest 169 func WithRoutingErrorHandler(fn RoutingErrorHandlerFunc) ServeMuxOption { 170 return func(serveMux *ServeMux) { 171 serveMux.routingErrorHandler = fn 172 } 173 } 174 175 // WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback. 176 func WithDisablePathLengthFallback() ServeMuxOption { 177 return func(serveMux *ServeMux) { 178 serveMux.disablePathLengthFallback = true 179 } 180 } 181 182 // NewServeMux returns a new ServeMux whose internal mapping is empty. 183 func NewServeMux(opts ...ServeMuxOption) *ServeMux { 184 serveMux := &ServeMux{ 185 handlers: make(map[string][]handler), 186 forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0), 187 marshalers: makeMarshalerMIMERegistry(), 188 errorHandler: DefaultHTTPErrorHandler, 189 streamErrorHandler: DefaultStreamErrorHandler, 190 routingErrorHandler: DefaultRoutingErrorHandler, 191 unescapingMode: UnescapingModeDefault, 192 } 193 194 for _, opt := range opts { 195 opt(serveMux) 196 } 197 198 if serveMux.incomingHeaderMatcher == nil { 199 serveMux.incomingHeaderMatcher = DefaultHeaderMatcher 200 } 201 202 if serveMux.outgoingHeaderMatcher == nil { 203 serveMux.outgoingHeaderMatcher = func(key string) (string, bool) { 204 return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true 205 } 206 } 207 208 return serveMux 209 } 210 211 // Handle associates "h" to the pair of HTTP method and path pattern. 212 func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) { 213 s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...) 214 } 215 216 // HandlePath allows users to configure custom path handlers. 217 // refer: https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/inject_router/ 218 func (s *ServeMux) HandlePath(meth string, pathPattern string, h HandlerFunc) error { 219 compiler, err := httprule.Parse(pathPattern) 220 if err != nil { 221 return fmt.Errorf("parsing path pattern: %w", err) 222 } 223 tp := compiler.Compile() 224 pattern, err := NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb) 225 if err != nil { 226 return fmt.Errorf("creating new pattern: %w", err) 227 } 228 s.Handle(meth, pattern, h) 229 return nil 230 } 231 232 // ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.Path. 233 func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { 234 ctx := r.Context() 235 236 path := r.URL.Path 237 if !strings.HasPrefix(path, "/") { 238 _, outboundMarshaler := MarshalerForRequest(s, r) 239 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest) 240 return 241 } 242 243 // TODO(v3): remove UnescapingModeLegacy 244 if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" { 245 path = r.URL.RawPath 246 } 247 248 components := strings.Split(path[1:], "/") 249 250 if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) { 251 r.Method = strings.ToUpper(override) 252 if err := r.ParseForm(); err != nil { 253 _, outboundMarshaler := MarshalerForRequest(s, r) 254 sterr := status.Error(codes.InvalidArgument, err.Error()) 255 s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) 256 return 257 } 258 } 259 260 // Verb out here is to memoize for the fallback case below 261 var verb string 262 263 for _, h := range s.handlers[r.Method] { 264 // If the pattern has a verb, explicitly look for a suffix in the last 265 // component that matches a colon plus the verb. This allows us to 266 // handle some cases that otherwise can't be correctly handled by the 267 // former LastIndex case, such as when the verb literal itself contains 268 // a colon. This should work for all cases that have run through the 269 // parser because we know what verb we're looking for, however, there 270 // are still some cases that the parser itself cannot disambiguate. See 271 // the comment there if interested. 272 patVerb := h.pat.Verb() 273 l := len(components) 274 lastComponent := components[l-1] 275 var idx int = -1 276 if patVerb != "" && strings.HasSuffix(lastComponent, ":"+patVerb) { 277 idx = len(lastComponent) - len(patVerb) - 1 278 } 279 if idx == 0 { 280 _, outboundMarshaler := MarshalerForRequest(s, r) 281 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound) 282 return 283 } 284 if idx > 0 { 285 components[l-1], verb = lastComponent[:idx], lastComponent[idx+1:] 286 } 287 288 pathParams, err := h.pat.MatchAndEscape(components, verb, s.unescapingMode) 289 if err != nil { 290 var mse MalformedSequenceError 291 if ok := errors.As(err, &mse); ok { 292 _, outboundMarshaler := MarshalerForRequest(s, r) 293 s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{ 294 HTTPStatus: http.StatusBadRequest, 295 Err: mse, 296 }) 297 } 298 continue 299 } 300 h.h(w, r, pathParams) 301 return 302 } 303 304 // lookup other methods to handle fallback from GET to POST and 305 // to determine if it is NotImplemented or NotFound. 306 for m, handlers := range s.handlers { 307 if m == r.Method { 308 continue 309 } 310 for _, h := range handlers { 311 pathParams, err := h.pat.MatchAndEscape(components, verb, s.unescapingMode) 312 if err != nil { 313 var mse MalformedSequenceError 314 if ok := errors.As(err, &mse); ok { 315 _, outboundMarshaler := MarshalerForRequest(s, r) 316 s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{ 317 HTTPStatus: http.StatusBadRequest, 318 Err: mse, 319 }) 320 } 321 continue 322 } 323 // X-HTTP-Method-Override is optional. Always allow fallback to POST. 324 if s.isPathLengthFallback(r) { 325 if err := r.ParseForm(); err != nil { 326 _, outboundMarshaler := MarshalerForRequest(s, r) 327 sterr := status.Error(codes.InvalidArgument, err.Error()) 328 s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) 329 return 330 } 331 h.h(w, r, pathParams) 332 return 333 } 334 _, outboundMarshaler := MarshalerForRequest(s, r) 335 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusMethodNotAllowed) 336 return 337 } 338 } 339 340 _, outboundMarshaler := MarshalerForRequest(s, r) 341 s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound) 342 } 343 344 // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux. 345 func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error { 346 return s.forwardResponseOptions 347 } 348 349 func (s *ServeMux) isPathLengthFallback(r *http.Request) bool { 350 return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" 351 } 352 353 type handler struct { 354 pat Pattern 355 h HandlerFunc 356 }