diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 11be5599cd4..ffff382e8f1 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -304,37 +304,92 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err // operateHeader takes action on the decoded headers. func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (fatal bool) { streamID := frame.Header().StreamID - state := &decodeState{ - serverSide: true, - } - if h2code, err := state.decodeHeader(frame); err != nil { - if _, ok := status.FromError(err); ok { - t.controlBuf.put(&cleanupStream{ - streamID: streamID, - rst: true, - rstCode: h2code, - onWrite: func() {}, - }) - } + + // frame.Truncated is set to true when framer detects that the current header + // list size hits MaxHeaderListSize limit. + if frame.Truncated { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeFrameSize, + onWrite: func() {}, + }) return false } buf := newRecvBuffer() s := &Stream{ - id: streamID, - st: t, - buf: buf, - fc: &inFlow{limit: uint32(t.initialWindowSize)}, - recvCompress: state.data.encoding, - method: state.data.method, - contentSubtype: state.data.contentSubtype, + id: streamID, + st: t, + buf: buf, + fc: &inFlow{limit: uint32(t.initialWindowSize)}, + } + + var ( + // If a gRPC Response-Headers has already been received, then it means + // that the peer is speaking gRPC and we are in gRPC mode. + isGRPC = false + mdata = make(map[string][]string) + httpMethod string + // headerError is set if an error is encountered while parsing the headers + headerError bool + + timeoutSet bool + timeout time.Duration + ) + + for _, hf := range frame.Fields { + switch hf.Name { + case "content-type": + contentSubtype, validContentType := grpcutil.ContentSubtype(hf.Value) + if !validContentType { + break + } + mdata[hf.Name] = append(mdata[hf.Name], hf.Value) + s.contentSubtype = contentSubtype + isGRPC = true + case "grpc-encoding": + s.recvCompress = hf.Value + case ":method": + httpMethod = hf.Value + case ":path": + s.method = hf.Value + case "grpc-timeout": + timeoutSet = true + var err error + if timeout, err = decodeTimeout(hf.Value); err != nil { + headerError = true + } + default: + if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) { + break + } + v, err := decodeMetadataHeader(hf.Name, hf.Value) + if err != nil { + headerError = true + logger.Warningf("Failed to decode metadata header (%q, %q): %v", hf.Name, hf.Value, err) + break + } + mdata[hf.Name] = append(mdata[hf.Name], v) + } } + + if !isGRPC || headerError { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeProtocol, + onWrite: func() {}, + }) + return false + } + if frame.StreamEnded() { // s is just created by the caller. No lock needed. s.state = streamReadDone } - if state.data.timeoutSet { - s.ctx, s.cancel = context.WithTimeout(t.ctx, state.data.timeout) + if timeoutSet { + s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout) } else { s.ctx, s.cancel = context.WithCancel(t.ctx) } @@ -347,14 +402,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } s.ctx = peer.NewContext(s.ctx, pr) // Attach the received metadata to the context. - if len(state.data.mdata) > 0 { - s.ctx = metadata.NewIncomingContext(s.ctx, state.data.mdata) - } - if state.data.statsTags != nil { - s.ctx = stats.SetIncomingTags(s.ctx, state.data.statsTags) - } - if state.data.statsTrace != nil { - s.ctx = stats.SetIncomingTrace(s.ctx, state.data.statsTrace) + if len(mdata) > 0 { + s.ctx = metadata.NewIncomingContext(s.ctx, mdata) + if statsTags := mdata["grpc-tags-bin"]; len(statsTags) > 0 { + s.ctx = stats.SetIncomingTags(s.ctx, []byte(statsTags[len(statsTags)-1])) + } + if statsTrace := mdata["grpc-trace-bin"]; len(statsTrace) > 0 { + s.ctx = stats.SetIncomingTrace(s.ctx, []byte(statsTrace[len(statsTrace)-1])) + } } t.mu.Lock() if t.state != reachable { @@ -383,10 +438,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( return true } t.maxStreamID = streamID - if state.data.httpMethod != http.MethodPost { + if httpMethod != http.MethodPost { t.mu.Unlock() if logger.V(logLevel) { - logger.Warningf("transport: http2Server.operateHeaders parsed a :method field: %v which should be POST", state.data.httpMethod) + logger.Infof("transport: http2Server.operateHeaders parsed a :method field: %v which should be POST", httpMethod) } t.controlBuf.put(&cleanupStream{ streamID: streamID, @@ -399,7 +454,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } if t.inTapHandle != nil { var err error - if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: state.data.method}); err != nil { + if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: s.method}); err != nil { t.mu.Unlock() if logger.V(logLevel) { logger.Infof("transport: http2Server.operateHeaders got an error from InTapHandle: %v", err) @@ -437,7 +492,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( LocalAddr: t.localAddr, Compression: s.recvCompress, WireLength: int(frame.Header().Length), - Header: metadata.MD(state.data.mdata).Copy(), + Header: metadata.MD(mdata).Copy(), } t.stats.HandleRPC(s.ctx, inHeader) } diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index 2771e224f7b..15d775fca3c 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -39,7 +39,6 @@ import ( spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/status" ) @@ -96,53 +95,6 @@ var ( logger = grpclog.Component("transport") ) -type parsedHeaderData struct { - encoding string - // statusGen caches the stream status received from the trailer the server - // sent. Client side only. Do not access directly. After all trailers are - // parsed, use the status method to retrieve the status. - statusGen *status.Status - // rawStatusCode and rawStatusMsg are set from the raw trailer fields and are not - // intended for direct access outside of parsing. - rawStatusCode *int - rawStatusMsg string - httpStatus *int - // Server side only fields. - timeoutSet bool - timeout time.Duration - method string - httpMethod string - // key-value metadata map from the peer. - mdata map[string][]string - statsTags []byte - statsTrace []byte - contentSubtype string - - // isGRPC field indicates whether the peer is speaking gRPC (otherwise HTTP). - // - // We are in gRPC mode (peer speaking gRPC) if: - // * We are client side and have already received a HEADER frame that indicates gRPC peer. - // * The header contains valid a content-type, i.e. a string starts with "application/grpc" - // And we should handle error specific to gRPC. - // - // Otherwise (i.e. a content-type string starts without "application/grpc", or does not exist), we - // are in HTTP fallback mode, and should handle error specific to HTTP. - isGRPC bool - grpcErr error - httpErr error - contentTypeErr string -} - -// decodeState configures decoding criteria and records the decoded data. -type decodeState struct { - // whether decoding on server side or not - serverSide bool - - // Records the states during HPACK decoding. It will be filled with info parsed from HTTP HEADERS - // frame once decodeHeader function has been invoked and returned. - data parsedHeaderData -} - // isReservedHeader checks whether hdr belongs to HTTP2 headers // reserved by gRPC protocol. Any other headers are classified as the // user-specified metadata. @@ -221,63 +173,6 @@ func decodeGRPCStatusDetails(rawDetails string) (*status.Status, error) { return status.FromProto(st), nil } -func (d *decodeState) decodeHeader(frame *http2.MetaHeadersFrame) (http2.ErrCode, error) { - // frame.Truncated is set to true when framer detects that the current header - // list size hits MaxHeaderListSize limit. - if frame.Truncated { - return http2.ErrCodeFrameSize, status.Error(codes.Internal, "peer header list size exceeded limit") - } - - for _, hf := range frame.Fields { - d.processHeaderField(hf) - } - - if d.data.isGRPC { - if d.data.grpcErr != nil { - return http2.ErrCodeProtocol, d.data.grpcErr - } - if d.serverSide { - return http2.ErrCodeNo, nil - } - if d.data.rawStatusCode == nil && d.data.statusGen == nil { - // gRPC status doesn't exist. - // Set rawStatusCode to be unknown and return nil error. - // So that, if the stream has ended this Unknown status - // will be propagated to the user. - // Otherwise, it will be ignored. In which case, status from - // a later trailer, that has StreamEnded flag set, is propagated. - code := int(codes.Unknown) - d.data.rawStatusCode = &code - } - return http2.ErrCodeNo, nil - } - - // HTTP fallback mode - if d.data.httpErr != nil { - return http2.ErrCodeProtocol, d.data.httpErr - } - - var ( - code = codes.Internal // when header does not include HTTP status, return INTERNAL - ok bool - ) - - if d.data.httpStatus != nil { - code, ok = HTTPStatusConvTab[*(d.data.httpStatus)] - if !ok { - code = codes.Unknown - } - } - - return http2.ErrCodeProtocol, status.Error(code, d.constructHTTPErrMsg()) -} - -// constructErrMsg constructs error message to be returned in HTTP fallback mode. -// Format: HTTP status code and its corresponding message + content-type error message. -func (d *decodeState) constructHTTPErrMsg() string { - return constructHTTPErrMsg(d.data.httpStatus, d.data.contentTypeErr) -} - // constructErrMsg constructs error message to be returned in HTTP fallback mode. // Format: HTTP status code and its corresponding message + content-type error message. func constructHTTPErrMsg(httpStatus *int, contentTypeErr string) string { @@ -298,99 +193,6 @@ func constructHTTPErrMsg(httpStatus *int, contentTypeErr string) string { return strings.Join(errMsgs, "; ") } -func (d *decodeState) addMetadata(k, v string) { - if d.data.mdata == nil { - d.data.mdata = make(map[string][]string) - } - d.data.mdata[k] = append(d.data.mdata[k], v) -} - -func (d *decodeState) processHeaderField(f hpack.HeaderField) { - switch f.Name { - case "content-type": - contentSubtype, validContentType := grpcutil.ContentSubtype(f.Value) - if !validContentType { - d.data.contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", f.Value) - return - } - d.data.contentSubtype = contentSubtype - // TODO: do we want to propagate the whole content-type in the metadata, - // or come up with a way to just propagate the content-subtype if it was set? - // ie {"content-type": "application/grpc+proto"} or {"content-subtype": "proto"} - // in the metadata? - d.addMetadata(f.Name, f.Value) - d.data.isGRPC = true - case "grpc-encoding": - d.data.encoding = f.Value - case "grpc-status": - code, err := strconv.Atoi(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status: %v", err) - return - } - d.data.rawStatusCode = &code - case "grpc-message": - d.data.rawStatusMsg = decodeGrpcMessage(f.Value) - case "grpc-status-details-bin": - v, err := decodeBinHeader(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) - return - } - s := &spb.Status{} - if err := proto.Unmarshal(v, s); err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-status-details-bin: %v", err) - return - } - d.data.statusGen = status.FromProto(s) - case "grpc-timeout": - d.data.timeoutSet = true - var err error - if d.data.timeout, err = decodeTimeout(f.Value); err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed time-out: %v", err) - } - case ":path": - d.data.method = f.Value - case ":status": - code, err := strconv.Atoi(f.Value) - if err != nil { - d.data.httpErr = status.Errorf(codes.Internal, "transport: malformed http-status: %v", err) - return - } - d.data.httpStatus = &code - case "grpc-tags-bin": - v, err := decodeBinHeader(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-tags-bin: %v", err) - return - } - d.data.statsTags = v - d.addMetadata(f.Name, string(v)) - case "grpc-trace-bin": - v, err := decodeBinHeader(f.Value) - if err != nil { - d.data.grpcErr = status.Errorf(codes.Internal, "transport: malformed grpc-trace-bin: %v", err) - return - } - d.data.statsTrace = v - d.addMetadata(f.Name, string(v)) - case ":method": - d.data.httpMethod = f.Value - default: - if isReservedHeader(f.Name) && !isWhitelistedHeader(f.Name) { - break - } - v, err := decodeMetadataHeader(f.Name, f.Value) - if err != nil { - if logger.V(logLevel) { - logger.Errorf("Failed to decode metadata header (%q, %q): %v", f.Name, f.Value, err) - } - return - } - d.addMetadata(f.Name, v) - } -} - type timeoutUnit uint8 const ( diff --git a/internal/transport/http_util_test.go b/internal/transport/http_util_test.go index 2205050acea..bbd53180471 100644 --- a/internal/transport/http_util_test.go +++ b/internal/transport/http_util_test.go @@ -23,9 +23,6 @@ import ( "reflect" "testing" "time" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/hpack" ) func (s) TestTimeoutDecode(t *testing.T) { @@ -189,68 +186,6 @@ func (s) TestDecodeMetadataHeader(t *testing.T) { } } -func (s) TestDecodeHeaderH2ErrCode(t *testing.T) { - for _, test := range []struct { - name string - // input - metaHeaderFrame *http2.MetaHeadersFrame - serverSide bool - // output - wantCode http2.ErrCode - }{ - { - name: "valid header", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/grpc"}, - }}, - wantCode: http2.ErrCodeNo, - }, - { - name: "valid header serverSide", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/grpc"}, - }}, - serverSide: true, - wantCode: http2.ErrCodeNo, - }, - { - name: "invalid grpc status header field", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/grpc"}, - {Name: "grpc-status", Value: "xxxx"}, - }}, - wantCode: http2.ErrCodeProtocol, - }, - { - name: "invalid http content type", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - {Name: "content-type", Value: "application/json"}, - }}, - wantCode: http2.ErrCodeProtocol, - }, - { - name: "http fallback and invalid http status", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: []hpack.HeaderField{ - // No content type provided then fallback into handling http error. - {Name: ":status", Value: "xxxx"}, - }}, - wantCode: http2.ErrCodeProtocol, - }, - { - name: "http2 frame size exceeds", - metaHeaderFrame: &http2.MetaHeadersFrame{Fields: nil, Truncated: true}, - wantCode: http2.ErrCodeFrameSize, - }, - } { - t.Run(test.name, func(t *testing.T) { - state := &decodeState{serverSide: test.serverSide} - if h2code, _ := state.decodeHeader(test.metaHeaderFrame); h2code != test.wantCode { - t.Fatalf("decodeState.decodeHeader(%v) = %v, want %v", test.metaHeaderFrame, h2code, test.wantCode) - } - }) - } -} - func (s) TestParseDialTarget(t *testing.T) { for _, test := range []struct { target, wantNet, wantAddr string diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index c3830a8fd0b..3aef8627714 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -1976,3 +1976,110 @@ func (s) TestClientHandshakeInfo(t *testing.T) { t.Fatalf("received attributes %v in creds, want %v", gotAttr, wantAttr) } } + +func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { + for _, test := range []struct { + name string + // input + metaHeaderFrame *http2.MetaHeadersFrame + // output + wantStatus *status.Status + }{ + { + name: "valid header", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + {Name: "grpc-status", Value: "0"}, + }, + }, + // no error + wantStatus: status.New(codes.OK, ""), + }, + { + name: "invalid grpc status header field", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + {Name: "grpc-status", Value: "xxxx"}, + }, + }, + wantStatus: status.New( + codes.Internal, + "transport: malformed grpc-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", + ), + }, + { + name: "invalid http content type", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/json"}, + }, + }, + wantStatus: status.New( + codes.Internal, + ": HTTP status code 0; transport: received the unexpected content-type \"application/json\"", + ), + }, + { + name: "http fallback and invalid http status", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + // No content type provided then fallback into handling http error. + {Name: ":status", Value: "xxxx"}, + }, + }, + wantStatus: status.New( + codes.Internal, + "transport: malformed http-status: strconv.ParseInt: parsing \"xxxx\": invalid syntax", + ), + }, + { + name: "http2 frame size exceeds", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: nil, + Truncated: true, + }, + wantStatus: status.New( + codes.Internal, + "peer header list size exceeded limit", + ), + }, + } { + t.Run(test.name, func(t *testing.T) { + ts := &Stream{ + done: make(chan struct{}), + headerChan: make(chan struct{}), + buf: &recvBuffer{ + c: make(chan recvMsg), + mu: sync.Mutex{}, + }, + } + s := &http2Client{ + mu: sync.Mutex{}, + activeStreams: map[uint32]*Stream{ + 0: ts, + }, + controlBuf: &controlBuffer{ + ch: make(chan struct{}), + done: make(chan struct{}), + list: &itemList{}, + }, + } + test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ + FrameHeader: http2.FrameHeader{ + StreamID: 0, + Flags: http2.FlagHeadersEndStream, + }, + } + + s.operateHeaders(test.metaHeaderFrame) + + got := ts.status + want := test.wantStatus + if got.Code() != want.Code() || got.Message() != want.Message() { + t.Fatalf("operateHeaders(%v); status = %v; want %v", test.metaHeaderFrame, got, want) + } + }) + } +}