gtsocial-umbx

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README | LICENSE

handler.go (6612B)


      1 package runtime
      2 
      3 import (
      4 	"context"
      5 	"fmt"
      6 	"io"
      7 	"net/http"
      8 	"net/textproto"
      9 	"strings"
     10 
     11 	"google.golang.org/genproto/googleapis/api/httpbody"
     12 	"google.golang.org/grpc/codes"
     13 	"google.golang.org/grpc/grpclog"
     14 	"google.golang.org/grpc/status"
     15 	"google.golang.org/protobuf/proto"
     16 )
     17 
     18 // ForwardResponseStream forwards the stream from gRPC server to REST client.
     19 func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
     20 	f, ok := w.(http.Flusher)
     21 	if !ok {
     22 		grpclog.Infof("Flush not supported in %T", w)
     23 		http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
     24 		return
     25 	}
     26 
     27 	md, ok := ServerMetadataFromContext(ctx)
     28 	if !ok {
     29 		grpclog.Infof("Failed to extract ServerMetadata from context")
     30 		http.Error(w, "unexpected error", http.StatusInternalServerError)
     31 		return
     32 	}
     33 	handleForwardResponseServerMetadata(w, mux, md)
     34 
     35 	w.Header().Set("Transfer-Encoding", "chunked")
     36 	if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
     37 		HTTPError(ctx, mux, marshaler, w, req, err)
     38 		return
     39 	}
     40 
     41 	var delimiter []byte
     42 	if d, ok := marshaler.(Delimited); ok {
     43 		delimiter = d.Delimiter()
     44 	} else {
     45 		delimiter = []byte("\n")
     46 	}
     47 
     48 	var wroteHeader bool
     49 	for {
     50 		resp, err := recv()
     51 		if err == io.EOF {
     52 			return
     53 		}
     54 		if err != nil {
     55 			handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
     56 			return
     57 		}
     58 		if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
     59 			handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
     60 			return
     61 		}
     62 
     63 		if !wroteHeader {
     64 			w.Header().Set("Content-Type", marshaler.ContentType(resp))
     65 		}
     66 
     67 		var buf []byte
     68 		httpBody, isHTTPBody := resp.(*httpbody.HttpBody)
     69 		switch {
     70 		case resp == nil:
     71 			buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
     72 		case isHTTPBody:
     73 			buf = httpBody.GetData()
     74 		default:
     75 			result := map[string]interface{}{"result": resp}
     76 			if rb, ok := resp.(responseBody); ok {
     77 				result["result"] = rb.XXX_ResponseBody()
     78 			}
     79 
     80 			buf, err = marshaler.Marshal(result)
     81 		}
     82 
     83 		if err != nil {
     84 			grpclog.Infof("Failed to marshal response chunk: %v", err)
     85 			handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err)
     86 			return
     87 		}
     88 		if _, err = w.Write(buf); err != nil {
     89 			grpclog.Infof("Failed to send response chunk: %v", err)
     90 			return
     91 		}
     92 		wroteHeader = true
     93 		if _, err = w.Write(delimiter); err != nil {
     94 			grpclog.Infof("Failed to send delimiter chunk: %v", err)
     95 			return
     96 		}
     97 		f.Flush()
     98 	}
     99 }
    100 
    101 func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
    102 	for k, vs := range md.HeaderMD {
    103 		if h, ok := mux.outgoingHeaderMatcher(k); ok {
    104 			for _, v := range vs {
    105 				w.Header().Add(h, v)
    106 			}
    107 		}
    108 	}
    109 }
    110 
    111 func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
    112 	for k := range md.TrailerMD {
    113 		tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
    114 		w.Header().Add("Trailer", tKey)
    115 	}
    116 }
    117 
    118 func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
    119 	for k, vs := range md.TrailerMD {
    120 		tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
    121 		for _, v := range vs {
    122 			w.Header().Add(tKey, v)
    123 		}
    124 	}
    125 }
    126 
    127 // responseBody interface contains method for getting field for marshaling to the response body
    128 // this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
    129 type responseBody interface {
    130 	XXX_ResponseBody() interface{}
    131 }
    132 
    133 // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
    134 func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
    135 	md, ok := ServerMetadataFromContext(ctx)
    136 	if !ok {
    137 		grpclog.Infof("Failed to extract ServerMetadata from context")
    138 	}
    139 
    140 	handleForwardResponseServerMetadata(w, mux, md)
    141 
    142 	// RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
    143 	// Unless the request includes a TE header field indicating "trailers"
    144 	// is acceptable, as described in Section 4.3, a server SHOULD NOT
    145 	// generate trailer fields that it believes are necessary for the user
    146 	// agent to receive.
    147 	doForwardTrailers := requestAcceptsTrailers(req)
    148 
    149 	if doForwardTrailers {
    150 		handleForwardResponseTrailerHeader(w, md)
    151 		w.Header().Set("Transfer-Encoding", "chunked")
    152 	}
    153 
    154 	handleForwardResponseTrailerHeader(w, md)
    155 
    156 	contentType := marshaler.ContentType(resp)
    157 	w.Header().Set("Content-Type", contentType)
    158 
    159 	if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
    160 		HTTPError(ctx, mux, marshaler, w, req, err)
    161 		return
    162 	}
    163 	var buf []byte
    164 	var err error
    165 	if rb, ok := resp.(responseBody); ok {
    166 		buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
    167 	} else {
    168 		buf, err = marshaler.Marshal(resp)
    169 	}
    170 	if err != nil {
    171 		grpclog.Infof("Marshal error: %v", err)
    172 		HTTPError(ctx, mux, marshaler, w, req, err)
    173 		return
    174 	}
    175 
    176 	if _, err = w.Write(buf); err != nil {
    177 		grpclog.Infof("Failed to write response: %v", err)
    178 	}
    179 
    180 	if doForwardTrailers {
    181 		handleForwardResponseTrailer(w, md)
    182 	}
    183 }
    184 
    185 func requestAcceptsTrailers(req *http.Request) bool {
    186 	te := req.Header.Get("TE")
    187 	return strings.Contains(strings.ToLower(te), "trailers")
    188 }
    189 
    190 func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
    191 	if len(opts) == 0 {
    192 		return nil
    193 	}
    194 	for _, opt := range opts {
    195 		if err := opt(ctx, w, resp); err != nil {
    196 			grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
    197 			return err
    198 		}
    199 	}
    200 	return nil
    201 }
    202 
    203 func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error) {
    204 	st := mux.streamErrorHandler(ctx, err)
    205 	msg := errorChunk(st)
    206 	if !wroteHeader {
    207 		w.Header().Set("Content-Type", marshaler.ContentType(msg))
    208 		w.WriteHeader(HTTPStatusFromCode(st.Code()))
    209 	}
    210 	buf, merr := marshaler.Marshal(msg)
    211 	if merr != nil {
    212 		grpclog.Infof("Failed to marshal an error: %v", merr)
    213 		return
    214 	}
    215 	if _, werr := w.Write(buf); werr != nil {
    216 		grpclog.Infof("Failed to notify error to client: %v", werr)
    217 		return
    218 	}
    219 }
    220 
    221 func errorChunk(st *status.Status) map[string]proto.Message {
    222 	return map[string]proto.Message{"error": st.Proto()}
    223 }