diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 5e167fc498e..92e23b13b67 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -264,7 +264,7 @@ private void updateMaxConcurrentRequests(@Nullable Long maxConcurrentRequests) { private void updateSslContextProviderSupplier(@Nullable UpstreamTlsContext tlsContext) { UpstreamTlsContext currentTlsContext = sslContextProviderSupplier != null - ? sslContextProviderSupplier.getUpstreamTlsContext() + ? (UpstreamTlsContext)sslContextProviderSupplier.getTlsContext() : null; if (Objects.equals(currentTlsContext, tlsContext)) { return; diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java index 715b81f35a9..020acd8eee2 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java @@ -19,31 +19,33 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; +import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.netty.handler.ssl.SslContext; /** - * Enables the CDS policy to initialize this object with the received {@link UpstreamTlsContext} & - * communicate it to the consumer i.e. {@link SdsProtocolNegotiators.ClientSdsProtocolNegotiator} + * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} + * and communicate it to the consumer i.e. {@link SdsProtocolNegotiators} * to lazily evaluate the {@link SslContextProvider}. The supplier prevents credentials leakage in - * cases where the user is not using xDS credentials but the CDS policy contains a non-default - * {@link UpstreamTlsContext}. + * cases where the user is not using xDS credentials but the client/server contains a non-default + * {@link BaseTlsContext}. */ public final class SslContextProviderSupplier implements Closeable { - private final UpstreamTlsContext upstreamTlsContext; + private final BaseTlsContext tlsContext; private final TlsContextManager tlsContextManager; private SslContextProvider sslContextProvider; private boolean shutdown; public SslContextProviderSupplier( - UpstreamTlsContext upstreamTlsContext, TlsContextManager tlsContextManager) { - this.upstreamTlsContext = upstreamTlsContext; + BaseTlsContext tlsContext, TlsContextManager tlsContextManager) { + this.tlsContext = tlsContext; this.tlsContextManager = tlsContextManager; } - public UpstreamTlsContext getUpstreamTlsContext() { - return upstreamTlsContext; + public BaseTlsContext getTlsContext() { + return tlsContext; } /** Updates SslContext via the passed callback. */ @@ -51,34 +53,48 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call checkNotNull(callback, "callback"); checkState(!shutdown, "Supplier is shutdown!"); if (sslContextProvider == null) { - sslContextProvider = - tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext); + sslContextProvider = getSslContextProvider(); } // we want to increment the ref-count so call findOrCreate again... - final SslContextProvider toRelease = - tlsContextManager.findOrCreateClientSslContextProvider(upstreamTlsContext); + final SslContextProvider toRelease = getSslContextProvider(); sslContextProvider.addCallback( new SslContextProvider.Callback(callback.getExecutor()) { @Override public void updateSecret(SslContext sslContext) { callback.updateSecret(sslContext); - tlsContextManager.releaseClientSslContextProvider(toRelease); + releaseSslContextProvider(toRelease); } @Override public void onException(Throwable throwable) { callback.onException(throwable); - tlsContextManager.releaseClientSslContextProvider(toRelease); + releaseSslContextProvider(toRelease); } }); } - /** Called by {@link io.grpc.xds.CdsLoadBalancer} when upstreamTlsContext changes. */ + private void releaseSslContextProvider(SslContextProvider toRelease) { + if (tlsContext instanceof UpstreamTlsContext) { + tlsContextManager.releaseClientSslContextProvider(toRelease); + } else { + tlsContextManager.releaseServerSslContextProvider(toRelease); + } + } + + private SslContextProvider getSslContextProvider() { + return tlsContext instanceof UpstreamTlsContext + ? tlsContextManager.findOrCreateClientSslContextProvider((UpstreamTlsContext) tlsContext) + : tlsContextManager.findOrCreateServerSslContextProvider((DownstreamTlsContext) tlsContext); + } + + /** Called by consumer when tlsContext changes. */ @Override public synchronized void close() { - if (sslContextProvider != null) { + if (tlsContext instanceof UpstreamTlsContext) { tlsContextManager.releaseClientSslContextProvider(sslContextProvider); + } else { + tlsContextManager.releaseServerSslContextProvider(sslContextProvider); } shutdown = true; } diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index a0ae9e00377..b1e5dceefac 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -520,7 +520,7 @@ private void subtest_endpointAddressesAttachedWithTlsConfig(boolean enableSecuri SslContextProviderSupplier supplier = eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); if (enableSecurity) { - assertThat(supplier.getUpstreamTlsContext()).isEqualTo(upstreamTlsContext); + assertThat(supplier.getTlsContext()).isEqualTo(upstreamTlsContext); } else { assertThat(supplier).isNull(); } @@ -554,7 +554,7 @@ private void subtest_endpointAddressesAttachedWithTlsConfig(boolean enableSecuri SslContextProviderSupplier supplier = eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); if (enableSecurity) { - assertThat(supplier.getUpstreamTlsContext()).isEqualTo(upstreamTlsContext); + assertThat(supplier.getTlsContext()).isEqualTo(upstreamTlsContext); } else { assertThat(supplier).isNull(); }