Skip to content

Commit

Permalink
transport: validate http 200 status for responses (#4474)
Browse files Browse the repository at this point in the history
  • Loading branch information
JNProtzman committed Jul 14, 2021
1 parent ebfe3be commit ba41bba
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 82 deletions.
14 changes: 9 additions & 5 deletions internal/status/status.go
Expand Up @@ -97,7 +97,7 @@ func (s *Status) Err() error {
if s.Code() == codes.OK {
return nil
}
return &Error{e: s.Proto()}
return &Error{s: s}
}

// WithDetails returns a new status with the provided details messages appended to the status.
Expand Down Expand Up @@ -136,19 +136,23 @@ func (s *Status) Details() []interface{} {
return details
}

func (s *Status) String() string {
return fmt.Sprintf("rpc error: code = %s desc = %s", s.Code(), s.Message())
}

// Error wraps a pointer of a status proto. It implements error and Status,
// and a nil *Error should never be returned by this package.
type Error struct {
e *spb.Status
s *Status
}

func (e *Error) Error() string {
return fmt.Sprintf("rpc error: code = %s desc = %s", codes.Code(e.e.GetCode()), e.e.GetMessage())
return e.s.String()
}

// GRPCStatus returns the Status represented by se.
func (e *Error) GRPCStatus() *Status {
return FromProto(e.e)
return e.s
}

// Is implements future error.Is functionality.
Expand All @@ -158,5 +162,5 @@ func (e *Error) Is(target error) bool {
if !ok {
return false
}
return proto.Equal(e.e, tse.e)
return proto.Equal(e.s.s, tse.s.s)
}
85 changes: 51 additions & 34 deletions internal/transport/http2_client.go
Expand Up @@ -24,6 +24,7 @@ import (
"io"
"math"
"net"
"net/http"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -1280,29 +1281,40 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
// that the peer is speaking gRPC and we are in gRPC mode.
isGRPC = !initialHeader
mdata = make(map[string][]string)
contentTypeErr string
contentTypeErr = "malformed header: missing HTTP content-type"
grpcMessage string
statusGen *status.Status

httpStatus string
rawStatus string
httpStatusCode *int
httpStatusErr string
rawStatusCode = codes.Unknown
// headerError is set if an error is encountered while parsing the headers
headerError string
)

if initialHeader {
httpStatusErr = "malformed header: missing HTTP status"
}

for _, hf := range frame.Fields {
switch hf.Name {
case "content-type":
if _, validContentType := grpcutil.ContentSubtype(hf.Value); !validContentType {
contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", hf.Value)
contentTypeErr = fmt.Sprintf("transport: received unexpected content-type %q", hf.Value)
break
}
contentTypeErr = ""
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
isGRPC = true
case "grpc-encoding":
s.recvCompress = hf.Value
case "grpc-status":
rawStatus = hf.Value
code, err := strconv.ParseInt(hf.Value, 10, 32)
if err != nil {
se := status.New(codes.Internal, fmt.Sprintf("transport: malformed grpc-status: %v", err))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return
}
rawStatusCode = codes.Code(uint32(code))
case "grpc-message":
grpcMessage = decodeGrpcMessage(hf.Value)
case "grpc-status-details-bin":
Expand All @@ -1312,7 +1324,27 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
headerError = fmt.Sprintf("transport: malformed grpc-status-details-bin: %v", err)
}
case ":status":
httpStatus = hf.Value
if hf.Value == "200" {
httpStatusErr = ""
statusCode := 200
httpStatusCode = &statusCode
break
}

c, err := strconv.ParseInt(hf.Value, 10, 32)
if err != nil {
se := status.New(codes.Internal, fmt.Sprintf("transport: malformed http-status: %v", err))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return
}
statusCode := int(c)
httpStatusCode = &statusCode

httpStatusErr = fmt.Sprintf(
"unexpected HTTP status code received from server: %d (%s)",
statusCode,
http.StatusText(statusCode),
)
default:
if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) {
break
Expand All @@ -1327,30 +1359,25 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
}
}

if !isGRPC {
var (
code = codes.Internal // when header does not include HTTP status, return INTERNAL
httpStatusCode int
)

if httpStatus != "" {
c, err := strconv.ParseInt(httpStatus, 10, 32)
if err != nil {
se := status.New(codes.Internal, fmt.Sprintf("transport: malformed http-status: %v", err))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return
}
httpStatusCode = int(c)
if !isGRPC || httpStatusErr != "" {
var code = codes.Internal // when header does not include HTTP status, return INTERNAL

if httpStatusCode != nil {
var ok bool
code, ok = HTTPStatusConvTab[httpStatusCode]
code, ok = HTTPStatusConvTab[*httpStatusCode]
if !ok {
code = codes.Unknown
}
}

var errs []string
if httpStatusErr != "" {
errs = append(errs, httpStatusErr)
}
if contentTypeErr != "" {
errs = append(errs, contentTypeErr)
}
// Verify the HTTP response is a 200.
se := status.New(code, constructHTTPErrMsg(&httpStatusCode, contentTypeErr))
se := status.New(code, strings.Join(errs, "; "))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return
}
Expand Down Expand Up @@ -1407,16 +1434,6 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
}

if statusGen == nil {
rawStatusCode := codes.Unknown
if rawStatus != "" {
code, err := strconv.ParseInt(rawStatus, 10, 32)
if err != nil {
se := status.New(codes.Internal, fmt.Sprintf("transport: malformed grpc-status: %v", err))
t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream)
return
}
rawStatusCode = codes.Code(uint32(code))
}
statusGen = status.New(rawStatusCode, grpcMessage)
}

