Skip to content

Commit

Permalink
transport: Fix closing a closed channel panic in handlePing (#5854)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasweq committed Dec 13, 2022
1 parent 2f413c4 commit 9373e5c
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 12 deletions.
25 changes: 13 additions & 12 deletions internal/transport/http2_server.go
Expand Up @@ -42,6 +42,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
Expand Down Expand Up @@ -102,13 +103,13 @@ type http2Server struct {

mu sync.Mutex // guard the following

// drainChan is initialized when Drain() is called the first time.
// After which the server writes out the first GoAway(with ID 2^31-1) frame.
// Then an independent goroutine will be launched to later send the second GoAway.
// During this time we don't want to write another first GoAway(with ID 2^31 -1) frame.
// Thus call to Drain() will be a no-op if drainChan is already initialized since draining is
// already underway.
drainChan chan struct{}
// drainEvent is initialized when Drain() is called the first time. After
// which the server writes out the first GoAway(with ID 2^31-1) frame. Then
// an independent goroutine will be launched to later send the second
// GoAway. During this time we don't want to write another first GoAway(with
// ID 2^31 -1) frame. Thus call to Drain() will be a no-op if drainEvent is
// already initialized since draining is already underway.
drainEvent *grpcsync.Event
state transportState
activeStreams map[uint32]*Stream
// idle is the time instant when the connection went idle.
Expand Down Expand Up @@ -838,8 +839,8 @@ const (

func (t *http2Server) handlePing(f *http2.PingFrame) {
if f.IsAck() {
if f.Data == goAwayPing.data && t.drainChan != nil {
close(t.drainChan)
if f.Data == goAwayPing.data && t.drainEvent != nil {
t.drainEvent.Fire()
return
}
// Maybe it's a BDP ping.
Expand Down Expand Up @@ -1287,10 +1288,10 @@ func (t *http2Server) RemoteAddr() net.Addr {
func (t *http2Server) Drain() {
t.mu.Lock()
defer t.mu.Unlock()
if t.drainChan != nil {
if t.drainEvent != nil {
return
}
t.drainChan = make(chan struct{})
t.drainEvent = grpcsync.NewEvent()
t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte{}, headsUp: true})
}

Expand Down Expand Up @@ -1346,7 +1347,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) {
timer := time.NewTimer(time.Minute)
defer timer.Stop()
select {
case <-t.drainChan:
case <-t.drainEvent.Done():
case <-timer.C:
case <-t.done:
return
Expand Down
58 changes: 58 additions & 0 deletions test/goaway_test.go
Expand Up @@ -363,6 +363,7 @@ func testServerMultipleGoAwayPendingRPC(t *testing.T, e env) {
close(ch2)
}()
// Loop until the server side GoAway signal is propagated to the client.

for {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil {
Expand Down Expand Up @@ -402,6 +403,7 @@ func testServerMultipleGoAwayPendingRPC(t *testing.T, e env) {
if err := stream.CloseSend(); err != nil {
t.Fatalf("%v.CloseSend() = %v, want <nil>", stream, err)
}

<-ch1
<-ch2
cancel()
Expand Down Expand Up @@ -707,3 +709,59 @@ func (s) TestGoAwayStreamIDSmallerThanCreatedStreams(t *testing.T) {
ct.writeGoAway(1, http2.ErrCodeNo, []byte{})
goAwayWritten.Fire()
}

// TestTwoGoAwayPingFrames tests the scenario where you get two go away ping
// frames from the client during graceful shutdown. This should not crash the
// server.
func (s) TestTwoGoAwayPingFrames(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
defer lis.Close()
s := grpc.NewServer()
defer s.Stop()
go s.Serve(lis)

conn, err := net.DialTimeout("tcp", lis.Addr().String(), defaultTestTimeout)
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}

st := newServerTesterFromConn(t, conn)
st.greet()
pingReceivedClientSide := testutils.NewChannel()
go func() {
for {
f, err := st.readFrame()
if err != nil {
return
}
switch f.(type) {
case *http2.GoAwayFrame:
case *http2.PingFrame:
pingReceivedClientSide.Send(nil)
default:
t.Errorf("server tester received unexpected frame type %T", f)
}
}
}()
gsDone := testutils.NewChannel()
go func() {
s.GracefulStop()
gsDone.Send(nil)
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := pingReceivedClientSide.Receive(ctx); err != nil {
t.Fatalf("Error waiting for ping frame client side from graceful shutdown: %v", err)
}
// Write two goaway pings here.
st.writePing(true, [8]byte{1, 6, 1, 8, 0, 3, 3, 9})
st.writePing(true, [8]byte{1, 6, 1, 8, 0, 3, 3, 9})
// Close the conn to finish up the Graceful Shutdown process.
conn.Close()
if _, err := gsDone.Receive(ctx); err != nil {
t.Fatalf("Error waiting for graceful shutdown of the server: %v", err)
}
}
6 changes: 6 additions & 0 deletions test/servertester.go
Expand Up @@ -273,3 +273,9 @@ func (st *serverTester) writeRSTStream(streamID uint32, code http2.ErrCode) {
st.t.Fatalf("Error writing RST_STREAM: %v", err)
}
}

func (st *serverTester) writePing(ack bool, data [8]byte) {
if err := st.fr.WritePing(ack, data); err != nil {
st.t.Fatalf("Error writing PING: %v", err)
}
}

0 comments on commit 9373e5c

Please sign in to comment.