From 9717d5f7f8b3497b0e6af7ba3751fae145bbb3e4 Mon Sep 17 00:00:00 2001 From: Sanjay Pujare Date: Wed, 12 May 2021 09:14:22 -0700 Subject: [PATCH 1/2] xds: add null reference checks in SslContextProviderSupplier --- .../sds/SslContextProviderSupplier.java | 14 ++++--- .../sds/SslContextProviderSupplierTest.java | 37 +++++++++++++++---- 2 files changed, 37 insertions(+), 14 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java index f6803c10700..b319bb83110 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java @@ -41,8 +41,8 @@ public final class SslContextProviderSupplier implements Closeable { public SslContextProviderSupplier( BaseTlsContext tlsContext, TlsContextManager tlsContextManager) { - this.tlsContext = tlsContext; - this.tlsContextManager = tlsContextManager; + this.tlsContext = checkNotNull(tlsContext, "tlsContext"); + this.tlsContextManager = checkNotNull(tlsContextManager, "tlsContextManager"); } public BaseTlsContext getTlsContext() { @@ -92,10 +92,12 @@ private SslContextProvider getSslContextProvider() { /** Called by consumer when tlsContext changes. */ @Override public synchronized void close() { - if (tlsContext instanceof UpstreamTlsContext) { - tlsContextManager.releaseClientSslContextProvider(sslContextProvider); - } else { - tlsContextManager.releaseServerSslContextProvider(sslContextProvider); + if (sslContextProvider != null) { + if (tlsContext instanceof UpstreamTlsContext) { + tlsContextManager.releaseClientSslContextProvider(sslContextProvider); + } else { + tlsContextManager.releaseServerSslContextProvider(sslContextProvider); + } } shutdown = true; } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java index 8d97e430765..34505839a1b 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java @@ -23,7 +23,9 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -57,7 +59,7 @@ public void setUp() { MockitoAnnotations.initMocks(this); } - private void prepareSupplier() { + private void prepareSupplier(boolean callUpdateSslContext) { upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); @@ -65,16 +67,18 @@ private void prepareSupplier() { doReturn(mockSslContextProvider) .when(mockTlsContextManager) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); - mockCallback = mock(SslContextProvider.Callback.class); - Executor mockExecutor = mock(Executor.class); - doReturn(mockExecutor).when(mockCallback).getExecutor(); supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); - supplier.updateSslContext(mockCallback); + if (callUpdateSslContext) { + mockCallback = mock(SslContextProvider.Callback.class); + Executor mockExecutor = mock(Executor.class); + doReturn(mockExecutor).when(mockCallback).getExecutor(); + supplier.updateSslContext(mockCallback); + } } @Test public void get_updateSecret() { - prepareSupplier(); + prepareSupplier(true); verify(mockTlsContextManager, times(2)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); verify(mockTlsContextManager, times(0)) @@ -96,7 +100,7 @@ public void get_updateSecret() { @Test public void get_onException() { - prepareSupplier(); + prepareSupplier(true); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(null); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); @@ -108,7 +112,7 @@ public void get_onException() { @Test public void testClose() { - prepareSupplier(); + prepareSupplier(true); supplier.close(); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); @@ -120,4 +124,21 @@ public void testClose() { assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!"); } } + + @Test + public void testClose_nullSslContextProvider() { + prepareSupplier(false); + doThrow(new NullPointerException()).when(mockTlsContextManager) + .releaseClientSslContextProvider(null); + supplier.close(); + verify(mockTlsContextManager, never()) + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + try { + supplier.updateSslContext(mockCallback); + Assert.fail("no exception thrown"); + } catch (IllegalStateException expected) { + assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!"); + } + } } From f62629de42baf9b2f88d4f1dac19a3a81cb43152 Mon Sep 17 00:00:00 2001 From: Sanjay Pujare Date: Wed, 12 May 2021 10:09:31 -0700 Subject: [PATCH 2/2] address review comment --- .../sds/SslContextProviderSupplierTest.java | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java index 34505839a1b..8c5922b7fdb 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java @@ -59,7 +59,7 @@ public void setUp() { MockitoAnnotations.initMocks(this); } - private void prepareSupplier(boolean callUpdateSslContext) { + private void prepareSupplier() { upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); @@ -68,17 +68,19 @@ private void prepareSupplier(boolean callUpdateSslContext) { .when(mockTlsContextManager) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); - if (callUpdateSslContext) { - mockCallback = mock(SslContextProvider.Callback.class); - Executor mockExecutor = mock(Executor.class); - doReturn(mockExecutor).when(mockCallback).getExecutor(); - supplier.updateSslContext(mockCallback); - } + } + + private void callUpdateSslContext() { + mockCallback = mock(SslContextProvider.Callback.class); + Executor mockExecutor = mock(Executor.class); + doReturn(mockExecutor).when(mockCallback).getExecutor(); + supplier.updateSslContext(mockCallback); } @Test public void get_updateSecret() { - prepareSupplier(true); + prepareSupplier(); + callUpdateSslContext(); verify(mockTlsContextManager, times(2)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); verify(mockTlsContextManager, times(0)) @@ -100,7 +102,8 @@ public void get_updateSecret() { @Test public void get_onException() { - prepareSupplier(true); + prepareSupplier(); + callUpdateSslContext(); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(null); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); @@ -112,7 +115,8 @@ public void get_onException() { @Test public void testClose() { - prepareSupplier(true); + prepareSupplier(); + callUpdateSslContext(); supplier.close(); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); @@ -127,7 +131,7 @@ public void testClose() { @Test public void testClose_nullSslContextProvider() { - prepareSupplier(false); + prepareSupplier(); doThrow(new NullPointerException()).when(mockTlsContextManager) .releaseClientSslContextProvider(null); supplier.close();