Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: fix a few issues where grpc server uses RST_STREAM for non-HTTP/2 errors #5893

Merged
merged 6 commits into from Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/transport/handler_server.go
Expand Up @@ -65,7 +65,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
contentSubtype, validContentType := grpcutil.ContentSubtype(contentType)
if !validContentType {
msg := fmt.Sprintf("invalid gRPC request content-type %q", contentType)
http.Error(w, msg, http.StatusBadRequest)
http.Error(w, msg, http.StatusUnsupportedMediaType)
return nil, errors.New(msg)
}
if _, ok := w.(http.Flusher); !ok {
Expand All @@ -87,7 +87,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := decodeTimeout(v)
if err != nil {
msg := fmt.Sprintf("malformed time-out: %v", err)
msg := fmt.Sprintf("malformed grpc-timeout: %v", err)
http.Error(w, msg, http.StatusBadRequest)
return nil, status.Error(codes.Internal, msg)
}
Expand Down
37 changes: 26 additions & 11 deletions internal/transport/handler_server_test.go
Expand Up @@ -41,11 +41,12 @@ import (

func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
type testCase struct {
name string
req *http.Request
wantErr string
modrw func(http.ResponseWriter) http.ResponseWriter
check func(*serverHandlerTransport, *testCase) error
name string
req *http.Request
wantErr string
wantErrCode int
modrw func(http.ResponseWriter) http.ResponseWriter
check func(*serverHandlerTransport, *testCase) error
}
tests := []testCase{
{
Expand All @@ -54,7 +55,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
ProtoMajor: 1,
ProtoMinor: 1,
},
wantErr: "gRPC requires HTTP/2",
wantErr: "gRPC requires HTTP/2",
wantErrCode: http.StatusBadRequest,
},
{
name: "bad method",
Expand All @@ -63,7 +65,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
Method: "GET",
Header: http.Header{},
},
wantErr: `invalid gRPC request method "GET"`,
wantErr: `invalid gRPC request method "GET"`,
wantErrCode: http.StatusBadRequest,
},
{
name: "bad content type",
Expand All @@ -74,7 +77,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
"Content-Type": {"application/foo"},
},
},
wantErr: `invalid gRPC request content-type "application/foo"`,
wantErr: `invalid gRPC request content-type "application/foo"`,
wantErrCode: http.StatusUnsupportedMediaType,
},
{
name: "not flusher",
Expand All @@ -93,7 +97,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
}
return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
},
wantErr: "gRPC requires a ResponseWriter supporting http.Flusher",
wantErr: "gRPC requires a ResponseWriter supporting http.Flusher",
wantErrCode: http.StatusInternalServerError,
},
{
name: "valid",
Expand Down Expand Up @@ -153,7 +158,8 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
Path: "/service/foo.bar",
},
},
wantErr: `rpc error: code = Internal desc = malformed time-out: transport: timeout unit is not recognized: "tomorrow"`,
wantErr: `rpc error: code = Internal desc = malformed grpc-timeout: transport: timeout unit is not recognized: "tomorrow"`,
wantErrCode: http.StatusBadRequest,
},
{
name: "with metadata",
Expand Down Expand Up @@ -187,7 +193,12 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
}

