diff --git a/api/src/main/java/io/grpc/ManagedChannelProvider.java b/api/src/main/java/io/grpc/ManagedChannelProvider.java index f57340d9ba9..42941dfc809 100644 --- a/api/src/main/java/io/grpc/ManagedChannelProvider.java +++ b/api/src/main/java/io/grpc/ManagedChannelProvider.java @@ -17,6 +17,8 @@ package io.grpc; import com.google.common.base.Preconditions; +import java.net.SocketAddress; +import java.util.Collection; /** * Provider of managed channels for transport agnostic consumption. @@ -79,6 +81,11 @@ protected NewChannelBuilderResult newChannelBuilder(String target, ChannelCreden return NewChannelBuilderResult.error("ChannelCredentials are unsupported"); } + /** + * Returns the {@link SocketAddress} types this ManagedChannelProvider supports. + */ + protected abstract Collection> getSupportedSocketAddressTypes(); + public static final class NewChannelBuilderResult { private final ManagedChannelBuilder channelBuilder; private final String error; diff --git a/api/src/main/java/io/grpc/ManagedChannelRegistry.java b/api/src/main/java/io/grpc/ManagedChannelRegistry.java index 8eb1cce14ac..677856ed8d8 100644 --- a/api/src/main/java/io/grpc/ManagedChannelRegistry.java +++ b/api/src/main/java/io/grpc/ManagedChannelRegistry.java @@ -18,7 +18,12 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import java.net.SocketAddress; +import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedHashSet; @@ -144,6 +149,28 @@ static List> getHardCodedClasses() { } ManagedChannelBuilder newChannelBuilder(String target, ChannelCredentials creds) { + return newChannelBuilder(NameResolverRegistry.getDefaultRegistry(), target, creds); + } + + @VisibleForTesting + ManagedChannelBuilder newChannelBuilder(NameResolverRegistry nameResolverRegistry, + String target, ChannelCredentials creds) { + NameResolverProvider nameResolverProvider = null; + try { + URI uri = new URI(target); + nameResolverProvider = nameResolverRegistry.providers().get(uri.getScheme()); + } catch (URISyntaxException ignore) { + // bad URI found, just ignore and continue + } + if (nameResolverProvider == null) { + nameResolverProvider = nameResolverRegistry.providers().get( + nameResolverRegistry.asFactory().getDefaultScheme()); + } + Collection> nameResolverSocketAddressTypes + = (nameResolverProvider != null) + ? nameResolverProvider.getProducedSocketAddressTypes() : + Collections.emptySet(); + List providers = providers(); if (providers.isEmpty()) { throw new ProviderNotFoundException("No functional channel service provider found. " @@ -152,6 +179,15 @@ ManagedChannelBuilder newChannelBuilder(String target, ChannelCredentials cre } StringBuilder error = new StringBuilder(); for (ManagedChannelProvider provider : providers()) { + Collection> channelProviderSocketAddressTypes + = provider.getSupportedSocketAddressTypes(); + if (!channelProviderSocketAddressTypes.containsAll(nameResolverSocketAddressTypes)) { + error.append("; "); + error.append(provider.getClass().getName()); + error.append(": does not support 1 or more of "); + error.append(Arrays.toString(nameResolverSocketAddressTypes.toArray())); + continue; + } ManagedChannelProvider.NewChannelBuilderResult result = provider.newChannelBuilder(target, creds); if (result.getChannelBuilder() != null) { diff --git a/api/src/main/java/io/grpc/NameResolverProvider.java b/api/src/main/java/io/grpc/NameResolverProvider.java index 2c337cd5052..e7cddfc36d0 100644 --- a/api/src/main/java/io/grpc/NameResolverProvider.java +++ b/api/src/main/java/io/grpc/NameResolverProvider.java @@ -17,6 +17,10 @@ package io.grpc; import io.grpc.NameResolver.Factory; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; /** * Provider of name resolvers for name agnostic consumption. @@ -62,4 +66,14 @@ public abstract class NameResolverProvider extends NameResolver.Factory { protected String getScheme() { return getDefaultScheme(); } + + /** + * Returns the {@link SocketAddress} types this provider's name-resolver is capable of producing. + * This enables selection of the appropriate {@link ManagedChannelProvider} for a channel. + * + * @return the {@link SocketAddress} types this provider's name-resolver is capable of producing. + */ + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java index 6f25f620576..283c1792777 100644 --- a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java +++ b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java @@ -19,6 +19,12 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import com.google.common.collect.ImmutableSet; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.util.Collection; +import java.util.Collections; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -156,6 +162,256 @@ public void newChannelBuilder_noProvider() { } } + @Test + public void newChannelBuilder_usesScheme() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + class SocketAddress2 extends SocketAddress { + } + + nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + }); + nameResolverRegistry.register(new BaseNameResolverProvider(true, 6, "sc2") { + @Override + protected Collection> getProducedSocketAddressTypes() { + fail("Should not be called"); + throw new AssertionError(); + } + }); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + registry.register(new BaseProvider(true, 5) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress2.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + fail("Should not be called"); + throw new AssertionError(); + } + }); + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat( + registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs( + mcb); + } + + @Test + public void newChannelBuilder_unsupportedSocketAddressTypes() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + class SocketAddress2 extends SocketAddress { + } + + nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { + @Override + protected Collection> getProducedSocketAddressTypes() { + return ImmutableSet.of(SocketAddress1.class, SocketAddress2.class); + } + }); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + registry.register(new BaseProvider(true, 5) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress2.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + fail("Should not be called"); + throw new AssertionError(); + } + }); + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + fail("Should not be called"); + throw new AssertionError(); + } + }); + try { + registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds); + fail("expected exception"); + } catch (ManagedChannelRegistry.ProviderNotFoundException ex) { + assertThat(ex).hasMessageThat().contains("does not support 1 or more of"); + assertThat(ex).hasMessageThat().contains("SocketAddress1"); + assertThat(ex).hasMessageThat().contains("SocketAddress2"); + } + } + + @Test + public void newChannelBuilder_emptySet_asDefault() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.emptySet(); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat( + registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs( + mcb); + } + + @Test + public void newChannelBuilder_noSchemeUsesDefaultScheme() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + }); + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat(registry.newChannelBuilder(nameResolverRegistry, target, creds)).isSameInstanceAs( + mcb); + } + + @Test + public void newChannelBuilder_badUri() { + NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + + class MockChannelBuilder extends ForwardingChannelBuilder { + @Override public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds) { + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + assertThat( + registry.newChannelBuilder(nameResolverRegistry, ":testing123", creds)).isSameInstanceAs( + mcb); + } + + private static class BaseNameResolverProvider extends NameResolverProvider { + private final boolean isAvailable; + private final int priority; + private final String defaultScheme; + + public BaseNameResolverProvider(boolean isAvailable, int priority, String defaultScheme) { + this.isAvailable = isAvailable; + this.priority = priority; + this.defaultScheme = defaultScheme; + } + + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return null; + } + + @Override + public String getDefaultScheme() { + return defaultScheme; + } + + @Override + protected boolean isAvailable() { + return isAvailable; + } + + @Override + protected int priority() { + return priority; + } + } + private static class BaseProvider extends ManagedChannelProvider { private final boolean isAvailable; private final int priority; @@ -184,5 +440,10 @@ protected ManagedChannelBuilder builderForAddress(String name, int port) { protected ManagedChannelBuilder builderForTarget(String target) { throw new UnsupportedOperationException(); } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } } diff --git a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java index 1c9290d2fc0..8078aa0d4c9 100644 --- a/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java +++ b/core/src/main/java/io/grpc/internal/DnsNameResolverProvider.java @@ -21,7 +21,11 @@ import io.grpc.InternalServiceProviders; import io.grpc.NameResolver; import io.grpc.NameResolverProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.util.Collection; +import java.util.Collections; /** * A provider for {@link DnsNameResolver}. @@ -75,4 +79,9 @@ protected boolean isAvailable() { public int priority() { return 5; } + + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/LoggingChannelProvider.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/LoggingChannelProvider.java index ffbc24be69b..81c3501e1d4 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/LoggingChannelProvider.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/LoggingChannelProvider.java @@ -24,6 +24,10 @@ import io.grpc.ManagedChannelProvider; import io.grpc.ManagedChannelRegistry; import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; /** A channel provider that injects logging interceptor. */ final class LoggingChannelProvider extends ManagedChannelProvider { @@ -90,4 +94,9 @@ protected NewChannelBuilderResult newChannelBuilder(String target, ChannelCreden checkNotNull(result.getError(), "Expected error to be set!"); return result; } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java index ac39ab1e625..b431cba4d27 100644 --- a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolverProvider.java @@ -22,7 +22,11 @@ import io.grpc.NameResolverProvider; import io.grpc.internal.GrpcUtil; import io.grpc.xds.InternalSharedXdsClientPoolProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.util.Collection; +import java.util.Collections; import java.util.Map; /** @@ -58,6 +62,11 @@ protected int priority() { return 4; } + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } + private static final class SharedXdsClientPoolProviderBootstrapSetter implements GoogleCloudToProdNameResolver.BootstrapSetter { @Override diff --git a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java index bc25f28f94c..da5b7c3353e 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java +++ b/grpclb/src/main/java/io/grpc/grpclb/SecretGrpclbNameResolverProvider.java @@ -22,7 +22,11 @@ import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; import io.grpc.internal.GrpcUtil; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.util.Collection; +import java.util.Collections; /** * A provider for {@code io.grpc.grpclb.GrpclbNameResolver}. @@ -85,5 +89,10 @@ public int priority() { // Must be higher than DnsNameResolverProvider#priority. return 6; } + + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } } diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java index bf3df4fa6aa..7cc77c150a0 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java @@ -19,6 +19,10 @@ import io.grpc.ChannelCredentials; import io.grpc.Internal; import io.grpc.ManagedChannelProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; /** Provider for {@link NettyChannelBuilder} instances. */ @Internal @@ -52,4 +56,9 @@ public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentia return NewChannelBuilderResult.channelBuilder( new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator)); } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java index 19f99d05029..17a2512a66a 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java @@ -20,6 +20,10 @@ import io.grpc.Internal; import io.grpc.InternalServiceProviders; import io.grpc.ManagedChannelProvider; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; /** * Provider for {@link OkHttpChannelBuilder} instances. @@ -57,4 +61,9 @@ public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentia return NewChannelBuilderResult.channelBuilder(new OkHttpChannelBuilder( target, creds, result.callCredentials, result.factory)); } + + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java index a02e27c37c7..0eb51c91281 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java @@ -23,7 +23,11 @@ import io.grpc.NameResolver.Args; import io.grpc.NameResolverProvider; import io.grpc.internal.ObjectPool; +import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.URI; +import java.util.Collection; +import java.util.Collections; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nullable; @@ -99,6 +103,11 @@ protected int priority() { return 4; } + @Override + protected Collection> getProducedSocketAddressTypes() { + return Collections.singleton(InetSocketAddress.class); + } + interface XdsClientPoolFactory { void setBootstrapOverride(Map bootstrap);