Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core,xds: Metrics recording in WRR LB #11129

Merged
merged 11 commits into from
Apr 26, 2024
88 changes: 81 additions & 7 deletions xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.Deadline.Ticker;
import io.grpc.DoubleHistogramMetricInstrument;
import io.grpc.EquivalentAddressGroup;
import io.grpc.ExperimentalApi;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider;
import io.grpc.LongCounterMetricInstrument;
import io.grpc.MetricInstrumentRegistry;
import io.grpc.NameResolver;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
Expand All @@ -57,12 +61,17 @@
import java.util.logging.Logger;

/**
* A {@link LoadBalancer} that provides weighted-round-robin load-balancing over
* the {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are
* A {@link LoadBalancer} that provides weighted-round-robin load-balancing over the
* {@link EquivalentAddressGroup}s from the {@link NameResolver}. The subchannel weights are
* determined by backend metrics using ORCA.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885")
final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {

private static final LongCounterMetricInstrument RR_FALLBACK_COUNTER;
private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER;
private static final LongCounterMetricInstrument ENDPOINT_WEIGHT_STALE_COUNTER;
private static final DoubleHistogramMetricInstrument ENDPOINT_WEIGHTS_HISTOGRAM;
private static final Logger log = Logger.getLogger(
WeightedRoundRobinLoadBalancer.class.getName());
private WeightedRoundRobinLoadBalancerConfig config;
Expand All @@ -74,6 +83,31 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
private final long infTime;
private final Ticker ticker;

// The metric instruments are only registered once and shared by all instances of this LB.
static {
MetricInstrumentRegistry metricInstrumentRegistry
= MetricInstrumentRegistry.getDefaultRegistry();
RR_FALLBACK_COUNTER = metricInstrumentRegistry.registerLongCounter("grpc.lb.wrr.rr_fallback",
"Number of scheduler updates in which there were not enough endpoints with valid "
+ "weight, which caused the WRR policy to fall back to RR behavior", "update",
Lists.newArrayList("grpc.target"), Lists.newArrayList("grpc.lb.locality"), true);
ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER = metricInstrumentRegistry.registerLongCounter(
"grpc.lb.wrr.endpoint_weight_not_yet_usable",
"Number of endpoints from each scheduler update that don't yet have usable weight "
+ "information", "endpoint", Lists.newArrayList("grpc.target"),
Lists.newArrayList("grpc.lb.locality"), true);
ENDPOINT_WEIGHT_STALE_COUNTER = metricInstrumentRegistry.registerLongCounter(
"grpc.lb.wrr.endpoint_weight_stale",
"Number of endpoints from each scheduler update whose latest weight is older than the "
+ "expiration period", "endpoint", Lists.newArrayList("grpc.target"),
Lists.newArrayList("grpc.lb.locality"), true);
ENDPOINT_WEIGHTS_HISTOGRAM = metricInstrumentRegistry.registerDoubleHistogram(
"grpc.lb.wrr.endpoint_weights", "The histogram buckets will be endpoint weight ranges.",
"weight", Lists.newArrayList(), Lists.newArrayList("grpc.target"),
Lists.newArrayList("grpc.lb.locality"),
true);
}

public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) {
this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, new Random());
}
Expand Down Expand Up @@ -145,7 +179,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
@Override
public SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) {
return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
config.enableOobLoadReport, config.errorUtilizationPenalty, sequence);
config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, getHelper());
}

@VisibleForTesting
Expand All @@ -163,16 +197,18 @@ public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Obj
super(key, policyProvider, childConfig, initialPicker);
}

private double getWeight() {
private double getWeight(AtomicInteger staleEndpoints, AtomicInteger notYetUsableEndpoints) {
if (config == null) {
return 0;
}
long now = ticker.nanoTime();
if (now - lastUpdated >= config.weightExpirationPeriodNanos) {
nonEmptySince = infTime;
staleEndpoints.incrementAndGet();
return 0;
} else if (now - nonEmptySince < config.blackoutPeriodNanos
&& config.blackoutPeriodNanos > 0) {
notYetUsableEndpoints.incrementAndGet();
return 0;
} else {
return weight;
Expand Down Expand Up @@ -336,10 +372,11 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker {
private final float errorUtilizationPenalty;
private final AtomicInteger sequence;
private final int hashCode;
private final LoadBalancer.Helper helper;
private volatile StaticStrideScheduler scheduler;

WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
float errorUtilizationPenalty, AtomicInteger sequence) {
float errorUtilizationPenalty, AtomicInteger sequence, LoadBalancer.Helper helper) {
checkNotNull(children, "children");
Preconditions.checkArgument(!children.isEmpty(), "empty child list");
this.children = children;
Expand All @@ -353,6 +390,7 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker {
this.enableOobLoadReport = enableOobLoadReport;
this.errorUtilizationPenalty = errorUtilizationPenalty;
this.sequence = checkNotNull(sequence, "sequence");
this.helper = helper;

// For equality we treat children as a set; use hash code as defined by Set
int sum = 0;
Expand Down Expand Up @@ -387,11 +425,37 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {

private void updateWeight() {
float[] newWeights = new float[children.size()];
AtomicInteger staleEndpoints = new AtomicInteger();
AtomicInteger notYetUsableEndpoints = new AtomicInteger();
for (int i = 0; i < children.size(); i++) {
double newWeight = ((WeightedChildLbState)children.get(i)).getWeight();
double newWeight = ((WeightedChildLbState) children.get(i)).getWeight(staleEndpoints,
notYetUsableEndpoints);
// TODO: add target and locality labels once available
helper.getMetricRecorder()
.recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight, ImmutableList.of(""),
ImmutableList.of(""));
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
}
if (staleEndpoints.get() > 0) {
// TODO: add target and locality labels once available
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(),
ImmutableList.of(""),
ImmutableList.of(""));
}
if (notYetUsableEndpoints.get() > 0) {
// TODO: add target and locality labels once available
helper.getMetricRecorder()
.addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(),
ImmutableList.of(""), ImmutableList.of(""));
}

this.scheduler = new StaticStrideScheduler(newWeights, sequence);
if (this.scheduler.usesRoundRobin()) {
// TODO: add target and locality labels once available
helper.getMetricRecorder()
.addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(""), ImmutableList.of(""));
}
}

@Override
Expand Down Expand Up @@ -454,6 +518,7 @@ public boolean equals(Object o) {
static final class StaticStrideScheduler {
private final short[] scaledWeights;
private final AtomicInteger sequence;
private final boolean usesRoundRobin;
private static final int K_MAX_WEIGHT = 0xFFFF;

// Assuming the mean of all known weights is M, StaticStrideScheduler will clamp
Expand Down Expand Up @@ -494,8 +559,10 @@ static final class StaticStrideScheduler {
if (numWeightedChannels > 0) {
unscaledMeanWeight = sumWeight / numWeightedChannels;
unscaledMaxWeight = Math.min(unscaledMaxWeight, (float) (K_MAX_RATIO * unscaledMeanWeight));
usesRoundRobin = false;
} else {
// Fall back to round robin if all values are non-positives
usesRoundRobin = true;
unscaledMeanWeight = 1;
unscaledMaxWeight = 1;
}
Expand All @@ -521,7 +588,14 @@ static final class StaticStrideScheduler {
this.sequence = sequence;
}

/** Returns the next sequence number and atomically increases sequence with wraparound. */
// Without properly weighted channels, we do plain vanilla round_robin.
boolean usesRoundRobin() {
return usesRoundRobin;
}

