Skip to content

Commit

Permalink
server: after GracefulStop, ensure connections are closed when final …
Browse files Browse the repository at this point in the history
…RPC completes (#5968)

Fixes #5930
  • Loading branch information
dfawley committed Jan 26, 2023
1 parent e2d69aa commit 2a1e934
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 9 deletions.
6 changes: 4 additions & 2 deletions internal/transport/controlbuf.go
Expand Up @@ -527,6 +527,9 @@ const minBatchSize = 1000
// As an optimization, to increase the batch size for each flush, loopy yields the processor, once
// if the batch size is too low to give stream goroutines a chance to fill it up.
func (l *loopyWriter) run() (err error) {
// Always flush the writer before exiting in case there are pending frames
// to be sent.
defer l.framer.writer.Flush()
for {
it, err := l.cbuf.get(true)
if err != nil {
Expand Down Expand Up @@ -759,7 +762,7 @@ func (l *loopyWriter) cleanupStreamHandler(c *cleanupStream) error {
return err
}
}
if l.side == clientSide && l.draining && len(l.estdStreams) == 0 {
if l.draining && len(l.estdStreams) == 0 {
return errors.New("finished processing active streams while in draining mode")
}
return nil
Expand Down Expand Up @@ -814,7 +817,6 @@ func (l *loopyWriter) goAwayHandler(g *goAway) error {
}

func (l *loopyWriter) closeConnectionHandler() error {
l.framer.writer.Flush()
// Exit loopyWriter entirely by returning an error here. This will lead to
// the transport closing the connection, and, ultimately, transport
// closure.
Expand Down
51 changes: 51 additions & 0 deletions test/gracefulstop_test.go
Expand Up @@ -26,6 +26,7 @@ import (
"testing"
"time"

"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
Expand Down Expand Up @@ -164,3 +165,53 @@ func (s) TestGracefulStop(t *testing.T) {
cancel()
wg.Wait()
}

func (s) TestGracefulStopClosesConnAfterLastStream(t *testing.T) {
// This test ensures that a server closes the connections to its clients
// when the final stream has completed after a GOAWAY.

handlerCalled := make(chan struct{})
gracefulStopCalled := make(chan struct{})

ts := &funcServer{streamingInputCall: func(stream testpb.TestService_StreamingInputCallServer) error {
close(handlerCalled) // Initiate call to GracefulStop.
<-gracefulStopCalled // Wait for GOAWAYs to be received by the client.
return nil
}}

te := newTest(t, tcpClearEnv)
te.startServer(ts)
defer te.tearDown()

te.withServerTester(func(st *serverTester) {
st.writeHeadersGRPC(1, "/grpc.testing.TestService/StreamingInputCall", false)

<-handlerCalled // Wait for the server to invoke its handler.

// Gracefully stop the server.
gracefulStopDone := make(chan struct{})
go func() {
te.srv.GracefulStop()
close(gracefulStopDone)
}()
st.wantGoAway(http2.ErrCodeNo) // Server sends a GOAWAY due to GracefulStop.
pf := st.wantPing() // Server sends a ping to verify client receipt.
st.writePing(true, pf.Data) // Send ping ack to confirm.
st.wantGoAway(http2.ErrCodeNo) // Wait for subsequent GOAWAY to indicate no new stream processing.

close(gracefulStopCalled) // Unblock server handler.

fr := st.wantAnyFrame() // Wait for trailer.
hdr, ok := fr.(*http2.MetaHeadersFrame)
if !ok {
t.Fatalf("Received unexpected frame of type (%T) from server: %v; want HEADERS", fr, fr)
}
if !hdr.StreamEnded() {
t.Fatalf("Received unexpected HEADERS frame from server: %v; want END_STREAM set", fr)
}

st.wantRSTStream(http2.ErrCodeNo) // Server should send RST_STREAM because client did not half-close.

<-gracefulStopDone // Wait for GracefulStop to return.
})
}
35 changes: 31 additions & 4 deletions test/servertester.go
Expand Up @@ -138,19 +138,46 @@ func (st *serverTester) writeSettingsAck() {
}
}

func (st *serverTester) wantGoAway(errCode http2.ErrCode) *http2.GoAwayFrame {
f, err := st.readFrame()
if err != nil {
st.t.Fatalf("Error while expecting an RST frame: %v", err)
}
gaf, ok := f.(*http2.GoAwayFrame)
if !ok {
st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f)
}
if gaf.ErrCode != errCode {
st.t.Fatalf("expected GOAWAY error code '%v', got '%v'", errCode.String(), gaf.ErrCode.String())
}
return gaf
}

func (st *serverTester) wantPing() *http2.PingFrame {
f, err := st.readFrame()
if err != nil {
st.t.Fatalf("Error while expecting an RST frame: %v", err)
}
pf, ok := f.(*http2.PingFrame)
if !ok {
st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f)
}
return pf
}

func (st *serverTester) wantRSTStream(errCode http2.ErrCode) *http2.RSTStreamFrame {
f, err := st.readFrame()
if err != nil {
st.t.Fatalf("Error while expecting an RST frame: %v", err)
}
sf, ok := f.(*http2.RSTStreamFrame)
rf, ok := f.(*http2.RSTStreamFrame)
if !ok {
st.t.Fatalf("got a %T; want *http2.RSTStreamFrame", f)
}
if sf.ErrCode != errCode {
st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), sf.ErrCode.String())
if rf.ErrCode != errCode {
st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), rf.ErrCode.String())
}
return sf
return rf
}

func (st *serverTester) wantSettings() *http2.SettingsFrame {
Expand Down
6 changes: 3 additions & 3 deletions test/stream_cleanup_test.go
Expand Up @@ -46,7 +46,7 @@ func (s) TestStreamCleanup(t *testing.T) {
return &testpb.Empty{}, nil
},
}
if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(callRecvMsgSize))), grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
if err := ss.Start(nil, grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(callRecvMsgSize))), grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
Expand Down Expand Up @@ -79,7 +79,7 @@ func (s) TestStreamCleanupAfterSendStatus(t *testing.T) {
})
},
}
if err := ss.Start([]grpc.ServerOption{grpc.MaxConcurrentStreams(1)}, grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
if err := ss.Start(nil, grpc.WithInitialWindowSize(int32(initialWindowSize))); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
Expand Down Expand Up @@ -132,6 +132,6 @@ func (s) TestStreamCleanupAfterSendStatus(t *testing.T) {
case <-gracefulStopDone:
timer.Stop()
case <-timer.C:
t.Fatalf("s.GracefulStop() didn't finish without 1 second after the last RPC")
t.Fatalf("s.GracefulStop() didn't finish within 1 second after the last RPC")
}
}

0 comments on commit 2a1e934

Please sign in to comment.