Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport: Fix closing a closed channel panic in handlePing #5854

Merged
merged 3 commits into from Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 13 additions & 12 deletions internal/transport/http2_server.go
Expand Up @@ -42,6 +42,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
Expand Down Expand Up @@ -102,13 +103,13 @@ type http2Server struct {

mu sync.Mutex // guard the following

// drainChan is initialized when Drain() is called the first time.
// After which the server writes out the first GoAway(with ID 2^31-1) frame.
// Then an independent goroutine will be launched to later send the second GoAway.
// During this time we don't want to write another first GoAway(with ID 2^31 -1) frame.
// Thus call to Drain() will be a no-op if drainChan is already initialized since draining is
// already underway.
drainChan chan struct{}
// drainEvent is initialized when Drain() is called the first time. After
// which the server writes out the first GoAway(with ID 2^31-1) frame. Then
// an independent goroutine will be launched to later send the second
// GoAway. During this time we don't want to write another first GoAway(with
// ID 2^31 -1) frame. Thus call to Drain() will be a no-op if drainEvent is
// already initialized since draining is already underway.
drainEvent *grpcsync.Event
state transportState
activeStreams map[uint32]*Stream
// idle is the time instant when the connection went idle.
Expand Down Expand Up @@ -838,8 +839,8 @@ const (

func (t *http2Server) handlePing(f *http2.PingFrame) {
if f.IsAck() {
if f.Data == goAwayPing.data && t.drainChan != nil {
close(t.drainChan)
if f.Data == goAwayPing.data && t.drainEvent != nil {
t.drainEvent.Fire()
return
}
// Maybe it's a BDP ping.
Expand Down Expand Up @@ -1287,10 +1288,10 @@ func (t *http2Server) RemoteAddr() net.Addr {
func (t *http2Server) Drain() {
t.mu.Lock()
defer t.mu.Unlock()
if t.drainChan != nil {
if t.drainEvent != nil {
return
}
t.drainChan = make(chan struct{})
t.drainEvent = grpcsync.NewEvent()
t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte{}, headsUp: true})
}

Expand Down Expand Up @@ -1346,7 +1347,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
timer := time.NewTimer(time.Minute)
defer timer.Stop()
select {
case <-t.drainChan:
case <-t.drainEvent.Done():
case <-timer.C:
case <-t.done:
return
Expand Down
58 changes: 58 additions & 0 deletions test/goaway_test.go
Expand Up @@ -363,6 +363,7 @@ func testServerMultipleGoAwayPendingRPC(t *testing.T, e env) {
close(ch2)
}()
// Loop until the server side GoAway signal is propagated to the client.

for {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
Expand Down Expand Up @@ -402,6 +403,7 @@ func testServerMultipleGoAwayPendingRPC(t *testing.T, e env) {
if err := stream.CloseSend(); err != nil {
t.Fatalf("%v.CloseSend() = %v, want <nil>", stream, err)
}

<-ch1
<-ch2
cancel()
Expand Down Expand Up @@ -707,3 +709,59 @@ func (s) TestGoAwayStreamIDSmallerThanCreatedStreams(t *testing.T) {
ct.writeGoAway(1, http2.ErrCodeNo, []byte{})
goAwayWritten.Fire()
}

// TestTwoGoAwayPingFrames tests the scenario where you get two go away ping
// frames from the client during graceful shutdown. This should not crash the
// server.
func (s) TestTwoGoAwayPingFrames(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()
s := grpc.NewServer()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to defer s.Stop() just in case GracefulStop() does hang (otherwise the leak detector hangs).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting added.

defer s.Stop()
go s.Serve(lis)

conn, err := net.DialTimeout("tcp", lis.Addr().String(), defaultTestTimeout)
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}

st := newServerTesterFromConn(t, conn)
st.greet()
pingReceivedClientSide := testutils.NewChannel()
go func() {
for {
f, err := st.readFrame()
if err != nil {
return
}
switch f.(type) {
case *http2.GoAwayFrame:
case *http2.PingFrame:
pingReceivedClientSide.Send(nil)
default:
t.Errorf("server tester received unexpected frame type %T", f)
}
}
}()
gsDone := testutils.NewChannel()
go func() {
s.GracefulStop()
gsDone.Send(nil)
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := pingReceivedClientSide.Receive(ctx); err != nil {
t.Fatalf("Error waiting for ping frame client side from graceful shutdown: %v", err)
}
// Write two goaway pings here.
st.writePing(true, [8]byte{1, 6, 1, 8, 0, 3, 3, 9})
st.writePing(true, [8]byte{1, 6, 1, 8, 0, 3, 3, 9})
// Close the conn to finish up the Graceful Shutdown process.
conn.Close()
if _, err := gsDone.Receive(ctx); err != nil {
t.Fatalf("Error waiting for graceful shutdown of the server: %v", err)
}
}
6 changes: 6 additions & 0 deletions test/servertester.go
Expand Up @@ -273,3 +273,9 @@ func (st *serverTester) writeRSTStream(streamID uint32, code http2.ErrCode) {
st.t.Fatalf("Error writing RST_STREAM: %v", err)
}
}

func (st *serverTester) writePing(ack bool, data [8]byte) {
if err := st.fr.WritePing(ack, data); err != nil {
st.t.Fatalf("Error writing PING: %v", err)
}
}