diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java index 0c28c79ee22..df09e8bb247 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java @@ -26,9 +26,11 @@ public final class CommonTlsContextUtil { private CommonTlsContextUtil() {} static boolean hasCertProviderInstance(CommonTlsContext commonTlsContext) { - return commonTlsContext != null - && (commonTlsContext.hasTlsCertificateCertificateProviderInstance() - || hasCertProviderValidationContext(commonTlsContext)); + if (commonTlsContext == null) { + return false; + } + return hasIdentityCertificateProviderInstance(commonTlsContext) + || hasCertProviderValidationContext(commonTlsContext); } private static boolean hasCertProviderValidationContext(CommonTlsContext commonTlsContext) { @@ -37,6 +39,19 @@ private static boolean hasCertProviderValidationContext(CommonTlsContext commonT commonTlsContext.getCombinedValidationContext(); return combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance(); } + return hasValidationProviderInstance(commonTlsContext); + } + + private static boolean hasIdentityCertificateProviderInstance(CommonTlsContext commonTlsContext) { + return commonTlsContext.hasTlsCertificateProviderInstance() + || commonTlsContext.hasTlsCertificateCertificateProviderInstance(); + } + + private static boolean hasValidationProviderInstance(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasValidationContext() && commonTlsContext.getValidationContext() + .hasCaCertificateProviderInstance()) { + return true; + } return commonTlsContext.hasValidationContextCertificateProviderInstance(); } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java index 06a3198b263..adc96a36336 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java @@ -208,6 +208,72 @@ public void createCertProviderClientSslContextProvider_2providers() verifyWatcher(sslContextProvider, watcherCaptor[1]); } + @Test + public void createNewCertProviderClientSslContextProvider_withSans() { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[2]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "file_watcher", 1); + + CertificateValidationContext staticCertValidationContext = + CertificateValidationContext.newBuilder() + .addAllMatchSubjectAltNames( + ImmutableSet.of( + StringMatcher.newBuilder().setExact("foo").build(), + StringMatcher.newBuilder().setExact("bar").build())) + .build(); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + "file_provider", + "root-default", + /* alpnProtocols= */ null, + staticCertValidationContext); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + bootstrapInfo, certProviderClientSslContextProviderFactory); + SslContextProvider sslContextProvider = + clientSslContextProviderFactory.create(upstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[1]); + } + + @Test + public void createNewCertProviderClientSslContextProvider_onlyRootCert() { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + CertificateValidationContext staticCertValidationContext = + CertificateValidationContext.newBuilder() + .addAllMatchSubjectAltNames( + ImmutableSet.of( + StringMatcher.newBuilder().setExact("foo").build(), + StringMatcher.newBuilder().setExact("bar").build())) + .build(); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + /* certInstanceName= */ null, + /* certName= */ null, + "gcp_id", + "root-default", + /* alpnProtocols= */ null, + staticCertValidationContext); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + bootstrapInfo, certProviderClientSslContextProviderFactory); + SslContextProvider sslContextProvider = + clientSslContextProviderFactory.create(upstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + } + @Test public void createNullCommonTlsContext_exception() throws IOException { clientSslContextProviderFactory = diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java index a4bab618a36..7623b614001 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java @@ -206,4 +206,41 @@ public void createCertProviderServerSslContextProvider_2providers() verifyWatcher(sslContextProvider, watcherCaptor[0]); verifyWatcher(sslContextProvider, watcherCaptor[1]); } + + @Test + public void createNewCertProviderServerSslContextProvider_withSans() + throws XdsInitializationException { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[2]; + createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); + createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "file_watcher", 1); + CertificateValidationContext staticCertValidationContext = + CertificateValidationContext.newBuilder() + .addAllMatchSubjectAltNames( + ImmutableSet.of( + StringMatcher.newBuilder().setExact("foo").build(), + StringMatcher.newBuilder().setExact("bar").build())) + .build(); + + DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildNewDownstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + "file_provider", + "root-default", + /* alpnProtocols= */ null, + staticCertValidationContext, + /* requireClientCert= */ true); + + Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + bootstrapInfo, certProviderServerSslContextProviderFactory); + SslContextProvider sslContextProvider = + serverSslContextProviderFactory.create(downstreamTlsContext); + assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); + verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[1]); + } }