diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 75586307435..ea3babb118b 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -1073,7 +1073,7 @@ func (t *http2Client) handleData(f *http2.DataFrame) { } // The server has closed the stream without sending trailers. Record that // the read direction is closed, and set the status appropriately. - if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { + if f.StreamEnded() { t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true) } } diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 3d1d5c1d4cd..f1594d663af 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -734,7 +734,7 @@ func (t *http2Server) handleData(f *http2.DataFrame) { s.write(recvMsg{buffer: buffer}) } } - if f.Header().Flags.Has(http2.FlagDataEndStream) { + if f.StreamEnded() { // Received the end of stream from the client. s.compareAndSwapState(streamActive, streamReadDone) s.write(recvMsg{err: io.EOF}) diff --git a/test/end2end_test.go b/test/end2end_test.go index bce752701da..e84b9e99170 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7352,8 +7352,11 @@ type httpServerResponse struct { } type httpServer struct { - refuseStream func(uint32) bool - responses []httpServerResponse + // If waitForEndStream is set, wait for the client to send a frame with end + // stream in it before sending a response/refused stream. + waitForEndStream bool + refuseStream func(uint32) bool + responses []httpServerResponse } func (s *httpServer) writeHeader(framer *http2.Framer, sid uint32, headerFields []string, endStream bool) error { @@ -7416,8 +7419,25 @@ func (s *httpServer) start(t *testing.T, lis net.Listener) { } return } - if hframe, ok := frame.(*http2.HeadersFrame); ok { - sid = hframe.Header().StreamID + sid = 0 + switch fr := frame.(type) { + case *http2.HeadersFrame: + // Respond after this if we are not waiting for an end + // stream or if this frame ends it. + if !s.waitForEndStream || fr.StreamEnded() { + sid = fr.Header().StreamID + } + + case *http2.DataFrame: + // Respond after this if we were waiting for an end stream + // and this frame ends it. (If we were not waiting for an + // end stream, this stream was already responded to when + // the headers were received.) + if s.waitForEndStream && fr.StreamEnded() { + sid = fr.Header().StreamID + } + } + if sid != 0 { if s.refuseStream == nil || !s.refuseStream(sid) { break } diff --git a/test/retry_test.go b/test/retry_test.go index dcd3a2158db..7f068d79f44 100644 --- a/test/retry_test.go +++ b/test/retry_test.go @@ -517,6 +517,7 @@ func (s) TestRetryStats(t *testing.T) { } defer lis.Close() server := &httpServer{ + waitForEndStream: true, responses: []httpServerResponse{{ trailers: [][]string{{ ":status", "200", @@ -588,13 +589,6 @@ func (s) TestRetryStats(t *testing.T) { &stats.End{}, } - // There is a race between noticing the RST_STREAM during the first RPC - // attempt and writing the payload. If we detect that the client did not - // send the OutPayload, we remove it from want. - if _, ok := handler.s[2].(*stats.End); ok { - want = append(want[:2], want[3:]...) - } - toString := func(ss []stats.RPCStats) (ret []string) { for _, s := range ss { ret = append(ret, fmt.Sprintf("%T - %v", s, s)) @@ -612,8 +606,7 @@ func (s) TestRetryStats(t *testing.T) { // There is a race between receiving the payload (triggered by the // application / gRPC library) and receiving the trailer (triggered at the // transport layer). Adjust the received stats accordingly if necessary. - // Note: we measure from the end of the RPCStats due to the race above. - tIdx, pIdx := len(handler.s)-3, len(handler.s)-2 + const tIdx, pIdx = 13, 14 _, okT := handler.s[tIdx].(*stats.InTrailer) _, okP := handler.s[pIdx].(*stats.InPayload) if okT && okP { @@ -654,8 +647,8 @@ func (s) TestRetryStats(t *testing.T) { } // Validate timings between last Begin and preceding End. - end := handler.s[len(handler.s)-8].(*stats.End) - begin := handler.s[len(handler.s)-7].(*stats.Begin) + end := handler.s[8].(*stats.End) + begin := handler.s[9].(*stats.Begin) diff := begin.BeginTime.Sub(end.EndTime) if diff < 10*time.Millisecond || diff > 50*time.Millisecond { t.Fatalf("pushback time before final attempt = %v; want ~10ms", diff)