/**
* Returns the next sequence number and atomically increases sequence with wraparound.
*/
private long nextSequence() {
return Integer.toUnsignedLong(sequence.getAndIncrement());
}
Expand Down
137 changes: 137 additions & 0 deletions xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.ConnectivityState.CONNECTING;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
Expand All @@ -40,6 +41,7 @@
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.DoubleHistogramMetricInstrument;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancer.CreateSubchannelArgs;
Expand All @@ -49,6 +51,8 @@
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.LongCounterMetricInstrument;
import io.grpc.MetricRecorder;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock;
Expand Down Expand Up @@ -82,6 +86,7 @@
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
Expand Down Expand Up @@ -120,6 +125,9 @@ public class WeightedRoundRobinLoadBalancerTest {

private final FakeClock fakeClock = new FakeClock();

@Mock
private MetricRecorder mockMetricRecorder;

private WeightedRoundRobinLoadBalancerConfig weightedConfig =
WeightedRoundRobinLoadBalancerConfig.newBuilder().build();

Expand Down Expand Up @@ -1121,6 +1129,130 @@ public void removingAddressShutsdownSubchannel() {
inOrder.verify(subchannel2).shutdown();
}


@Test
public void metrics() {
// Give WRR some valid addresses to work with.
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));

// Flip the three subchannels to READY state to initiate the WRR logic
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel3 = it.next();
getSubchannelStateListener(readySubchannel3).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));

