diff --git a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java index 61c6b103d81..f320bc340ee 100644 --- a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java @@ -18,10 +18,12 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME; import com.google.common.base.MoreObjects; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerRegistry; import io.grpc.Status; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.util.GracefulSwitchLoadBalancer; @@ -43,9 +45,15 @@ final class WrrLocalityLoadBalancer extends LoadBalancer { private final XdsLogger logger; private final Helper helper; private final GracefulSwitchLoadBalancer switchLb; + private final LoadBalancerRegistry lbRegistry; WrrLocalityLoadBalancer(Helper helper) { + this(helper, LoadBalancerRegistry.getDefaultRegistry()); + } + + WrrLocalityLoadBalancer(Helper helper, LoadBalancerRegistry lbRegistry) { this.helper = checkNotNull(helper, "helper"); + this.lbRegistry = lbRegistry; switchLb = new GracefulSwitchLoadBalancer(helper); logger = XdsLogger.withLogId( InternalLogId.allocate("xds-wrr-locality-lb", helper.getAuthority())); @@ -88,7 +96,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { .setAttributes(resolvedAddresses.getAttributes().toBuilder() .discard(InternalXdsAttributes.ATTR_LOCALITY_WEIGHTS).build()).build(); - switchLb.switchTo(wrrLocalityConfig.childPolicy.getProvider()); + switchLb.switchTo(lbRegistry.getProvider(WEIGHTED_TARGET_POLICY_NAME)); switchLb.handleResolvedAddresses( resolvedAddresses.toBuilder() .setLoadBalancingPolicyConfig(new WeightedTargetConfig(weightedPolicySelections)) diff --git a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java index 61064345344..02dcfd21635 100644 --- a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java +++ b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java @@ -17,23 +17,30 @@ package io.grpc.xds; +import static com.google.common.truth.Truth.assertThat; import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; import static org.junit.Assert.assertEquals; +import com.github.xds.type.v3.TypedStruct; import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; import com.google.protobuf.Message; +import com.google.protobuf.Struct; import com.google.protobuf.UInt32Value; +import com.google.protobuf.Value; import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; import io.envoyproxy.envoy.config.core.v3.Address; import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; import io.envoyproxy.envoy.config.core.v3.ConfigSource; import io.envoyproxy.envoy.config.core.v3.HealthStatus; import io.envoyproxy.envoy.config.core.v3.SocketAddress; import io.envoyproxy.envoy.config.core.v3.TrafficDirection; +import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; @@ -53,12 +60,27 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; import io.grpc.InsecureServerCredentials; +import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; import io.grpc.NameResolverRegistry; import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; import io.grpc.netty.NettyServerBuilder; import io.grpc.stub.StreamObserver; import io.grpc.testing.protobuf.SimpleRequest; @@ -100,6 +122,7 @@ public class FakeControlPlaneXdsIntegrationTest { private Server controlPlane; private XdsTestControlPlaneService controlPlaneService; private XdsNameResolverProvider nameResolverProvider; + private MetadataLoadBalancerProvider metadataLoadBalancerProvider; protected int testServerPort = 0; protected int controlPlaneServicePort; @@ -135,10 +158,13 @@ public class FakeControlPlaneXdsIntegrationTest { */ @Before public void setUp() throws Exception { + ClientXdsClient.enableCustomLbConfig = true; startControlPlane(); nameResolverProvider = XdsNameResolverProvider.createForTest(SCHEME, defaultBootstrapOverride()); NameResolverRegistry.getDefaultRegistry().register(nameResolverProvider); + metadataLoadBalancerProvider = new MetadataLoadBalancerProvider(); + LoadBalancerRegistry.getDefaultRegistry().register(metadataLoadBalancerProvider); } @After @@ -156,6 +182,7 @@ public void tearDown() throws Exception { } } NameResolverRegistry.getDefaultRegistry().deregister(nameResolverProvider); + LoadBalancerRegistry.getDefaultRegistry().deregister(metadataLoadBalancerProvider); } @Test @@ -186,7 +213,108 @@ serverHostName, clientListener(serverHostName) assertEquals(goldenResponse, blockingStub.unaryRpc(request)); } + @Test + public void pingPong_metadataLoadBalancer() throws Exception { + String tcpListenerName = SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT; + String serverHostName = "test-server"; + controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, ImmutableMap.of( + tcpListenerName, serverListener(tcpListenerName), + serverHostName, clientListener(serverHostName) + )); + startServer(defaultBootstrapOverride()); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_RDS, + ImmutableMap.of(RDS_NAME, rds(serverHostName))); + + // Use the LoadBalancingPolicy to configure a custom LB that adds a header to server calls. + Policy metadataLbPolicy = Policy.newBuilder().setTypedExtensionConfig( + TypedExtensionConfig.newBuilder().setTypedConfig(Any.pack( + TypedStruct.newBuilder().setTypeUrl("type.googleapis.com/test.MetadataLoadBalancer") + .setValue(Struct.newBuilder() + .putFields("metadataKey", Value.newBuilder().setStringValue("foo").build()) + .putFields("metadataValue", Value.newBuilder().setStringValue("bar").build())) + .build()))).build(); + Policy wrrLocalityPolicy = Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder().setTypedConfig( + Any.pack(WrrLocality.newBuilder().setEndpointPickingPolicy( + LoadBalancingPolicy.newBuilder().addPolicies(metadataLbPolicy)).build()))).build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, + ImmutableMap.of(CLUSTER_NAME, cds().toBuilder().setLoadBalancingPolicy( + LoadBalancingPolicy.newBuilder() + .addPolicies(wrrLocalityPolicy)).build())); + + InetSocketAddress edsInetSocketAddress = (InetSocketAddress) server.getListenSockets().get(0); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, + ImmutableMap.of(EDS_NAME, eds(edsInetSocketAddress.getHostName(), + edsInetSocketAddress.getPort()))); + ManagedChannel channel = Grpc.newChannelBuilder(SCHEME + ":///" + serverHostName, + InsecureChannelCredentials.create()).build(); + ResponseHeaderClientInterceptor responseHeaderInterceptor + = new ResponseHeaderClientInterceptor(); + + // We add an interceptor to catch the response headers from the server. + blockingStub = SimpleServiceGrpc.newBlockingStub(channel) + .withInterceptors(responseHeaderInterceptor); + SimpleRequest request = SimpleRequest.newBuilder() + .build(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS!") + .build(); + assertEquals(goldenResponse, blockingStub.unaryRpc(request)); + + // Make sure we got back the header we configured the LB with. + assertThat(responseHeaderInterceptor.reponseHeaders.get( + Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER))).isEqualTo("bar"); + } + + // Captures response headers from the server. + private class ResponseHeaderClientInterceptor implements ClientInterceptor { + Metadata reponseHeaders; + + @Override + public ClientCall interceptCall(MethodDescriptor method, + CallOptions callOptions, Channel next) { + + return new SimpleForwardingClientCall(next.newCall(method, callOptions)) { + @Override + public void start(ClientCall.Listener responseListener, Metadata headers) { + super.start(new ForwardingClientCallListener() { + @Override + protected ClientCall.Listener delegate() { + return responseListener; + } + + @Override + public void onHeaders(Metadata headers) { + reponseHeaders = headers; + } + }, headers); + } + }; + } + } + private void startServer(Map bootstrapOverride) throws Exception { + ServerInterceptor metadataInterceptor = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata requestHeaders, ServerCallHandler next) { + logger.fine("Received following metadata: " + requestHeaders); + + return next.startCall(new SimpleForwardingServerCall(call) { + @Override + public void sendHeaders(Metadata responseHeaders) { + responseHeaders.merge(requestHeaders); + super.sendHeaders(responseHeaders); + } + + @Override + public void close(Status status, Metadata trailers) { + super.close(status, trailers); + } + }, requestHeaders); + } + }; + SimpleServiceGrpc.SimpleServiceImplBase simpleServiceImpl = new SimpleServiceGrpc.SimpleServiceImplBase() { @Override @@ -202,6 +330,7 @@ public void unaryRpc( XdsServerBuilder serverBuilder = XdsServerBuilder.forPort( 0, InsecureServerCredentials.create()) .addService(simpleServiceImpl) + .intercept(metadataInterceptor) .overrideBootstrapForTest(bootstrapOverride); server = serverBuilder.build().start(); testServerPort = server.getPort(); diff --git a/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java b/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java new file mode 100644 index 00000000000..f52e9c815c9 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/MetadataLoadBalancerProvider.java @@ -0,0 +1,173 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.ConnectivityState; +import io.grpc.LoadBalancer; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; +import io.grpc.LoadBalancer.PickSubchannelArgs; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; +import io.grpc.Metadata; +import io.grpc.NameResolver; +import io.grpc.Status; +import io.grpc.internal.JsonUtil; +import io.grpc.util.ForwardingLoadBalancer; +import io.grpc.util.ForwardingLoadBalancerHelper; +import java.util.Map; +import javax.annotation.Nonnull; + +/** + * A custom LB for testing purposes that simply delegates to round_robin and adds a metadata entry + * to each request. + */ +public class MetadataLoadBalancerProvider extends LoadBalancerProvider { + + @Override + public NameResolver.ConfigOrError parseLoadBalancingPolicyConfig( + Map rawLoadBalancingPolicyConfig) { + String metadataKey = JsonUtil.getString(rawLoadBalancingPolicyConfig, "metadataKey"); + if (metadataKey == null) { + return NameResolver.ConfigOrError.fromError( + Status.INVALID_ARGUMENT.withDescription("no 'metadataKey' defined")); + } + + String metadataValue = JsonUtil.getString(rawLoadBalancingPolicyConfig, "metadataValue"); + if (metadataValue == null) { + return NameResolver.ConfigOrError.fromError( + Status.INVALID_ARGUMENT.withDescription("no 'metadataValue' defined")); + } + + return NameResolver.ConfigOrError.fromConfig( + new MetadataLoadBalancerConfig(metadataKey, metadataValue)); + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + MetadataHelper metadataHelper = new MetadataHelper(helper); + return new MetadataLoadBalancer(metadataHelper, + LoadBalancerRegistry.getDefaultRegistry().getProvider("round_robin") + .newLoadBalancer(metadataHelper)); + } + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "test.MetadataLoadBalancer"; + } + + static class MetadataLoadBalancerConfig { + + final String metadataKey; + final String metadataValue; + + MetadataLoadBalancerConfig(String metadataKey, String metadataValue) { + this.metadataKey = metadataKey; + this.metadataValue = metadataValue; + } + } + + static class MetadataLoadBalancer extends ForwardingLoadBalancer { + + private final MetadataHelper helper; + private final LoadBalancer delegateLb; + + MetadataLoadBalancer(MetadataHelper helper, LoadBalancer delegateLb) { + this.helper = helper; + this.delegateLb = delegateLb; + } + + @Override + protected LoadBalancer delegate() { + return delegateLb; + } + + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + MetadataLoadBalancerConfig config + = (MetadataLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + helper.setMetadata(config.metadataKey, config.metadataValue); + delegateLb.handleResolvedAddresses(resolvedAddresses); + } + } + + /** + * Wraps the picker that is provided when the balancing change updates with the {@link + * MetadataPicker} that injects the metadata entry. + */ + static class MetadataHelper extends ForwardingLoadBalancerHelper { + + private final Helper delegateHelper; + private String metadataKey; + private String metadataValue; + + MetadataHelper(Helper delegateHelper) { + this.delegateHelper = delegateHelper; + } + + void setMetadata(String metadataKey, String metadataValue) { + this.metadataKey = metadataKey; + this.metadataValue = metadataValue; + } + + @Override + protected Helper delegate() { + return delegateHelper; + } + + @Override + public void updateBalancingState(@Nonnull ConnectivityState newState, + @Nonnull SubchannelPicker newPicker) { + delegateHelper.updateBalancingState(newState, + new MetadataPicker(newPicker, metadataKey, metadataValue)); + } + } + + /** + * Includes the rpc-behavior metadata entry on each subchannel pick. + */ + static class MetadataPicker extends SubchannelPicker { + + private final SubchannelPicker delegatePicker; + private final String metadataKey; + private final String metadataValue; + + MetadataPicker(SubchannelPicker delegatePicker, String metadataKey, String metadataValue) { + this.delegatePicker = delegatePicker; + this.metadataKey = metadataKey; + this.metadataValue = metadataValue; + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + args.getHeaders() + .put(Metadata.Key.of(metadataKey, Metadata.ASCII_STRING_MARSHALLER), metadataValue); + return delegatePicker.pickSubchannel(args); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java index f1474fcf150..29777a5284f 100644 --- a/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WrrLocalityLoadBalancerTest.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.never; @@ -34,7 +35,9 @@ import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedPolicySelection; import io.grpc.xds.WeightedTargetLoadBalancerProvider.WeightedTargetConfig; @@ -62,7 +65,11 @@ public class WrrLocalityLoadBalancerTest { public final MockitoRule mockito = MockitoJUnit.rule(); @Mock - private LoadBalancerProvider mockProvider; + private LoadBalancerProvider mockWeightedTargetProvider; + @Mock + private LoadBalancer mockWeightedTargetLb; + @Mock + private LoadBalancerProvider mockChildProvider; @Mock private LoadBalancer mockChildLb; @Mock @@ -80,12 +87,31 @@ public class WrrLocalityLoadBalancerTest { private final EquivalentAddressGroup eag = new EquivalentAddressGroup(mockSocketAddress); private WrrLocalityLoadBalancer loadBalancer; + private LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + }); @Before public void setUp() { - when(mockProvider.newLoadBalancer(isA(Helper.class))).thenReturn(mockChildLb); - when(mockProvider.getPolicyName()).thenReturn("round_robin"); - loadBalancer = new WrrLocalityLoadBalancer(mockHelper); + when(mockHelper.getSynchronizationContext()).thenReturn(syncContext); + + when(mockWeightedTargetProvider.newLoadBalancer(isA(Helper.class))).thenReturn( + mockWeightedTargetLb); + when(mockWeightedTargetProvider.getPolicyName()).thenReturn(WEIGHTED_TARGET_POLICY_NAME); + when(mockWeightedTargetProvider.isAvailable()).thenReturn(true); + lbRegistry.register(mockWeightedTargetProvider); + + when(mockChildProvider.newLoadBalancer(isA(Helper.class))).thenReturn(mockChildLb); + when(mockChildProvider.getPolicyName()).thenReturn("round_robin"); + lbRegistry.register(mockWeightedTargetProvider); + + loadBalancer = new WrrLocalityLoadBalancer(mockHelper, lbRegistry); } @Test @@ -93,7 +119,7 @@ public void handleResolvedAddresses() { // A two locality cluster with a mock child LB policy. Locality localityOne = Locality.create("region1", "zone1", "subzone1"); Locality localityTwo = Locality.create("region2", "zone2", "subzone2"); - PolicySelection childPolicy = new PolicySelection(mockProvider, null); + PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); // The child config is delivered wrapped in the wrr_locality config and the locality weights // in a ResolvedAddresses attribute. @@ -103,7 +129,7 @@ public void handleResolvedAddresses() { // Assert that the child policy and the locality weights were correctly mapped to a // WeightedTargetConfig. - verify(mockChildLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(mockWeightedTargetLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); Object config = resolvedAddressesCaptor.getValue().getLoadBalancingPolicyConfig(); assertThat(config).isInstanceOf(WeightedTargetConfig.class); WeightedTargetConfig wtConfig = (WeightedTargetConfig) config; @@ -117,7 +143,7 @@ public void handleResolvedAddresses() { @Test public void handleResolvedAddresses_noLocalityWeights() { // A two locality cluster with a mock child LB policy. - PolicySelection childPolicy = new PolicySelection(mockProvider, null); + PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); // The child config is delivered wrapped in the wrr_locality config and the locality weights // in a ResolvedAddresses attribute. @@ -143,20 +169,20 @@ public void handleNameResolutionError_noChildLb() { @Test public void handleNameResolutionError_withChildLb() { - deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockProvider, null)), + deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockChildProvider, null)), ImmutableMap.of( Locality.create("region", "zone", "subzone"), 1)); loadBalancer.handleNameResolutionError(Status.DEADLINE_EXCEEDED); verify(mockHelper, never()).updateBalancingState(isA(ConnectivityState.class), isA(ErrorPicker.class)); - verify(mockChildLb).handleNameResolutionError(Status.DEADLINE_EXCEEDED); + verify(mockWeightedTargetLb).handleNameResolutionError(Status.DEADLINE_EXCEEDED); } @Test public void localityWeightAttributeNotPropagated() { Locality locality = Locality.create("region1", "zone1", "subzone1"); - PolicySelection childPolicy = new PolicySelection(mockProvider, null); + PolicySelection childPolicy = new PolicySelection(mockChildProvider, null); WrrLocalityConfig wlConfig = new WrrLocalityConfig(childPolicy); Map localityWeights = ImmutableMap.of(locality, 1); @@ -164,27 +190,29 @@ public void localityWeightAttributeNotPropagated() { // Assert that the child policy and the locality weights were correctly mapped to a // WeightedTargetConfig. - verify(mockChildLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); + verify(mockWeightedTargetLb).handleResolvedAddresses(resolvedAddressesCaptor.capture()); assertThat(resolvedAddressesCaptor.getValue().getAttributes() .get(InternalXdsAttributes.ATTR_LOCALITY_WEIGHTS)).isNull(); } @Test public void shutdown() { - deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockProvider, null)), + deliverAddresses(new WrrLocalityConfig(new PolicySelection(mockChildProvider, null)), ImmutableMap.of( Locality.create("region", "zone", "subzone"), 1)); loadBalancer.shutdown(); - verify(mockChildLb).shutdown(); + verify(mockWeightedTargetLb).shutdown(); } @Test public void configEquality() { - WrrLocalityConfig configOne = new WrrLocalityConfig(new PolicySelection(mockProvider, null)); - WrrLocalityConfig configTwo = new WrrLocalityConfig(new PolicySelection(mockProvider, null)); + WrrLocalityConfig configOne = new WrrLocalityConfig( + new PolicySelection(mockChildProvider, null)); + WrrLocalityConfig configTwo = new WrrLocalityConfig( + new PolicySelection(mockChildProvider, null)); WrrLocalityConfig differentConfig = new WrrLocalityConfig( - new PolicySelection(mockProvider, "config")); + new PolicySelection(mockChildProvider, "config")); new EqualsTester().addEqualityGroup(configOne, configTwo).addEqualityGroup(differentConfig) .testEquals();