diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index c3f60623a95..1f9d6908126 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -24,13 +24,17 @@ import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Deadline.Ticker; +import io.grpc.DoubleHistogramMetricInstrument; import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.MetricInstrumentRegistry; import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.SynchronizationContext; @@ -57,12 +61,17 @@ import java.util.logging.Logger; /** - * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over - * the {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are + * A {@link LoadBalancer} that provides weighted-round-robin load-balancing over the + * {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are * determined by backend metrics using ORCA. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885") final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { + + private static final LongCounterMetricInstrument RR_FALLBACK_COUNTER; + private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER; + private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_STALE_COUNTER; + private static final DoubleHistogramMetricInstrument ENDPOINT_WEIGHTS_HISTOGRAM; private static final Logger log = Logger.getLogger( WeightedRoundRobinLoadBalancer.class.getName()); private WeightedRoundRobinLoadBalancerConfig config; @@ -74,6 +83,31 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { private final long infTime; private final Ticker ticker; + // The metric instruments are only registered once and shared by all instances of this LB. + static { + MetricInstrumentRegistry metricInstrumentRegistry + = MetricInstrumentRegistry.getDefaultRegistry(); + RR_FALLBACK_COUNTER = metricInstrumentRegistry.registerLongCounter("grpc.lb.wrr.rr_fallback", + "Number of scheduler updates in which there were not enough endpoints with valid " + + "weight, which caused the WRR policy to fall back to RR behavior", "update", + Lists.newArrayList("grpc.target"), Lists.newArrayList("grpc.lb.locality"), true); + ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.lb.wrr.endpoint_weight_not_yet_usable", + "Number of endpoints from each scheduler update that don't yet have usable weight " + + "information", "endpoint", Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.locality"), true); + ENDPOINT_WEIGHT_STALE_COUNTER = metricInstrumentRegistry.registerLongCounter( + "grpc.lb.wrr.endpoint_weight_stale", + "Number of endpoints from each scheduler update whose latest weight is older than the " + + "expiration period", "endpoint", Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.locality"), true); + ENDPOINT_WEIGHTS_HISTOGRAM = metricInstrumentRegistry.registerDoubleHistogram( + "grpc.lb.wrr.endpoint_weights", "The histogram buckets will be endpoint weight ranges.", + "weight", Lists.newArrayList(), Lists.newArrayList("grpc.target"), + Lists.newArrayList("grpc.lb.locality"), + true); + } + public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) { this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, new Random()); } @@ -145,7 +179,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { @Override public SubchannelPicker createReadyPicker(Collection activeList) { return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), - config.enableOobLoadReport, config.errorUtilizationPenalty, sequence); + config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, getHelper()); } @VisibleForTesting @@ -163,16 +197,18 @@ public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Obj super(key, policyProvider, childConfig, initialPicker); } - private double getWeight() { + private double getWeight(AtomicInteger staleEndpoints, AtomicInteger notYetUsableEndpoints) { if (config == null) { return 0; } long now = ticker.nanoTime(); if (now - lastUpdated >= config.weightExpirationPeriodNanos) { nonEmptySince = infTime; + staleEndpoints.incrementAndGet(); return 0; } else if (now - nonEmptySince < config.blackoutPeriodNanos && config.blackoutPeriodNanos > 0) { + notYetUsableEndpoints.incrementAndGet(); return 0; } else { return weight; @@ -336,10 +372,11 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker { private final float errorUtilizationPenalty; private final AtomicInteger sequence; private final int hashCode; + private final LoadBalancer.Helper helper; private volatile StaticStrideScheduler scheduler; WeightedRoundRobinPicker(List children, boolean enableOobLoadReport, - float errorUtilizationPenalty, AtomicInteger sequence) { + float errorUtilizationPenalty, AtomicInteger sequence, LoadBalancer.Helper helper) { checkNotNull(children, "children"); Preconditions.checkArgument(!children.isEmpty(), "empty child list"); this.children = children; @@ -353,6 +390,7 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker { this.enableOobLoadReport = enableOobLoadReport; this.errorUtilizationPenalty = errorUtilizationPenalty; this.sequence = checkNotNull(sequence, "sequence"); + this.helper = helper; // For equality we treat children as a set; use hash code as defined by Set int sum = 0; @@ -387,11 +425,37 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { private void updateWeight() { float[] newWeights = new float[children.size()]; + AtomicInteger staleEndpoints = new AtomicInteger(); + AtomicInteger notYetUsableEndpoints = new AtomicInteger(); for (int i = 0; i < children.size(); i++) { - double newWeight = ((WeightedChildLbState)children.get(i)).getWeight(); + double newWeight = ((WeightedChildLbState) children.get(i)).getWeight(staleEndpoints, + notYetUsableEndpoints); + // TODO: add target and locality labels once available + helper.getMetricRecorder() + .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight, ImmutableList.of(""), + ImmutableList.of("")); newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; } + if (staleEndpoints.get() > 0) { + // TODO: add target and locality labels once available + helper.getMetricRecorder() + .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(), + ImmutableList.of(""), + ImmutableList.of("")); + } + if (notYetUsableEndpoints.get() > 0) { + // TODO: add target and locality labels once available + helper.getMetricRecorder() + .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(), + ImmutableList.of(""), ImmutableList.of("")); + } + this.scheduler = new StaticStrideScheduler(newWeights, sequence); + if (this.scheduler.usesRoundRobin()) { + // TODO: add target and locality labels once available + helper.getMetricRecorder() + .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(""), ImmutableList.of("")); + } } @Override @@ -454,6 +518,7 @@ public boolean equals(Object o) { static final class StaticStrideScheduler { private final short[] scaledWeights; private final AtomicInteger sequence; + private final boolean usesRoundRobin; private static final int K_MAX_WEIGHT = 0xFFFF; // Assuming the mean of all known weights is M, StaticStrideScheduler will clamp @@ -494,8 +559,10 @@ static final class StaticStrideScheduler { if (numWeightedChannels > 0) { unscaledMeanWeight = sumWeight / numWeightedChannels; unscaledMaxWeight = Math.min(unscaledMaxWeight, (float) (K_MAX_RATIO * unscaledMeanWeight)); + usesRoundRobin = false; } else { // Fall back to round robin if all values are non-positives + usesRoundRobin = true; unscaledMeanWeight = 1; unscaledMaxWeight = 1; } @@ -521,7 +588,14 @@ static final class StaticStrideScheduler { this.sequence = sequence; } - /** Returns the next sequence number and atomically increases sequence with wraparound. */ + // Without properly weighted channels, we do plain vanilla round_robin. + boolean usesRoundRobin() { + return usesRoundRobin; + } + + /** + * Returns the next sequence number and atomically increases sequence with wraparound. + */ private long nextSequence() { return Integer.toUnsignedLong(sequence.getAndIncrement()); } diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index 63077ddcf69..e43de19b517 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.CONNECTING; import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.any; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; @@ -40,6 +41,7 @@ import io.grpc.ClientCall; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; +import io.grpc.DoubleHistogramMetricInstrument; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; @@ -49,6 +51,8 @@ import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; +import io.grpc.LongCounterMetricInstrument; +import io.grpc.MetricRecorder; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; @@ -82,6 +86,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -120,6 +125,9 @@ public class WeightedRoundRobinLoadBalancerTest { private final FakeClock fakeClock = new FakeClock(); + @Mock + private MetricRecorder mockMetricRecorder; + private WeightedRoundRobinLoadBalancerConfig weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder().build(); @@ -1121,6 +1129,130 @@ public void removingAddressShutsdownSubchannel() { inOrder.verify(subchannel2).shutdown(); } + + @Test + public void metrics() { + // Give WRR some valid addresses to work with. + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + + // Flip the three subchannels to READY state to initiate the WRR logic + Iterator it = subchannels.values().iterator(); + Subchannel readySubchannel1 = it.next(); + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel2 = it.next(); + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + Subchannel readySubchannel3 = it.next(); + getSubchannelStateListener(readySubchannel3).onSubchannelState(ConnectivityStateInfo + .forNonError(ConnectivityState.READY)); + + // WRR creates a picker that updates the weights for each of the child subchannels. This should + // give us three "rr_fallback" metric events as we don't yet have any weights to do weighted + // round-robin. + verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 3, 1); + + // We should also see six records of endpoint weights. They should all be for 0 as we don't yet + // have valid weights. + verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 6, 0); + + // We should not yet be seeing any "endpoint_weight_stale" events since we don't even have + // valid weights yet. + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 0, 1); + + // Each time weights are updated, WRR will see if each subchannel weight is useable. As we have + // no weights yet, we should see three "endpoint_weight_not_yet_usable" metric events with the + // value increasing by one each time as all the endpoints come online. + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 1); + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 2); + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 3); + + // Send each child LB state an ORCA update with some valid utilization/qps data so that weights + // can be calculated. + Iterator childLbStates = wrr.getChildLbStates().iterator(); + ((WeightedChildLbState)childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty).onLoadReport( + InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), + new HashMap<>(), new HashMap<>())); + ((WeightedChildLbState)childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty).onLoadReport( + InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), + new HashMap<>(), new HashMap<>())); + ((WeightedChildLbState)childLbStates.next()).new OrcaReportListener( + weightedConfig.errorUtilizationPenalty).onLoadReport( + InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(), + new HashMap<>(), new HashMap<>())); + + // Let's reset the mock MetricsRecorder so that it's easier to verify what happened after the + // weights were updated + reset(mockMetricRecorder); + + // We go forward in time past the default 10s blackout period before weights can be considered + // for wrr. The eights would get updated as the default update interval is 1s. + fakeClock.forwardTime(11, TimeUnit.SECONDS); + + // Since we have weights on all the child LB states, the weight update should not result in + // further rr_fallback metric entries. + verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 0, 1); + + // We should not see an increase to the earlier count of "endpoint_weight_not_yet_usable". + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 0, 1); + + // No endpoints should have gotten stale yet either. + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 0, 1); + + // Now with valid weights we should have seen the value in the endpoint weights histogram. + verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 3, 10); + + reset(mockMetricRecorder); + + // Weights become stale in three minutes. Let's move ahead in time by 3 minutes and make sure + // we get metrics events for each endpoint. + fakeClock.forwardTime(3, TimeUnit.MINUTES); + + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 1, 3); + + // With the weights stale each three endpoints should report 0 weights. + verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 3, 0); + + // Since the weights are now stale the update should have triggered an additional rr_fallback + // event. + verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 1, 1); + + // No further weights-not-useable events should occur, since we have received weights and + // are out of the blackout. + verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 0, 1); + + // All metric events should be accounted for. + verifyNoMoreInteractions(mockMetricRecorder); + } + + // Verifies that the MetricRecorder has been called to record a long counter value of 1 for the + // given metric name, the given number of times + private void verifyLongCounterRecord(String name, int times, long value) { + verify(mockMetricRecorder, times(times)).addLongCounter( + argThat(new ArgumentMatcher() { + @Override + public boolean matches(LongCounterMetricInstrument longCounterInstrument) { + return longCounterInstrument.getName().equals(name); + } + }), eq(value), eq(Lists.newArrayList("")), eq(Lists.newArrayList(""))); + } + + // Verifies that the MetricRecorder has been called to record a given double histogram value the + // given amount of times. + private void verifyDoubleHistogramRecord(String name, int times, double value) { + verify(mockMetricRecorder, times(times)).recordDoubleHistogram( + argThat(new ArgumentMatcher() { + @Override + public boolean matches(DoubleHistogramMetricInstrument doubleHistogramInstrument) { + return doubleHistogramInstrument.getName().equals(name); + } + }), eq(value), eq(Lists.newArrayList("")), eq(Lists.newArrayList(""))); + } + private int getNumFilteredPendingTasks() { return AbstractTestHelper.getNumFilteredPendingTasks(fakeClock); } @@ -1189,5 +1321,10 @@ public Map getMockToRealSubChannelMap() { public Map getSubchannelStateListeners() { return subchannelStateListeners; } + + @Override + public MetricRecorder getMetricRecorder() { + return mockMetricRecorder; + } } }