// WRR creates a picker that updates the weights for each of the child subchannels. This should
// give us three "rr_fallback" metric events as we don't yet have any weights to do weighted
// round-robin.
verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 3, 1);

// We should also see six records of endpoint weights. They should all be for 0 as we don't yet
// have valid weights.
verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 6, 0);

// We should not yet be seeing any "endpoint_weight_stale" events since we don't even have
// valid weights yet.
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 0, 1);

// Each time weights are updated, WRR will see if each subchannel weight is useable. As we have
// no weights yet, we should see three "endpoint_weight_not_yet_usable" metric events with the
// value increasing by one each time as all the endpoints come online.
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 1);
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 2);
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 1, 3);

// Send each child LB state an ORCA update with some valid utilization/qps data so that weights
// can be calculated.
Iterator<ChildLbState> childLbStates = wrr.getChildLbStates().iterator();
((WeightedChildLbState)childLbStates.next()).new OrcaReportListener(
weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(),
new HashMap<>(), new HashMap<>()));
((WeightedChildLbState)childLbStates.next()).new OrcaReportListener(
weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(),
new HashMap<>(), new HashMap<>()));
((WeightedChildLbState)childLbStates.next()).new OrcaReportListener(
weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(0.1, 0, 0.1, 1, 0, new HashMap<>(),
new HashMap<>(), new HashMap<>()));

// Let's reset the mock MetricsRecorder so that it's easier to verify what happened after the
// weights were updated
reset(mockMetricRecorder);

// We go forward in time past the default 10s blackout period before weights can be considered
// for wrr. The eights would get updated as the default update interval is 1s.
fakeClock.forwardTime(11, TimeUnit.SECONDS);

// Since we have weights on all the child LB states, the weight update should not result in
// further rr_fallback metric entries.
verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 0, 1);

// We should not see an increase to the earlier count of "endpoint_weight_not_yet_usable".
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 0, 1);

// No endpoints should have gotten stale yet either.
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 0, 1);

// Now with valid weights we should have seen the value in the endpoint weights histogram.
verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 3, 10);

reset(mockMetricRecorder);

// Weights become stale in three minutes. Let's move ahead in time by 3 minutes and make sure
// we get metrics events for each endpoint.
fakeClock.forwardTime(3, TimeUnit.MINUTES);

verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_stale", 1, 3);

// With the weights stale each three endpoints should report 0 weights.
verifyDoubleHistogramRecord("grpc.lb.wrr.endpoint_weights", 3, 0);

// Since the weights are now stale the update should have triggered an additional rr_fallback
// event.
verifyLongCounterRecord("grpc.lb.wrr.rr_fallback", 1, 1);

// No further weights-not-useable events should occur, since we have received weights and
// are out of the blackout.
verifyLongCounterRecord("grpc.lb.wrr.endpoint_weight_not_yet_usable", 0, 1);

// All metric events should be accounted for.
verifyNoMoreInteractions(mockMetricRecorder);
}

// Verifies that the MetricRecorder has been called to record a long counter value of 1 for the
// given metric name, the given number of times
private void verifyLongCounterRecord(String name, int times, long value) {
verify(mockMetricRecorder, times(times)).addLongCounter(
argThat(new ArgumentMatcher<LongCounterMetricInstrument>() {
@Override
public boolean matches(LongCounterMetricInstrument longCounterInstrument) {
return longCounterInstrument.getName().equals(name);
}
}), eq(value), eq(Lists.newArrayList("")), eq(Lists.newArrayList("")));
}

// Verifies that the MetricRecorder has been called to record a given double histogram value the
// given amount of times.
private void verifyDoubleHistogramRecord(String name, int times, double value) {
verify(mockMetricRecorder, times(times)).recordDoubleHistogram(
argThat(new ArgumentMatcher<DoubleHistogramMetricInstrument>() {
@Override
public boolean matches(DoubleHistogramMetricInstrument doubleHistogramInstrument) {
return doubleHistogramInstrument.getName().equals(name);
}
}), eq(value), eq(Lists.newArrayList("")), eq(Lists.newArrayList("")));
}

private int getNumFilteredPendingTasks() {
return AbstractTestHelper.getNumFilteredPendingTasks(fakeClock);
}
Expand Down Expand Up @@ -1189,5 +1321,10 @@ public Map<Subchannel, Subchannel> getMockToRealSubChannelMap() {
public Map<Subchannel, SubchannelStateListener> getSubchannelStateListeners() {
return subchannelStateListeners;
}

@Override
public MetricRecorder getMetricRecorder() {
return mockMetricRecorder;
}
}
}