diff --git a/xds/csds/csds.go b/xds/csds/csds.go index d32bebac81b..1b54a3a4c6e 100644 --- a/xds/csds/csds.go +++ b/xds/csds/csds.go @@ -38,33 +38,17 @@ import ( "google.golang.org/grpc/grpclog" "google.golang.org/grpc/status" "google.golang.org/grpc/xds/internal/xdsclient" - "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" "google.golang.org/protobuf/types/known/timestamppb" _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register v2 xds_client. _ "google.golang.org/grpc/xds/internal/xdsclient/v3" // Register v3 xds_client. ) -// xdsClient contains methods from xdsClient.Client which are used by -// the server. This is useful for overriding in unit tests. -type xdsClient interface { - DumpLDS() (string, map[string]xdsclient.UpdateWithMD) - DumpRDS() (string, map[string]xdsclient.UpdateWithMD) - DumpCDS() (string, map[string]xdsclient.UpdateWithMD) - DumpEDS() (string, map[string]xdsclient.UpdateWithMD) - BootstrapConfig() *bootstrap.Config - Close() -} - var ( logger = grpclog.Component("xds") - newXDSClient = func() xdsClient { + newXDSClient = func() xdsclient.XDSClient { c, err := xdsclient.New() if err != nil { - // If err is not nil, c is a typed nil (of type *xdsclient.Client). - // If c is returned and assigned to the xdsClient field in the CSDS - // server, the nil checks in the handlers will not handle it - // properly. logger.Warningf("failed to create xds client: %v", err) return nil } @@ -76,7 +60,7 @@ var ( type ClientStatusDiscoveryServer struct { // xdsClient will always be the same in practice. But we keep a copy in each // server instance for testing. - xdsClient xdsClient + xdsClient xdsclient.XDSClient } // NewClientStatusDiscoveryServer returns an implementation of the CSDS server that can be diff --git a/xds/csds/csds_test.go b/xds/csds/csds_test.go index 7f0e90bebc1..98dc93e8671 100644 --- a/xds/csds/csds_test.go +++ b/xds/csds/csds_test.go @@ -59,13 +59,6 @@ const ( defaultTestTimeout = 10 * time.Second ) -type xdsClientWithWatch interface { - WatchListener(string, func(xdsclient.ListenerUpdate, error)) func() - WatchRouteConfig(string, func(xdsclient.RouteConfigUpdate, error)) func() - WatchCluster(string, func(xdsclient.ClusterUpdate, error)) func() - WatchEndpoints(string, func(xdsclient.EndpointsUpdate, error)) func() -} - var cmpOpts = cmp.Options{ cmpopts.EquateEmpty(), cmp.Comparer(func(a, b *timestamppb.Timestamp) bool { return true }), @@ -250,7 +243,7 @@ func TestCSDS(t *testing.T) { } } -func commonSetup(t *testing.T) (xdsClientWithWatch, *e2e.ManagementServer, string, v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient, func()) { +func commonSetup(t *testing.T) (xdsclient.XDSClient, *e2e.ManagementServer, string, v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient, func()) { t.Helper() // Spin up a xDS management server on a local port. @@ -275,7 +268,7 @@ func commonSetup(t *testing.T) (xdsClientWithWatch, *e2e.ManagementServer, strin t.Fatalf("failed to create xds client: %v", err) } oldNewXDSClient := newXDSClient - newXDSClient = func() xdsClient { return xdsC } + newXDSClient = func() xdsclient.XDSClient { return xdsC } // Initialize an gRPC server and register CSDS on it. server := grpc.NewServer() @@ -635,7 +628,7 @@ func protoToJSON(p proto.Message) string { func TestCSDSNoXDSClient(t *testing.T) { oldNewXDSClient := newXDSClient - newXDSClient = func() xdsClient { return nil } + newXDSClient = func() xdsclient.XDSClient { return nil } defer func() { newXDSClient = oldNewXDSClient }() // Initialize an gRPC server and register CSDS on it. diff --git a/xds/googledirectpath/googlec2p.go b/xds/googledirectpath/googlec2p.go index af487ec4a73..0c2f984fbcb 100644 --- a/xds/googledirectpath/googlec2p.go +++ b/xds/googledirectpath/googlec2p.go @@ -62,15 +62,11 @@ const ( dnsName, xdsName = "dns", "xds" ) -type xdsClient interface { - Close() -} - // For overriding in unittests. var ( onGCE = googlecloud.OnGCE - newClientWithConfig = func(config *bootstrap.Config) (xdsClient, error) { + newClientWithConfig = func(config *bootstrap.Config) (xdsclient.XDSClient, error) { return xdsclient.NewWithConfig(config) } @@ -139,7 +135,7 @@ func (c2pResolverBuilder) Scheme() string { type c2pResolver struct { resolver.Resolver - client xdsClient + client xdsclient.XDSClient } func (r *c2pResolver) Close() { diff --git a/xds/googledirectpath/googlec2p_test.go b/xds/googledirectpath/googlec2p_test.go index fb68fa23a1d..8f98d3159d3 100644 --- a/xds/googledirectpath/googlec2p_test.go +++ b/xds/googledirectpath/googlec2p_test.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/resolver" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/structpb" @@ -130,6 +131,7 @@ func TestBuildNotOnGCE(t *testing.T) { } type testXDSClient struct { + xdsclient.XDSClient closed chan struct{} } @@ -177,7 +179,7 @@ func TestBuildXDS(t *testing.T) { configCh := make(chan *bootstrap.Config, 1) oldNewClient := newClientWithConfig - newClientWithConfig = func(config *bootstrap.Config) (xdsClient, error) { + newClientWithConfig = func(config *bootstrap.Config) (xdsclient.XDSClient, error) { configCh <- config return tXDSClient, nil } diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer.go b/xds/internal/balancer/cdsbalancer/cdsbalancer.go index 7278c624361..a710e498316 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer.go @@ -36,7 +36,6 @@ import ( "google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/xds/internal/balancer/edsbalancer" "google.golang.org/grpc/xds/internal/xdsclient" - "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const ( @@ -59,7 +58,6 @@ var ( // not deal with subConns. return builder.Build(cc, opts), nil } - newXDSClient func() (xdsClient, error) buildProvider = buildProviderFunc ) @@ -84,17 +82,6 @@ func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Bal } b.logger = prefixLogger((b)) b.logger.Infof("Created") - - if newXDSClient != nil { - // For tests - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.xdsClient = client - } - var creds credentials.TransportCredentials switch { case opts.DialCreds != nil: @@ -137,14 +124,6 @@ func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, err return &cfg, nil } -// xdsClient contains methods from xdsClient.Client which are used by -// the cdsBalancer. This will be faked out in unittests. -type xdsClient interface { - WatchCluster(string, func(xdsclient.ClusterUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - // ccUpdate wraps a clientConn update received from gRPC (pushed from the // xdsResolver). A valid clusterName causes the cdsBalancer to register a CDS // watcher with the xdsClient, while a non-nil error causes it to cancel the @@ -184,7 +163,7 @@ type cdsBalancer struct { ccw *ccWrapper // ClientConn interface passed to child LB. bOpts balancer.BuildOptions // BuildOptions passed to child LB. updateCh *buffer.Unbounded // Channel for gRPC and xdsClient updates. - xdsClient xdsClient // xDS client to watch Cluster resource. + xdsClient xdsclient.XDSClient // xDS client to watch Cluster resource. cancelWatch func() // Cluster watch cancel func. edsLB balancer.Balancer // EDS child policy. clusterToWatch string @@ -361,15 +340,8 @@ func (b *cdsBalancer) handleWatchUpdate(update *watchUpdate) { lbCfg.LrsLoadReportingServerName = new(string) } - resolverState := resolver.State{} - // Include the xds client for the child LB policies to use. For unit - // tests, b.xdsClient may not be a full *xdsclient.Client, but it will - // always be in production. - if c, ok := b.xdsClient.(*xdsclient.Client); ok { - resolverState = xdsclient.SetClient(resolverState, c) - } ccState := balancer.ClientConnState{ - ResolverState: resolverState, + ResolverState: xdsclient.SetClient(resolver.State{}, b.xdsClient), BalancerConfig: lbCfg, } if err := b.edsLB.UpdateClientConnState(ccState); err != nil { @@ -407,9 +379,6 @@ func (b *cdsBalancer) run() { b.edsLB.Close() b.edsLB = nil } - if newXDSClient != nil { - b.xdsClient.Close() - } if b.cachedRoot != nil { b.cachedRoot.Close() } diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go index 9964b9de925..067bc2b0536 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go @@ -133,11 +133,7 @@ func (p *fakeProvider) Close() { // xDSCredentials. func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) { t.Helper() - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { return xdsC, nil } - builder := balancer.Get(cdsName) if builder == nil { t.Fatalf("balancer.Get(%q) returned nil", cdsName) @@ -164,7 +160,7 @@ func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDS } // Push a ClientConnState update to the CDS balancer with a cluster name. - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } @@ -181,8 +177,8 @@ func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDS } return xdsC, cdsB.(*cdsBalancer), edsB, tcc, func() { - newXDSClient = oldNewXDSClient newEDSBalancer = oldEDSBalancerBuilder + xdsC.Close() } } diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go index 5c5161807be..f36117620e6 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go @@ -28,7 +28,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal" @@ -129,7 +128,10 @@ func (tb *testEDSBalancer) waitForClientConnUpdate(ctx context.Context, wantCCS return err } gotCCS := ccs.(balancer.ClientConnState) - if !cmp.Equal(gotCCS, wantCCS, cmpopts.IgnoreUnexported(attributes.Attributes{})) { + if xdsclient.FromResolverState(gotCCS.ResolverState) == nil { + return fmt.Errorf("want resolver state with XDSClient attached, got one without") + } + if !cmp.Equal(gotCCS, wantCCS, cmpopts.IgnoreFields(resolver.State{}, "Attributes")) { return fmt.Errorf("received ClientConnState: %+v, want %+v", gotCCS, wantCCS) } return nil @@ -173,7 +175,7 @@ func (tb *testEDSBalancer) waitForClose(ctx context.Context) error { // cdsCCS is a helper function to construct a good update passed from the // xdsResolver to the cdsBalancer. -func cdsCCS(cluster string) balancer.ClientConnState { +func cdsCCS(cluster string, xdsC xdsclient.XDSClient) balancer.ClientConnState { const cdsLBConfig = `{ "loadBalancingConfig":[ { @@ -185,9 +187,9 @@ func cdsCCS(cluster string) balancer.ClientConnState { }` jsonSC := fmt.Sprintf(cdsLBConfig, cluster) return balancer.ClientConnState{ - ResolverState: resolver.State{ + ResolverState: xdsclient.SetClient(resolver.State{ ServiceConfig: internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)(jsonSC), - }, + }, xdsC), BalancerConfig: &lbConfig{ClusterName: clusterName}, } } @@ -211,11 +213,7 @@ func edsCCS(service string, countMax *uint32, enableLRS bool) balancer.ClientCon // newEDSBalancer function to return it), and also returns a cleanup function. func setup(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) { t.Helper() - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { return xdsC, nil } - builder := balancer.Get(cdsName) if builder == nil { t.Fatalf("balancer.Get(%q) returned nil", cdsName) @@ -232,7 +230,7 @@ func setup(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *x return xdsC, cdsB.(*cdsBalancer), edsB, tcc, func() { newEDSBalancer = oldEDSBalancerBuilder - newXDSClient = oldNewXDSClient + xdsC.Close() } } @@ -242,7 +240,7 @@ func setupWithWatch(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBal t.Helper() xdsC, cdsB, edsB, tcc, cancel := setup(t) - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } @@ -262,6 +260,9 @@ func setupWithWatch(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBal // cdsBalancer with different inputs and verifies that the CDS watch API on the // provided xdsClient is invoked appropriately. func (s) TestUpdateClientConnState(t *testing.T) { + xdsC := fakeclient.NewClient() + defer xdsC.Close() + tests := []struct { name string ccs balancer.ClientConnState @@ -280,14 +281,14 @@ func (s) TestUpdateClientConnState(t *testing.T) { }, { name: "happy-good-case", - ccs: cdsCCS(clusterName), + ccs: cdsCCS(clusterName, xdsC), wantCluster: clusterName, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - xdsC, cdsB, _, _, cancel := setup(t) + _, cdsB, _, _, cancel := setup(t) defer func() { cancel() cdsB.Close() @@ -324,7 +325,7 @@ func (s) TestUpdateClientConnStateWithSameState(t *testing.T) { }() // This is the same clientConn update sent in setupWithWatch(). - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } // The above update should not result in a new watch being registered. @@ -660,7 +661,7 @@ func (s) TestClose(t *testing.T) { // Make sure that the UpdateClientConnState() method on the CDS balancer // returns error. - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != errBalancerClosed { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != errBalancerClosed { t.Fatalf("UpdateClientConnState() after close returned %v, want %v", err, errBalancerClosed) } diff --git a/xds/internal/balancer/cdsbalancer/cluster_handler.go b/xds/internal/balancer/cdsbalancer/cluster_handler.go index c38d1a6c31a..09d945cd0c3 100644 --- a/xds/internal/balancer/cdsbalancer/cluster_handler.go +++ b/xds/internal/balancer/cdsbalancer/cluster_handler.go @@ -40,7 +40,7 @@ type clusterHandler struct { // CDS Balancer cares about is the most recent update. updateChannel chan clusterHandlerUpdate - xdsClient xdsClient + xdsClient xdsclient.XDSClient } func (ch *clusterHandler) updateRootCluster(rootClusterName string) { @@ -112,7 +112,7 @@ type clusterNode struct { // CreateClusterNode creates a cluster node from a given clusterName. This will // also start the watch for that cluster. -func createClusterNode(clusterName string, xdsClient xdsClient, topLevelHandler *clusterHandler) *clusterNode { +func createClusterNode(clusterName string, xdsClient xdsclient.XDSClient, topLevelHandler *clusterHandler) *clusterNode { c := &clusterNode{ clusterHandler: topLevelHandler, } diff --git a/xds/internal/balancer/clusterimpl/balancer_test.go b/xds/internal/balancer/clusterimpl/balancer_test.go index 404dfb22d00..ab3613bec31 100644 --- a/xds/internal/balancer/clusterimpl/balancer_test.go +++ b/xds/internal/balancer/clusterimpl/balancer_test.go @@ -74,9 +74,7 @@ func init() { func TestDropByCategory(t *testing.T) { defer xdsclient.ClearCounterForTesting(testClusterName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) @@ -89,9 +87,7 @@ func TestDropByCategory(t *testing.T) { dropDenominator = 2 ) if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, @@ -176,9 +172,7 @@ func TestDropByCategory(t *testing.T) { dropDenominator2 = 4 ) if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, @@ -232,9 +226,7 @@ func TestDropByCategory(t *testing.T) { func TestDropCircuitBreaking(t *testing.T) { defer xdsclient.ClearCounterForTesting(testClusterName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) @@ -243,9 +235,7 @@ func TestDropCircuitBreaking(t *testing.T) { var maxRequest uint32 = 50 if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, @@ -344,9 +334,7 @@ func TestDropCircuitBreaking(t *testing.T) { func TestPickerUpdateAfterClose(t *testing.T) { defer xdsclient.ClearCounterForTesting(testClusterName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) @@ -354,9 +342,7 @@ func TestPickerUpdateAfterClose(t *testing.T) { var maxRequest uint32 = 50 if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, @@ -389,9 +375,7 @@ func TestPickerUpdateAfterClose(t *testing.T) { func TestClusterNameInAddressAttributes(t *testing.T) { defer xdsclient.ClearCounterForTesting(testClusterName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) @@ -399,9 +383,7 @@ func TestClusterNameInAddressAttributes(t *testing.T) { defer b.Close() if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, @@ -450,9 +432,7 @@ func TestClusterNameInAddressAttributes(t *testing.T) { const testClusterName2 = "test-cluster-2" var addr2 = resolver.Address{Addr: "2.2.2.2"} if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: []resolver.Address{addr2}, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: []resolver.Address{addr2}}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName2, EDSServiceName: testServiceName, @@ -480,9 +460,7 @@ func TestClusterNameInAddressAttributes(t *testing.T) { func TestReResolution(t *testing.T) { defer xdsclient.ClearCounterForTesting(testClusterName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) @@ -490,9 +468,7 @@ func TestReResolution(t *testing.T) { defer b.Close() if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, diff --git a/xds/internal/balancer/clusterimpl/clusterimpl.go b/xds/internal/balancer/clusterimpl/clusterimpl.go index 9f3acafbc92..f5fa7c12589 100644 --- a/xds/internal/balancer/clusterimpl/clusterimpl.go +++ b/xds/internal/balancer/clusterimpl/clusterimpl.go @@ -52,8 +52,6 @@ func init() { balancer.Register(bb{}) } -var newXDSClient func() (xdsClient, error) - type bb struct{} func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { @@ -67,18 +65,7 @@ func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Ba requestCountMax: defaultRequestCountMax, } b.logger = prefixLogger(b) - - if newXDSClient != nil { - // For tests - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.xdsClient = client - } go b.run() - b.logger.Infof("Created") return b } @@ -91,13 +78,6 @@ func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, err return parseConfig(c) } -// xdsClient contains only the xds_client methods needed by LRS -// balancer. It's defined so we can override xdsclient in tests. -type xdsClient interface { - ReportLoad(server string) (*load.Store, func()) - Close() -} - type clusterImplBalancer struct { balancer.ClientConn @@ -115,7 +95,7 @@ type clusterImplBalancer struct { bOpts balancer.BuildOptions logger *grpclog.PrefixLogger - xdsClient xdsClient + xdsClient xdsclient.XDSClient config *LBConfig childLB balancer.Balancer @@ -328,9 +308,6 @@ func (b *clusterImplBalancer) Close() { b.childLB.Close() b.childLB = nil } - if newXDSClient != nil { - b.xdsClient.Close() - } <-b.done.Done() b.logger.Infof("Shutdown") } diff --git a/xds/internal/balancer/edsbalancer/eds.go b/xds/internal/balancer/edsbalancer/eds.go index ffc46cea469..ea11b2f8a25 100644 --- a/xds/internal/balancer/edsbalancer/eds.go +++ b/xds/internal/balancer/edsbalancer/eds.go @@ -41,19 +41,10 @@ import ( const edsName = "eds_experimental" -// xdsClient contains only the xds_client methods needed by EDS -// balancer. It's defined so we can override xdsclient.New function in tests. -type xdsClient interface { - WatchEndpoints(clusterName string, edsCb func(xdsclient.EndpointsUpdate, error)) (cancel func()) - ReportLoad(server string) (loadStore *load.Store, cancel func()) - Close() -} - var ( newEDSBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions, enqueueState func(priorityType, balancer.State), lw load.PerClusterReporter, logger *grpclog.PrefixLogger) edsBalancerImplInterface { return newEDSBalancerImpl(cc, opts, enqueueState, lw, logger) } - newXDSClient func() (xdsClient, error) ) func init() { @@ -74,17 +65,6 @@ func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Bal config: &EDSConfig{}, } x.logger = prefixLogger(x) - - if newXDSClient != nil { - // For tests - client, err := newXDSClient() - if err != nil { - x.logger.Errorf("xds: failed to create xds-client: %v", err) - return nil - } - x.xdsClient = client - } - x.edsImpl = newEDSBalancer(x.cc, opts, x.enqueueChildBalancerState, x.loadWrapper, x.logger) x.logger.Infof("Created") go x.run() @@ -144,7 +124,7 @@ type edsBalancer struct { xdsClientUpdate chan *edsUpdate childPolicyUpdate *buffer.Unbounded - xdsClient xdsClient + xdsClient xdsclient.XDSClient loadWrapper *loadstore.Wrapper config *EDSConfig // may change when passed a different service config edsImpl edsBalancerImplInterface @@ -174,9 +154,6 @@ func (b *edsBalancer) run() { b.edsImpl.updateState(u.priority, u.s) case <-b.closed.Done(): b.cancelWatch() - if newXDSClient != nil { - b.xdsClient.Close() - } b.edsImpl.close() b.logger.Infof("Shutdown") b.done.Fire() diff --git a/xds/internal/balancer/edsbalancer/eds_test.go b/xds/internal/balancer/edsbalancer/eds_test.go index 7e16076751a..c20e8206b9e 100644 --- a/xds/internal/balancer/edsbalancer/eds_test.go +++ b/xds/internal/balancer/edsbalancer/eds_test.go @@ -255,8 +255,6 @@ func waitForNewEDSLB(ctx context.Context, ch *testutils.Channel) (*fakeEDSBalanc // cleanup. func setup(edsLBCh *testutils.Channel) (*fakeclient.Client, func()) { xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar) - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { return xdsC, nil } origNewEDSBalancer := newEDSBalancer newEDSBalancer = func(cc balancer.ClientConn, _ balancer.BuildOptions, _ func(priorityType, balancer.State), _ load.PerClusterReporter, _ *grpclog.PrefixLogger) edsBalancerImplInterface { @@ -266,7 +264,7 @@ func setup(edsLBCh *testutils.Channel) (*fakeclient.Client, func()) { } return xdsC, func() { newEDSBalancer = origNewEDSBalancer - newXDSClient = oldNewXDSClient + xdsC.Close() } } @@ -348,6 +346,7 @@ func (s) TestConfigChildPolicyUpdate(t *testing.T) { Config: json.RawMessage("{}"), } if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{ ChildPolicy: lbCfgA, ClusterName: testEDSClusterName, @@ -377,6 +376,7 @@ func (s) TestConfigChildPolicyUpdate(t *testing.T) { Config: json.RawMessage("{}"), } if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{ ChildPolicy: lbCfgB, ClusterName: testEDSClusterName, @@ -421,6 +421,7 @@ func (s) TestSubConnStateChange(t *testing.T) { } if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, }); err != nil { t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) @@ -467,6 +468,7 @@ func (s) TestErrorFromXDSClientUpdate(t *testing.T) { } if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, }); err != nil { t.Fatal(err) @@ -511,6 +513,7 @@ func (s) TestErrorFromXDSClientUpdate(t *testing.T) { // An update with the same service name should not trigger a new watch. if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, }); err != nil { t.Fatal(err) @@ -549,6 +552,7 @@ func (s) TestErrorFromResolver(t *testing.T) { } if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, }); err != nil { t.Fatal(err) @@ -589,6 +593,7 @@ func (s) TestErrorFromResolver(t *testing.T) { // An update with the same service name should trigger a new watch, because // the previous watch was canceled. if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, }); err != nil { t.Fatal(err) @@ -640,6 +645,7 @@ func (s) TestClientWatchEDS(t *testing.T) { defer cancel() // If eds service name is not set, should watch for cluster name. if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{ClusterName: "cluster-1"}, }); err != nil { t.Fatal(err) @@ -651,6 +657,7 @@ func (s) TestClientWatchEDS(t *testing.T) { // Update with an non-empty edsServiceName should trigger an EDS watch for // the same. if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{EDSServiceName: "foobar-1"}, }); err != nil { t.Fatal(err) @@ -664,6 +671,7 @@ func (s) TestClientWatchEDS(t *testing.T) { // registered watch will be cancelled, which will result in an EDS request // with no resource names being sent to the server. if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{EDSServiceName: "foobar-2"}, }); err != nil { t.Fatal(err) @@ -677,7 +685,7 @@ func (s) TestClientWatchEDS(t *testing.T) { // service name from an update's config. func (s) TestCounterUpdate(t *testing.T) { edsLBCh := testutils.NewChannel() - _, cleanup := setup(edsLBCh) + xdsC, cleanup := setup(edsLBCh) defer cleanup() builder := balancer.Get(edsName) @@ -690,6 +698,7 @@ func (s) TestCounterUpdate(t *testing.T) { var testCountMax uint32 = 100 // Update should trigger counter update with provided service name. if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{ ClusterName: "foobar-1", MaxConcurrentRequests: &testCountMax, @@ -724,6 +733,7 @@ func (s) TestClusterNameUpdateInAddressAttributes(t *testing.T) { // Update should trigger counter update with provided service name. if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{ ClusterName: "foobar-1", }, @@ -743,6 +753,7 @@ func (s) TestClusterNameUpdateInAddressAttributes(t *testing.T) { // Update should trigger counter update with provided service name. if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{ ClusterName: "foobar-2", }, diff --git a/xds/internal/balancer/edsbalancer/xds_lrs_test.go b/xds/internal/balancer/edsbalancer/xds_lrs_test.go index d5b40dd98d3..3dcbf5e259c 100644 --- a/xds/internal/balancer/edsbalancer/xds_lrs_test.go +++ b/xds/internal/balancer/edsbalancer/xds_lrs_test.go @@ -25,7 +25,9 @@ import ( "testing" "google.golang.org/grpc/balancer" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" ) // TestXDSLoadReporting verifies that the edsBalancer starts the loadReport @@ -33,9 +35,7 @@ import ( // server (empty string). func (s) TestXDSLoadReporting(t *testing.T) { xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(edsName) edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) @@ -45,6 +45,7 @@ func (s) TestXDSLoadReporting(t *testing.T) { defer edsB.Close() if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), BalancerConfig: &EDSConfig{ EDSServiceName: testEDSClusterName, LrsLoadReportingServerName: new(string), diff --git a/xds/internal/balancer/lrs/balancer.go b/xds/internal/balancer/lrs/balancer.go index 75a8cbb0dd7..ed7fb38c854 100644 --- a/xds/internal/balancer/lrs/balancer.go +++ b/xds/internal/balancer/lrs/balancer.go @@ -36,8 +36,6 @@ func init() { balancer.Register(bb{}) } -var newXDSClient func() (xdsClient, error) - // Name is the name of the LRS balancer. const Name = "lrs_experimental" @@ -50,17 +48,6 @@ func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Bal } b.logger = prefixLogger(b) b.logger.Infof("Created") - - if newXDSClient != nil { - // For tests - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.xdsClient = newXDSClientWrapper(client) - } - return b } @@ -169,15 +156,8 @@ func (ccw *ccWrapper) UpdateState(s balancer.State) { ccw.ClientConn.UpdateState(s) } -// xdsClient contains only the xds_client methods needed by LRS -// balancer. It's defined so we can override xdsclient in tests. -type xdsClient interface { - ReportLoad(server string) (*load.Store, func()) - Close() -} - type xdsClientWrapper struct { - c xdsClient + c xdsclient.XDSClient cancelLoadReport func() clusterName string edsServiceName string @@ -187,7 +167,7 @@ type xdsClientWrapper struct { loadWrapper *loadstore.Wrapper } -func newXDSClientWrapper(c xdsClient) *xdsClientWrapper { +func newXDSClientWrapper(c xdsclient.XDSClient) *xdsClientWrapper { return &xdsClientWrapper{ c: c, loadWrapper: loadstore.NewWrapper(), @@ -256,7 +236,4 @@ func (w *xdsClientWrapper) close() { w.cancelLoadReport() w.cancelLoadReport = nil } - if newXDSClient != nil { - w.c.Close() - } } diff --git a/xds/internal/balancer/lrs/balancer_test.go b/xds/internal/balancer/lrs/balancer_test.go index 9ffa2894dad..c0ec9cc41dd 100644 --- a/xds/internal/balancer/lrs/balancer_test.go +++ b/xds/internal/balancer/lrs/balancer_test.go @@ -35,6 +35,7 @@ import ( xdsinternal "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" ) const defaultTestTimeout = 1 * time.Second @@ -55,9 +56,7 @@ var ( // server (empty string). func TestLoadReporting(t *testing.T) { xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) @@ -65,9 +64,7 @@ func TestLoadReporting(t *testing.T) { defer lrsB.Close() if err := lrsB.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ ClusterName: testClusterName, EDSServiceName: testServiceName, diff --git a/xds/internal/resolver/watch_service.go b/xds/internal/resolver/watch_service.go index 591cc383393..bea5bbcda14 100644 --- a/xds/internal/resolver/watch_service.go +++ b/xds/internal/resolver/watch_service.go @@ -54,7 +54,7 @@ type ldsConfig struct { // Note that during race (e.g. an xDS response is received while the user is // calling cancel()), there's a small window where the callback can be called // after the watcher is canceled. The caller needs to handle this case. -func watchService(c xdsClient, serviceName string, cb func(serviceUpdate, error), logger *grpclog.PrefixLogger) (cancel func()) { +func watchService(c xdsclient.XDSClient, serviceName string, cb func(serviceUpdate, error), logger *grpclog.PrefixLogger) (cancel func()) { w := &serviceUpdateWatcher{ logger: logger, c: c, @@ -70,7 +70,7 @@ func watchService(c xdsClient, serviceName string, cb func(serviceUpdate, error) // callback at the right time. type serviceUpdateWatcher struct { logger *grpclog.PrefixLogger - c xdsClient + c xdsclient.XDSClient serviceName string ldsCancel func() serviceCb func(serviceUpdate, error) diff --git a/xds/internal/resolver/xds_resolver.go b/xds/internal/resolver/xds_resolver.go index a6a013698ac..19ee01773e8 100644 --- a/xds/internal/resolver/xds_resolver.go +++ b/xds/internal/resolver/xds_resolver.go @@ -27,10 +27,8 @@ import ( "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/pretty" - "google.golang.org/grpc/resolver" - "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" - iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/xds/internal/xdsclient" ) @@ -41,21 +39,21 @@ const xdsScheme = "xds" // the same time. func NewBuilder(config []byte) (resolver.Builder, error) { return &xdsResolverBuilder{ - newXDSClient: func() (xdsClient, error) { + newXDSClient: func() (xdsclient.XDSClient, error) { return xdsclient.NewClientWithBootstrapContents(config) }, }, nil } // For overriding in unittests. -var newXDSClient = func() (xdsClient, error) { return xdsclient.New() } +var newXDSClient = func() (xdsclient.XDSClient, error) { return xdsclient.New() } func init() { resolver.Register(&xdsResolverBuilder{}) } type xdsResolverBuilder struct { - newXDSClient func() (xdsClient, error) + newXDSClient func() (xdsclient.XDSClient, error) } // Build helps implement the resolver.Builder interface. @@ -119,15 +117,6 @@ func (*xdsResolverBuilder) Scheme() string { return xdsScheme } -// xdsClient contains methods from xdsClient.Client which are used by -// the resolver. This will be faked out in unittests. -type xdsClient interface { - WatchListener(serviceName string, cb func(xdsclient.ListenerUpdate, error)) func() - WatchRouteConfig(routeName string, cb func(xdsclient.RouteConfigUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - // suWithError wraps the ServiceUpdate and error received through a watch API // callback, so that it can pushed onto the update channel as a single entity. type suWithError struct { @@ -149,7 +138,7 @@ type xdsResolver struct { logger *grpclog.PrefixLogger // The underlying xdsClient which performs all xDS requests and responses. - client xdsClient + client xdsclient.XDSClient // A channel for the watch API callback to write service updates on to. The // updates are read by the run goroutine and passed on to the ClientConn. updateCh chan suWithError @@ -196,14 +185,7 @@ func (r *xdsResolver) sendNewServiceConfig(cs *configSelector) bool { state := iresolver.SetConfigSelector(resolver.State{ ServiceConfig: r.cc.ParseServiceConfig(string(sc)), }, cs) - - // Include the xds client for the LB policies to use. For unit tests, - // r.client may not be a full *xdsclient.Client, but it will always be in - // production. - if c, ok := r.client.(*xdsclient.Client); ok { - state = xdsclient.SetClient(state, c) - } - r.cc.UpdateState(state) + r.cc.UpdateState(xdsclient.SetClient(state, r.client)) return true } diff --git a/xds/internal/resolver/xds_resolver_test.go b/xds/internal/resolver/xds_resolver_test.go index d588ff157cd..a4192099827 100644 --- a/xds/internal/resolver/xds_resolver_test.go +++ b/xds/internal/resolver/xds_resolver_test.go @@ -114,19 +114,19 @@ func newTestClientConn() *testClientConn { func (s) TestResolverBuilder(t *testing.T) { tests := []struct { name string - xdsClientFunc func() (xdsClient, error) + xdsClientFunc func() (xdsclient.XDSClient, error) wantErr bool }{ { name: "simple-good", - xdsClientFunc: func() (xdsClient, error) { + xdsClientFunc: func() (xdsclient.XDSClient, error) { return fakeclient.NewClient(), nil }, wantErr: false, }, { name: "newXDSClient-throws-error", - xdsClientFunc: func() (xdsClient, error) { + xdsClientFunc: func() (xdsclient.XDSClient, error) { return nil, errors.New("newXDSClient-throws-error") }, wantErr: true, @@ -167,7 +167,7 @@ func (s) TestResolverBuilder_xdsCredsBootstrapMismatch(t *testing.T) { // Fake out the xdsClient creation process by providing a fake, which does // not have any certificate provider configuration. oldClientMaker := newXDSClient - newXDSClient = func() (xdsClient, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { fc := fakeclient.NewClient() fc.SetBootstrapConfig(&bootstrap.Config{}) return fc, nil @@ -194,7 +194,7 @@ func (s) TestResolverBuilder_xdsCredsBootstrapMismatch(t *testing.T) { } type setupOpts struct { - xdsClientFunc func() (xdsClient, error) + xdsClientFunc func() (xdsclient.XDSClient, error) } func testSetup(t *testing.T, opts setupOpts) (*xdsResolver, *testClientConn, func()) { @@ -254,7 +254,7 @@ func waitForWatchRouteConfig(ctx context.Context, t *testing.T, xdsC *fakeclient func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() @@ -286,7 +286,7 @@ func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { func (s) TestXDSResolverCloseClosesXDSClient(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, _, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() xdsR.Close() @@ -300,7 +300,7 @@ func (s) TestXDSResolverCloseClosesXDSClient(t *testing.T) { func (s) TestXDSResolverBadServiceUpdate(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -326,7 +326,7 @@ func (s) TestXDSResolverBadServiceUpdate(t *testing.T) { func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -460,7 +460,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() defer xdsR.Close() @@ -520,7 +520,7 @@ func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { func (s) TestXDSResolverRemovedResource(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() defer xdsR.Close() @@ -628,7 +628,7 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { func (s) TestXDSResolverWRR(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -689,7 +689,7 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { defer func(old bool) { env.TimeoutSupport = old }(env.TimeoutSupport) xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -792,7 +792,7 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -941,7 +941,7 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -995,7 +995,7 @@ func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -1041,7 +1041,7 @@ func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { func (s) TestXDSResolverMultipleLDSUpdates(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -1216,7 +1216,7 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) { t.Run(tc.name, func(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClient, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() diff --git a/xds/internal/testutils/fakeclient/client.go b/xds/internal/testutils/fakeclient/client.go index 37e84f998b9..2538b59255c 100644 --- a/xds/internal/testutils/fakeclient/client.go +++ b/xds/internal/testutils/fakeclient/client.go @@ -32,6 +32,11 @@ import ( // Client is a fake implementation of an xds client. It exposes a bunch of // channels to signal the occurrence of various events. type Client struct { + // Embed XDSClient so this fake client implements the interface, but it's + // never set (it's always nil). This may cause nil panic since not all the + // methods are implemented. + xdsclient.XDSClient + name string ldsWatchCh *testutils.Channel rdsWatchCh *testutils.Channel diff --git a/xds/internal/xdsclient/attributes.go b/xds/internal/xdsclient/attributes.go index 99060177e1e..d2357df0727 100644 --- a/xds/internal/xdsclient/attributes.go +++ b/xds/internal/xdsclient/attributes.go @@ -17,20 +17,43 @@ package xdsclient -import "google.golang.org/grpc/resolver" +import ( + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient/load" +) type clientKeyType string const clientKey = clientKeyType("grpc.xds.internal.client.Client") +// XDSClient is a full fledged gRPC client which queries a set of discovery APIs +// (collectively termed as xDS) on a remote management server, to discover +// various dynamic resources. +type XDSClient interface { + WatchListener(string, func(ListenerUpdate, error)) func() + WatchRouteConfig(string, func(RouteConfigUpdate, error)) func() + WatchCluster(string, func(ClusterUpdate, error)) func() + WatchEndpoints(clusterName string, edsCb func(EndpointsUpdate, error)) (cancel func()) + ReportLoad(server string) (*load.Store, func()) + + DumpLDS() (string, map[string]UpdateWithMD) + DumpRDS() (string, map[string]UpdateWithMD) + DumpCDS() (string, map[string]UpdateWithMD) + DumpEDS() (string, map[string]UpdateWithMD) + + BootstrapConfig() *bootstrap.Config + Close() +} + // FromResolverState returns the Client from state, or nil if not present. -func FromResolverState(state resolver.State) *Client { - cs, _ := state.Attributes.Value(clientKey).(*Client) +func FromResolverState(state resolver.State) XDSClient { + cs, _ := state.Attributes.Value(clientKey).(XDSClient) return cs } // SetClient sets c in state and returns the new state. -func SetClient(state resolver.State, c *Client) resolver.State { +func SetClient(state resolver.State, c XDSClient) resolver.State { state.Attributes = state.Attributes.WithValues(clientKey, c) return state } diff --git a/xds/internal/xdsclient/client.go b/xds/internal/xdsclient/client.go index ac832f205d5..13ef973807c 100644 --- a/xds/internal/xdsclient/client.go +++ b/xds/internal/xdsclient/client.go @@ -579,7 +579,7 @@ func newWithConfig(config *bootstrap.Config, watchExpiryTimeout time.Duration) ( // BootstrapConfig returns the configuration read from the bootstrap file. // Callers must treat the return value as read-only. -func (c *Client) BootstrapConfig() *bootstrap.Config { +func (c *clientRefCounted) BootstrapConfig() *bootstrap.Config { return c.config } diff --git a/xds/internal/xdsclient/client_test.go b/xds/internal/xdsclient/client_test.go index c1d4b38e576..12590408e6c 100644 --- a/xds/internal/xdsclient/client_test.go +++ b/xds/internal/xdsclient/client_test.go @@ -263,7 +263,7 @@ func (s) TestClientNewSingleton(t *testing.T) { defer cleanup() // The first New(). Should create a Client and a new APIClient. - client, err := New() + client, err := newRefCounted() if err != nil { t.Fatalf("failed to create client: %v", err) } @@ -280,7 +280,7 @@ func (s) TestClientNewSingleton(t *testing.T) { // and should not create new API client. const count = 9 for i := 0; i < count; i++ { - tc, terr := New() + tc, terr := newRefCounted() if terr != nil { client.Close() t.Fatalf("%d-th call to New() failed with error: %v", i, terr) @@ -324,7 +324,7 @@ func (s) TestClientNewSingleton(t *testing.T) { // Call New() again after the previous Client is actually closed. Should // create a Client and a new APIClient. - client2, err2 := New() + client2, err2 := newRefCounted() if err2 != nil { t.Fatalf("failed to create client: %v", err) } diff --git a/xds/internal/xdsclient/singleton.go b/xds/internal/xdsclient/singleton.go index 8d0e10f2c31..f045790e2a4 100644 --- a/xds/internal/xdsclient/singleton.go +++ b/xds/internal/xdsclient/singleton.go @@ -32,18 +32,14 @@ const defaultWatchExpiryTimeout = 15 * time.Second // This is the Client returned by New(). It contains one client implementation, // and maintains the refcount. -var singletonClient = &Client{} +var singletonClient = &clientRefCounted{} // To override in tests. var bootstrapNewConfig = bootstrap.NewConfig -// Client is a full fledged gRPC client which queries a set of discovery APIs -// (collectively termed as xDS) on a remote management server, to discover -// various dynamic resources. -// -// The xds client is a singleton. It will be shared by the xds resolver and +// clientRefCounted is ref-counted, and to be shared by the xds resolver and // balancer implementations, across multiple ClientConns and Servers. -type Client struct { +type clientRefCounted struct { *clientImpl // This mu protects all the fields, including the embedded clientImpl above. @@ -60,7 +56,18 @@ type Client struct { // Note that the first invocation of New() or NewWithConfig() sets the client // singleton. The following calls will return the singleton xds client without // checking or using the config. -func New() (*Client, error) { +func New() (XDSClient, error) { + // This cannot just return newRefCounted(), because in error cases, the + // returned nil is a typed nil (*clientRefCounted), which may cause nil + // checks fail. + c, err := newRefCounted() + if err != nil { + return nil, err + } + return c, nil +} + +func newRefCounted() (*clientRefCounted, error) { singletonClient.mu.Lock() defer singletonClient.mu.Unlock() // If the client implementation was created, increment ref count and return @@ -96,7 +103,7 @@ func New() (*Client, error) { // // This function is internal only, for c2p resolver and testing to use. DO NOT // use this elsewhere. Use New() instead. -func NewWithConfig(config *bootstrap.Config) (*Client, error) { +func NewWithConfig(config *bootstrap.Config) (XDSClient, error) { singletonClient.mu.Lock() defer singletonClient.mu.Unlock() // If the client implementation was created, increment ref count and return @@ -120,7 +127,7 @@ func NewWithConfig(config *bootstrap.Config) (*Client, error) { // Close closes the client. It does ref count of the xds client implementation, // and closes the gRPC connection to the management server when ref count // reaches 0. -func (c *Client) Close() { +func (c *clientRefCounted) Close() { c.mu.Lock() defer c.mu.Unlock() c.refCount-- @@ -136,18 +143,18 @@ func (c *Client) Close() { // // Note that this function doesn't set the singleton, so that the testing states // don't leak. -func NewWithConfigForTesting(config *bootstrap.Config, watchExpiryTimeout time.Duration) (*Client, error) { +func NewWithConfigForTesting(config *bootstrap.Config, watchExpiryTimeout time.Duration) (XDSClient, error) { cl, err := newWithConfig(config, watchExpiryTimeout) if err != nil { return nil, err } - return &Client{clientImpl: cl, refCount: 1}, nil + return &clientRefCounted{clientImpl: cl, refCount: 1}, nil } // NewClientWithBootstrapContents returns an xds client for this config, // separate from the global singleton. This should be used for testing // purposes only. -func NewClientWithBootstrapContents(contents []byte) (*Client, error) { +func NewClientWithBootstrapContents(contents []byte) (XDSClient, error) { // Normalize the contents buf := bytes.Buffer{} err := json.Indent(&buf, contents, "", "") @@ -180,12 +187,12 @@ func NewClientWithBootstrapContents(contents []byte) (*Client, error) { return nil, err } - c := &Client{clientImpl: cImpl, refCount: 1} + c := &clientRefCounted{clientImpl: cImpl, refCount: 1} clients[string(contents)] = c return c, nil } var ( - clients = map[string]*Client{} + clients = map[string]*clientRefCounted{} clientsMu sync.Mutex ) diff --git a/xds/internal/xdsclient/tests/dump_test.go b/xds/internal/xdsclient/tests/dump_test.go index 541f5901c12..64c78f67285 100644 --- a/xds/internal/xdsclient/tests/dump_test.go +++ b/xds/internal/xdsclient/tests/dump_test.go @@ -85,6 +85,7 @@ func (s) TestLDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpLDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -111,7 +112,7 @@ func (s) TestLDSConfigDump(t *testing.T) { Raw: r, } } - client.NewListeners(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewListeners(update0, xdsclient.UpdateMetadata{Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpLDS, testVersion, want0); err != nil { @@ -120,7 +121,7 @@ func (s) TestLDSConfigDump(t *testing.T) { const nackVersion = "lds-version-nack" var nackErr = fmt.Errorf("lds nack error") - client.NewListeners( + updateHandler.NewListeners( map[string]xdsclient.ListenerUpdate{ ldsTargets[0]: {}, }, @@ -195,6 +196,7 @@ func (s) TestRDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpRDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -221,7 +223,7 @@ func (s) TestRDSConfigDump(t *testing.T) { Raw: r, } } - client.NewRouteConfigs(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewRouteConfigs(update0, xdsclient.UpdateMetadata{Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpRDS, testVersion, want0); err != nil { @@ -230,7 +232,7 @@ func (s) TestRDSConfigDump(t *testing.T) { const nackVersion = "rds-version-nack" var nackErr = fmt.Errorf("rds nack error") - client.NewRouteConfigs( + updateHandler.NewRouteConfigs( map[string]xdsclient.RouteConfigUpdate{ rdsTargets[0]: {}, }, @@ -305,6 +307,7 @@ func (s) TestCDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpCDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -331,7 +334,7 @@ func (s) TestCDSConfigDump(t *testing.T) { Raw: r, } } - client.NewClusters(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewClusters(update0, xdsclient.UpdateMetadata{Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpCDS, testVersion, want0); err != nil { @@ -340,7 +343,7 @@ func (s) TestCDSConfigDump(t *testing.T) { const nackVersion = "cds-version-nack" var nackErr = fmt.Errorf("cds nack error") - client.NewClusters( + updateHandler.NewClusters( map[string]xdsclient.ClusterUpdate{ cdsTargets[0]: {}, }, @@ -401,6 +404,7 @@ func (s) TestEDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpEDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -427,7 +431,7 @@ func (s) TestEDSConfigDump(t *testing.T) { Raw: r, } } - client.NewEndpoints(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewEndpoints(update0, xdsclient.UpdateMetadata{Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpEDS, testVersion, want0); err != nil { @@ -436,7 +440,7 @@ func (s) TestEDSConfigDump(t *testing.T) { const nackVersion = "eds-version-nack" var nackErr = fmt.Errorf("eds nack error") - client.NewEndpoints( + updateHandler.NewEndpoints( map[string]xdsclient.EndpointsUpdate{ edsTargets[0]: {}, }, diff --git a/xds/server.go b/xds/server.go index 989859bc65c..cfbea1a1bca 100644 --- a/xds/server.go +++ b/xds/server.go @@ -35,14 +35,13 @@ import ( "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/xds/internal/server" "google.golang.org/grpc/xds/internal/xdsclient" - "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const serverPrefix = "[xds-server %p] " var ( // These new functions will be overridden in unit tests. - newXDSClient = func() (xdsClient, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { return xdsclient.New() } newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { @@ -58,14 +57,6 @@ func prefixLogger(p *GRPCServer) *internalgrpclog.PrefixLogger { return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(serverPrefix, p)) } -// xdsClient contains methods from xdsClient.Client which are used by -// the server. This is useful for overriding in unit tests. -type xdsClient interface { - WatchListener(string, func(xdsclient.ListenerUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - // grpcServer contains methods from grpc.Server which are used by the // GRPCServer type here. This is useful for overriding in unit tests. type grpcServer interface { @@ -90,7 +81,7 @@ type GRPCServer struct { // beginning of Serve(), where we have to decide if we have to create a // client or use an existing one. clientMu sync.Mutex - xdsC xdsClient + xdsC xdsclient.XDSClient } // NewGRPCServer creates an xDS-enabled gRPC server using the passed in opts. @@ -156,7 +147,7 @@ func (s *GRPCServer) initXDSClient() error { newXDSClient := newXDSClient if s.opts.bootstrapContents != nil { - newXDSClient = func() (xdsClient, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { return xdsclient.NewClientWithBootstrapContents(s.opts.bootstrapContents) } } diff --git a/xds/server_test.go b/xds/server_test.go index 27a33da091d..45df8b76fca 100644 --- a/xds/server_test.go +++ b/xds/server_test.go @@ -247,7 +247,7 @@ func (p *fakeProvider) Close() { func setupOverrides() (*fakeGRPCServer, *testutils.Channel, func()) { clientCh := testutils.NewChannel() origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { c := fakeclient.NewClient() c.SetBootstrapConfig(&bootstrap.Config{ BalancerName: "dummyBalancer", @@ -277,7 +277,7 @@ func setupOverrides() (*fakeGRPCServer, *testutils.Channel, func()) { func setupOverridesForXDSCreds(includeCertProviderCfg bool) (*testutils.Channel, func()) { clientCh := testutils.NewChannel() origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { c := fakeclient.NewClient() bc := &bootstrap.Config{ BalancerName: "dummyBalancer", @@ -544,7 +544,7 @@ func (s) TestServeBootstrapConfigInvalid(t *testing.T) { // xdsClient with the specified bootstrap configuration. clientCh := testutils.NewChannel() origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { c := fakeclient.NewClient() c.SetBootstrapConfig(test.bootstrapConfig) clientCh.Send(c) @@ -587,7 +587,7 @@ func (s) TestServeBootstrapConfigInvalid(t *testing.T) { // verifies that Server() exits with a non-nil error. func (s) TestServeNewClientFailure(t *testing.T) { origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClient, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { return nil, errors.New("xdsClient creation failed") } defer func() { newXDSClient = origNewXDSClient }()