Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xds: fix the validation code to accept new-style CertificateProviderPluginInstance wherever used (v1.44.x backport) #8901

Merged
merged 1 commit into from Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -26,9 +26,11 @@ public final class CommonTlsContextUtil {
private CommonTlsContextUtil() {}

static boolean hasCertProviderInstance(CommonTlsContext commonTlsContext) {
return commonTlsContext != null
&& (commonTlsContext.hasTlsCertificateCertificateProviderInstance()
|| hasCertProviderValidationContext(commonTlsContext));
if (commonTlsContext == null) {
return false;
}
return hasIdentityCertificateProviderInstance(commonTlsContext)
|| hasCertProviderValidationContext(commonTlsContext);
}

private static boolean hasCertProviderValidationContext(CommonTlsContext commonTlsContext) {
Expand All @@ -37,6 +39,19 @@ private static boolean hasCertProviderValidationContext(CommonTlsContext commonT
commonTlsContext.getCombinedValidationContext();
return combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance();
}
return hasValidationProviderInstance(commonTlsContext);
}

private static boolean hasIdentityCertificateProviderInstance(CommonTlsContext commonTlsContext) {
return commonTlsContext.hasTlsCertificateProviderInstance()
|| commonTlsContext.hasTlsCertificateCertificateProviderInstance();
}

private static boolean hasValidationProviderInstance(CommonTlsContext commonTlsContext) {
if (commonTlsContext.hasValidationContext() && commonTlsContext.getValidationContext()
.hasCaCertificateProviderInstance()) {
return true;
}
return commonTlsContext.hasValidationContextCertificateProviderInstance();
}

Expand Down
Expand Up @@ -208,6 +208,72 @@ public void createCertProviderClientSslContextProvider_2providers()
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}

@Test
public void createNewCertProviderClientSslContextProvider_withSans() {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[2];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "file_watcher", 1);

CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"file_provider",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext);

Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapInfo, certProviderClientSslContextProviderFactory);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}

@Test
public void createNewCertProviderClientSslContextProvider_onlyRootCert() {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
/* certInstanceName= */ null,
/* certName= */ null,
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext);

Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapInfo, certProviderClientSslContextProviderFactory);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}

@Test
public void createNullCommonTlsContext_exception() throws IOException {
clientSslContextProviderFactory =
Expand Down
Expand Up @@ -206,4 +206,41 @@ public void createCertProviderServerSslContextProvider_2providers()
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}

@Test
public void createNewCertProviderServerSslContextProvider_withSans()
throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[2];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "file_watcher", 1);
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();

DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildNewDownstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"file_provider",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext,
/* requireClientCert= */ true);

Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
serverSslContextProviderFactory =
new ServerSslContextProviderFactory(
bootstrapInfo, certProviderServerSslContextProviderFactory);
SslContextProvider sslContextProvider =
serverSslContextProviderFactory.create(downstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}
}