From 376457aabbfce6173f989299753ff3eb65b27047 Mon Sep 17 00:00:00 2001 From: Sanjay Pujare Date: Mon, 28 Jun 2021 17:29:34 -0700 Subject: [PATCH 1/2] xds: fix race condition in SslContextProviderSupplier's updateSslContext and close --- .../sds/SslContextProviderSupplier.java | 11 +++-- .../sds/SslContextProviderSupplierTest.java | 47 +++++-------------- 2 files changed, 19 insertions(+), 39 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java index 3902569d873..5454a55b894 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java @@ -17,7 +17,6 @@ package io.grpc.xds.internal.sds; import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; @@ -56,13 +55,14 @@ public BaseTlsContext getTlsContext() { public synchronized void updateSslContext(final SslContextProvider.Callback callback) { checkNotNull(callback, "callback"); try { - checkState(!shutdown, "Supplier is shutdown!"); - if (sslContextProvider == null) { - sslContextProvider = getSslContextProvider(); + if (!shutdown) { + if (sslContextProvider == null) { + sslContextProvider = getSslContextProvider(); + } } // we want to increment the ref-count so call findOrCreate again... final SslContextProvider toRelease = getSslContextProvider(); - sslContextProvider.addCallback( + toRelease.addCallback( new SslContextProvider.Callback(callback.getExecutor()) { @Override @@ -115,6 +115,7 @@ public synchronized void close() { tlsContextManager.releaseServerSslContextProvider(sslContextProvider); } } + // don't set sslContextProvider to null since we don't want reallocation under any circumstances shutdown = true; } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java index ec2c85e5b8c..19fd0e189c1 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java @@ -23,16 +23,13 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import com.google.common.util.concurrent.MoreExecutors; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; import java.util.concurrent.Executor; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -91,11 +88,11 @@ public void get_updateSecret() { capturedCallback.updateSecret(mockSslContext); verify(mockCallback, times(1)).updateSecret(eq(mockSslContext)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); supplier.updateSslContext(mockCallback); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); } @Test @@ -106,9 +103,11 @@ public void get_onException() { verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); - capturedCallback.onException(new Exception("test")); + Exception exception = new Exception("test"); + capturedCallback.onException(exception); + verify(mockCallback, times(1)).onException(eq(exception)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider)); } @Test @@ -118,20 +117,11 @@ public void testClose() { supplier.close(); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); - SslContextProvider.Callback mockCallback = spy( - new SslContextProvider.Callback(MoreExecutors.directExecutor()) { - @Override - public void updateSecret(SslContext sslContext) { - Assert.fail("unexpected call"); - } - - @Override - protected void onException(Throwable argument) { - assertThat(argument).isInstanceOf(IllegalStateException.class); - assertThat(argument).hasMessageThat().contains("Supplier is shutdown!"); - } - }); supplier.updateSslContext(mockCallback); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(any(SslContextProvider.class)); } @Test @@ -142,19 +132,8 @@ public void testClose_nullSslContextProvider() { supplier.close(); verify(mockTlsContextManager, never()) .releaseClientSslContextProvider(eq(mockSslContextProvider)); - SslContextProvider.Callback mockCallback = spy( - new SslContextProvider.Callback(MoreExecutors.directExecutor()) { - @Override - public void updateSecret(SslContext sslContext) { - Assert.fail("unexpected call"); - } - - @Override - protected void onException(Throwable argument) { - assertThat(argument).isInstanceOf(IllegalStateException.class); - assertThat(argument).hasMessageThat().contains("Supplier is shutdown!"); - } - }); - supplier.updateSslContext(mockCallback); + callUpdateSslContext(); + verify(mockTlsContextManager, times(1)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); } } From d1d07b4a95a7d1a8264bd35ce8dbfe9a138636ab Mon Sep 17 00:00:00 2001 From: Sanjay Pujare Date: Thu, 8 Jul 2021 17:36:52 -0700 Subject: [PATCH 2/2] address review comment --- .../io/grpc/xds/internal/sds/SslContextProviderSupplier.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java index 5454a55b894..3300c22b2bf 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java @@ -115,7 +115,7 @@ public synchronized void close() { tlsContextManager.releaseServerSslContextProvider(sslContextProvider); } } - // don't set sslContextProvider to null since we don't want reallocation under any circumstances + sslContextProvider = null; shutdown = true; }