Expand Down
20 changes: 0 additions & 20 deletions internal/transport/http_util.go
Expand Up @@ -173,26 +173,6 @@ func decodeGRPCStatusDetails(rawDetails string) (*status.Status, error) {
return status.FromProto(st), nil
}

// constructErrMsg constructs error message to be returned in HTTP fallback mode.
// Format: HTTP status code and its corresponding message + content-type error message.
func constructHTTPErrMsg(httpStatus *int, contentTypeErr string) string {
var errMsgs []string

if httpStatus == nil {
errMsgs = append(errMsgs, "malformed header: missing HTTP status")
} else {
errMsgs = append(errMsgs, fmt.Sprintf("%s: HTTP status code %d", http.StatusText(*(httpStatus)), *httpStatus))
}

if contentTypeErr == "" {
errMsgs = append(errMsgs, "transport: missing content-type field")
} else {
errMsgs = append(errMsgs, contentTypeErr)
}

return strings.Join(errMsgs, "; ")
}

type timeoutUnit uint8

const (
Expand Down
105 changes: 87 additions & 18 deletions internal/transport/transport_test.go
Expand Up @@ -1978,6 +1978,31 @@ func (s) TestClientHandshakeInfo(t *testing.T) {
}

func (s) TestClientDecodeHeaderStatusErr(t *testing.T) {
testStream := func() *Stream {
return &Stream{
done: make(chan struct{}),
headerChan: make(chan struct{}),
buf: &recvBuffer{
c: make(chan recvMsg),
mu: sync.Mutex{},
},
}
}

testClient := func(ts *Stream) *http2Client {
return &http2Client{
mu: sync.Mutex{},
activeStreams: map[uint32]*Stream{
0: ts,
},
controlBuf: &controlBuffer{
ch: make(chan struct{}),
done: make(chan struct{}),
list: &itemList{},
},
}
}

for _, test := range []struct {
name string
// input
Expand All @@ -1991,17 +2016,32 @@ func (s) TestClientDecodeHeaderStatusErr(t *testing.T) {
Fields: []hpack.HeaderField{
{Name: "content-type", Value: "application/grpc"},
{Name: "grpc-status", Value: "0"},
{Name: ":status", Value: "200"},
},
},
// no error
wantStatus: status.New(codes.OK, ""),
},
{
name: "missing content-type header",
metaHeaderFrame: &http2.MetaHeadersFrame{
Fields: []hpack.HeaderField{
{Name: "grpc-status", Value: "0"},
{Name: ":status", Value: "200"},
},
},
wantStatus: status.New(
codes.Unknown,
"malformed header: missing HTTP content-type",
),
},
{
name: "invalid grpc status header field",
metaHeaderFrame: &http2.MetaHeadersFrame{
Fields: []hpack.HeaderField{
{Name: "content-type", Value: "application/grpc"},
{Name: "grpc-status", Value: "xxxx"},
{Name: ":status", Value: "200"},
},
},
wantStatus: status.New(
Expand All @@ -2018,7 +2058,7 @@ func (s) TestClientDecodeHeaderStatusErr(t *testing.T) {
},
wantStatus: status.New(
codes.Internal,
": HTTP status code 0; transport: received the unexpected content-type \"application/json\"",
"malformed header: missing HTTP status; transport: received unexpected content-type \"application/json\"",
),
},
{
Expand All @@ -2045,27 +2085,56 @@ func (s) TestClientDecodeHeaderStatusErr(t *testing.T) {
"peer header list size exceeded limit",
),
},
{
name: "bad status in grpc mode",
metaHeaderFrame: &http2.MetaHeadersFrame{
Fields: []hpack.HeaderField{
{Name: "content-type", Value: "application/grpc"},
{Name: "grpc-status", Value: "0"},
{Name: ":status", Value: "504"},
},
},
wantStatus: status.New(
codes.Unavailable,
"unexpected HTTP status code received from server: 504 (Gateway Timeout)",
),
},
{
name: "missing http status",
metaHeaderFrame: &http2.MetaHeadersFrame{
Fields: []hpack.HeaderField{
{Name: "content-type", Value: "application/grpc"},
},
},
wantStatus: status.New(
codes.Internal,
"malformed header: missing HTTP status",
),
},
} {

t.Run(test.name, func(t *testing.T) {
ts := &Stream{
done: make(chan struct{}),
headerChan: make(chan struct{}),
buf: &recvBuffer{
c: make(chan recvMsg),
mu: sync.Mutex{},
ts := testStream()
s := testClient(ts)

test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{
FrameHeader: http2.FrameHeader{
StreamID: 0,
},
}
s := &http2Client{
mu: sync.Mutex{},
activeStreams: map[uint32]*Stream{
0: ts,
},
controlBuf: &controlBuffer{
ch: make(chan struct{}),
done: make(chan struct{}),
list: &itemList{},
},

s.operateHeaders(test.metaHeaderFrame)

got := ts.status
want := test.wantStatus
if got.Code() != want.Code() || got.Message() != want.Message() {
t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want)
}
})
t.Run(fmt.Sprintf("%s-end_stream", test.name), func(t *testing.T) {
ts := testStream()
s := testClient(ts)

test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{
FrameHeader: http2.FrameHeader{
StreamID: 0,
Expand All @@ -2078,7 +2147,7 @@ func (s) TestClientDecodeHeaderStatusErr(t *testing.T) {
got := ts.status
want := test.wantStatus
if got.Code() != want.Code() || got.Message() != want.Message() {
t.Fatalf("operateHeaders(%v); status = %v; want %v", test.metaHeaderFrame, got, want)
t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want)
}
})
}
Expand Down

0 comments on commit ba41bba

Please sign in to comment.