From e59604b7ce8cc3faee6f8252331ed6769d349f69 Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Wed, 12 May 2021 10:27:44 -0700 Subject: [PATCH] xds: add null reference checks in SslContextProviderSupplier (#8169) --- .../sds/SslContextProviderSupplier.java | 14 +++++----- .../sds/SslContextProviderSupplierTest.java | 27 ++++++++++++++++++- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java index f6803c10700..b319bb83110 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java @@ -41,8 +41,8 @@ public final class SslContextProviderSupplier implements Closeable { public SslContextProviderSupplier( BaseTlsContext tlsContext, TlsContextManager tlsContextManager) { - this.tlsContext = tlsContext; - this.tlsContextManager = tlsContextManager; + this.tlsContext = checkNotNull(tlsContext, "tlsContext"); + this.tlsContextManager = checkNotNull(tlsContextManager, "tlsContextManager"); } public BaseTlsContext getTlsContext() { @@ -92,10 +92,12 @@ private SslContextProvider getSslContextProvider() { /** Called by consumer when tlsContext changes. */ @Override public synchronized void close() { - if (tlsContext instanceof UpstreamTlsContext) { - tlsContextManager.releaseClientSslContextProvider(sslContextProvider); - } else { - tlsContextManager.releaseServerSslContextProvider(sslContextProvider); + if (sslContextProvider != null) { + if (tlsContext instanceof UpstreamTlsContext) { + tlsContextManager.releaseClientSslContextProvider(sslContextProvider); + } else { + tlsContextManager.releaseServerSslContextProvider(sslContextProvider); + } } shutdown = true; } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java index 8d97e430765..8c5922b7fdb 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java @@ -23,7 +23,9 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -65,16 +67,20 @@ private void prepareSupplier() { doReturn(mockSslContextProvider) .when(mockTlsContextManager) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); + } + + private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); Executor mockExecutor = mock(Executor.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); - supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); supplier.updateSslContext(mockCallback); } @Test public void get_updateSecret() { prepareSupplier(); + callUpdateSslContext(); verify(mockTlsContextManager, times(2)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); verify(mockTlsContextManager, times(0)) @@ -97,6 +103,7 @@ public void get_updateSecret() { @Test public void get_onException() { prepareSupplier(); + callUpdateSslContext(); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(null); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); @@ -109,6 +116,7 @@ public void get_onException() { @Test public void testClose() { prepareSupplier(); + callUpdateSslContext(); supplier.close(); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); @@ -120,4 +128,21 @@ public void testClose() { assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!"); } } + + @Test + public void testClose_nullSslContextProvider() { + prepareSupplier(); + doThrow(new NullPointerException()).when(mockTlsContextManager) + .releaseClientSslContextProvider(null); + supplier.close(); + verify(mockTlsContextManager, never()) + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + try { + supplier.updateSslContext(mockCallback); + Assert.fail("no exception thrown"); + } catch (IllegalStateException expected) { + assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!"); + } + } }