diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index 7735199134f..7a346e01871 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -32,6 +32,7 @@ import io.grpc.xds.CdsLoadBalancerProvider.CdsConfig; import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism; +import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import io.grpc.xds.XdsClient.CdsResourceWatcher; import io.grpc.xds.XdsClient.CdsUpdate; @@ -190,6 +191,10 @@ private void handleClusterDiscovered() { lbProvider = lbRegistry.getProvider("ring_hash_experimental"); lbConfig = new RingHashConfig(root.result.minRingSize(), root.result.maxRingSize()); } + if (root.result.lbPolicy() == LbPolicy.LEAST_REQUEST) { + lbProvider = lbRegistry.getProvider("least_request_experimental"); + lbConfig = new LeastRequestConfig(root.result.choiceCount()); + } if (lbProvider == null) { lbProvider = lbRegistry.getProvider("round_robin"); lbConfig = null; diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 72f5db82ed0..1e090e164f4 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -44,6 +44,7 @@ import io.envoyproxy.envoy.config.cluster.v3.Cluster.CustomClusterType; import io.envoyproxy.envoy.config.cluster.v3.Cluster.DiscoveryType; import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.LeastRequestLbConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig; import io.envoyproxy.envoy.config.core.v3.HttpProtocolOptions; import io.envoyproxy.envoy.config.core.v3.RoutingPriority; @@ -140,6 +141,8 @@ final class ClientXdsClient extends XdsClient implements XdsResponseHandler, Res @VisibleForTesting static final long DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE = 8 * 1024 * 1024L; @VisibleForTesting + static final int DEFAULT_LEAST_REQUEST_CHOICE_COUNT = 2; + @VisibleForTesting static final long MAX_RING_HASH_LB_POLICY_RING_SIZE = 8 * 1024 * 1024L; @VisibleForTesting static final String AGGREGATE_CLUSTER_TYPE_NAME = "envoy.clusters.aggregate"; @@ -161,6 +164,11 @@ final class ClientXdsClient extends XdsClient implements XdsResponseHandler, Res static boolean enableRouteLookup = !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_XDS_RLS_LB")) && Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_XDS_RLS_LB")); + @VisibleForTesting + static boolean enableLeastRequest = + !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) + ? Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_ENABLE_LEAST_REQUEST")) + : Boolean.parseBoolean(System.getProperty("io.grpc.xds.experimentalEnableLeastRequest")); private static final String TYPE_URL_HTTP_CONNECTION_MANAGER_V2 = "type.googleapis.com/envoy.config.filter.network.http_connection_manager.v2" + ".HttpConnectionManager"; @@ -1616,6 +1624,17 @@ static CdsUpdate processCluster(Cluster cluster, Set retainedEdsResource updateBuilder.ringHashLbPolicy(minRingSize, maxRingSize); } else if (cluster.getLbPolicy() == LbPolicy.ROUND_ROBIN) { updateBuilder.roundRobinLbPolicy(); + } else if (enableLeastRequest && cluster.getLbPolicy() == LbPolicy.LEAST_REQUEST) { + LeastRequestLbConfig lbConfig = cluster.getLeastRequestLbConfig(); + int choiceCount = + lbConfig.hasChoiceCount() + ? lbConfig.getChoiceCount().getValue() + : DEFAULT_LEAST_REQUEST_CHOICE_COUNT; + if (choiceCount < DEFAULT_LEAST_REQUEST_CHOICE_COUNT) { + throw new ResourceInvalidException( + "Cluster " + cluster.getName() + ": invalid least_request_lb_config: " + lbConfig); + } + updateBuilder.leastRequestLbPolicy(choiceCount); } else { throw new ResourceInvalidException( "Cluster " + cluster.getName() + ": unsupported lb policy: " + cluster.getLbPolicy()); diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java index e4800cf668c..309daf55a18 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java @@ -672,7 +672,7 @@ private static PriorityChildConfig generateDnsBasedPriorityChildConfig( * Generates configs to be used in the priority LB policy for priorities in an EDS cluster. * *

