subchannelStateRef = getSubchannelStateInfoRef(subchannel);
+
+ // Don't proactively reconnect if the subchannel enters IDLE, even if previously was connected.
+ // If the subchannel was previously in TRANSIENT_FAILURE, it is considered to stay in
+ // TRANSIENT_FAILURE until it becomes READY.
+ if (subchannelStateRef.value.getState() == TRANSIENT_FAILURE) {
+ if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) {
+ return;
+ }
+ }
+ subchannelStateRef.value = stateInfo;
+ updateBalancingState();
+ }
+
+ /**
+ * Aggregates the connectivity states of a group of subchannels for overall connectivity state.
+ *
+ * Aggregation rules (in order of dominance):
+ *
+ * - If there is at least one subchannel in READY state, overall state is READY
+ * - If there are 2 or more subchannels in TRANSIENT_FAILURE, overall state is
+ * TRANSIENT_FAILURE
+ * - If there is at least one subchannel in CONNECTING state, overall state is
+ * CONNECTING
+ * - If there is at least one subchannel in IDLE state, overall state is IDLE
+ * - Otherwise, overall state is TRANSIENT_FAILURE
+ *
+ */
+ private static ConnectivityState aggregateState(Iterable subchannels) {
+ int failureCount = 0;
+ boolean hasIdle = false;
+ boolean hasConnecting = false;
+ for (Subchannel subchannel : subchannels) {
+ ConnectivityState state = getSubchannelStateInfoRef(subchannel).value.getState();
+ if (state == READY) {
+ return state;
+ }
+ if (state == TRANSIENT_FAILURE) {
+ failureCount++;
+ } else if (state == CONNECTING) {
+ hasConnecting = true;
+ } else if (state == IDLE) {
+ hasIdle = true;
+ }
+ }
+ if (failureCount >= 2) {
+ return TRANSIENT_FAILURE;
+ }
+ if (hasConnecting) {
+ return CONNECTING;
+ }
+ return hasIdle ? IDLE : TRANSIENT_FAILURE;
+ }
+
+ private static void shutdownSubchannel(Subchannel subchannel) {
+ subchannel.shutdown();
+ getSubchannelStateInfoRef(subchannel).value = ConnectivityStateInfo.forNonError(SHUTDOWN);
+ }
+
+ /**
+ * Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and
+ * remove all attributes. The values are the original EAGs.
+ */
+ private static Map stripAttrs(
+ List groupList) {
+ Map addrs =
+ new HashMap<>(groupList.size() * 2);
+ for (EquivalentAddressGroup group : groupList) {
+ addrs.put(stripAttrs(group), group);
+ }
+ return addrs;
+ }
+
+ private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
+ return new EquivalentAddressGroup(eag.getAddresses());
+ }
+
+ private static Ref getSubchannelStateInfoRef(
+ Subchannel subchannel) {
+ return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO");
+ }
+
+ private static final class RingHashPicker extends SubchannelPicker {
+ private final SynchronizationContext syncContext;
+ private final List ring;
+ // Avoid synchronization between pickSubchannel and subchannel's connectivity state change,
+ // freeze picker's view of subchannel's connectivity state.
+ // TODO(chengyuanzhang): can be more performance-friendly with
+ // IdentityHashMap and RingEntry contains Subchannel.
+ private final Map pickableSubchannels; // read-only
+
+ private RingHashPicker(
+ SynchronizationContext syncContext, List ring,
+ Map subchannels) {
+ this.syncContext = syncContext;
+ this.ring = ring;
+ pickableSubchannels = new HashMap<>(subchannels.size());
+ for (Map.Entry entry : subchannels.entrySet()) {
+ Subchannel subchannel = entry.getValue();
+ ConnectivityStateInfo stateInfo = subchannel.getAttributes().get(STATE_INFO).value;
+ pickableSubchannels.put(entry.getKey(), new SubchannelView(subchannel, stateInfo));
+ }
+ }
+
+ @Override
+ public PickResult pickSubchannel(PickSubchannelArgs args) {
+ Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY);
+ if (requestHash == null) {
+ return PickResult.withError(RPC_HASH_NOT_FOUND);
+ }
+
+ // Find the ring entry with hash next to (clockwise) the RPC's hash.
+ int low = 0;
+ int high = ring.size();
+ int mid;
+ while (true) {
+ mid = (low + high) / 2;
+ if (mid == ring.size()) {
+ mid = 0;
+ break;
+ }
+ long midVal = ring.get(mid).hash;
+ long midValL = mid == 0 ? 0 : ring.get(mid - 1).hash;
+ if (requestHash <= midVal && requestHash > midValL) {
+ break;
+ }
+ if (midVal < requestHash) {
+ low = mid + 1;
+ } else {
+ high = mid - 1;
+ }
+ if (low > high) {
+ mid = 0;
+ break;
+ }
+ }
+
+ // Try finding a READY subchannel. Starting from the ring entry next to the RPC's hash.
+ // If the one of the first two subchannels is not in TRANSIENT_FAILURE, return result
+ // based on that subchannel. Otherwise, fail the pick unless a READY subchannel is found.
+ // Meanwhile, trigger connection for the first subchannel that is in IDLE if no subchannel
+ // before it is in CONNECTING or READY.
+ boolean hasPending = false; // true if having subchannel(s) in CONNECTING or IDLE
+ boolean canBuffer = true; // true if RPCs can be buffered with a pending subchannel
+ Subchannel firstSubchannel = null;
+ Subchannel secondSubchannel = null;
+ for (int i = 0; i < ring.size(); i++) {
+ int index = (mid + i) % ring.size();
+ EquivalentAddressGroup addrKey = ring.get(index).addrKey;
+ SubchannelView subchannel = pickableSubchannels.get(addrKey);
+ if (subchannel.stateInfo.getState() == READY) {
+ return PickResult.withSubchannel(subchannel.subchannel);
+ }
+
+ // RPCs can be buffered if any of the first two subchannels is pending. Otherwise, RPCs
+ // are failed unless there is a READY connection.
+ if (firstSubchannel == null) {
+ firstSubchannel = subchannel.subchannel;
+ } else if (subchannel.subchannel != firstSubchannel) {
+ if (secondSubchannel == null) {
+ secondSubchannel = subchannel.subchannel;
+ } else if (subchannel.subchannel != secondSubchannel) {
+ canBuffer = false;
+ }
+ }
+ if (subchannel.stateInfo.getState() == TRANSIENT_FAILURE) {
+ continue;
+ }
+ if (!hasPending) { // first non-failing subchannel
+ if (subchannel.stateInfo.getState() == IDLE) {
+ final Subchannel finalSubchannel = subchannel.subchannel;
+ syncContext.execute(new Runnable() {
+ @Override
+ public void run() {
+ finalSubchannel.requestConnection();
+ }
+ });
+ }
+ if (canBuffer) { // done if this is the first or second two subchannel
+ return PickResult.withNoResult(); // queue the pick and re-process later
+ }
+ hasPending = true;
+ }
+ }
+ // Fail the pick with error status of the original subchannel hit by hash.
+ SubchannelView originalSubchannel = pickableSubchannels.get(ring.get(mid).addrKey);
+ return PickResult.withError(originalSubchannel.stateInfo.getStatus());
+ }
+ }
+
+ /**
+ * An unmodifiable view of a subchannel with state not subject to its real connectivity
+ * state changes.
+ */
+ private static final class SubchannelView {
+ private final Subchannel subchannel;
+ private final ConnectivityStateInfo stateInfo;
+
+ private SubchannelView(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
+ this.subchannel = subchannel;
+ this.stateInfo = stateInfo;
+ }
+ }
+
+ private static final class RingEntry implements Comparable {
+ private final long hash;
+ private final EquivalentAddressGroup addrKey;
+
+ private RingEntry(long hash, EquivalentAddressGroup addrKey) {
+ this.hash = hash;
+ this.addrKey = addrKey;
+ }
+
+ @Override
+ public int compareTo(RingEntry entry) {
+ return Long.compare(hash, entry.hash);
+ }
+ }
+
+ /**
+ * A lighter weight Reference than AtomicReference.
+ */
+ private static final class Ref {
+ T value;
+
+ Ref(T value) {
+ this.value = value;
+ }
+ }
+
+ /**
+ * Configures the ring property. The larger the ring is (that is, the more hashes there are
+ * for each provided host) the better the request distribution will reflect the desired weights.
+ */
+ static final class RingHashConfig {
+ final long minRingSize;
+ final long maxRingSize;
+
+ RingHashConfig(long minRingSize, long maxRingSize) {
+ checkArgument(minRingSize > 0, "minRingSize <= 0");
+ checkArgument(maxRingSize > 0, "maxRingSize <= 0");
+ checkArgument(minRingSize <= maxRingSize, "minRingSize > maxRingSize");
+ this.minRingSize = minRingSize;
+ this.maxRingSize = maxRingSize;
+ }
+
+ @Override
+ public String toString() {
+ return MoreObjects.toStringHelper(this)
+ .add("minRingSize", minRingSize)
+ .add("maxRingSize", maxRingSize)
+ .toString();
+ }
+ }
+}
diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java
new file mode 100644
index 00000000000..fcbd527bf5c
--- /dev/null
+++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2021 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds;
+
+import io.grpc.Internal;
+import io.grpc.LoadBalancer;
+import io.grpc.LoadBalancer.Helper;
+import io.grpc.LoadBalancerProvider;
+import io.grpc.NameResolver.ConfigOrError;
+import io.grpc.Status;
+import io.grpc.internal.JsonUtil;
+import io.grpc.xds.RingHashLoadBalancer.RingHashConfig;
+import java.util.Map;
+
+/**
+ * The provider for the "ring_hash" balancing policy.
+ */
+@Internal
+public final class RingHashLoadBalancerProvider extends LoadBalancerProvider {
+
+ private static final boolean enableRingHash =
+ Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH"));
+
+ @Override
+ public LoadBalancer newLoadBalancer(Helper helper) {
+ return new RingHashLoadBalancer(helper);
+ }
+
+ @Override
+ public boolean isAvailable() {
+ return enableRingHash;
+ }
+
+ @Override
+ public int getPriority() {
+ return 5;
+ }
+
+ @Override
+ public String getPolicyName() {
+ return "ring_hash";
+ }
+
+ @Override
+ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawLoadBalancingPolicyConfig) {
+ Long minRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "minRingSize");
+ Long maxRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "maxRingSize");
+ if (minRingSize == null || maxRingSize == null) {
+ return ConfigOrError.fromError(Status.INVALID_ARGUMENT.withDescription(
+ "Missing 'mingRingSize'/'maxRingSize'"));
+ }
+ if (minRingSize <= 0 || maxRingSize <= 0 || minRingSize > maxRingSize) {
+ return ConfigOrError.fromError(Status.INVALID_ARGUMENT.withDescription(
+ "Invalid 'mingRingSize'/'maxRingSize'"));
+ }
+ return ConfigOrError.fromConfig(new RingHashConfig(minRingSize, maxRingSize));
+ }
+}
diff --git a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider
index 6929006c103..7ba3dcf22f5 100644
--- a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider
+++ b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider
@@ -4,3 +4,4 @@ io.grpc.xds.WeightedTargetLoadBalancerProvider
io.grpc.xds.ClusterManagerLoadBalancerProvider
io.grpc.xds.ClusterResolverLoadBalancerProvider
io.grpc.xds.ClusterImplLoadBalancerProvider
+io.grpc.xds.RingHashLoadBalancerProvider
diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java
index 87cc2ceb16d..692bf9ec9e3 100644
--- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java
+++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java
@@ -2127,7 +2127,6 @@ protected final Message buildListener(String name, Message routeConfiguration) {
return buildListener(name, routeConfiguration, Collections.emptyList());
}
- @SuppressWarnings("unchecked")
protected abstract Message buildListener(
String name, Message routeConfiguration, List extends Message> httpFilters);
diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java
new file mode 100644
index 00000000000..2d7eb4fd59f
--- /dev/null
+++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java
@@ -0,0 +1,121 @@
+/*
+ * Copyright 2021 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import io.grpc.InternalServiceProviders;
+import io.grpc.LoadBalancer.Helper;
+import io.grpc.LoadBalancerProvider;
+import io.grpc.NameResolver.ConfigOrError;
+import io.grpc.Status.Code;
+import io.grpc.SynchronizationContext;
+import io.grpc.internal.JsonParser;
+import io.grpc.xds.RingHashLoadBalancer.RingHashConfig;
+import java.io.IOException;
+import java.lang.Thread.UncaughtExceptionHandler;
+import java.util.Map;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link RingHashLoadBalancerProvider}. */
+@RunWith(JUnit4.class)
+public class RingHashLoadBalancerProviderTest {
+ private static final String AUTHORITY = "foo.googleapis.com";
+
+ private final SynchronizationContext syncContext = new SynchronizationContext(
+ new UncaughtExceptionHandler() {
+ @Override
+ public void uncaughtException(Thread t, Throwable e) {
+ throw new AssertionError(e);
+ }
+ });
+ private final RingHashLoadBalancerProvider provider = new RingHashLoadBalancerProvider();
+
+ @Test
+ public void provided() {
+ for (LoadBalancerProvider current : InternalServiceProviders.getCandidatesViaServiceLoader(
+ LoadBalancerProvider.class, getClass().getClassLoader())) {
+ if (current instanceof RingHashLoadBalancerProvider) {
+ return;
+ }
+ }
+ fail("RingHashLoadBalancerProvider not registered");
+ }
+
+ @Test
+ public void providesLoadBalancer() {
+ Helper helper = mock(Helper.class);
+ when(helper.getSynchronizationContext()).thenReturn(syncContext);
+ when(helper.getAuthority()).thenReturn(AUTHORITY);
+ assertThat(provider.newLoadBalancer(helper))
+ .isInstanceOf(RingHashLoadBalancer.class);
+ }
+
+ @Test
+ public void parseLoadBalancingConfig_valid() throws IOException {
+ String lbConfig = "{\"minRingSize\" : 10, \"maxRingSize\" : 100}";
+ ConfigOrError configOrError =
+ provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig));
+ assertThat(configOrError.getConfig()).isNotNull();
+ RingHashConfig config = (RingHashConfig) configOrError.getConfig();
+ assertThat(config.minRingSize).isEqualTo(10L);
+ assertThat(config.maxRingSize).isEqualTo(100L);
+ }
+
+ @Test
+ public void parseLoadBalancingConfig_missingRingSize() throws IOException {
+ String lbConfig = "{\"minRingSize\" : 10}";
+ ConfigOrError configOrError =
+ provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig));
+ assertThat(configOrError.getError()).isNotNull();
+ assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT);
+ assertThat(configOrError.getError().getDescription())
+ .isEqualTo("Missing 'mingRingSize'/'maxRingSize'");
+ }
+
+ @Test
+ public void parseLoadBalancingConfig_zeroMinRingSize() throws IOException {
+ String lbConfig = "{\"minRingSize\" : 0, \"maxRingSize\" : 100}";
+ ConfigOrError configOrError =
+ provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig));
+ assertThat(configOrError.getError()).isNotNull();
+ assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT);
+ assertThat(configOrError.getError().getDescription())
+ .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'");
+ }
+
+ @Test
+ public void parseLoadBalancingConfig_minRingSizeGreaterThanMaxRingSize() throws IOException {
+ String lbConfig = "{\"minRingSize\" : 100, \"maxRingSize\" : 10}";
+ ConfigOrError configOrError =
+ provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig));
+ assertThat(configOrError.getError()).isNotNull();
+ assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT);
+ assertThat(configOrError.getError().getDescription())
+ .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'");
+ }
+
+ @SuppressWarnings("unchecked")
+ private static Map parseJsonObject(String json) throws IOException {
+ return (Map) JsonParser.parse(json);
+ }
+}
diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java
new file mode 100644
index 00000000000..6b70e5974df
--- /dev/null
+++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java
@@ -0,0 +1,728 @@
+/*
+ * Copyright 2021 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds;
+
+import static com.google.common.truth.Truth.assertThat;
+import static io.grpc.ConnectivityState.CONNECTING;
+import static io.grpc.ConnectivityState.IDLE;
+import static io.grpc.ConnectivityState.READY;
+import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+import com.google.common.collect.Iterables;
+import io.grpc.Attributes;
+import io.grpc.CallOptions;
+import io.grpc.ConnectivityStateInfo;
+import io.grpc.EquivalentAddressGroup;
+import io.grpc.LoadBalancer.CreateSubchannelArgs;
+import io.grpc.LoadBalancer.Helper;
+import io.grpc.LoadBalancer.PickResult;
+import io.grpc.LoadBalancer.PickSubchannelArgs;
+import io.grpc.LoadBalancer.ResolvedAddresses;
+import io.grpc.LoadBalancer.Subchannel;
+import io.grpc.LoadBalancer.SubchannelPicker;
+import io.grpc.LoadBalancer.SubchannelStateListener;
+import io.grpc.Metadata;
+import io.grpc.Status;
+import io.grpc.Status.Code;
+import io.grpc.SynchronizationContext;
+import io.grpc.internal.PickSubchannelArgsImpl;
+import io.grpc.testing.TestMethodDescriptors;
+import io.grpc.xds.RingHashLoadBalancer.RingHashConfig;
+import java.lang.Thread.UncaughtExceptionHandler;
+import java.net.SocketAddress;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.InOrder;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
+import org.mockito.stubbing.Answer;
+
+/** Unit test for {@link io.grpc.LoadBalancer}. */
+@RunWith(JUnit4.class)
+public class RingHashLoadBalancerTest {
+ private static final String AUTHORITY = "foo.googleapis.com";
+ private static final Attributes.Key CUSTOM_KEY = Attributes.Key.create("custom-key");
+
+ @Rule
+ public final MockitoRule mocks = MockitoJUnit.rule();
+ private final SynchronizationContext syncContext = new SynchronizationContext(
+ new UncaughtExceptionHandler() {
+ @Override
+ public void uncaughtException(Thread t, Throwable e) {
+ throw new AssertionError(e);
+ }
+ });
+ private final Map, Subchannel> subchannels = new HashMap<>();
+ private final Map subchannelStateListeners =
+ new HashMap<>();
+ private final XxHash64 hashFunc = XxHash64.INSTANCE;
+ @Mock
+ private Helper helper;
+ @Captor
+ private ArgumentCaptor pickerCaptor;
+ private RingHashLoadBalancer loadBalancer;
+
+ @Before
+ public void setUp() {
+ when(helper.getAuthority()).thenReturn(AUTHORITY);
+ when(helper.getSynchronizationContext()).thenReturn(syncContext);
+ when(helper.createSubchannel(any(CreateSubchannelArgs.class))).thenAnswer(
+ new Answer() {
+ @Override
+ public Subchannel answer(InvocationOnMock invocation) throws Throwable {
+ CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0];
+ final Subchannel subchannel = mock(Subchannel.class);
+ when(subchannel.getAllAddresses()).thenReturn(args.getAddresses());
+ when(subchannel.getAttributes()).thenReturn(args.getAttributes());
+ subchannels.put(args.getAddresses(), subchannel);
+ doAnswer(new Answer() {
+ @Override
+ public Void answer(InvocationOnMock invocation) throws Throwable {
+ subchannelStateListeners.put(
+ subchannel, (SubchannelStateListener) invocation.getArguments()[0]);
+ return null;
+ }
+ }).when(subchannel).start(any(SubchannelStateListener.class));
+ return subchannel;
+ }
+ });
+ loadBalancer = new RingHashLoadBalancer(helper);
+ // Skip uninterested interactions.
+ verify(helper).getAuthority();
+ verify(helper).getSynchronizationContext();
+ }
+
+ @After
+ public void tearDown() {
+ loadBalancer.shutdown();
+ for (Subchannel subchannel : subchannels.values()) {
+ verify(subchannel).shutdown();
+ }
+ }
+
+ @Test
+ public void subchannelLazyConnectUntilPicked() {
+ RingHashConfig config = new RingHashConfig(10, 100);
+ List servers = createWeightedServerAddrs(1); // one server
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
+ Subchannel subchannel = Iterables.getOnlyElement(subchannels.values());
+ verify(subchannel, never()).requestConnection();
+ verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+
+ // Picking subchannel triggers connection.
+ PickSubchannelArgs args = new PickSubchannelArgsImpl(
+ TestMethodDescriptors.voidMethod(), new Metadata(),
+ CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ PickResult result = pickerCaptor.getValue().pickSubchannel(args);
+ assertThat(result.getStatus().isOk()).isTrue();
+ assertThat(result.getSubchannel()).isNull();
+ verify(subchannel).requestConnection();
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING));
+ verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
+
+ // Subchannel becomes ready, triggers pick again.
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
+ verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+ result = pickerCaptor.getValue().pickSubchannel(args);
+ assertThat(result.getSubchannel()).isSameInstanceAs(subchannel);
+ verifyNoMoreInteractions(helper);
+ }
+
+ @Test
+ public void subchannelNotAutoReconnectAfterReenteringIdle() {
+ RingHashConfig config = new RingHashConfig(10, 100);
+ List servers = createWeightedServerAddrs(1); // one server
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ Subchannel subchannel = Iterables.getOnlyElement(subchannels.values());
+ InOrder inOrder = Mockito.inOrder(helper, subchannel);
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+ inOrder.verify(subchannel, never()).requestConnection();
+
+ // Picking subchannel triggers connection.
+ PickSubchannelArgs args = new PickSubchannelArgsImpl(
+ TestMethodDescriptors.voidMethod(), new Metadata(),
+ CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ pickerCaptor.getValue().pickSubchannel(args);
+ inOrder.verify(subchannel).requestConnection();
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
+ inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE));
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+ inOrder.verify(subchannel, never()).requestConnection();
+
+ // Picking again triggers reconnection.
+ pickerCaptor.getValue().pickSubchannel(args);
+ inOrder.verify(subchannel).requestConnection();
+ }
+
+ @Test
+ public void aggregateSubchannelStates_connectingReadyIdleFailure() {
+ RingHashConfig config = new RingHashConfig(10, 100);
+ List servers = createWeightedServerAddrs(1, 1);
+ InOrder inOrder = Mockito.inOrder(helper);
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ inOrder.verify(helper, times(2)).createSubchannel(any(CreateSubchannelArgs.class));
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ // one in CONNECTING, one in IDLE
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(0))),
+ ConnectivityStateInfo.forNonError(CONNECTING));
+ inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
+
+ // two in CONNECTING
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(1))),
+ ConnectivityStateInfo.forNonError(CONNECTING));
+ inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
+
+ // one in CONNECTING, one in READY
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(1))),
+ ConnectivityStateInfo.forNonError(READY));
+ inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
+
+ // one in TRANSIENT_FAILURE, one in READY
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(0))),
+ ConnectivityStateInfo.forTransientFailure(
+ Status.UNKNOWN.withDescription("unknown failure")));
+ inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
+
+ // one in TRANSIENT_FAILURE, one in IDLE
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(1))),
+ ConnectivityStateInfo.forNonError(IDLE));
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ verifyNoMoreInteractions(helper);
+ }
+
+ @Test
+ public void aggregateSubchannelStates_twoOrMoreSubchannelsInTransientFailure() {
+ RingHashConfig config = new RingHashConfig(10, 100);
+ List servers = createWeightedServerAddrs(1, 1, 1, 1);
+ InOrder inOrder = Mockito.inOrder(helper);
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ inOrder.verify(helper, times(4)).createSubchannel(any(CreateSubchannelArgs.class));
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ // one in TRANSIENT_FAILURE, three in IDLE
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(0))),
+ ConnectivityStateInfo.forTransientFailure(
+ Status.UNAVAILABLE.withDescription("not found")));
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ // two in TRANSIENT_FAILURE, two in IDLE
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(1))),
+ ConnectivityStateInfo.forTransientFailure(
+ Status.UNAVAILABLE.withDescription("also not found")));
+ inOrder.verify(helper)
+ .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
+
+ // two in TRANSIENT_FAILURE, one in CONNECTING, one in IDLE
+ // The overall state is dominated by the two in TRANSIENT_FAILURE.
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(2))),
+ ConnectivityStateInfo.forNonError(CONNECTING));
+ inOrder.verify(helper)
+ .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
+
+ // three in TRANSIENT_FAILURE, one in CONNECTING
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(3))),
+ ConnectivityStateInfo.forTransientFailure(
+ Status.UNAVAILABLE.withDescription("connection lost")));
+ inOrder.verify(helper)
+ .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
+
+ // three in TRANSIENT_FAILURE, one in READY
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(2))),
+ ConnectivityStateInfo.forNonError(READY));
+ inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
+
+ verifyNoMoreInteractions(helper);
+ }
+
+ @Test
+ public void subchannelStayInTransientFailureUntilBecomeReady() {
+ RingHashConfig config = new RingHashConfig(10, 100);
+ List servers = createWeightedServerAddrs(1, 1, 1);
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
+ verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+ reset(helper);
+
+ // Simulate picks have taken place and subchannels have requested connection.
+ for (Subchannel subchannel : subchannels.values()) {
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(
+ Status.UNAUTHENTICATED.withDescription("Permission denied")));
+ }
+
+ // Stays in IDLE when until there are two or more subchannels in TRANSIENT_FAILURE.
+ verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+ verify(helper, times(2))
+ .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
+
+ verifyNoMoreInteractions(helper);
+ // Simulate underlying subchannel auto reconnect after backoff.
+ for (Subchannel subchannel : subchannels.values()) {
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING));
+ }
+ verifyNoMoreInteractions(helper);
+
+ // Simulate one subchannel enters READY.
+ deliverSubchannelState(
+ subchannels.values().iterator().next(), ConnectivityStateInfo.forNonError(READY));
+ verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
+ }
+
+ @Test
+ public void deterministicPickWithHostsPartiallyRemoved() {
+ RingHashConfig config = new RingHashConfig(10, 100);
+ List servers = createWeightedServerAddrs(1, 1, 1, 1, 1);
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ InOrder inOrder = Mockito.inOrder(helper);
+ inOrder.verify(helper, times(5)).createSubchannel(any(CreateSubchannelArgs.class));
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ // Bring all subchannels to READY so that next pick always succeeds.
+ for (Subchannel subchannel : subchannels.values()) {
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
+ inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+ }
+
+ // Simulate rpc hash hits one ring entry exactly for server1.
+ long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server1]_0");
+ PickSubchannelArgs args = new PickSubchannelArgsImpl(
+ TestMethodDescriptors.voidMethod(), new Metadata(),
+ CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash));
+ pickerCaptor.getValue().pickSubchannel(args);
+ PickResult result = pickerCaptor.getValue().pickSubchannel(args);
+ Subchannel subchannel = result.getSubchannel();
+ assertThat(subchannel.getAddresses()).isEqualTo(servers.get(1));
+
+ List updatedServers = new ArrayList<>();
+ for (EquivalentAddressGroup addr : servers.subList(0, 2)) { // only server0 and server1 left
+ Attributes attr = addr.getAttributes().toBuilder().set(CUSTOM_KEY, "custom value").build();
+ updatedServers.add(new EquivalentAddressGroup(addr.getAddresses(), attr));
+ }
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(updatedServers).setLoadBalancingPolicyConfig(config).build());
+ verify(subchannels.get(Collections.singletonList(servers.get(0))))
+ .updateAddresses(Collections.singletonList(updatedServers.get(0)));
+ verify(subchannels.get(Collections.singletonList(servers.get(1))))
+ .updateAddresses(Collections.singletonList(updatedServers.get(1)));
+ inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+ assertThat(pickerCaptor.getValue().pickSubchannel(args).getSubchannel())
+ .isSameInstanceAs(subchannel);
+ verifyNoMoreInteractions(helper);
+ }
+
+ @Test
+ public void deterministicPickWithNewHostsAdded() {
+ RingHashConfig config = new RingHashConfig(10, 100);
+ List servers = createWeightedServerAddrs(1, 1); // server0 and server1
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ InOrder inOrder = Mockito.inOrder(helper);
+ inOrder.verify(helper, times(2)).createSubchannel(any(CreateSubchannelArgs.class));
+ inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+
+ // Bring all subchannels to READY so that next pick always succeeds.
+ for (Subchannel subchannel : subchannels.values()) {
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
+ inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+ }
+
+ // Simulate rpc hash hits one ring entry exactly for server1.
+ long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server1]_0");
+ PickSubchannelArgs args = new PickSubchannelArgsImpl(
+ TestMethodDescriptors.voidMethod(), new Metadata(),
+ CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash));
+ pickerCaptor.getValue().pickSubchannel(args);
+ PickResult result = pickerCaptor.getValue().pickSubchannel(args);
+ Subchannel subchannel = result.getSubchannel();
+ assertThat(subchannel.getAddresses()).isEqualTo(servers.get(1));
+
+ servers = createWeightedServerAddrs(1, 1, 1, 1, 1); // server2, server3, server4 added
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ inOrder.verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
+ inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+ assertThat(pickerCaptor.getValue().pickSubchannel(args).getSubchannel())
+ .isSameInstanceAs(subchannel);
+ verifyNoMoreInteractions(helper);
+ }
+
+ @Test
+ public void skipFailingHosts_pickNextNonFailingHostInFirstTwoHosts() {
+ // Map each server address to exactly one ring entry.
+ RingHashConfig config = new RingHashConfig(3, 3);
+ List servers = createWeightedServerAddrs(1, 1, 1);
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
+ verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // initial IDLE
+ reset(helper);
+ // ring:
+ // "[FakeSocketAddress-server1]_0"
+ // "[FakeSocketAddress-server0]_0"
+ // "[FakeSocketAddress-server2]_0"
+
+ long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server0]_0");
+ PickSubchannelArgs args = new PickSubchannelArgsImpl(
+ TestMethodDescriptors.voidMethod(), new Metadata(),
+ CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash));
+
+ // Bring down server0 to force trying server2.
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(0))),
+ ConnectivityStateInfo.forTransientFailure(
+ Status.UNAVAILABLE.withDescription("unreachable")));
+ verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+
+ PickResult result = pickerCaptor.getValue().pickSubchannel(args);
+ assertThat(result.getStatus().isOk()).isTrue();
+ assertThat(result.getSubchannel()).isNull(); // buffer request
+ verify(subchannels.get(Collections.singletonList(servers.get(2))))
+ .requestConnection(); // kick off connection to server2
+
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(2))),
+ ConnectivityStateInfo.forNonError(CONNECTING));
+ verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
+
+ result = pickerCaptor.getValue().pickSubchannel(args);
+ assertThat(result.getStatus().isOk()).isTrue();
+ assertThat(result.getSubchannel()).isNull(); // buffer request
+
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(2))),
+ ConnectivityStateInfo.forNonError(READY));
+ verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+
+ result = pickerCaptor.getValue().pickSubchannel(args);
+ assertThat(result.getStatus().isOk()).isTrue();
+ assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(2));
+ }
+
+ @Test
+ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() {
+ // Map each server address to exactly one ring entry.
+ RingHashConfig config = new RingHashConfig(4, 4);
+ List servers = createWeightedServerAddrs(1, 1, 1, 1);
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ verify(helper, times(4)).createSubchannel(any(CreateSubchannelArgs.class));
+ verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // initial IDLE
+ reset(helper);
+ // ring:
+ // "[FakeSocketAddress-server3]_0"
+ // "[FakeSocketAddress-server1]_0"
+ // "[FakeSocketAddress-server0]_0"
+ // "[FakeSocketAddress-server2]_0"
+
+ long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server0]_0");
+ PickSubchannelArgs args = new PickSubchannelArgsImpl(
+ TestMethodDescriptors.voidMethod(), new Metadata(),
+ CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash));
+
+ // Bring down server0 and server2 to force trying other servers.
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(0))),
+ ConnectivityStateInfo.forTransientFailure(
+ Status.UNAVAILABLE.withDescription("unreachable")));
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(2))),
+ ConnectivityStateInfo.forTransientFailure(
+ Status.PERMISSION_DENIED.withDescription("permission denied")));
+ verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
+
+ PickResult result = pickerCaptor.getValue().pickSubchannel(args);
+ assertThat(result.getStatus().isOk()).isFalse(); // fail the RPC
+ assertThat(result.getStatus().getCode())
+ .isEqualTo(Code.UNAVAILABLE); // with error status for the original server hit by hash
+ assertThat(result.getStatus().getDescription()).isEqualTo("unreachable");
+ verify(subchannels.get(Collections.singletonList(servers.get(3))))
+ .requestConnection(); // kickoff connection to server3 (next first non-failing)
+ verify(subchannels.get(Collections.singletonList(servers.get(1))), never())
+ .requestConnection(); // no excessive connection
+
+ // Now connecting to server3.
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(3))),
+ ConnectivityStateInfo.forNonError(CONNECTING));
+ verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
+
+ result = pickerCaptor.getValue().pickSubchannel(args);
+ assertThat(result.getStatus().isOk()).isFalse(); // fail the RPC
+ assertThat(result.getStatus().getCode())
+ .isEqualTo(Code.UNAVAILABLE); // with error status for the original server hit by hash
+ assertThat(result.getStatus().getDescription()).isEqualTo("unreachable");
+ verify(subchannels.get(Collections.singletonList(servers.get(1))), never())
+ .requestConnection(); // no excessive connection (server3 connection already in progress)
+
+ // Simulate server1 becomes READY.
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(1))),
+ ConnectivityStateInfo.forNonError(READY));
+ verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
+
+ result = pickerCaptor.getValue().pickSubchannel(args);
+ assertThat(result.getStatus().isOk()).isTrue(); // succeed
+ assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); // with server1
+ }
+
+ @Test
+ public void allSubchannelsInTransientFailure() {
+ // Map each server address to exactly one ring entry.
+ RingHashConfig config = new RingHashConfig(3, 3);
+ List servers = createWeightedServerAddrs(1, 1, 1);
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
+ verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ // Bring all subchannels to TRANSIENT_FAILURE.
+ for (Subchannel subchannel : subchannels.values()) {
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(
+ Status.UNAVAILABLE.withDescription(
+ subchannel.getAddresses().getAddresses() + " unreachable")));
+ }
+ verify(helper, atLeastOnce())
+ .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
+
+ // Picking subchannel triggers connection. RPC hash hits server0.
+ PickSubchannelArgs args = new PickSubchannelArgsImpl(
+ TestMethodDescriptors.voidMethod(), new Metadata(),
+ CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ PickResult result = pickerCaptor.getValue().pickSubchannel(args);
+ assertThat(result.getStatus().isOk()).isFalse();
+ assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE);
+ assertThat(result.getStatus().getDescription())
+ .isEqualTo("[FakeSocketAddress-server0] unreachable");
+ }
+
+ @Test
+ public void hostSelectionProportionalToWeights() {
+ RingHashConfig config = new RingHashConfig(10000, 100000); // large ring
+ List servers = createWeightedServerAddrs(1, 10, 100); // 1:10:100
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
+ verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ // Bring all subchannels to READY.
+ Map pickCounts = new HashMap<>();
+ for (Subchannel subchannel : subchannels.values()) {
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
+ pickCounts.put(subchannel.getAddresses(), 0);
+ }
+ verify(helper, times(3)).updateBalancingState(eq(READY), pickerCaptor.capture());
+ SubchannelPicker picker = pickerCaptor.getValue();
+
+ for (int i = 0; i < 10000; i++) {
+ long hash = hashFunc.hashInt(i);
+ PickSubchannelArgs args = new PickSubchannelArgsImpl(
+ TestMethodDescriptors.voidMethod(), new Metadata(),
+ CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hash));
+ Subchannel pickedSubchannel = picker.pickSubchannel(args).getSubchannel();
+ EquivalentAddressGroup addr = pickedSubchannel.getAddresses();
+ pickCounts.put(addr, pickCounts.get(addr) + 1);
+ }
+
+ // Actual distribution: server0 = 104, server1 = 808, server2 = 9088
+ double ratio01 = (double) pickCounts.get(servers.get(0)) / pickCounts.get(servers.get(1));
+ double ratio12 = (double) pickCounts.get(servers.get(1)) / pickCounts.get(servers.get(2));
+ assertThat(ratio01).isWithin(0.03).of((double) 1 / 10);
+ assertThat(ratio12).isWithin(0.03).of((double) 10 / 100);
+ }
+
+ @Test
+ public void hostSelectionProportionalToRepeatedAddressCount() {
+ RingHashConfig config = new RingHashConfig(10000, 100000);
+ List servers = createRepeatedServerAddrs(1, 10, 100); // 1:10:100
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
+ verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ // Bring all subchannels to READY.
+ Map pickCounts = new HashMap<>();
+ for (Subchannel subchannel : subchannels.values()) {
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
+ pickCounts.put(subchannel.getAddresses(), 0);
+ }
+ verify(helper, times(3)).updateBalancingState(eq(READY), pickerCaptor.capture());
+ SubchannelPicker picker = pickerCaptor.getValue();
+
+ for (int i = 0; i < 10000; i++) {
+ long hash = hashFunc.hashInt(i);
+ PickSubchannelArgs args = new PickSubchannelArgsImpl(
+ TestMethodDescriptors.voidMethod(), new Metadata(),
+ CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hash));
+ Subchannel pickedSubchannel = picker.pickSubchannel(args).getSubchannel();
+ EquivalentAddressGroup addr = pickedSubchannel.getAddresses();
+ pickCounts.put(addr, pickCounts.get(addr) + 1);
+ }
+
+ // Actual distribution: server0 = 104, server1 = 808, server2 = 9088
+ double ratio01 = (double) pickCounts.get(servers.get(0)) / pickCounts.get(servers.get(1));
+ double ratio12 = (double) pickCounts.get(servers.get(1)) / pickCounts.get(servers.get(11));
+ assertThat(ratio01).isWithin(0.03).of((double) 1 / 10);
+ assertThat(ratio12).isWithin(0.03).of((double) 10 / 100);
+ }
+
+ @Test
+ public void nameResolutionErrorWithNoActiveSubchannels() {
+ Status error = Status.UNAVAILABLE.withDescription("not reachable");
+ loadBalancer.handleNameResolutionError(error);
+ verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
+ PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class));
+ assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE);
+ assertThat(result.getStatus().getDescription()).isEqualTo("not reachable");
+ assertThat(result.getSubchannel()).isNull();
+ verifyNoMoreInteractions(helper);
+ }
+
+ @Test
+ public void nameResolutionErrorWithActiveSubchannels() {
+ RingHashConfig config = new RingHashConfig(10, 100);
+ List servers = createWeightedServerAddrs(1);
+ loadBalancer.handleResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
+ verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+
+ // Picking subchannel triggers subchannel creation and connection.
+ PickSubchannelArgs args = new PickSubchannelArgsImpl(
+ TestMethodDescriptors.voidMethod(), new Metadata(),
+ CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ pickerCaptor.getValue().pickSubchannel(args);
+ deliverSubchannelState(
+ Iterables.getOnlyElement(subchannels.values()), ConnectivityStateInfo.forNonError(READY));
+ verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
+
+ loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("target not found"));
+ verifyNoMoreInteractions(helper);
+ }
+
+ private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo state) {
+ subchannelStateListeners.get(subchannel).onSubchannelState(state);
+ }
+
+ private static List createWeightedServerAddrs(long... weights) {
+ List addrs = new ArrayList<>();
+ for (int i = 0; i < weights.length; i++) {
+ SocketAddress addr = new FakeSocketAddress("server" + i);
+ Attributes attr = Attributes.newBuilder().set(
+ InternalXdsAttributes.ATTR_SERVER_WEIGHT, weights[i]).build();
+ EquivalentAddressGroup eag = new EquivalentAddressGroup(addr, attr);
+ addrs.add(eag);
+ }
+ return addrs;
+ }
+
+ private static List createRepeatedServerAddrs(long... weights) {
+ List addrs = new ArrayList<>();
+ for (int i = 0; i < weights.length; i++) {
+ SocketAddress addr = new FakeSocketAddress("server" + i);
+ for (int j = 0; j < weights[i]; j++) {
+ EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
+ addrs.add(eag);
+ }
+ }
+ return addrs;
+ }
+
+ private static class FakeSocketAddress extends SocketAddress {
+ private final String name;
+
+ FakeSocketAddress(String name) {
+ this.name = name;
+ }
+
+ @Override
+ public int hashCode() {
+ return name.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof FakeSocketAddress)) {
+ return false;
+ }
+ return name.equals(((FakeSocketAddress) other).name);
+ }
+
+ @Override
+ public String toString() {
+ return "FakeSocketAddress-" + name;
+ }
+ }
+}