for _, tt := range tests {
rw := newTestHandlerResponseWriter()
rrec := httptest.NewRecorder()
rw := http.ResponseWriter(testHandlerResponseWriter{
ResponseRecorder: rrec,
closeNotify: make(chan bool, 1),
})

if tt.modrw != nil {
rw = tt.modrw(rw)
}
Expand All @@ -196,6 +207,10 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
continue
}
if tt.wantErr != "" && rrec.Code != tt.wantErrCode {
jhump marked this conversation as resolved.
Show resolved Hide resolved
t.Errorf("%s: code = %d; want %d", tt.name, rrec.Code, tt.wantErrCode)
continue
}
if gotErr != nil {
continue
}
Expand Down
50 changes: 36 additions & 14 deletions internal/transport/http2_server.go
Expand Up @@ -380,13 +380,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
fc: &inFlow{limit: uint32(t.initialWindowSize)},
}
var (
// If a gRPC Response-Headers has already been received, then it means
// that the peer is speaking gRPC and we are in gRPC mode.
isGRPC = false
mdata = make(map[string][]string)
httpMethod string
// headerError is set if an error is encountered while parsing the headers
headerError bool
// if false, content-type was missing or invalid
isGRPC = false
contentType = ""
mdata = make(map[string][]string)
httpMethod string
// these are set if an error is encountered while parsing the headers
protocolError bool
headerError *status.Status

timeoutSet bool
timeout time.Duration
Expand All @@ -397,6 +398,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
case "content-type":
contentSubtype, validContentType := grpcutil.ContentSubtype(hf.Value)
if !validContentType {
contentType = hf.Value
break
}
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
Expand All @@ -412,22 +414,22 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
timeoutSet = true
var err error
if timeout, err = decodeTimeout(hf.Value); err != nil {
headerError = true
headerError = status.Newf(codes.Internal, "malformed grpc-timeout: %v", err)
}
// "Transports must consider requests containing the Connection header
// as malformed." - A41
case "connection":
if logger.V(logLevel) {
logger.Errorf("transport: http2Server.operateHeaders parsed a :connection header which makes a request malformed as per the HTTP/2 spec")
}
headerError = true
protocolError = true
default:
if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) {
break
}
v, err := decodeMetadataHeader(hf.Name, hf.Value)
if err != nil {
headerError = true
headerError = status.Newf(codes.Internal, "malformed binary metadata %q in header %q: %v", hf.Value, hf.Name, err)
dfawley marked this conversation as resolved.
Show resolved Hide resolved
logger.Warningf("Failed to decode metadata header (%q, %q): %v", hf.Name, hf.Value, err)
break
}
Expand All @@ -445,8 +447,8 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
if logger.V(logLevel) {
logger.Errorf("transport: %v", errMsg)
}
t.controlBuf.put(&earlyAbortStream{
httpStatus: 400,
_ = t.controlBuf.put(&earlyAbortStream{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it required to assign the return value to a blank identifier instead of ignoring it? Here and other places in this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My IDE warns about code that ignores errors, which is what this function returns. This is a common analysis in CI (via tools like errcheck), to prevent mistakes where code forgets to check errors. So I added the _ = to make it explicit that I am ignoring the returned error (which suppresses the yellow indicator in GoLand).

I am happy to undo this if you want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally don't prefer the _ = option. If and where it makes sense, we prefer adding a small comment saying why it is safe to ignore the error. So, if you could do that, that would be great.

httpStatus: http.StatusBadRequest,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.New(codes.Internal, errMsg),
Expand All @@ -455,15 +457,35 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
return nil
}

if !isGRPC || headerError {
t.controlBuf.put(&cleanupStream{
if protocolError {
_ = t.controlBuf.put(&cleanupStream{
streamID: streamID,
rst: true,
rstCode: http2.ErrCodeProtocol,
onWrite: func() {},
})
return nil
}
if headerError != nil {
_ = t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusBadRequest,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: headerError,
rst: !frame.StreamEnded(),
})
return nil
}
if !isGRPC {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me like this should come first -- why complain about the syntax of some grpc headers if the other side isn't even speaking grpc?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Fixed.

_ = t.controlBuf.put(&earlyAbortStream{
httpStatus: http.StatusUnsupportedMediaType,
streamID: streamID,
contentSubtype: s.contentSubtype,
status: status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType),
jhump marked this conversation as resolved.
Show resolved Hide resolved
rst: !frame.StreamEnded(),
})
return nil
}
easwars marked this conversation as resolved.
Show resolved Hide resolved

// "If :authority is missing, Host must be renamed to :authority." - A41
if len(mdata[":authority"]) == 0 {
Expand Down