From 05537c1a16cf0463222bf1e99ee7b69e8ba34e39 Mon Sep 17 00:00:00 2001 From: Zach Reyes Date: Mon, 12 Sep 2022 22:47:44 -0400 Subject: [PATCH] Fixed deadlock in transport --- internal/transport/http2_client.go | 13 ++++++-- internal/transport/transport_test.go | 44 ++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 53643fa97477..7543357084d3 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -1232,16 +1232,23 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { if upperLimit == 0 { // This is the first GoAway Frame. upperLimit = math.MaxUint32 // Kill all streams after the GoAway ID. } + + activeStreams := make(map[uint32]*Stream) for streamID, stream := range t.activeStreams { + activeStreams[streamID] = stream + } + + t.prevGoAwayID = id + t.mu.Unlock() + for streamID, stream := range activeStreams { if streamID > id && streamID <= upperLimit { // The stream was unprocessed by the server. atomic.StoreUint32(&stream.unprocessed, 1) t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false) } } - t.prevGoAwayID = id - active := len(t.activeStreams) - t.mu.Unlock() + + active := len(activeStreams) if active == 0 { t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams")) } diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 760e1b64f358..972a828a801d 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -2501,3 +2501,47 @@ func (s) TestPeerSetInServerContext(t *testing.T) { } server.mu.Unlock() } + +// TestGoAwayCloseStreams tests the scenario where a client has many streams +// created, and the server sends a GOAWAY frame with a stream id less than some +// of them, while the client is still creating new streams. This should not +// induce a deadlock. +func (s) TestGoAwayCloseStreams(t *testing.T) { + server, ct, cancel := setUp(t, 0, math.MaxUint32, normal) + defer cancel() + defer server.stop() + defer ct.Close(fmt.Errorf("closed manually by test")) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + for i := 0; i < 5; i++ { + _, err := ct.NewStream(ctx, &CallHdr{}) + if err != nil { + t.Fatalf("error creating stream: %v", err) + } + } + + waitWhileTrue(t, func() (bool, error) { + server.mu.Lock() + defer server.mu.Unlock() + + if len(server.conns) == 0 { + return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") + } + return false, nil + }) + + var st *http2Server + server.mu.Lock() + for k := range server.conns { + st = k.(*http2Server) + } + server.mu.Unlock() + + st.framer.fr.WriteGoAway(5, http2.ErrCodeNo, []byte{}) + for i := 0; i < 10; i++ { + _, err := ct.NewStream(ctx, &CallHdr{}) + if err != nil { + t.Fatalf("error creating stream: %v", err) + } + } +}