diff --git a/xds/internal/xdsclient/controller/controller.go b/xds/internal/xdsclient/controller/controller.go index 4b07dc8d6ac..520da06a103 100644 --- a/xds/internal/xdsclient/controller/controller.go +++ b/xds/internal/xdsclient/controller/controller.go @@ -57,6 +57,11 @@ type Controller struct { cc *grpc.ClientConn // Connection to the management server. vClient version.VersionedClient stopRunGoroutine context.CancelFunc + // The run goroutine closes this channel when it exits, and we block on this + // channel in Close(). This ensures that when Close() returns, the + // underlying transport is closed, and we can guarantee that we will not + // process any subsequent responses from the management server. + runDoneCh chan struct{} backoff func(int) time.Duration streamCh chan grpc.ClientStream @@ -77,6 +82,7 @@ type Controller struct { versionMap map[xdsresource.ResourceType]string // nonceMap contains the nonce from the most recent received response. nonceMap map[xdsresource.ResourceType]string + closed bool // Changes to map lrsClients and the lrsClient inside the map need to be // protected by lrsMu. @@ -127,6 +133,7 @@ func New(config *bootstrap.ServerConfig, updateHandler pubsub.UpdateHandler, val config: config, updateValidator: validator, updateHandler: updateHandler, + runDoneCh: make(chan struct{}), backoff: boff, streamCh: make(chan grpc.ClientStream, 1), @@ -170,6 +177,14 @@ func New(config *bootstrap.ServerConfig, updateHandler pubsub.UpdateHandler, val // Close closes the controller. func (t *Controller) Close() { + t.mu.Lock() + if t.closed { + t.mu.Unlock() + return + } + t.closed = true + t.mu.Unlock() + // Note that Close needs to check for nils even if some of them are always // set in the constructor. This is because the constructor defers Close() in // error cases, and the fields might not be set when the error happens. @@ -179,4 +194,8 @@ func (t *Controller) Close() { if t.cc != nil { t.cc.Close() } + // Wait on the run goroutine to be done only if it was started. + if t.stopRunGoroutine != nil { + <-t.runDoneCh + } } diff --git a/xds/internal/xdsclient/controller/transport.go b/xds/internal/xdsclient/controller/transport.go index 28641dc874a..526aefae29b 100644 --- a/xds/internal/xdsclient/controller/transport.go +++ b/xds/internal/xdsclient/controller/transport.go @@ -54,7 +54,13 @@ func (t *Controller) RemoveWatch(rType xdsresource.ResourceType, resourceName st // stream failed without receiving a single reply) and runs the sender and // receiver routines to send and receive data from the stream respectively. func (t *Controller) run(ctx context.Context) { - go t.send(ctx) + sendDoneCh := make(chan struct{}) + defer func() { + <-sendDoneCh + close(t.runDoneCh) + }() + go t.send(ctx, sendDoneCh) + // TODO: start a goroutine monitoring ClientConn's connectivity state, and // report error (and log) when stats is transient failure. @@ -109,7 +115,9 @@ func (t *Controller) run(ctx context.Context) { // Note that this goroutine doesn't do anything to the old stream when there's a // new one. In fact, there should be only one stream in progress, and new one // should only be created when the old one fails (recv returns an error). -func (t *Controller) send(ctx context.Context) { +func (t *Controller) send(ctx context.Context, doneCh chan struct{}) { + defer func() { close(doneCh) }() + var stream grpc.ClientStream for { select { diff --git a/xds/internal/xdsclient/e2e_test/misc_watchers_test.go b/xds/internal/xdsclient/e2e_test/misc_watchers_test.go index a22970ccdab..414fb249b9a 100644 --- a/xds/internal/xdsclient/e2e_test/misc_watchers_test.go +++ b/xds/internal/xdsclient/e2e_test/misc_watchers_test.go @@ -85,7 +85,6 @@ func (s) TestWatchCallAnotherWatch(t *testing.T) { }) t.Cleanup(rdsCancel3) }) - // defer rdsCancel1() t.Cleanup(rdsCancel1) // Verify the contents of the received update for the all watchers.