handler.go (6612B)
1 package runtime 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net/http" 8 "net/textproto" 9 "strings" 10 11 "google.golang.org/genproto/googleapis/api/httpbody" 12 "google.golang.org/grpc/codes" 13 "google.golang.org/grpc/grpclog" 14 "google.golang.org/grpc/status" 15 "google.golang.org/protobuf/proto" 16 ) 17 18 // ForwardResponseStream forwards the stream from gRPC server to REST client. 19 func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { 20 f, ok := w.(http.Flusher) 21 if !ok { 22 grpclog.Infof("Flush not supported in %T", w) 23 http.Error(w, "unexpected type of web server", http.StatusInternalServerError) 24 return 25 } 26 27 md, ok := ServerMetadataFromContext(ctx) 28 if !ok { 29 grpclog.Infof("Failed to extract ServerMetadata from context") 30 http.Error(w, "unexpected error", http.StatusInternalServerError) 31 return 32 } 33 handleForwardResponseServerMetadata(w, mux, md) 34 35 w.Header().Set("Transfer-Encoding", "chunked") 36 if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil { 37 HTTPError(ctx, mux, marshaler, w, req, err) 38 return 39 } 40 41 var delimiter []byte 42 if d, ok := marshaler.(Delimited); ok { 43 delimiter = d.Delimiter() 44 } else { 45 delimiter = []byte("\n") 46 } 47 48 var wroteHeader bool 49 for { 50 resp, err := recv() 51 if err == io.EOF { 52 return 53 } 54 if err != nil { 55 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err) 56 return 57 } 58 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { 59 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err) 60 return 61 } 62 63 if !wroteHeader { 64 w.Header().Set("Content-Type", marshaler.ContentType(resp)) 65 } 66 67 var buf []byte 68 httpBody, isHTTPBody := resp.(*httpbody.HttpBody) 69 switch { 70 case resp == nil: 71 buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response"))) 72 case isHTTPBody: 73 buf = httpBody.GetData() 74 default: 75 result := map[string]interface{}{"result": resp} 76 if rb, ok := resp.(responseBody); ok { 77 result["result"] = rb.XXX_ResponseBody() 78 } 79 80 buf, err = marshaler.Marshal(result) 81 } 82 83 if err != nil { 84 grpclog.Infof("Failed to marshal response chunk: %v", err) 85 handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err) 86 return 87 } 88 if _, err = w.Write(buf); err != nil { 89 grpclog.Infof("Failed to send response chunk: %v", err) 90 return 91 } 92 wroteHeader = true 93 if _, err = w.Write(delimiter); err != nil { 94 grpclog.Infof("Failed to send delimiter chunk: %v", err) 95 return 96 } 97 f.Flush() 98 } 99 } 100 101 func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) { 102 for k, vs := range md.HeaderMD { 103 if h, ok := mux.outgoingHeaderMatcher(k); ok { 104 for _, v := range vs { 105 w.Header().Add(h, v) 106 } 107 } 108 } 109 } 110 111 func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) { 112 for k := range md.TrailerMD { 113 tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)) 114 w.Header().Add("Trailer", tKey) 115 } 116 } 117 118 func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) { 119 for k, vs := range md.TrailerMD { 120 tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k) 121 for _, v := range vs { 122 w.Header().Add(tKey, v) 123 } 124 } 125 } 126 127 // responseBody interface contains method for getting field for marshaling to the response body 128 // this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule` 129 type responseBody interface { 130 XXX_ResponseBody() interface{} 131 } 132 133 // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client. 134 func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { 135 md, ok := ServerMetadataFromContext(ctx) 136 if !ok { 137 grpclog.Infof("Failed to extract ServerMetadata from context") 138 } 139 140 handleForwardResponseServerMetadata(w, mux, md) 141 142 // RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2 143 // Unless the request includes a TE header field indicating "trailers" 144 // is acceptable, as described in Section 4.3, a server SHOULD NOT 145 // generate trailer fields that it believes are necessary for the user 146 // agent to receive. 147 doForwardTrailers := requestAcceptsTrailers(req) 148 149 if doForwardTrailers { 150 handleForwardResponseTrailerHeader(w, md) 151 w.Header().Set("Transfer-Encoding", "chunked") 152 } 153 154 handleForwardResponseTrailerHeader(w, md) 155 156 contentType := marshaler.ContentType(resp) 157 w.Header().Set("Content-Type", contentType) 158 159 if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { 160 HTTPError(ctx, mux, marshaler, w, req, err) 161 return 162 } 163 var buf []byte 164 var err error 165 if rb, ok := resp.(responseBody); ok { 166 buf, err = marshaler.Marshal(rb.XXX_ResponseBody()) 167 } else { 168 buf, err = marshaler.Marshal(resp) 169 } 170 if err != nil { 171 grpclog.Infof("Marshal error: %v", err) 172 HTTPError(ctx, mux, marshaler, w, req, err) 173 return 174 } 175 176 if _, err = w.Write(buf); err != nil { 177 grpclog.Infof("Failed to write response: %v", err) 178 } 179 180 if doForwardTrailers { 181 handleForwardResponseTrailer(w, md) 182 } 183 } 184 185 func requestAcceptsTrailers(req *http.Request) bool { 186 te := req.Header.Get("TE") 187 return strings.Contains(strings.ToLower(te), "trailers") 188 } 189 190 func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error { 191 if len(opts) == 0 { 192 return nil 193 } 194 for _, opt := range opts { 195 if err := opt(ctx, w, resp); err != nil { 196 grpclog.Infof("Error handling ForwardResponseOptions: %v", err) 197 return err 198 } 199 } 200 return nil 201 } 202 203 func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) { 204 st := mux.streamErrorHandler(ctx, err) 205 msg := errorChunk(st) 206 if !wroteHeader { 207 w.Header().Set("Content-Type", marshaler.ContentType(msg)) 208 w.WriteHeader(HTTPStatusFromCode(st.Code())) 209 } 210 buf, merr := marshaler.Marshal(msg) 211 if merr != nil { 212 grpclog.Infof("Failed to marshal an error: %v", merr) 213 return 214 } 215 if _, werr := w.Write(buf); werr != nil { 216 grpclog.Infof("Failed to notify error to client: %v", werr) 217 return 218 } 219 } 220 221 func errorChunk(st *status.Status) map[string]proto.Message { 222 return map[string]proto.Message{"error": st.Proto()} 223 }