From 8422cc14bb2642c2e4b7d5b5fda115f05c1bc6e9 Mon Sep 17 00:00:00 2001 From: Terry Wilson Date: Mon, 23 May 2022 15:53:13 -0700 Subject: [PATCH] xds: Use weighted_target LB provider in wrr_locality Fixes a bug where WrrLocalityLoadBalancer would use the endpoint picking policy provider instead of WeightedTargetLoadBalancerProvider. Also adds a test to fake control plane integration test that caught this bug. The test scaffolding is also updated to have the test server echo all client headers back in the response. The test load balancer in the test is an almost straight copy of: https://github.com/grpc/grpc-java/blob/master/interop-testing/src/test/java/io/grpc/testing/integration/RpcBehaviorLoadBalancerProviderTest.java --- .../io/grpc/xds/WrrLocalityLoadBalancer.java | 10 +- .../FakeControlPlaneXdsIntegrationTest.java | 129 +++++++++++++ .../xds/MetadataLoadBalancerProvider.java | 173 ++++++++++++++++++ .../grpc/xds/WrrLocalityLoadBalancerTest.java | 60 ++++-- 4 files changed, 355 insertions(+), 17 deletions(-) create mode 100644 xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java diff --git a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java index 61c6b103d81..f320bc340ee 100644 --- a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java @@ -18,10 +18,12 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME; import com.google.common.base.MoreObjects; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerRegistry; import io.grpc.Status; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.util.GracefulSwitchLoadBalancer; @@ -43,9 +45,15 @@ final class WrrLocalityLoadBalancer extends LoadBalancer { private final XdsLogger logger; private final Helper helper; private final GracefulSwitchLoadBalancer switchLb; + private final LoadBalancerRegistry lbRegistry; WrrLocalityLoadBalancer(Helper helper) { + this(helper, LoadBalancerRegistry.getDefaultRegistry()); + } + + WrrLocalityLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry) { this.helper = checkNotNull(helper, "helper"); + this.lbRegistry = lbRegistry; switchLb = new GracefulSwitchLoadBalancer(helper); logger = XdsLogger.withLogId( InternalLogId.allocate("xds-wrr-locality-lb", helper.getAuthority())); @@ -88,7 +96,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { .setAttributes(resolvedAddresses.getAttributes().toBuilder() .discard(InternalXdsAttributes.ATTR_LOCALITY_WEIGHTS).build()).build(); - switchLb.switchTo(wrrLocalityConfig.childPolicy.getProvider()); + switchLb.switchTo(lbRegistry.getProvider(WEIGHTED_TARGET_POLICY_NAME)); switchLb.handleResolvedAddresses( resolvedAddresses.toBuilder() .setLoadBalancingPolicyConfig(new WeightedTargetConfig(weightedPolicySelections)) diff --git a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java index 61064345344..02dcfd21635 100644 --- a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java +++ b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java @@ -17,23 +17,30 @@ package io.grpc.xds; +import static com.google.common.truth.Truth.assertThat; import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; import static org.junit.Assert.assertEquals; +import com.github.xds.type.v3.TypedStruct; import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; import com.google.protobuf.Message; +import com.google.protobuf.Struct; import com.google.protobuf.UInt32Value; +import com.google.protobuf.Value; import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; import io.envoyproxy.envoy.config.core.v3.Address; import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; import io.envoyproxy.envoy.config.core.v3.ConfigSource; import io.envoyproxy.envoy.config.core.v3.HealthStatus; import io.envoyproxy.envoy.config.core.v3.SocketAddress; import io.envoyproxy.envoy.config.core.v3.TrafficDirection; +import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; @@ -53,12 +60,27 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; import io.grpc.InsecureServerCredentials; +import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; import io.grpc.NameResolverRegistry; import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; import io.grpc.netty.NettyServerBuilder; import io.grpc.stub.StreamObserver; import io.grpc.testing.protobuf.SimpleRequest; @@ -100,6 +122,7 @@ public class FakeControlPlaneXdsIntegrationTest { private Server controlPlane; private XdsTestControlPlaneService controlPlaneService; private XdsNameResolverProvider nameResolverProvider; + private MetadataLoadBalancerProvider metadataLoadBalancerProvider; protected int testServerPort = 0; protected int controlPlaneServicePort; @@ -135,10 +158,13 @@ public class FakeControlPlaneXdsIntegrationTest { */ @Before public void setUp() throws Exception { + ClientXdsClient.enableCustomLbConfig = true; startControlPlane(); nameResolverProvider = XdsNameResolverProvider.createForTest(SCHEME, defaultBootstrapOverride()); NameResolverRegistry.getDefaultRegistry().register(nameResolverProvider); + metadataLoadBalancerProvider = new MetadataLoadBalancerProvider(); + LoadBalancerRegistry.getDefaultRegistry().register(metadataLoadBalancerProvider); } @After @@ -156,6 +182,7 @@ public void tearDown() throws Exception { } } NameResolverRegistry.getDefaultRegistry().deregister(nameResolverProvider); + LoadBalancerRegistry.getDefaultRegistry().deregister(metadataLoadBalancerProvider); } @Test @@ -186,7 +213,108 @@ serverHostName, clientListener(serverHostName) assertEquals(goldenResponse, blockingStub.unaryRpc(request)); } + @Test + public void pingPong_metadataLoadBalancer() throws Exception { + String tcpListenerName = SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT; + String serverHostName = "test-server"; + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, ImmutableMap.of( + tcpListenerName, serverListener(tcpListenerName), + serverHostName, clientListener(serverHostName) + )); + startServer(defaultBootstrapOverride()); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_RDS, + ImmutableMap.of(RDS_NAME, rds(serverHostName))); + + // Use the LoadBalancingPolicy to configure a custom LB that adds a header to server calls. + Policy metadataLbPolicy = Policy.newBuilder().setTypedExtensionConfig( + TypedExtensionConfig.newBuilder().setTypedConfig(Any.pack( + TypedStruct.newBuilder().setTypeUrl("type.googleapis.com/test.MetadataLoadBalancer") + .setValue(Struct.newBuilder() + .putFields("metadataKey", Value.newBuilder().setStringValue("foo").build()) + .putFields("metadataValue", Value.newBuilder().setStringValue("bar").build())) + .build()))).build(); + Policy wrrLocalityPolicy = Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder().setTypedConfig( + Any.pack(WrrLocality.newBuilder().setEndpointPickingPolicy( + LoadBalancingPolicy.newBuilder().addPolicies(metadataLbPolicy)).build()))).build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, + ImmutableMap.of(CLUSTER_NAME, cds().toBuilder().setLoadBalancingPolicy( + LoadBalancingPolicy.newBuilder() + .addPolicies(wrrLocalityPolicy)).build())); + + InetSocketAddress edsInetSocketAddress = (InetSocketAddress) server.getListenSockets().get(0); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, + ImmutableMap.of(EDS_NAME, eds(edsInetSocketAddress.getHostName(), + edsInetSocketAddress.getPort()))); + ManagedChannel channel = Grpc.newChannelBuilder(SCHEME + ":///" + serverHostName, + InsecureChannelCredentials.create()).build(); + ResponseHeaderClientInterceptor responseHeaderInterceptor + = new ResponseHeaderClientInterceptor(); + + // We add an interceptor to catch the response headers from the server. + blockingStub = SimpleServiceGrpc.newBlockingStub(channel) + .withInterceptors(responseHeaderInterceptor); + SimpleRequest request = SimpleRequest.newBuilder() + .build(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS!") + .build(); + assertEquals(goldenResponse, blockingStub.unaryRpc(request)); + + // Make sure we got back the header we configured the LB with. + assertThat(responseHeaderInterceptor.reponseHeaders.get( + Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER))).isEqualTo("bar"); + } + + // Captures response headers from the server. + private class ResponseHeaderClientInterceptor implements ClientInterceptor { + Metadata reponseHeaders; + + @Override + public ClientCall interceptCall(MethodDescriptor method, + CallOptions callOptions, Channel next) { + + return new SimpleForwardingClientCall(next.newCall(method, callOptions)) { + @Override + public void start(ClientCall.Listener responseListener, Metadata headers) { + super.start(new ForwardingClientCallListener() { + @Override + protected ClientCall.Listener delegate() { + return responseListener; + } + + @Override + public void onHeaders(Metadata headers) { + reponseHeaders = headers; + } + }, headers); + } + }; + } + } + private void startServer(Map bootstrapOverride) throws Exception { + ServerInterceptor metadataInterceptor = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata requestHeaders, ServerCallHandler next) { + logger.fine("Received following metadata: " + requestHeaders); + + return next.startCall(new SimpleForwardingServerCall(call) { + @Override + public void sendHeaders(Metadata responseHeaders) { + responseHeaders.merge(requestHeaders); + super.sendHeaders(responseHeaders); + } + + @Override + public void close(Status status, Metadata trailers) { + super.close(status, trailers); + } + }, requestHeaders); + } + }; + SimpleServiceGrpc.SimpleServiceImplBase simpleServiceImpl = new SimpleServiceGrpc.SimpleServiceImplBase() { @Override @@ -202,6 +330,7 @@ public void unaryRpc( XdsServerBuilder serverBuilder = XdsServerBuilder.forPort( 0, InsecureServerCredentials.create()) .addService(simpleServiceImpl) + .intercept(metadataInterceptor) .overrideBootstrapForTest(bootstrapOverride); server = serverBuilder.build().start(); testServerPort = server.getPort(); diff --git a/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java b/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java new file mode 100644 index 00000000000..f52e9c815c9 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java @@ -0,0 +1,173 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.ConnectivityState; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.Metadata; +import io.grpc.NameResolver; +import io.grpc.Status; +import io.grpc.internal.JsonUtil; +import io.grpc.util.ForwardingLoadBalancer; +import io.grpc.util.ForwardingLoadBalancerHelper; +import java.util.Map; +import javax.annotation.Nonnull; + +/** + * A custom LB for testing purposes that simply delegates to round_robin and adds a metadata entry + * to each request. + */ +public class MetadataLoadBalancerProvider extends LoadBalancerProvider { + + @Override + public NameResolver.ConfigOrError parseLoadBalancingPolicyConfig( + Map rawLoadBalancingPolicyConfig) { + String metadataKey = JsonUtil.getString(rawLoadBalancingPolicyConfig, "metadataKey"); + if (metadataKey == null) { + return NameResolver.ConfigOrError.fromError( + Status.INVALID_ARGUMENT.withDescription("no 'metadataKey' defined")); + } + + String metadataValue = JsonUtil.getString(rawLoadBalancingPolicyConfig, "metadataValue"); + if (metadataValue == null) { + return NameResolver.ConfigOrError.fromError( + Status.INVALID_ARGUMENT.withDescription("no 'metadataValue' defined")); + } + + return NameResolver.ConfigOrError.fromConfig( + new MetadataLoadBalancerConfig(metadataKey, metadataValue)); + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + MetadataHelper metadataHelper = new MetadataHelper(helper); + return new MetadataLoadBalancer(metadataHelper, + LoadBalancerRegistry.getDefaultRegistry().getProvider("round_robin") + .newLoadBalancer(metadataHelper)); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "test.MetadataLoadBalancer"; + } + + static class MetadataLoadBalancerConfig { + + final String metadataKey; + final String metadataValue; + + MetadataLoadBalancerConfig(String metadataKey, String metadataValue) { + this.metadataKey = metadataKey; + this.metadataValue = metadataValue; + } + } + + static class MetadataLoadBalancer extends ForwardingLoadBalancer { + + private final MetadataHelper helper; + private final LoadBalancer delegateLb; + + MetadataLoadBalancer(MetadataHelper helper, LoadBalancer delegateLb) { + this.helper = helper; + this.delegateLb = delegateLb; + } + + @Override + protected LoadBalancer delegate() { + return delegateLb; + } + + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + MetadataLoadBalancerConfig config + = (MetadataLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + helper.setMetadata(config.metadataKey, config.metadataValue); + delegateLb.handleResolvedAddresses(resolvedAddresses); + } + } + + /** + * Wraps the picker that is provided when the balancing change updates with the {@link + * MetadataPicker} that injects the metadata entry. + */ + static class MetadataHelper extends ForwardingLoadBalancerHelper { + + private final Helper delegateHelper; + private String metadataKey; + private String metadataValue; + + MetadataHelper(Helper delegateHelper) { + this.delegateHelper = delegateHelper; + } + + void setMetadata(String metadataKey, String metadataValue) { + this.metadataKey = metadataKey; + this.metadataValue = metadataValue; + } + + @Override + protected Helper delegate() { + return delegateHelper; + } + + @Override + public void updateBalancingState(@Nonnull ConnectivityState newState, + @Nonnull SubchannelPicker newPicker) { + delegateHelper.updateBalancingState(newState, + new MetadataPicker(newPicker, metadataKey, metadataValue)); + } + } + + /** + * Includes the rpc-behavior metadata entry on each subchannel pick. + */ + static class MetadataPicker extends SubchannelPicker { + + private final SubchannelPicker delegatePicker; + private final String metadataKey; + private final String metadataValue; + + MetadataPicker(SubchannelPicker delegatePicker, String metadataKey, String metadataValue) { + this.delegatePicker = delegatePicker; + this.metadataKey = metadataKey; + this.metadataValue = metadataValue; + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + args.getHeaders() + .put(Metadata.Key.of(metadataKey, Metadata.ASCII_STRING_MARSHALLER), metadataValue); + return delegatePicker.pickSubchannel(args); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java index f1474fcf150..29777a5284f 100644 --- a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.never; @@ -34,7 +35,9 @@ import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; @@ -62,7 +65,11 @@ public class WrrLocalityLoadBalancerTest { public final MockitoRule mockito = MockitoJUnit.rule(); @Mock - private LoadBalancerProvider mockProvider; + private LoadBalancerProvider mockWeightedTargetProvider; + @Mock + private LoadBalancer mockWeightedTargetLb; + @Mock + private LoadBalancerProvider mockChildProvider; @Mock private LoadBalancer mockChildLb; @Mock @@ -80,12 +87,31 @@ public class WrrLocalityLoadBalancerTest { private final EquivalentAddressGroup eag = new EquivalentAddressGroup(mockSocketAddress); private WrrLocalityLoadBalancer loadBalancer; + private LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); @Before public void setUp() { - when(mockProvider.newLoadBalancer(isA(Helper.class))).thenReturn(mockChildLb); - when(mockProvider.getPolicyName()).thenReturn("round_robin"); - loadBalancer = new WrrLocalityLoadBalancer(mockHelper); + when(mockHelper.getSynchronizationContext()).thenReturn(syncContext); + + when(mockWeightedTargetProvider.newLoadBalancer(isA(Helper.class))).thenReturn( + mockWeightedTargetLb); + when(mockWeightedTargetProvider.getPolicyName()).thenReturn(WEIGHTED_TARGET_POLICY_NAME); + when(mockWeightedTargetProvider.isAvailable()).thenReturn(true); + lbRegistry.register(mockWeightedTargetProvider); + + when(mockChildProvider.newLoadBalancer(isA(Helper.class))).thenReturn(mockChildLb); + when(mockChildProvider.getPolicyName()).thenReturn("round_robin"); + lbRegistry.register(mockWeightedTargetProvider); + + loadBalancer = new WrrLocalityLoadBalancer(mockHelper, lbRegistry); } @Test @@ -93,7 +119,7 @@ public void handleResolvedAddresses() { // A two locality cluster with a mock child LB policy. Locality localityOne = Locality.create("region1", "zone1", "subzone1"); Locality localityTwo = Locality.create("region2", "zone2", "subzone2"); - PolicySelection childPolicy = new PolicySelection(mockProvider, null); + PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); // The child config is delivered wrapped in the wrr_locality config and the locality weights // in a ResolvedAddresses attribute. @@ -103,7 +129,7 @@ public void handleResolvedAddresses() { // Assert that the child policy and the locality weights were correctly mapped to a // WeightedTargetConfig. - verify(mockChildLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(mockWeightedTargetLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); Object config = resolvedAddressesCaptor.getValue().getLoadBalancingPolicyConfig(); assertThat(config).isInstanceOf(WeightedTargetConfig.class); WeightedTargetConfig wtConfig = (WeightedTargetConfig) config; @@ -117,7 +143,7 @@ public void handleResolvedAddresses() { @Test public void handleResolvedAddresses_noLocalityWeights() { // A two locality cluster with a mock child LB policy. - PolicySelection childPolicy = new PolicySelection(mockProvider, null); + PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); // The child config is delivered wrapped in the wrr_locality config and the locality weights // in a ResolvedAddresses attribute. @@ -143,20 +169,20 @@ public void handleNameResolutionError_noChildLb() { @Test public void handleNameResolutionError_withChildLb() { - deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockProvider, null)), + deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockChildProvider, null)), ImmutableMap.of( Locality.create("region", "zone", "subzone"), 1)); loadBalancer.handleNameResolutionError(Status.DEADLINE_EXCEEDED); verify(mockHelper, never()).updateBalancingState(isA(ConnectivityState.class), isA(ErrorPicker.class)); - verify(mockChildLb).handleNameResolutionError(Status.DEADLINE_EXCEEDED); + verify(mockWeightedTargetLb).handleNameResolutionError(Status.DEADLINE_EXCEEDED); } @Test public void localityWeightAttributeNotPropagated() { Locality locality = Locality.create("region1", "zone1", "subzone1"); - PolicySelection childPolicy = new PolicySelection(mockProvider, null); + PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); WrrLocalityConfig wlConfig = new WrrLocalityConfig(childPolicy); Map localityWeights = ImmutableMap.of(locality, 1); @@ -164,27 +190,29 @@ public void localityWeightAttributeNotPropagated() { // Assert that the child policy and the locality weights were correctly mapped to a // WeightedTargetConfig. - verify(mockChildLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(mockWeightedTargetLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); assertThat(resolvedAddressesCaptor.getValue().getAttributes() .get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHTS)).isNull(); } @Test public void shutdown() { - deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockProvider, null)), + deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockChildProvider, null)), ImmutableMap.of( Locality.create("region", "zone", "subzone"), 1)); loadBalancer.shutdown(); - verify(mockChildLb).shutdown(); + verify(mockWeightedTargetLb).shutdown(); } @Test public void configEquality() { - WrrLocalityConfig configOne = new WrrLocalityConfig(new PolicySelection(mockProvider, null)); - WrrLocalityConfig configTwo = new WrrLocalityConfig(new PolicySelection(mockProvider, null)); + WrrLocalityConfig configOne = new WrrLocalityConfig( + new PolicySelection(mockChildProvider, null)); + WrrLocalityConfig configTwo = new WrrLocalityConfig( + new PolicySelection(mockChildProvider, null)); WrrLocalityConfig differentConfig = new WrrLocalityConfig( - new PolicySelection(mockProvider, "config")); + new PolicySelection(mockChildProvider, "config")); new EqualsTester().addEqualityGroup(configOne, configTwo).addEqualityGroup(differentConfig) .testEquals();