priority LB -> cluster_impl LB (one per priority) -> (weighted_target LB - * -> round_robin (one per locality)) / ring_hash_experimental + * -> round_robin / least_request_experimental (one per locality)) / ring_hash_experimental */ private static Map generateEdsBasedPriorityChildConfigs( String cluster, @Nullable String edsServiceName, @Nullable ServerInfo lrsServerInfo, @@ -684,13 +684,14 @@ private static Map generateEdsBasedPriorityChildCon for (String priority : prioritizedLocalityWeights.keySet()) { PolicySelection leafPolicy = endpointLbPolicy; // Depending on the endpoint-level load balancing policy, different LB hierarchy may be - // created. If the endpoint-level LB policy is round_robin, it creates a two-level LB - // hierarchy: a locality-level LB policy that balances load according to locality weights - // followed by an endpoint-level LB policy that simply rounds robin the endpoints within - // the locality. If the endpoint-level LB policy is ring_hash_experimental, it creates - // a unified LB policy that balances load by weighing the product of each endpoint's weight - // and the weight of the locality it belongs to. - if (endpointLbPolicy.getProvider().getPolicyName().equals("round_robin")) { + // created. If the endpoint-level LB policy is round_robin or least_request_experimental, + // it creates a two-level LB hierarchy: a locality-level LB policy that balances load + // according to locality weights followed by an endpoint-level LB policy that balances load + // between endpoints within the locality. If the endpoint-level LB policy is + // ring_hash_experimental, it creates a unified LB policy that balances load by weighing the + // product of each endpoint's weight and the weight of the locality it belongs to. + if (endpointLbPolicy.getProvider().getPolicyName().equals("round_robin") + || endpointLbPolicy.getProvider().getPolicyName().equals("least_request_experimental")) { Map localityWeights = prioritizedLocalityWeights.get(priority); Map targets = new HashMap<>(); for (Locality locality : localityWeights.keySet()) { diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java index 4aaf0dcadde..6f6f887e925 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java @@ -68,7 +68,8 @@ public LoadBalancer newLoadBalancer(Helper helper) { static final class ClusterResolverConfig { // Ordered list of clusters to be resolved. final List discoveryMechanisms; - // Endpoint-level load balancing policy with config (round_robin or ring_hash_experimental). + // Endpoint-level load balancing policy with config + // (round_robin, least_request_experimental or ring_hash_experimental). final PolicySelection lbPolicy; ClusterResolverConfig(List discoveryMechanisms, PolicySelection lbPolicy) { diff --git a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java new file mode 100644 index 00000000000..584ac2dd16f --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java @@ -0,0 +1,430 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.IDLE; +import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.SHUTDOWN; +import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.xds.LeastRequestLoadBalancerProvider.DEFAULT_CHOICE_COUNT; +import static io.grpc.xds.LeastRequestLoadBalancerProvider.MAX_CHOICE_COUNT; +import static io.grpc.xds.LeastRequestLoadBalancerProvider.MIN_CHOICE_COUNT; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import com.google.common.base.Objects; +import com.google.common.base.Preconditions; +import io.grpc.Attributes; +import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nonnull; + +/** + * A {@link LoadBalancer} that provides least request load balancing based on + * outstanding request counters. + * It works by sampling a number of subchannels and picking the one with the + * fewest amount of outstanding requests. + * The default sampling amount of two is also known as + * the "power of two choices" (P2C). + */ +final class LeastRequestLoadBalancer extends LoadBalancer { + @VisibleForTesting + static final Attributes.Key> STATE_INFO = + Attributes.Key.create("state-info"); + @VisibleForTesting + static final Attributes.Key IN_FLIGHTS = + Attributes.Key.create("in-flights"); + + private final Helper helper; + private final ThreadSafeRandom random; + private final Map subchannels = + new HashMap<>(); + + private ConnectivityState currentState; + private LeastRequestPicker currentPicker = new EmptyPicker(EMPTY_OK); + private int choiceCount = DEFAULT_CHOICE_COUNT; + + LeastRequestLoadBalancer(Helper helper) { + this(helper, ThreadSafeRandomImpl.instance); + } + + @VisibleForTesting + LeastRequestLoadBalancer(Helper helper, ThreadSafeRandom random) { + this.helper = checkNotNull(helper, "helper"); + this.random = checkNotNull(random, "random"); + } + + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + LeastRequestConfig config = + (LeastRequestConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + // Config may be null if least_request is used outside xDS + if (config != null) { + choiceCount = config.choiceCount; + } + + List servers = resolvedAddresses.getAddresses(); + Set currentAddrs = subchannels.keySet(); + Map latestAddrs = stripAttrs(servers); + Set removedAddrs = setsDifference(currentAddrs, latestAddrs.keySet()); + + for (Map.Entry latestEntry : + latestAddrs.entrySet()) { + EquivalentAddressGroup strippedAddressGroup = latestEntry.getKey(); + EquivalentAddressGroup originalAddressGroup = latestEntry.getValue(); + Subchannel existingSubchannel = subchannels.get(strippedAddressGroup); + if (existingSubchannel != null) { + // EAG's Attributes may have changed. + existingSubchannel.updateAddresses(Collections.singletonList(originalAddressGroup)); + continue; + } + // Create new subchannels for new addresses. + Attributes.Builder subchannelAttrs = Attributes.newBuilder() + .set(STATE_INFO, new Ref<>(ConnectivityStateInfo.forNonError(IDLE))) + // Used to track the in flight requests on this particular subchannel + .set(IN_FLIGHTS, new AtomicInteger(0)); + + final Subchannel subchannel = checkNotNull( + helper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(originalAddressGroup) + .setAttributes(subchannelAttrs.build()) + .build()), + "subchannel"); + subchannel.start(new SubchannelStateListener() { + @Override + public void onSubchannelState(ConnectivityStateInfo state) { + processSubchannelState(subchannel, state); + } + }); + subchannels.put(strippedAddressGroup, subchannel); + subchannel.requestConnection(); + } + + ArrayList removedSubchannels = new ArrayList<>(); + for (EquivalentAddressGroup addressGroup : removedAddrs) { + removedSubchannels.add(subchannels.remove(addressGroup)); + } + + // Update the picker before shutting down the subchannels, to reduce the chance of the race + // between picking a subchannel and shutting it down. + updateBalancingState(); + + // Shutdown removed subchannels + for (Subchannel removedSubchannel : removedSubchannels) { + shutdownSubchannel(removedSubchannel); + } + } + + @Override + public void handleNameResolutionError(Status error) { + if (currentState != READY) { + updateBalancingState(TRANSIENT_FAILURE, new EmptyPicker(error)); + } + } + + private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { + if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) { + return; + } + if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) { + helper.refreshNameResolution(); + } + if (stateInfo.getState() == IDLE) { + subchannel.requestConnection(); + } + Ref subchannelStateRef = getSubchannelStateInfoRef(subchannel); + if (subchannelStateRef.value.getState().equals(TRANSIENT_FAILURE)) { + if (stateInfo.getState().equals(CONNECTING) || stateInfo.getState().equals(IDLE)) { + return; + } + } + subchannelStateRef.value = stateInfo; + updateBalancingState(); + } + + private void shutdownSubchannel(Subchannel subchannel) { + subchannel.shutdown(); + getSubchannelStateInfoRef(subchannel).value = + ConnectivityStateInfo.forNonError(SHUTDOWN); + } + + @Override + public void shutdown() { + for (Subchannel subchannel : getSubchannels()) { + shutdownSubchannel(subchannel); + } + subchannels.clear(); + } + + private static final Status EMPTY_OK = Status.OK.withDescription("no subchannels ready"); + + /** + * Updates picker with the list of active subchannels (state == READY). + */ + @SuppressWarnings("ReferenceEquality") + private void updateBalancingState() { + List activeList = filterNonFailingSubchannels(getSubchannels()); + if (activeList.isEmpty()) { + // No READY subchannels, determine aggregate state and error status + boolean isConnecting = false; + Status aggStatus = EMPTY_OK; + for (Subchannel subchannel : getSubchannels()) { + ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).value; + // This subchannel IDLE is not because of channel IDLE_TIMEOUT, + // in which case LB is already shutdown. + // LRLB will request connection immediately on subchannel IDLE. + if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) { + isConnecting = true; + } + if (aggStatus == EMPTY_OK || !aggStatus.isOk()) { + aggStatus = stateInfo.getStatus(); + } + } + updateBalancingState(isConnecting ? CONNECTING : TRANSIENT_FAILURE, + // If all subchannels are TRANSIENT_FAILURE, return the Status associated with + // an arbitrary subchannel, otherwise return OK. + new EmptyPicker(aggStatus)); + } else { + updateBalancingState(READY, new ReadyPicker(activeList, choiceCount, random)); + } + } + + private void updateBalancingState(ConnectivityState state, LeastRequestPicker picker) { + if (state != currentState || !picker.isEquivalentTo(currentPicker)) { + helper.updateBalancingState(state, picker); + currentState = state; + currentPicker = picker; + } + } + + /** + * Filters out non-ready subchannels. + */ + private static List filterNonFailingSubchannels( + Collection subchannels) { + List readySubchannels = new ArrayList<>(subchannels.size()); + for (Subchannel subchannel : subchannels) { + if (isReady(subchannel)) { + readySubchannels.add(subchannel); + } + } + return readySubchannels; + } + + /** + * Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and + * remove all attributes. The values are the original EAGs. + */ + private static Map stripAttrs( + List groupList) { + Map addrs = new HashMap<>(groupList.size() * 2); + for (EquivalentAddressGroup group : groupList) { + addrs.put(stripAttrs(group), group); + } + return addrs; + } + + private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { + return new EquivalentAddressGroup(eag.getAddresses()); + } + + @VisibleForTesting + Collection getSubchannels() { + return subchannels.values(); + } + + private static Ref getSubchannelStateInfoRef( + Subchannel subchannel) { + return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO"); + } + + private static AtomicInteger getInFlights(Subchannel subchannel) { + return checkNotNull(subchannel.getAttributes().get(IN_FLIGHTS), "IN_FLIGHTS"); + } + + // package-private to avoid synthetic access + static boolean isReady(Subchannel subchannel) { + return getSubchannelStateInfoRef(subchannel).value.getState() == READY; + } + + private static Set setsDifference(Set a, Set b) { + Set aCopy = new HashSet<>(a); + aCopy.removeAll(b); + return aCopy; + } + + // Only subclasses are ReadyPicker or EmptyPicker + private abstract static class LeastRequestPicker extends SubchannelPicker { + abstract boolean isEquivalentTo(LeastRequestPicker picker); + } + + @VisibleForTesting + static final class ReadyPicker extends LeastRequestPicker { + private final List list; // non-empty + private final int choiceCount; + private final ThreadSafeRandom random; + + ReadyPicker(List list, int choiceCount, ThreadSafeRandom random) { + checkArgument(!list.isEmpty(), "empty list"); + this.list = list; + this.choiceCount = choiceCount; + this.random = checkNotNull(random, "random"); + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + final Subchannel subchannel = nextSubchannel(); + final OutstandingRequestsTracingFactory factory = + new OutstandingRequestsTracingFactory(getInFlights(subchannel)); + return PickResult.withSubchannel(subchannel, factory); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(ReadyPicker.class) + .add("list", list) + .add("choiceCount", choiceCount) + .toString(); + } + + private Subchannel nextSubchannel() { + Subchannel candidate = list.get(random.nextInt(list.size())); + for (int i = 0; i < choiceCount - 1; ++i) { + Subchannel sampled = list.get(random.nextInt(list.size())); + if (getInFlights(sampled).get() < getInFlights(candidate).get()) { + candidate = sampled; + } + } + return candidate; + } + + @VisibleForTesting + List getList() { + return list; + } + + @Override + boolean isEquivalentTo(LeastRequestPicker picker) { + if (!(picker instanceof ReadyPicker)) { + return false; + } + ReadyPicker other = (ReadyPicker) picker; + // the lists cannot contain duplicate subchannels + return other == this + || ((list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list)) + && choiceCount == other.choiceCount); + } + } + + @VisibleForTesting + static final class EmptyPicker extends LeastRequestPicker { + + private final Status status; + + EmptyPicker(@Nonnull Status status) { + this.status = Preconditions.checkNotNull(status, "status"); + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + return status.isOk() ? PickResult.withNoResult() : PickResult.withError(status); + } + + @Override + boolean isEquivalentTo(LeastRequestPicker picker) { + return picker instanceof EmptyPicker && (Objects.equal(status, ((EmptyPicker) picker).status) + || (status.isOk() && ((EmptyPicker) picker).status.isOk())); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(EmptyPicker.class).add("status", status).toString(); + } + } + + /** + * A lighter weight Reference than AtomicReference. + */ + static final class Ref { + T value; + + Ref(T value) { + this.value = value; + } + } + + private static final class OutstandingRequestsTracingFactory extends + ClientStreamTracer.Factory { + private final AtomicInteger inFlights; + + private OutstandingRequestsTracingFactory(AtomicInteger inFlights) { + this.inFlights = checkNotNull(inFlights, "inFlights"); + } + + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return new ClientStreamTracer() { + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + inFlights.incrementAndGet(); + } + + @Override + public void streamClosed(Status status) { + inFlights.decrementAndGet(); + } + }; + } + } + + static final class LeastRequestConfig { + final int choiceCount; + + LeastRequestConfig(int choiceCount) { + checkArgument(choiceCount >= MIN_CHOICE_COUNT, "choiceCount <= 1"); + // Even though a choiceCount value larger than 2 is currently considered valid in xDS + // we restrict it to 10 here as specified in "A48: xDS Least Request LB Policy". + this.choiceCount = Math.min(choiceCount, MAX_CHOICE_COUNT); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("choiceCount", choiceCount) + .toString(); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancerProvider.java new file mode 100644 index 00000000000..3abac1d2f0d --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancerProvider.java @@ -0,0 +1,80 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Internal; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status; +import io.grpc.internal.JsonUtil; +import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; +import java.util.Map; + +/** + * Provider for the "least_request_experimental" balancing policy. + */ +@Internal +public final class LeastRequestLoadBalancerProvider extends LoadBalancerProvider { + // Minimum number of choices allowed. + static final int MIN_CHOICE_COUNT = 2; + // Maximum number of choices allowed. + static final int MAX_CHOICE_COUNT = 10; + // Same as ClientXdsClient.DEFAULT_LEAST_REQUEST_CHOICE_COUNT + @VisibleForTesting + static final Integer DEFAULT_CHOICE_COUNT = 2; + + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return new LeastRequestLoadBalancer(helper); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "least_request_experimental"; + } + + @Override + public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { + try { + Integer choiceCount = JsonUtil.getNumberAsInteger(rawConfig, "choiceCount"); + if (choiceCount == null) { + choiceCount = DEFAULT_CHOICE_COUNT; + } + if (choiceCount < MIN_CHOICE_COUNT) { + return ConfigOrError.fromError(Status.INVALID_ARGUMENT.withDescription( + "Invalid 'choiceCount'")); + } + return ConfigOrError.fromConfig(new LeastRequestConfig(choiceCount)); + } catch (RuntimeException e) { + return ConfigOrError.fromError( + Status.fromThrowable(e).withDescription( + "Failed to parse least_request_experimental LB config: " + rawConfig)); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsClient.java b/xds/src/main/java/io/grpc/xds/XdsClient.java index 6f2a661361e..789e576da06 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClient.java +++ b/xds/src/main/java/io/grpc/xds/XdsClient.java @@ -120,6 +120,9 @@ abstract static class CdsUpdate implements ResourceUpdate { // Only valid if lbPolicy is "ring_hash_experimental". abstract long maxRingSize(); + // Only valid if lbPolicy is "least_request_experimental". + abstract int choiceCount(); + // Alternative resource name to be used in EDS requests. /// Only valid for EDS cluster. @Nullable @@ -158,6 +161,7 @@ static Builder forAggregate(String clusterName, List prioritizedClusterN .clusterType(ClusterType.AGGREGATE) .minRingSize(0) .maxRingSize(0) + .choiceCount(0) .prioritizedClusterNames(ImmutableList.copyOf(prioritizedClusterNames)); } @@ -169,6 +173,7 @@ static Builder forEds(String clusterName, @Nullable String edsServiceName, .clusterType(ClusterType.EDS) .minRingSize(0) .maxRingSize(0) + .choiceCount(0) .edsServiceName(edsServiceName) .lrsServerInfo(lrsServerInfo) .maxConcurrentRequests(maxConcurrentRequests) @@ -183,6 +188,7 @@ static Builder forLogicalDns(String clusterName, String dnsHostName, .clusterType(ClusterType.LOGICAL_DNS) .minRingSize(0) .maxRingSize(0) + .choiceCount(0) .dnsHostName(dnsHostName) .lrsServerInfo(lrsServerInfo) .maxConcurrentRequests(maxConcurrentRequests) @@ -194,7 +200,7 @@ enum ClusterType { } enum LbPolicy { - ROUND_ROBIN, RING_HASH + ROUND_ROBIN, RING_HASH, LEAST_REQUEST } // FIXME(chengyuanzhang): delete this after UpstreamTlsContext's toString() is fixed. @@ -206,6 +212,7 @@ public final String toString() { .add("lbPolicy", lbPolicy()) .add("minRingSize", minRingSize()) .add("maxRingSize", maxRingSize()) + .add("choiceCount", choiceCount()) .add("edsServiceName", edsServiceName()) .add("dnsHostName", dnsHostName()) .add("lrsServerInfo", lrsServerInfo()) @@ -234,6 +241,13 @@ Builder ringHashLbPolicy(long minRingSize, long maxRingSize) { return this.lbPolicy(LbPolicy.RING_HASH).minRingSize(minRingSize).maxRingSize(maxRingSize); } + Builder leastRequestLbPolicy(int choiceCount) { + return this.lbPolicy(LbPolicy.LEAST_REQUEST).choiceCount(choiceCount); + } + + // Private, use leastRequestLbPolicy(int). + protected abstract Builder choiceCount(int choiceCount); + // Private, use ringHashLbPolicy(long, long). protected abstract Builder minRingSize(long minRingSize); diff --git a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider index 7ba3dcf22f5..8e5c2dd1c6a 100644 --- a/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider +++ b/xds/src/main/resources/META-INF/services/io.grpc.LoadBalancerProvider @@ -4,4 +4,5 @@ io.grpc.xds.WeightedTargetLoadBalancerProvider io.grpc.xds.ClusterManagerLoadBalancerProvider io.grpc.xds.ClusterResolverLoadBalancerProvider io.grpc.xds.ClusterImplLoadBalancerProvider +io.grpc.xds.LeastRequestLoadBalancerProvider io.grpc.xds.RingHashLoadBalancerProvider diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index 7664660d4c6..78e6d6473ca 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -48,6 +48,7 @@ import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig; import io.grpc.xds.ClusterResolverLoadBalancerProvider.ClusterResolverConfig.DiscoveryMechanism; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import io.grpc.xds.XdsClient.CdsUpdate; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; @@ -121,6 +122,7 @@ public void setUp() { lbRegistry.register(new FakeLoadBalancerProvider(CLUSTER_RESOLVER_POLICY_NAME)); lbRegistry.register(new FakeLoadBalancerProvider("round_robin")); lbRegistry.register(new FakeLoadBalancerProvider("ring_hash_experimental")); + lbRegistry.register(new FakeLoadBalancerProvider("least_request_experimental")); loadBalancer = new CdsLoadBalancer2(helper, lbRegistry); loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder() @@ -164,7 +166,7 @@ public void discoverTopLevelEdsCluster() { public void discoverTopLevelLogicalDnsCluster() { CdsUpdate update = CdsUpdate.forLogicalDns(CLUSTER, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext) - .roundRobinLbPolicy().build(); + .leastRequestLbPolicy(3).build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); @@ -174,7 +176,9 @@ public void discoverTopLevelLogicalDnsCluster() { DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext); - assertThat(childLbConfig.lbPolicy.getProvider().getPolicyName()).isEqualTo("round_robin"); + assertThat(childLbConfig.lbPolicy.getProvider().getPolicyName()) + .isEqualTo("least_request_experimental"); + assertThat(((LeastRequestConfig) childLbConfig.lbPolicy.getConfig()).choiceCount).isEqualTo(3); } @Test diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index 940d96fd10e..33dbc622ac8 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -37,6 +37,7 @@ import io.envoyproxy.envoy.config.cluster.v3.Cluster.DiscoveryType; import io.envoyproxy.envoy.config.cluster.v3.Cluster.EdsClusterConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.LeastRequestLbConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig.HashFunction; import io.envoyproxy.envoy.config.core.v3.Address; @@ -156,6 +157,7 @@ public class ClientXdsClientDataTest { private boolean originalEnableRetry; private boolean originalEnableRbac; private boolean originalEnableRouteLookup; + private boolean originalEnableLeastRequest; @Before public void setUp() { @@ -165,6 +167,8 @@ public void setUp() { assertThat(originalEnableRbac).isTrue(); originalEnableRouteLookup = ClientXdsClient.enableRouteLookup; assertThat(originalEnableRouteLookup).isFalse(); + originalEnableLeastRequest = ClientXdsClient.enableLeastRequest; + assertThat(originalEnableLeastRequest).isFalse(); } @After @@ -172,6 +176,7 @@ public void tearDown() { ClientXdsClient.enableRetry = originalEnableRetry; ClientXdsClient.enableRbac = originalEnableRbac; ClientXdsClient.enableRouteLookup = originalEnableRouteLookup; + ClientXdsClient.enableLeastRequest = originalEnableLeastRequest; } @Test @@ -1667,6 +1672,28 @@ public void parseCluster_ringHashLbPolicy_defaultLbConfig() throws ResourceInval .isEqualTo(ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE); } + @Test + public void parseCluster_leastRequestLbPolicy_defaultLbConfig() throws ResourceInvalidException { + ClientXdsClient.enableLeastRequest = true; + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.LEAST_REQUEST) + .build(); + + CdsUpdate update = ClientXdsClient.processCluster( + cluster, new HashSet(), null, LRS_SERVER_INFO); + assertThat(update.lbPolicy()).isEqualTo(CdsUpdate.LbPolicy.LEAST_REQUEST); + assertThat(update.choiceCount()) + .isEqualTo(ClientXdsClient.DEFAULT_LEAST_REQUEST_CHOICE_COUNT); + } + @Test public void parseCluster_transportSocketMatches_exception() throws ResourceInvalidException { Cluster cluster = Cluster.newBuilder() @@ -1741,6 +1768,31 @@ public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_tooLargeRingSize ClientXdsClient.processCluster(cluster, new HashSet(), null, LRS_SERVER_INFO); } + @Test + public void parseCluster_leastRequestLbPolicy_invalidChoiceCountConfig_tooSmallChoiceCount() + throws ResourceInvalidException { + ClientXdsClient.enableLeastRequest = true; + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.LEAST_REQUEST) + .setLeastRequestLbConfig( + LeastRequestLbConfig.newBuilder() + .setChoiceCount(UInt32Value.newBuilder().setValue(1)) + ) + .build(); + + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("Cluster cluster-foo.googleapis.com: invalid least_request_lb_config"); + ClientXdsClient.processCluster(cluster, new HashSet(), null, LRS_SERVER_INFO); + } + @Test public void parseServerSideListener_invalidTrafficDirection() throws ResourceInvalidException { Listener listener = diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java index 48db92b4de5..9e81f45cfc5 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java @@ -205,8 +205,8 @@ public long currentTimeNanos() { // CDS test resources. private final Any testClusterRoundRobin = - Any.pack(mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, false, null, - "envoy.transport_sockets.tls", null + Any.pack(mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, + null, false, null, "envoy.transport_sockets.tls", null )); // EDS test resources. @@ -258,6 +258,7 @@ public long currentTimeNanos() { private ClientXdsClient xdsClient; private boolean originalEnableFaultInjection; private boolean originalEnableRbac; + private boolean originalEnableLeastRequest; @Before public void setUp() throws IOException { @@ -272,6 +273,8 @@ public void setUp() throws IOException { ClientXdsClient.enableFaultInjection = true; originalEnableRbac = ClientXdsClient.enableRbac; assertThat(originalEnableRbac).isTrue(); + originalEnableLeastRequest = ClientXdsClient.enableLeastRequest; + ClientXdsClient.enableLeastRequest = true; final String serverName = InProcessServerBuilder.generateName(); cleanupRule.register( InProcessServerBuilder @@ -345,6 +348,7 @@ SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS, useProtocolV3()))))) public void tearDown() { ClientXdsClient.enableFaultInjection = originalEnableFaultInjection; ClientXdsClient.enableRbac = originalEnableRbac; + ClientXdsClient.enableLeastRequest = originalEnableLeastRequest; xdsClient.shutdown(); channel.shutdown(); // channel not owned by XdsClient assertThat(adsEnded.get()).isTrue(); @@ -1264,9 +1268,9 @@ public void cdsResourceNotFound() { List clusters = ImmutableList.of( Any.pack(mf.buildEdsCluster("cluster-bar.googleapis.com", null, "round_robin", null, - false, null, "envoy.transport_sockets.tls", null)), + null, false, null, "envoy.transport_sockets.tls", null)), Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, - false, null, "envoy.transport_sockets.tls", null))); + null, false, null, "envoy.transport_sockets.tls", null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); // Client sent an ACK CDS request. @@ -1340,13 +1344,13 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { // CDS -> {A, B, C}, version 1 ImmutableMap resourcesV1 = ImmutableMap.of( - "A", Any.pack(mf.buildEdsCluster("A", "A.1", "round_robin", null, false, null, + "A", Any.pack(mf.buildEdsCluster("A", "A.1", "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null )), - "B", Any.pack(mf.buildEdsCluster("B", "B.1", "round_robin", null, false, null, + "B", Any.pack(mf.buildEdsCluster("B", "B.1", "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null )), - "C", Any.pack(mf.buildEdsCluster("C", "C.1", "round_robin", null, false, null, + "C", Any.pack(mf.buildEdsCluster("C", "C.1", "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null ))); call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); @@ -1359,7 +1363,7 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { // CDS -> {A, B}, version 2 // Failed to parse endpoint B ImmutableMap resourcesV2 = ImmutableMap.of( - "A", Any.pack(mf.buildEdsCluster("A", "A.2", "round_robin", null, false, null, + "A", Any.pack(mf.buildEdsCluster("A", "A.2", "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null )), "B", Any.pack(mf.buildClusterInvalid("B"))); @@ -1376,10 +1380,10 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { // CDS -> {B, C} version 3 ImmutableMap resourcesV3 = ImmutableMap.of( - "B", Any.pack(mf.buildEdsCluster("B", "B.3", "round_robin", null, false, null, + "B", Any.pack(mf.buildEdsCluster("B", "B.3", "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null )), - "C", Any.pack(mf.buildEdsCluster("C", "C.3", "round_robin", null, false, null, + "C", Any.pack(mf.buildEdsCluster("C", "C.3", "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null ))); call.sendResponse(CDS, resourcesV3.values().asList(), VERSION_3, "0002"); @@ -1412,13 +1416,13 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscripti // CDS -> {A, B, C}, version 1 ImmutableMap resourcesV1 = ImmutableMap.of( - "A", Any.pack(mf.buildEdsCluster("A", "A.1", "round_robin", null, false, null, + "A", Any.pack(mf.buildEdsCluster("A", "A.1", "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null )), - "B", Any.pack(mf.buildEdsCluster("B", "B.1", "round_robin", null, false, null, + "B", Any.pack(mf.buildEdsCluster("B", "B.1", "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null )), - "C", Any.pack(mf.buildEdsCluster("C", "C.1", "round_robin", null, false, null, + "C", Any.pack(mf.buildEdsCluster("C", "C.1", "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null ))); call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); @@ -1444,7 +1448,7 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscripti // CDS -> {A, B}, version 2 // Failed to parse endpoint B ImmutableMap resourcesV2 = ImmutableMap.of( - "A", Any.pack(mf.buildEdsCluster("A", "A.2", "round_robin", null, false, null, + "A", Any.pack(mf.buildEdsCluster("A", "A.2", "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null )), "B", Any.pack(mf.buildClusterInvalid("B"))); @@ -1489,13 +1493,40 @@ public void cdsResourceFound() { verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); } + @Test + public void cdsResourceFound_leastRequestLbPolicy() { + DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + Message leastRequestConfig = mf.buildLeastRequestLbConfig(3); + Any clusterRingHash = Any.pack( + mf.buildEdsCluster(CDS_RESOURCE, null, "least_request_experimental", null, + leastRequestConfig, false, null, "envoy.transport_sockets.tls", null + )); + call.sendResponse(ResourceType.CDS, clusterRingHash, VERSION_1, "0000"); + + // Client sent an ACK CDS request. + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); + CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); + assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.EDS); + assertThat(cdsUpdate.edsServiceName()).isNull(); + assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.LEAST_REQUEST); + assertThat(cdsUpdate.choiceCount()).isEqualTo(3); + assertThat(cdsUpdate.lrsServerInfo()).isNull(); + assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); + assertThat(cdsUpdate.upstreamTlsContext()).isNull(); + assertThat(fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterRingHash, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + } + @Test public void cdsResourceFound_ringHashLbPolicy() { DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); Message ringHashConfig = mf.buildRingHashLbConfig("xx_hash", 10L, 100L); Any clusterRingHash = Any.pack( - mf.buildEdsCluster(CDS_RESOURCE, null, "ring_hash_experimental", ringHashConfig, false, - null, "envoy.transport_sockets.tls", null + mf.buildEdsCluster(CDS_RESOURCE, null, "ring_hash_experimental", ringHashConfig, null, + false, null, "envoy.transport_sockets.tls", null )); call.sendResponse(ResourceType.CDS, clusterRingHash, VERSION_1, "0000"); @@ -1523,7 +1554,7 @@ public void cdsResponseWithAggregateCluster() { List candidates = Arrays.asList( "cluster1.googleapis.com", "cluster2.googleapis.com", "cluster3.googleapis.com"); Any clusterAggregate = - Any.pack(mf.buildAggregateCluster(CDS_RESOURCE, "round_robin", null, candidates)); + Any.pack(mf.buildAggregateCluster(CDS_RESOURCE, "round_robin", null, null, candidates)); call.sendResponse(CDS, clusterAggregate, VERSION_1, "0000"); // Client sent an ACK CDS request. @@ -1542,7 +1573,7 @@ public void cdsResponseWithAggregateCluster() { public void cdsResponseWithCircuitBreakers() { DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); Any clusterCircuitBreakers = Any.pack( - mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, false, null, + mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, null, false, null, "envoy.transport_sockets.tls", mf.buildCircuitBreakers(50, 200))); call.sendResponse(CDS, clusterCircuitBreakers, VERSION_1, "0000"); @@ -1574,15 +1605,15 @@ public void cdsResponseWithUpstreamTlsContext() { // Management server sends back CDS response with UpstreamTlsContext. Any clusterEds = Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", - null, true, + null, null, true, mf.buildUpstreamTlsContext("cert-instance-name", "cert1"), "envoy.transport_sockets.tls", null)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", - "dns-service-bar.googleapis.com", 443, "round_robin", null, false, null, null)), + "dns-service-bar.googleapis.com", 443, "round_robin", null, null,false, null, null)), clusterEds, - Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, false, - null, "envoy.transport_sockets.tls", null))); + Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, + false, null, "envoy.transport_sockets.tls", null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); // Client sent an ACK CDS request. @@ -1610,15 +1641,15 @@ public void cdsResponseWithNewUpstreamTlsContext() { // Management server sends back CDS response with UpstreamTlsContext. Any clusterEds = Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", - null, true, + null, null,true, mf.buildNewUpstreamTlsContext("cert-instance-name", "cert1"), "envoy.transport_sockets.tls", null)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", - "dns-service-bar.googleapis.com", 443, "round_robin", null, false, null, null)), + "dns-service-bar.googleapis.com", 443, "round_robin", null, null, false, null, null)), clusterEds, - Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, false, - null, "envoy.transport_sockets.tls", null))); + Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, null, + false, null, "envoy.transport_sockets.tls", null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); // Client sent an ACK CDS request. @@ -1645,7 +1676,7 @@ public void cdsResponseErrorHandling_badUpstreamTlsContext() { // Management server sends back CDS response with UpstreamTlsContext. List clusters = ImmutableList.of(Any .pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", - null, true, + null, null, true, mf.buildUpstreamTlsContext(null, null), "envoy.transport_sockets.tls", null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); @@ -1673,7 +1704,7 @@ public void cdsResponseErrorHandling_badTransportSocketName() { // Management server sends back CDS response with UpstreamTlsContext. List clusters = ImmutableList.of(Any .pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", - null, true, + null, null, true, mf.buildUpstreamTlsContext("secret1", "cert1"), "envoy.transport_sockets.bad", null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); @@ -1737,7 +1768,7 @@ public void cdsResourceUpdated() { int dnsHostPort = 443; Any clusterDns = Any.pack(mf.buildLogicalDnsCluster(CDS_RESOURCE, dnsHostAddr, dnsHostPort, "round_robin", - null, false, null, null)); + null, null, false, null, null)); call.sendResponse(CDS, clusterDns, VERSION_1, "0000"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); @@ -1754,7 +1785,7 @@ public void cdsResourceUpdated() { // Updated CDS response. String edsService = "eds-service-bar.googleapis.com"; Any clusterEds = Any.pack( - mf.buildEdsCluster(CDS_RESOURCE, edsService, "round_robin", null, true, null, + mf.buildEdsCluster(CDS_RESOURCE, edsService, "round_robin", null, null, true, null, "envoy.transport_sockets.tls", null )); call.sendResponse(CDS, clusterEds, VERSION_2, "0001"); @@ -1828,9 +1859,9 @@ public void multipleCdsWatchers() { String edsService = "eds-service-bar.googleapis.com"; List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster(CDS_RESOURCE, dnsHostAddr, dnsHostPort, "round_robin", - null, false, null, null)), - Any.pack(mf.buildEdsCluster(cdsResourceTwo, edsService, "round_robin", null, true, null, - "envoy.transport_sockets.tls", null))); + null, null, false, null, null)), + Any.pack(mf.buildEdsCluster(cdsResourceTwo, edsService, "round_robin", null, null, true, + null, "envoy.transport_sockets.tls", null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); @@ -2091,11 +2122,11 @@ public void edsResourceDeletedByCds() { DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); List clusters = ImmutableList.of( - Any.pack(mf.buildEdsCluster(resource, null, "round_robin", null, true, null, + Any.pack(mf.buildEdsCluster(resource, null, "round_robin", null, null, true, null, "envoy.transport_sockets.tls", null )), - Any.pack(mf.buildEdsCluster(CDS_RESOURCE, EDS_RESOURCE, "round_robin", null, false, null, - "envoy.transport_sockets.tls", null))); + Any.pack(mf.buildEdsCluster(CDS_RESOURCE, EDS_RESOURCE, "round_robin", null, null, false, + null, "envoy.transport_sockets.tls", null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); verify(cdsWatcher).onChanged(cdsUpdateCaptor.capture()); CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); @@ -2140,9 +2171,9 @@ public void edsResourceDeletedByCds() { verifySubscribedResourcesMetadataSizes(0, 2, 0, 2); clusters = ImmutableList.of( - Any.pack(mf.buildEdsCluster(resource, null, "round_robin", null, true, null, + Any.pack(mf.buildEdsCluster(resource, null, "round_robin", null, null, true, null, "envoy.transport_sockets.tls", null)), // no change - Any.pack(mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, false, null, + Any.pack(mf.buildEdsCluster(CDS_RESOURCE, null, "round_robin", null, null, false, null, "envoy.transport_sockets.tls", null ))); call.sendResponse(CDS, clusters, VERSION_2, "0001"); @@ -2691,20 +2722,24 @@ protected abstract Message buildVirtualHost( protected abstract Message buildClusterInvalid(String name); protected abstract Message buildEdsCluster(String clusterName, @Nullable String edsServiceName, - String lbPolicy, @Nullable Message ringHashLbConfig, boolean enableLrs, - @Nullable Message upstreamTlsContext, String transportSocketName, + String lbPolicy, @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, + boolean enableLrs, @Nullable Message upstreamTlsContext, String transportSocketName, @Nullable Message circuitBreakers); protected abstract Message buildLogicalDnsCluster(String clusterName, String dnsHostAddr, - int dnsHostPort, String lbPolicy, @Nullable Message ringHashLbConfig, boolean enableLrs, + int dnsHostPort, String lbPolicy, @Nullable Message ringHashLbConfig, + @Nullable Message leastRequestLbConfig, boolean enableLrs, @Nullable Message upstreamTlsContext, @Nullable Message circuitBreakers); protected abstract Message buildAggregateCluster(String clusterName, String lbPolicy, - @Nullable Message ringHashLbConfig, List clusters); + @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, + List clusters); protected abstract Message buildRingHashLbConfig(String hashFunction, long minRingSize, long maxRingSize); + protected abstract Message buildLeastRequestLbConfig(int choiceCount); + protected abstract Message buildUpstreamTlsContext(String instanceName, String certName); protected abstract Message buildNewUpstreamTlsContext(String instanceName, String certName); diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java index be28f07b73c..29c7fdc4c01 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java @@ -36,6 +36,7 @@ import io.envoyproxy.envoy.api.v2.Cluster.DiscoveryType; import io.envoyproxy.envoy.api.v2.Cluster.EdsClusterConfig; import io.envoyproxy.envoy.api.v2.Cluster.LbPolicy; +import io.envoyproxy.envoy.api.v2.Cluster.LeastRequestLbConfig; import io.envoyproxy.envoy.api.v2.Cluster.RingHashLbConfig; import io.envoyproxy.envoy.api.v2.Cluster.RingHashLbConfig.HashFunction; import io.envoyproxy.envoy.api.v2.ClusterLoadAssignment; @@ -392,10 +393,12 @@ protected Message buildClusterInvalid(String name) { @Override protected Message buildEdsCluster(String clusterName, @Nullable String edsServiceName, - String lbPolicy, @Nullable Message ringHashLbConfig, boolean enableLrs, + String lbPolicy, @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, + boolean enableLrs, @Nullable Message upstreamTlsContext, String transportSocketName, @Nullable Message circuitBreakers) { - Cluster.Builder builder = initClusterBuilder(clusterName, lbPolicy, ringHashLbConfig, + Cluster.Builder builder = initClusterBuilder( + clusterName, lbPolicy, ringHashLbConfig, leastRequestLbConfig, enableLrs, upstreamTlsContext, circuitBreakers); builder.setType(DiscoveryType.EDS); EdsClusterConfig.Builder edsClusterConfigBuilder = EdsClusterConfig.newBuilder(); @@ -410,9 +413,11 @@ protected Message buildEdsCluster(String clusterName, @Nullable String edsServic @Override protected Message buildLogicalDnsCluster(String clusterName, String dnsHostAddr, - int dnsHostPort, String lbPolicy, @Nullable Message ringHashLbConfig, boolean enableLrs, + int dnsHostPort, String lbPolicy, @Nullable Message ringHashLbConfig, + @Nullable Message leastRequestLbConfig, boolean enableLrs, @Nullable Message upstreamTlsContext, @Nullable Message circuitBreakers) { - Cluster.Builder builder = initClusterBuilder(clusterName, lbPolicy, ringHashLbConfig, + Cluster.Builder builder = initClusterBuilder( + clusterName, lbPolicy, ringHashLbConfig, leastRequestLbConfig, enableLrs, upstreamTlsContext, circuitBreakers); builder.setType(DiscoveryType.LOGICAL_DNS); builder.setLoadAssignment( @@ -428,7 +433,8 @@ protected Message buildLogicalDnsCluster(String clusterName, String dnsHostAddr, @Override protected Message buildAggregateCluster(String clusterName, String lbPolicy, - @Nullable Message ringHashLbConfig, List clusters) { + @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, + List clusters) { ClusterConfig clusterConfig = ClusterConfig.newBuilder().addAllClusters(clusters).build(); CustomClusterType type = CustomClusterType.newBuilder() @@ -441,6 +447,9 @@ protected Message buildAggregateCluster(String clusterName, String lbPolicy, } else if (lbPolicy.equals("ring_hash_experimental")) { builder.setLbPolicy(LbPolicy.RING_HASH); builder.setRingHashLbConfig((RingHashLbConfig) ringHashLbConfig); + } else if (lbPolicy.equals("least_request_experimental")) { + builder.setLbPolicy(LbPolicy.LEAST_REQUEST); + builder.setLeastRequestLbConfig((LeastRequestLbConfig) leastRequestLbConfig); } else { throw new AssertionError("Invalid LB policy"); } @@ -448,8 +457,9 @@ protected Message buildAggregateCluster(String clusterName, String lbPolicy, } private Cluster.Builder initClusterBuilder(String clusterName, String lbPolicy, - @Nullable Message ringHashLbConfig, boolean enableLrs, - @Nullable Message upstreamTlsContext, @Nullable Message circuitBreakers) { + @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, + boolean enableLrs, @Nullable Message upstreamTlsContext, + @Nullable Message circuitBreakers) { Cluster.Builder builder = Cluster.newBuilder(); builder.setName(clusterName); if (lbPolicy.equals("round_robin")) { @@ -457,6 +467,9 @@ private Cluster.Builder initClusterBuilder(String clusterName, String lbPolicy, } else if (lbPolicy.equals("ring_hash_experimental")) { builder.setLbPolicy(LbPolicy.RING_HASH); builder.setRingHashLbConfig((RingHashLbConfig) ringHashLbConfig); + } else if (lbPolicy.equals("least_request_experimental")) { + builder.setLbPolicy(LbPolicy.LEAST_REQUEST); + builder.setLeastRequestLbConfig((LeastRequestLbConfig) leastRequestLbConfig); } else { throw new AssertionError("Invalid LB policy"); } @@ -493,6 +506,13 @@ protected Message buildRingHashLbConfig(String hashFunction, long minRingSize, return builder.build(); } + @Override + protected Message buildLeastRequestLbConfig(int choiceCount) { + LeastRequestLbConfig.Builder builder = LeastRequestLbConfig.newBuilder(); + builder.setChoiceCount(UInt32Value.newBuilder().setValue(choiceCount)); + return builder.build(); + } + @Override protected Message buildUpstreamTlsContext(String instanceName, String certName) { GrpcService grpcService = diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java index 69e75292778..6a75d9ab068 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java @@ -38,6 +38,7 @@ import io.envoyproxy.envoy.config.cluster.v3.Cluster.DiscoveryType; import io.envoyproxy.envoy.config.cluster.v3.Cluster.EdsClusterConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.LeastRequestLbConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig; import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig.HashFunction; import io.envoyproxy.envoy.config.core.v3.Address; @@ -448,10 +449,12 @@ protected Message buildClusterInvalid(String name) { @Override protected Message buildEdsCluster(String clusterName, @Nullable String edsServiceName, - String lbPolicy, @Nullable Message ringHashLbConfig, boolean enableLrs, + String lbPolicy, @Nullable Message ringHashLbConfig, + @Nullable Message leastRequestLbConfig, boolean enableLrs, @Nullable Message upstreamTlsContext, String transportSocketName, @Nullable Message circuitBreakers) { - Cluster.Builder builder = initClusterBuilder(clusterName, lbPolicy, ringHashLbConfig, + Cluster.Builder builder = initClusterBuilder( + clusterName, lbPolicy, ringHashLbConfig, leastRequestLbConfig, enableLrs, upstreamTlsContext, transportSocketName, circuitBreakers); builder.setType(DiscoveryType.EDS); EdsClusterConfig.Builder edsClusterConfigBuilder = EdsClusterConfig.newBuilder(); @@ -466,9 +469,11 @@ protected Message buildEdsCluster(String clusterName, @Nullable String edsServic @Override protected Message buildLogicalDnsCluster(String clusterName, String dnsHostAddr, - int dnsHostPort, String lbPolicy, @Nullable Message ringHashLbConfig, boolean enableLrs, + int dnsHostPort, String lbPolicy, @Nullable Message ringHashLbConfig, + @Nullable Message leastRequestLbConfig, boolean enableLrs, @Nullable Message upstreamTlsContext, @Nullable Message circuitBreakers) { - Cluster.Builder builder = initClusterBuilder(clusterName, lbPolicy, ringHashLbConfig, + Cluster.Builder builder = initClusterBuilder( + clusterName, lbPolicy, ringHashLbConfig, leastRequestLbConfig, enableLrs, upstreamTlsContext, "envoy.transport_sockets.tls", circuitBreakers); builder.setType(DiscoveryType.LOGICAL_DNS); builder.setLoadAssignment( @@ -484,7 +489,8 @@ protected Message buildLogicalDnsCluster(String clusterName, String dnsHostAddr, @Override protected Message buildAggregateCluster(String clusterName, String lbPolicy, - @Nullable Message ringHashLbConfig, List clusters) { + @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, + List clusters) { ClusterConfig clusterConfig = ClusterConfig.newBuilder().addAllClusters(clusters).build(); CustomClusterType type = CustomClusterType.newBuilder() @@ -497,6 +503,9 @@ protected Message buildAggregateCluster(String clusterName, String lbPolicy, } else if (lbPolicy.equals("ring_hash_experimental")) { builder.setLbPolicy(LbPolicy.RING_HASH); builder.setRingHashLbConfig((RingHashLbConfig) ringHashLbConfig); + } else if (lbPolicy.equals("least_request_experimental")) { + builder.setLbPolicy(LbPolicy.LEAST_REQUEST); + builder.setLeastRequestLbConfig((LeastRequestLbConfig) leastRequestLbConfig); } else { throw new AssertionError("Invalid LB policy"); } @@ -504,8 +513,8 @@ protected Message buildAggregateCluster(String clusterName, String lbPolicy, } private Cluster.Builder initClusterBuilder(String clusterName, String lbPolicy, - @Nullable Message ringHashLbConfig, boolean enableLrs, - @Nullable Message upstreamTlsContext, String transportSocketName, + @Nullable Message ringHashLbConfig, @Nullable Message leastRequestLbConfig, + boolean enableLrs, @Nullable Message upstreamTlsContext, String transportSocketName, @Nullable Message circuitBreakers) { Cluster.Builder builder = Cluster.newBuilder(); builder.setName(clusterName); @@ -514,6 +523,9 @@ private Cluster.Builder initClusterBuilder(String clusterName, String lbPolicy, } else if (lbPolicy.equals("ring_hash_experimental")) { builder.setLbPolicy(LbPolicy.RING_HASH); builder.setRingHashLbConfig((RingHashLbConfig) ringHashLbConfig); + } else if (lbPolicy.equals("least_request_experimental")) { + builder.setLbPolicy(LbPolicy.LEAST_REQUEST); + builder.setLeastRequestLbConfig((LeastRequestLbConfig) leastRequestLbConfig); } else { throw new AssertionError("Invalid LB policy"); } @@ -550,6 +562,13 @@ protected Message buildRingHashLbConfig(String hashFunction, long minRingSize, return builder.build(); } + @Override + protected Message buildLeastRequestLbConfig(int choiceCount) { + LeastRequestLbConfig.Builder builder = LeastRequestLbConfig.newBuilder(); + builder.setChoiceCount(UInt32Value.newBuilder().setValue(choiceCount)); + return builder.build(); + } + @Override @SuppressWarnings("deprecation") protected Message buildUpstreamTlsContext(String instanceName, String certName) { diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index a85f476ed4b..51a7ce5066b 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -65,6 +65,7 @@ import io.grpc.xds.Endpoints.LbEndpoint; import io.grpc.xds.Endpoints.LocalityLbEndpoints; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig; import io.grpc.xds.PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; @@ -136,6 +137,8 @@ public void uncaughtException(Thread t, Throwable e) { new PolicySelection(new FakeLoadBalancerProvider("round_robin"), null); private final PolicySelection ringHash = new PolicySelection( new FakeLoadBalancerProvider("ring_hash_experimental"), new RingHashConfig(10L, 100L)); + private final PolicySelection leastRequest = new PolicySelection( + new FakeLoadBalancerProvider("least_request_experimental"), new LeastRequestConfig(3)); private final List childBalancers = new ArrayList<>(); private final List resolvers = new ArrayList<>(); private final FakeXdsClient xdsClient = new FakeXdsClient(); @@ -267,6 +270,45 @@ public void edsClustersWithRingHashEndpointLbPolicy() { assertThat(ringHashConfig.maxRingSize).isEqualTo(100L); } + @Test + public void edsClustersWithLeastRequestEndpointLbPolicy() { + ClusterResolverConfig config = new ClusterResolverConfig( + Collections.singletonList(edsDiscoveryMechanism1), leastRequest); + deliverLbConfig(config); + assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); + assertThat(childBalancers).isEmpty(); + + // Simple case with one priority and one locality + EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); + LocalityLbEndpoints localityLbEndpoints = + LocalityLbEndpoints.create( + Arrays.asList( + LbEndpoint.create(endpoint, 0 /* loadBalancingWeight */, true)), + 100 /* localityWeight */, 1 /* priority */); + xdsClient.deliverClusterLoadAssignment( + EDS_SERVICE_NAME1, + ImmutableMap.of(locality1, localityLbEndpoints)); + assertThat(childBalancers).hasSize(1); + FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); + assertThat(childBalancer.addresses).hasSize(1); + EquivalentAddressGroup addr = childBalancer.addresses.get(0); + assertThat(addr.getAddresses()).isEqualTo(endpoint.getAddresses()); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); + PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; + assertThat(priorityLbConfig.priorities).containsExactly(CLUSTER1 + "[priority1]"); + PriorityChildConfig priorityChildConfig = + Iterables.getOnlyElement(priorityLbConfig.childConfigs.values()); + assertThat(priorityChildConfig.policySelection.getProvider().getPolicyName()) + .isEqualTo(CLUSTER_IMPL_POLICY_NAME); + ClusterImplConfig clusterImplConfig = + (ClusterImplConfig) priorityChildConfig.policySelection.getConfig(); + assertClusterImplConfig(clusterImplConfig, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_INFO, 100L, + tlsContext, Collections.emptyList(), WEIGHTED_TARGET_POLICY_NAME); + WeightedTargetConfig weightedTargetConfig = + (WeightedTargetConfig) clusterImplConfig.childPolicy.getConfig(); + assertThat(weightedTargetConfig.targets.keySet()).containsExactly(locality1.toString()); + } + @Test public void onlyEdsClusters_receivedEndpoints() { ClusterResolverConfig config = new ClusterResolverConfig( diff --git a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerProviderTest.java new file mode 100644 index 00000000000..2e8519b150d --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerProviderTest.java @@ -0,0 +1,139 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.grpc.InternalServiceProviders; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancerProvider; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.Status.Code; +import io.grpc.SynchronizationContext; +import io.grpc.internal.JsonParser; +import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; +import java.io.IOException; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link LeastRequestLoadBalancerProvider}. */ +@RunWith(JUnit4.class) +public class LeastRequestLoadBalancerProviderTest { + private static final String AUTHORITY = "foo.googleapis.com"; + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); + private final LeastRequestLoadBalancerProvider provider = new LeastRequestLoadBalancerProvider(); + + @Test + public void provided() { + for (LoadBalancerProvider current : InternalServiceProviders.getCandidatesViaServiceLoader( + LoadBalancerProvider.class, getClass().getClassLoader())) { + if (current instanceof LeastRequestLoadBalancerProvider) { + return; + } + } + fail("LeastRequestLoadBalancerProvider not registered"); + } + + @Test + public void providesLoadBalancer() { + Helper helper = mock(Helper.class); + when(helper.getSynchronizationContext()).thenReturn(syncContext); + when(helper.getAuthority()).thenReturn(AUTHORITY); + assertThat(provider.newLoadBalancer(helper)) + .isInstanceOf(LeastRequestLoadBalancer.class); + } + + @Test + public void parseLoadBalancingConfig_valid() throws IOException { + String lbConfig = "{\"choiceCount\" : 3}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + LeastRequestConfig config = (LeastRequestConfig) configOrError.getConfig(); + assertThat(config.choiceCount).isEqualTo(3); + } + + @Test + public void parseLoadBalancingConfig_missingChoiceCount_useDefaults() throws IOException { + String lbConfig = "{}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + LeastRequestConfig config = (LeastRequestConfig) configOrError.getConfig(); + assertThat(config.choiceCount) + .isEqualTo(LeastRequestLoadBalancerProvider.DEFAULT_CHOICE_COUNT); + } + + @Test + public void parseLoadBalancingConfig_invalid_negativeSize() throws IOException { + String lbConfig = "{\"choiceCount\" : -10}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getError()).isNotNull(); + assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(configOrError.getError().getDescription()) + .isEqualTo("Invalid 'choiceCount'"); + } + + @Test + public void parseLoadBalancingConfig_invalid_tooSmallSize() throws IOException { + String lbConfig = "{\"choiceCount\" : 1}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getError()).isNotNull(); + assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(configOrError.getError().getDescription()) + .isEqualTo("Invalid 'choiceCount'"); + } + + @Test + public void parseLoadBalancingConfig_choiceCountCappedAtMax() throws IOException { + String lbConfig = "{\"choiceCount\" : 11}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + LeastRequestConfig config = (LeastRequestConfig) configOrError.getConfig(); + assertThat(config.choiceCount).isEqualTo(LeastRequestLoadBalancerProvider.MAX_CHOICE_COUNT); + } + + @Test + public void parseLoadBalancingConfig_invalidInteger() throws IOException { + Map lbConfig = parseJsonObject("{\"choiceCount\" : \"NaN\"}"); + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(lbConfig); + assertThat(configOrError.getError()).isNotNull(); + assertThat(configOrError.getError().getDescription()).isEqualTo( + "Failed to parse least_request_experimental LB config: " + lbConfig); + } + + @SuppressWarnings("unchecked") + private static Map parseJsonObject(String json) throws IOException { + return (Map) JsonParser.parse(json); + } +} diff --git a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java new file mode 100644 index 00000000000..2d09dbfe1fc --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java @@ -0,0 +1,632 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.IDLE; +import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.SHUTDOWN; +import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.xds.LeastRequestLoadBalancer.IN_FLIGHTS; +import static io.grpc.xds.LeastRequestLoadBalancer.STATE_INFO; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import io.grpc.Attributes; +import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; +import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.ResolvedAddresses; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.xds.LeastRequestLoadBalancer.EmptyPicker; +import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; +import io.grpc.xds.LeastRequestLoadBalancer.ReadyPicker; +import io.grpc.xds.LeastRequestLoadBalancer.Ref; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +/** Unit test for {@link LeastRequestLoadBalancer}. */ +@RunWith(JUnit4.class) +public class LeastRequestLoadBalancerTest { + private static final Attributes.Key MAJOR_KEY = Attributes.Key.create("major-key"); + + private LeastRequestLoadBalancer loadBalancer; + private final List servers = Lists.newArrayList(); + private final Map, Subchannel> subchannels = Maps.newLinkedHashMap(); + private final Map subchannelStateListeners = + Maps.newLinkedHashMap(); + private final Attributes affinity = + Attributes.newBuilder().set(MAJOR_KEY, "I got the keys").build(); + + @Captor + private ArgumentCaptor pickerCaptor; + @Captor + private ArgumentCaptor stateCaptor; + @Captor + private ArgumentCaptor createArgsCaptor; + @Mock + private Helper mockHelper; + @Mock + private ThreadSafeRandom mockRandom; + + @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). + private PickSubchannelArgs mockArgs; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + + for (int i = 0; i < 3; i++) { + SocketAddress addr = new FakeSocketAddress("server" + i); + EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); + servers.add(eag); + Subchannel sc = mock(Subchannel.class); + subchannels.put(Arrays.asList(eag), sc); + } + + when(mockHelper.createSubchannel(any(CreateSubchannelArgs.class))) + .then(new Answer() { + @Override + public Subchannel answer(InvocationOnMock invocation) throws Throwable { + CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; + final Subchannel subchannel = subchannels.get(args.getAddresses()); + when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); + when(subchannel.getAttributes()).thenReturn(args.getAttributes()); + doAnswer( + new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + subchannelStateListeners.put( + subchannel, (SubchannelStateListener) invocation.getArguments()[0]); + return null; + } + }).when(subchannel).start(any(SubchannelStateListener.class)); + return subchannel; + } + }); + loadBalancer = new LeastRequestLoadBalancer(mockHelper, mockRandom); + } + + @After + public void tearDown() throws Exception { + verifyNoMoreInteractions(mockRandom); + verifyNoMoreInteractions(mockArgs); + } + + @Test + public void pickAfterResolved() throws Exception { + final Subchannel readySubchannel = subchannels.values().iterator().next(); + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + + verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture()); + List> capturedAddrs = new ArrayList<>(); + for (CreateSubchannelArgs arg : createArgsCaptor.getAllValues()) { + capturedAddrs.add(arg.getAddresses()); + } + + assertThat(capturedAddrs).containsAtLeastElementsIn(subchannels.keySet()); + for (Subchannel subchannel : subchannels.values()) { + verify(subchannel).requestConnection(); + verify(subchannel, never()).shutdown(); + } + + verify(mockHelper, times(2)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + + assertEquals(CONNECTING, stateCaptor.getAllValues().get(0)); + assertEquals(READY, stateCaptor.getAllValues().get(1)); + assertThat(getList(pickerCaptor.getValue())).containsExactly(readySubchannel); + + verifyNoMoreInteractions(mockHelper); + } + + @Test + public void pickAfterResolvedUpdatedHosts() throws Exception { + Subchannel removedSubchannel = mock(Subchannel.class); + Subchannel oldSubchannel = mock(Subchannel.class); + Subchannel newSubchannel = mock(Subchannel.class); + + Attributes.Key key = Attributes.Key.create("check-that-it-is-propagated"); + FakeSocketAddress removedAddr = new FakeSocketAddress("removed"); + EquivalentAddressGroup removedEag = new EquivalentAddressGroup(removedAddr); + FakeSocketAddress oldAddr = new FakeSocketAddress("old"); + EquivalentAddressGroup oldEag1 = new EquivalentAddressGroup(oldAddr); + EquivalentAddressGroup oldEag2 = new EquivalentAddressGroup( + oldAddr, Attributes.newBuilder().set(key, "oldattr").build()); + FakeSocketAddress newAddr = new FakeSocketAddress("new"); + EquivalentAddressGroup newEag = new EquivalentAddressGroup( + newAddr, Attributes.newBuilder().set(key, "newattr").build()); + + subchannels.put(Collections.singletonList(removedEag), removedSubchannel); + subchannels.put(Collections.singletonList(oldEag1), oldSubchannel); + subchannels.put(Collections.singletonList(newEag), newSubchannel); + + List currentServers = Lists.newArrayList(removedEag, oldEag1); + + InOrder inOrder = inOrder(mockHelper); + + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity) + .build()); + + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + + deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(oldSubchannel, ConnectivityStateInfo.forNonError(READY)); + + inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + SubchannelPicker picker = pickerCaptor.getValue(); + assertThat(getList(picker)).containsExactly(removedSubchannel, oldSubchannel); + + verify(removedSubchannel, times(1)).requestConnection(); + verify(oldSubchannel, times(1)).requestConnection(); + + assertThat(loadBalancer.getSubchannels()).containsExactly(removedSubchannel, + oldSubchannel); + + // This time with Attributes + List latestServers = Lists.newArrayList(oldEag2, newEag); + + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(latestServers).setAttributes(affinity).build()); + + verify(newSubchannel, times(1)).requestConnection(); + verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2)); + verify(removedSubchannel, times(1)).shutdown(); + + deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(SHUTDOWN)); + deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); + + assertThat(loadBalancer.getSubchannels()).containsExactly(oldSubchannel, + newSubchannel); + + verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); + + picker = pickerCaptor.getValue(); + assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel); + + // test going from non-empty to empty + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setAttributes(affinity) + .build()); + + inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + assertEquals(PickResult.withNoResult(), pickerCaptor.getValue().pickSubchannel(mockArgs)); + + verifyNoMoreInteractions(mockHelper); + } + + @Test + public void pickAfterStateChange() throws Exception { + InOrder inOrder = inOrder(mockHelper); + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .build()); + Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); + Ref subchannelStateInfo = subchannel.getAttributes().get( + STATE_INFO); + + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE)); + + deliverSubchannelState(subchannel, + ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); + assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); + assertThat(subchannelStateInfo.value).isEqualTo( + ConnectivityStateInfo.forNonError(READY)); + + Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"); + deliverSubchannelState(subchannel, + ConnectivityStateInfo.forTransientFailure(error)); + assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error); + inOrder.verify(mockHelper).refreshNameResolution(); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); + + deliverSubchannelState(subchannel, + ConnectivityStateInfo.forNonError(IDLE)); + inOrder.verify(mockHelper).refreshNameResolution(); + assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error); + + verify(subchannel, times(2)).requestConnection(); + verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verifyNoMoreInteractions(mockHelper); + } + + @Test + public void pickAfterConfigChange() { + final LeastRequestConfig oldConfig = new LeastRequestConfig(4); + final LeastRequestConfig newConfig = new LeastRequestConfig(6); + final Subchannel readySubchannel = subchannels.values().iterator().next(); + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity) + .setLoadBalancingPolicyConfig(oldConfig).build()); + deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(mockHelper, times(2)) + .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture()); + + // At this point it should use a ReadyPicker with oldConfig + pickerCaptor.getValue().pickSubchannel(mockArgs); + verify(mockRandom, times(oldConfig.choiceCount)).nextInt(1); + + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity) + .setLoadBalancingPolicyConfig(newConfig).build()); + verify(mockHelper, times(3)) + .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture()); + + // At this point it should use a ReadyPicker with newConfig + pickerCaptor.getValue().pickSubchannel(mockArgs); + verify(mockRandom, times(oldConfig.choiceCount + newConfig.choiceCount)).nextInt(1); + verifyNoMoreInteractions(mockHelper); + } + + @Test + public void ignoreShutdownSubchannelStateChange() { + InOrder inOrder = inOrder(mockHelper); + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .build()); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + + loadBalancer.shutdown(); + for (Subchannel sc : loadBalancer.getSubchannels()) { + verify(sc).shutdown(); + // When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered + // back to the subchannel state listener. + deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(SHUTDOWN)); + } + + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void stayTransientFailureUntilReady() { + InOrder inOrder = inOrder(mockHelper); + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .build()); + + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + + // Simulate state transitions for each subchannel individually. + for (Subchannel sc : loadBalancer.getSubchannels()) { + Status error = Status.UNKNOWN.withDescription("connection broken"); + deliverSubchannelState( + sc, + ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).refreshNameResolution(); + deliverSubchannelState( + sc, + ConnectivityStateInfo.forNonError(CONNECTING)); + Ref scStateInfo = sc.getAttributes().get( + STATE_INFO); + assertThat(scStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); + assertThat(scStateInfo.value.getStatus()).isEqualTo(error); + } + inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(EmptyPicker.class)); + inOrder.verifyNoMoreInteractions(); + + Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); + Ref subchannelStateInfo = subchannel.getAttributes().get( + STATE_INFO); + assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); + + verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verifyNoMoreInteractions(mockHelper); + } + + @Test + public void refreshNameResolutionWhenSubchannelConnectionBroken() { + InOrder inOrder = inOrder(mockHelper); + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .build()); + + verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + + // Simulate state transitions for each subchannel individually. + for (Subchannel sc : loadBalancer.getSubchannels()) { + verify(sc).requestConnection(); + deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); + Status error = Status.UNKNOWN.withDescription("connection broken"); + deliverSubchannelState(sc, ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).refreshNameResolution(); + deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); + // Simulate receiving go-away so READY subchannels transit to IDLE. + deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(IDLE)); + inOrder.verify(mockHelper).refreshNameResolution(); + verify(sc, times(2)).requestConnection(); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + } + + verifyNoMoreInteractions(mockHelper); + } + + @Test + public void pickerLeastRequest() throws Exception { + int choiceCount = 2; + // This should add inFlight counters to all subchannels. + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .setLoadBalancingPolicyConfig(new LeastRequestConfig(choiceCount)) + .build()); + + assertEquals(3, loadBalancer.getSubchannels().size()); + + List subchannels = Lists.newArrayList(loadBalancer.getSubchannels()); + + // Make sure all inFlight counters have started at 0 + assertEquals(0, + subchannels.get(0).getAttributes().get(IN_FLIGHTS).get()); + assertEquals(0, + subchannels.get(1).getAttributes().get(IN_FLIGHTS).get()); + assertEquals(0, + subchannels.get(2).getAttributes().get(IN_FLIGHTS).get()); + + for (Subchannel sc : subchannels) { + deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(READY)); + } + + // Capture the active ReadyPicker once all subchannels are READY + verify(mockHelper, times(4)) + .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture()); + assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); + + ReadyPicker picker = (ReadyPicker) pickerCaptor.getValue(); + + assertThat(picker.getList()).containsExactlyElementsIn(subchannels); + + // Make random return 0, then 2 for the sample indexes. + when(mockRandom.nextInt(subchannels.size())).thenReturn(0, 2); + PickResult pickResult1 = picker.pickSubchannel(mockArgs); + verify(mockRandom, times(choiceCount)).nextInt(subchannels.size()); + assertEquals(subchannels.get(0), pickResult1.getSubchannel()); + // This simulates sending the actual RPC on the picked channel + ClientStreamTracer streamTracer1 = + pickResult1.getStreamTracerFactory() + .newClientStreamTracer(StreamInfo.newBuilder().build(), new Metadata()); + streamTracer1.streamCreated(Attributes.EMPTY, new Metadata()); + assertEquals(1, + pickResult1.getSubchannel().getAttributes().get(IN_FLIGHTS).get()); + + // For the second pick it should pick the one with lower inFlight. + when(mockRandom.nextInt(subchannels.size())).thenReturn(0, 2); + PickResult pickResult2 = picker.pickSubchannel(mockArgs); + // Since this is the second pick we expect the total random samples to be choiceCount * 2 + verify(mockRandom, times(choiceCount * 2)).nextInt(subchannels.size()); + assertEquals(subchannels.get(2), pickResult2.getSubchannel()); + + // For the third pick we unavoidably pick subchannel with index 1. + when(mockRandom.nextInt(subchannels.size())).thenReturn(1, 1); + PickResult pickResult3 = picker.pickSubchannel(mockArgs); + verify(mockRandom, times(choiceCount * 3)).nextInt(subchannels.size()); + assertEquals(subchannels.get(1), pickResult3.getSubchannel()); + + // Finally ensure a finished RPC decreases inFlight + streamTracer1.streamClosed(Status.OK); + assertEquals(0, + pickResult1.getSubchannel().getAttributes().get(IN_FLIGHTS).get()); + } + + @Test + public void pickerEmptyList() throws Exception { + SubchannelPicker picker = new EmptyPicker(Status.UNKNOWN); + + assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel()); + assertEquals(Status.UNKNOWN, + picker.pickSubchannel(mockArgs).getStatus()); + } + + @Test + public void nameResolutionErrorWithNoChannels() throws Exception { + Status error = Status.NOT_FOUND.withDescription("nameResolutionError"); + loadBalancer.handleNameResolutionError(error); + verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs); + assertNull(pickResult.getSubchannel()); + assertEquals(error, pickResult.getStatus()); + verifyNoMoreInteractions(mockHelper); + } + + @Test + public void nameResolutionErrorWithActiveChannels() throws Exception { + int choiceCount = 8; + final Subchannel readySubchannel = subchannels.values().iterator().next(); + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setLoadBalancingPolicyConfig(new LeastRequestConfig(choiceCount)) + .setAddresses(servers).setAttributes(affinity).build()); + deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); + + verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(mockHelper, times(2)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + + Iterator stateIterator = stateCaptor.getAllValues().iterator(); + assertEquals(CONNECTING, stateIterator.next()); + assertEquals(READY, stateIterator.next()); + + LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs); + verify(mockRandom, times(choiceCount)).nextInt(1); + assertEquals(readySubchannel, pickResult.getSubchannel()); + assertEquals(Status.OK.getCode(), pickResult.getStatus().getCode()); + + LoadBalancer.PickResult pickResult2 = pickerCaptor.getValue().pickSubchannel(mockArgs); + verify(mockRandom, times(choiceCount * 2)).nextInt(1); + assertEquals(readySubchannel, pickResult2.getSubchannel()); + verifyNoMoreInteractions(mockHelper); + } + + @Test + public void subchannelStateIsolation() throws Exception { + Iterator subchannelIterator = subchannels.values().iterator(); + Subchannel sc1 = subchannelIterator.next(); + Subchannel sc2 = subchannelIterator.next(); + Subchannel sc3 = subchannelIterator.next(); + + loadBalancer.handleResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) + .build()); + verify(sc1, times(1)).requestConnection(); + verify(sc2, times(1)).requestConnection(); + verify(sc3, times(1)).requestConnection(); + + deliverSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(sc2, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(sc3, ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(sc2, ConnectivityStateInfo.forNonError(IDLE)); + deliverSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE)); + + verify(mockHelper, times(6)) + .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture()); + Iterator stateIterator = stateCaptor.getAllValues().iterator(); + Iterator pickers = pickerCaptor.getAllValues().iterator(); + // The picker is incrementally updated as subchannels become READY + assertEquals(CONNECTING, stateIterator.next()); + assertThat(pickers.next()).isInstanceOf(EmptyPicker.class); + assertEquals(READY, stateIterator.next()); + assertThat(getList(pickers.next())).containsExactly(sc1); + assertEquals(READY, stateIterator.next()); + assertThat(getList(pickers.next())).containsExactly(sc1, sc2); + assertEquals(READY, stateIterator.next()); + assertThat(getList(pickers.next())).containsExactly(sc1, sc2, sc3); + // The IDLE subchannel is dropped from the picker, but a reconnection is requested + assertEquals(READY, stateIterator.next()); + assertThat(getList(pickers.next())).containsExactly(sc1, sc3); + verify(sc2, times(2)).requestConnection(); + // The failing subchannel is dropped from the picker, with no requested reconnect + assertEquals(READY, stateIterator.next()); + assertThat(getList(pickers.next())).containsExactly(sc1); + verify(sc3, times(1)).requestConnection(); + assertThat(stateIterator.hasNext()).isFalse(); + assertThat(pickers.hasNext()).isFalse(); + } + + @Test(expected = IllegalArgumentException.class) + public void readyPicker_emptyList() { + // ready picker list must be non-empty + new ReadyPicker(Collections.emptyList(), 2, mockRandom); + } + + @Test + public void internalPickerComparisons() { + EmptyPicker emptyOk1 = new EmptyPicker(Status.OK); + EmptyPicker emptyOk2 = new EmptyPicker(Status.OK.withDescription("different OK")); + EmptyPicker emptyErr = new EmptyPicker(Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯")); + + Iterator subchannelIterator = subchannels.values().iterator(); + Subchannel sc1 = subchannelIterator.next(); + Subchannel sc2 = subchannelIterator.next(); + ReadyPicker ready1 = new ReadyPicker(Arrays.asList(sc1, sc2), 2, mockRandom); + ReadyPicker ready2 = new ReadyPicker(Arrays.asList(sc1), 2, mockRandom); + ReadyPicker ready3 = new ReadyPicker(Arrays.asList(sc2, sc1), 2, mockRandom); + ReadyPicker ready4 = new ReadyPicker(Arrays.asList(sc1, sc2), 2, mockRandom); + ReadyPicker ready5 = new ReadyPicker(Arrays.asList(sc2, sc1), 2, mockRandom); + ReadyPicker ready6 = new ReadyPicker(Arrays.asList(sc2, sc1), 8, mockRandom); + + assertTrue(emptyOk1.isEquivalentTo(emptyOk2)); + assertFalse(emptyOk1.isEquivalentTo(emptyErr)); + assertFalse(ready1.isEquivalentTo(ready2)); + assertTrue(ready1.isEquivalentTo(ready3)); + assertTrue(ready3.isEquivalentTo(ready4)); + assertTrue(ready4.isEquivalentTo(ready5)); + assertFalse(emptyOk1.isEquivalentTo(ready1)); + assertFalse(ready1.isEquivalentTo(emptyOk1)); + assertFalse(ready5.isEquivalentTo(ready6)); + } + + private static List getList(SubchannelPicker picker) { + return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() : + Collections.emptyList(); + } + + private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { + subchannelStateListeners.get(subchannel).onSubchannelState(newState); + } + + private static class FakeSocketAddress extends SocketAddress { + final String name; + + FakeSocketAddress(String name) { + this.name = name; + } + + @Override + public String toString() { + return "FakeSocketAddress-" + name; + } + } +}