Skip to content

Commit

Permalink
xds: do not expose orca proto in ORCA api (#9366)
Browse files Browse the repository at this point in the history
The fix avoids shaded dependency of orca protos
  • Loading branch information
YifeiZhuang committed Jul 28, 2022
1 parent 289a442 commit 25183ed
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 36 deletions.
Expand Up @@ -19,11 +19,11 @@
import static io.grpc.testing.integration.AbstractInteropTest.ORCA_OOB_REPORT_KEY;
import static io.grpc.testing.integration.AbstractInteropTest.ORCA_RPC_REPORT_KEY;

import com.github.xds.data.orca.v3.OrcaLoadReport;
import io.grpc.ConnectivityState;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry;
import io.grpc.services.CallMetricRecorder;
import io.grpc.testing.integration.Messages.TestOrcaReport;
import io.grpc.util.ForwardingLoadBalancer;
import io.grpc.util.ForwardingLoadBalancerHelper;
Expand Down Expand Up @@ -87,8 +87,8 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) {
Subchannel subchannel = super.createSubchannel(args);
OrcaOobUtil.setListener(subchannel, new OrcaOobUtil.OrcaOobReportListener() {
@Override
public void onLoadReport(OrcaLoadReport orcaLoadReport) {
latestOobReport = fromOrcaLoadReport(orcaLoadReport);
public void onLoadReport(CallMetricRecorder.CallMetricReport orcaLoadReport) {
latestOobReport = fromCallMetricReport(orcaLoadReport);
}
},
OrcaOobUtil.OrcaReportingConfig.newBuilder()
Expand Down Expand Up @@ -133,22 +133,23 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
new OrcaPerRequestUtil.OrcaPerRequestReportListener() {
@Override
public void onLoadReport(OrcaLoadReport orcaLoadReport) {
public void onLoadReport(CallMetricRecorder.CallMetricReport callMetricReport) {
AtomicReference<TestOrcaReport> reportRef =
args.getCallOptions().getOption(ORCA_RPC_REPORT_KEY);
reportRef.set(fromOrcaLoadReport(orcaLoadReport));
reportRef.set(fromCallMetricReport(callMetricReport));
}
}));
}
}
}

private static TestOrcaReport fromOrcaLoadReport(OrcaLoadReport orcaLoadReport) {
private static TestOrcaReport fromCallMetricReport(
CallMetricRecorder.CallMetricReport callMetricReport) {
return TestOrcaReport.newBuilder()
.setCpuUtilization(orcaLoadReport.getCpuUtilization())
.setMemoryUtilization(orcaLoadReport.getMemUtilization())
.putAllRequestCost(orcaLoadReport.getRequestCostMap())
.putAllUtilization(orcaLoadReport.getUtilizationMap())
.setCpuUtilization(callMetricReport.getCpuUtilization())
.setMemoryUtilization(callMetricReport.getMemoryUtilization())
.putAllRequestCost(callMetricReport.getRequestCostMetrics())
.putAllUtilization(callMetricReport.getUtilizationMetrics())
.build();
}
}
Expand Up @@ -53,7 +53,8 @@ public static final class CallMetricReport {
private Map<String, Double> utilizationMetrics;

/**
* Create a report for all backend metrics.
* A gRPC object of orca load report. LB policies listening at per-rpc or oob orca load reports
* will be notified of the metrics data in this data format.
*/
CallMetricReport(double cpuUtilization, double memoryUtilization,
Map<String, Double> requestCostMetrics,
Expand Down
Expand Up @@ -44,4 +44,12 @@ public static Map<String, Double> finalizeAndDump(CallMetricRecorder recorder) {
public static CallMetricRecorder.CallMetricReport finalizeAndDump2(CallMetricRecorder recorder) {
return recorder.finalizeAndDump2();
}

public static CallMetricRecorder.CallMetricReport createMetricReport(
double cpuUtilization, double memoryUtilization,
Map<String, Double> requestCostMetrics,
Map<String, Double> utilizationMetrics) {
return new CallMetricRecorder.CallMetricReport(cpuUtilization, memoryUtilization,
requestCostMetrics, utilizationMetrics);
}
}
9 changes: 6 additions & 3 deletions xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java
Expand Up @@ -51,6 +51,7 @@
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.ExponentialBackoffPolicy;
import io.grpc.internal.GrpcUtil;
import io.grpc.services.CallMetricRecorder;
import io.grpc.util.ForwardingLoadBalancerHelper;
import io.grpc.util.ForwardingSubchannel;
import java.util.HashMap;
Expand Down Expand Up @@ -168,9 +169,9 @@ public interface OrcaOobReportListener {
* <p>Note this callback will be invoked from the {@link SynchronizationContext} of the
* delegated helper, implementations should not block.
*
* @param report load report in the format of ORCA protocol.
* @param report load report in the format of grpc {@link CallMetricRecorder.CallMetricReport}.
*/
void onLoadReport(OrcaLoadReport report);
void onLoadReport(CallMetricRecorder.CallMetricReport report);
}

static final Attributes.Key<SubchannelImpl> ORCA_REPORTING_STATE_KEY =
Expand Down Expand Up @@ -450,8 +451,10 @@ void handleResponse(OrcaLoadReport response) {
callHasResponded = true;
backoffPolicy = null;
subchannelLogger.log(ChannelLogLevel.DEBUG, "Received an ORCA report: {0}", response);
CallMetricRecorder.CallMetricReport metricReport =
OrcaPerRequestUtil.fromOrcaLoadReport(response);
for (OrcaOobReportListener listener : configs.keySet()) {
listener.onLoadReport(response);
listener.onLoadReport(metricReport);
}
call.request(1);
}
Expand Down
17 changes: 13 additions & 4 deletions xds/src/main/java/io/grpc/xds/orca/OrcaPerRequestUtil.java
Expand Up @@ -28,6 +28,8 @@
import io.grpc.Metadata;
import io.grpc.internal.ForwardingClientStreamTracer;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.services.CallMetricRecorder;
import io.grpc.services.InternalCallMetricRecorder;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -175,14 +177,14 @@ public abstract ClientStreamTracer.Factory newOrcaClientStreamTracerFactory(
public interface OrcaPerRequestReportListener {

/**
* Invoked when an per-request ORCA report is received.
* Invoked when a per-request ORCA report is received.
*
* <p>Note this callback will be invoked from the network thread as the RPC finishes,
* implementations should not block.
*
* @param report load report in the format of ORCA format.
* @param report load report in the format of grpc {@link CallMetricRecorder.CallMetricReport}.
*/
void onLoadReport(OrcaLoadReport report);
void onLoadReport(CallMetricRecorder.CallMetricReport report);
}

/**
Expand Down Expand Up @@ -250,6 +252,12 @@ public void inboundTrailers(Metadata trailers) {
}
}

static CallMetricRecorder.CallMetricReport fromOrcaLoadReport(OrcaLoadReport loadReport) {
return InternalCallMetricRecorder.createMetricReport(loadReport.getCpuUtilization(),
loadReport.getMemUtilization(), loadReport.getRequestCostMap(),
loadReport.getUtilizationMap());
}

/**
* A container class to hold registered {@link OrcaPerRequestReportListener}s and invoke all of
* them when an {@link OrcaLoadReport} is received.
Expand All @@ -263,8 +271,9 @@ void addListener(OrcaPerRequestReportListener listener) {
}

void onReport(OrcaLoadReport report) {
CallMetricRecorder.CallMetricReport metricReport = fromOrcaLoadReport(report);
for (OrcaPerRequestReportListener listener : listeners) {
listener.onLoadReport(report);
listener.onLoadReport(metricReport);
}
}
}
Expand Down
29 changes: 20 additions & 9 deletions xds/src/test/java/io/grpc/xds/orca/OrcaOobUtilTest.java
Expand Up @@ -25,6 +25,7 @@
import static io.grpc.ConnectivityState.SHUTDOWN;
import static org.junit.Assert.fail;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.inOrder;
Expand Down Expand Up @@ -61,6 +62,7 @@
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.FakeClock;
import io.grpc.services.CallMetricRecorder;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.util.ForwardingLoadBalancerHelper;
Expand Down Expand Up @@ -285,7 +287,9 @@ public void singlePolicyTypicalWorkflow() {
OrcaLoadReport report = OrcaLoadReport.getDefaultInstance();
serverCall.responseObserver.onNext(report);
assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report);
verify(mockOrcaListener0, times(i + 1)).onLoadReport(eq(report));
verify(mockOrcaListener0, times(i + 1)).onLoadReport(
argThat(new OrcaPerRequestUtilTest.MetricsReportMatcher(
OrcaPerRequestUtil.fromOrcaLoadReport(report))));
}

for (int i = 0; i < NUM_SUBCHANNELS; i++) {
Expand Down Expand Up @@ -369,7 +373,9 @@ public void twoLevelPoliciesTypicalWorkflow() {
OrcaLoadReport report = OrcaLoadReport.getDefaultInstance();
serverCall.responseObserver.onNext(report);
assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report);
verify(mockOrcaListener1, times(i + 1)).onLoadReport(eq(report));
verify(mockOrcaListener1, times(i + 1)).onLoadReport(
argThat(new OrcaPerRequestUtilTest.MetricsReportMatcher(
OrcaPerRequestUtil.fromOrcaLoadReport(report))));
}

for (int i = 0; i < NUM_SUBCHANNELS; i++) {
Expand Down Expand Up @@ -425,8 +431,9 @@ public void orcReportingDisabledWhenServiceNotImplemented() {
OrcaLoadReport report = OrcaLoadReport.getDefaultInstance();
serverCall.responseObserver.onNext(report);
assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report);
verify(mockOrcaListener0).onLoadReport(eq(report));

verify(mockOrcaListener0).onLoadReport(
argThat(new OrcaPerRequestUtilTest.MetricsReportMatcher(
OrcaPerRequestUtil.fromOrcaLoadReport(report))));
verifyNoInteractions(backoffPolicyProvider);
}

Expand Down Expand Up @@ -471,8 +478,9 @@ public void orcaReportingStreamClosedAndRetried() {
OrcaLoadReport report = OrcaLoadReport.getDefaultInstance();
orcaServiceImp.calls.peek().responseObserver.onNext(report);
assertLog(subchannel.logs, "DEBUG: Received an ORCA report: " + report);
inOrder.verify(mockOrcaListener0).onLoadReport(eq(report));

inOrder.verify(mockOrcaListener0).onLoadReport(
argThat(new OrcaPerRequestUtilTest.MetricsReportMatcher(
OrcaPerRequestUtil.fromOrcaLoadReport(report))));
// Server closes the ORCA reporting RPC after a response, will restart immediately.
orcaServiceImp.calls.poll().responseObserver.onCompleted();
assertThat(subchannel.logs).containsExactly(
Expand Down Expand Up @@ -659,17 +667,20 @@ public void policiesReceiveSameReportIndependently() {
orcaServiceImps[0].calls.peek().responseObserver.onNext(report);
assertLog(subchannels[0].logs, "DEBUG: Received an ORCA report: " + report);
// Only parent helper's listener receives the report.
ArgumentCaptor<OrcaLoadReport> parentReportCaptor = ArgumentCaptor.forClass(null);
ArgumentCaptor<CallMetricRecorder.CallMetricReport> parentReportCaptor =
ArgumentCaptor.forClass(null);
verify(mockOrcaListener1).onLoadReport(parentReportCaptor.capture());
assertThat(parentReportCaptor.getValue()).isEqualTo(report);
assertThat(OrcaPerRequestUtilTest.reportEqual(parentReportCaptor.getValue(),
OrcaPerRequestUtil.fromOrcaLoadReport(report))).isTrue();
verifyNoMoreInteractions(mockOrcaListener2);

// Now child helper also wants to receive reports.
OrcaOobUtil.setListener(childSubchannel, mockOrcaListener2, SHORT_INTERVAL_CONFIG);
orcaServiceImps[0].calls.peek().responseObserver.onNext(report);
assertLog(subchannels[0].logs, "DEBUG: Received an ORCA report: " + report);
// Both helper receives the same report instance.
ArgumentCaptor<OrcaLoadReport> childReportCaptor = ArgumentCaptor.forClass(null);
ArgumentCaptor<CallMetricRecorder.CallMetricReport> childReportCaptor =
ArgumentCaptor.forClass(null);
verify(mockOrcaListener1, times(2))
.onLoadReport(parentReportCaptor.capture());
verify(mockOrcaListener2)
Expand Down
49 changes: 40 additions & 9 deletions xds/src/test/java/io/grpc/xds/orca/OrcaPerRequestUtilTest.java
Expand Up @@ -19,7 +19,7 @@
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
Expand All @@ -28,15 +28,18 @@
import static org.mockito.Mockito.when;

import com.github.xds.data.orca.v3.OrcaLoadReport;
import com.google.common.base.Objects;
import io.grpc.ClientStreamTracer;
import io.grpc.Metadata;
import io.grpc.services.CallMetricRecorder;
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaReportingTracerFactory;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatcher;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

Expand Down Expand Up @@ -96,9 +99,33 @@ public void singlePolicyTypicalWorkflow() {
OrcaReportingTracerFactory.ORCA_ENDPOINT_LOAD_METRICS_KEY,
OrcaLoadReport.getDefaultInstance());
tracer.inboundTrailers(trailer);
ArgumentCaptor<OrcaLoadReport> reportCaptor = ArgumentCaptor.forClass(null);
ArgumentCaptor<CallMetricRecorder.CallMetricReport> reportCaptor =
ArgumentCaptor.forClass(null);
verify(orcaListener1).onLoadReport(reportCaptor.capture());
assertThat(reportCaptor.getValue()).isEqualTo(OrcaLoadReport.getDefaultInstance());
assertThat(reportEqual(reportCaptor.getValue(),
OrcaPerRequestUtil.fromOrcaLoadReport(OrcaLoadReport.getDefaultInstance()))).isTrue();
}

static final class MetricsReportMatcher implements
ArgumentMatcher<CallMetricRecorder.CallMetricReport> {
private CallMetricRecorder.CallMetricReport original;

public MetricsReportMatcher(CallMetricRecorder.CallMetricReport report) {
this.original = report;
}

@Override
public boolean matches(CallMetricRecorder.CallMetricReport argument) {
return reportEqual(original, argument);
}
}

static boolean reportEqual(CallMetricRecorder.CallMetricReport a,
CallMetricRecorder.CallMetricReport b) {
return a.getCpuUtilization() == b.getCpuUtilization()
&& a.getMemoryUtilization() == b.getMemoryUtilization()
&& Objects.equal(a.getRequestCostMetrics(), b.getRequestCostMetrics())
&& Objects.equal(a.getUtilizationMetrics(), b.getUtilizationMetrics());
}

/**
Expand Down Expand Up @@ -136,11 +163,14 @@ public void twoLevelPoliciesTypicalWorkflow() {
OrcaReportingTracerFactory.ORCA_ENDPOINT_LOAD_METRICS_KEY,
OrcaLoadReport.getDefaultInstance());
childTracer.inboundTrailers(trailer);
ArgumentCaptor<OrcaLoadReport> parentReportCap = ArgumentCaptor.forClass(null);
ArgumentCaptor<OrcaLoadReport> childReportCap = ArgumentCaptor.forClass(null);
ArgumentCaptor<CallMetricRecorder.CallMetricReport> parentReportCap =
ArgumentCaptor.forClass(null);
ArgumentCaptor<CallMetricRecorder.CallMetricReport> childReportCap =
ArgumentCaptor.forClass(null);
verify(orcaListener1).onLoadReport(parentReportCap.capture());
verify(orcaListener2).onLoadReport(childReportCap.capture());
assertThat(parentReportCap.getValue()).isEqualTo(OrcaLoadReport.getDefaultInstance());
assertThat(reportEqual(parentReportCap.getValue(),
OrcaPerRequestUtil.fromOrcaLoadReport(OrcaLoadReport.getDefaultInstance()))).isTrue();
assertThat(childReportCap.getValue()).isSameInstanceAs(parentReportCap.getValue());
}

Expand All @@ -159,11 +189,12 @@ public void onlyParentPolicyReceivesReportsIfCreatesOwnTracer() {
ClientStreamTracer parentTracer =
parentFactory.newClientStreamTracer(STREAM_INFO, new Metadata());
Metadata trailer = new Metadata();
OrcaLoadReport report = OrcaLoadReport.getDefaultInstance();
trailer.put(
OrcaReportingTracerFactory.ORCA_ENDPOINT_LOAD_METRICS_KEY,
OrcaLoadReport.getDefaultInstance());
OrcaReportingTracerFactory.ORCA_ENDPOINT_LOAD_METRICS_KEY, report);
parentTracer.inboundTrailers(trailer);
verify(orcaListener1).onLoadReport(eq(OrcaLoadReport.getDefaultInstance()));
verify(orcaListener1).onLoadReport(
argThat(new MetricsReportMatcher(OrcaPerRequestUtil.fromOrcaLoadReport(report))));
verifyNoInteractions(childFactory);
verifyNoInteractions(orcaListener2);
}
Expand Down

0 comments on commit 25183ed

Please sign in to comment.