diff --git a/balancer/grpclb/grpclb.go b/balancer/grpclb/grpclb.go index adf59611160..fe423af182a 100644 --- a/balancer/grpclb/grpclb.go +++ b/balancer/grpclb/grpclb.go @@ -135,6 +135,7 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal lb := &lbBalancer{ cc: newLBCacheClientConn(cc), + dialTarget: opt.Target.Endpoint, target: opt.Target.Endpoint, opt: opt, fallbackTimeout: b.fallbackTimeout, @@ -164,9 +165,10 @@ func (b *lbBuilder) Build(cc balancer.ClientConn, opt balancer.BuildOptions) bal } type lbBalancer struct { - cc *lbCacheClientConn - target string - opt balancer.BuildOptions + cc *lbCacheClientConn + dialTarget string // user's dial target + target string // same as dialTarget unless overridden in service config + opt balancer.BuildOptions usePickFirst bool @@ -398,6 +400,30 @@ func (lb *lbBalancer) handleServiceConfig(gc *grpclbServiceConfig) { lb.mu.Lock() defer lb.mu.Unlock() + // grpclb uses the user's dial target to populate the `Name` field of the + // `InitialLoadBalanceRequest` message sent to the remote balancer. But when + // grpclb is used a child policy in the context of RLS, we want the `Name` + // field to be populated with the value received from the RLS server. To + // support this use case, an optional "target_name" field has been added to + // the grpclb LB policy's config. If specified, it overrides the name of + // the target to be sent to the remote balancer; if not, the target to be + // sent to the balancer will continue to be obtained from the target URI + // passed to the gRPC client channel. Whenever that target to be sent to the + // balancer is updated, we need to restart the stream to the balancer as + // this target is sent in the first message on the stream. + if gc != nil { + target := lb.dialTarget + if gc.TargetName != "" { + target = gc.TargetName + } + if target != lb.target { + lb.target = target + if lb.ccRemoteLB != nil { + lb.ccRemoteLB.cancelRemoteBalancerCall() + } + } + } + newUsePickFirst := childIsPickFirst(gc) if lb.usePickFirst == newUsePickFirst { return diff --git a/balancer/grpclb/grpclb_config.go b/balancer/grpclb/grpclb_config.go index aac3719631b..b4e23dee017 100644 --- a/balancer/grpclb/grpclb_config.go +++ b/balancer/grpclb/grpclb_config.go @@ -34,6 +34,7 @@ const ( type grpclbServiceConfig struct { serviceconfig.LoadBalancingConfig ChildPolicy *[]map[string]json.RawMessage + TargetName string } func (b *lbBuilder) ParseConfig(lbConfig json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { diff --git a/balancer/grpclb/grpclb_config_test.go b/balancer/grpclb/grpclb_config_test.go index 5a45de90494..0db2299157e 100644 --- a/balancer/grpclb/grpclb_config_test.go +++ b/balancer/grpclb/grpclb_config_test.go @@ -20,52 +20,68 @@ package grpclb import ( "encoding/json" - "errors" - "fmt" - "reflect" - "strings" "testing" + "github.com/google/go-cmp/cmp" "google.golang.org/grpc/serviceconfig" ) func (s) TestParse(t *testing.T) { tests := []struct { name string - s string + sc string want serviceconfig.LoadBalancingConfig - wantErr error + wantErr bool }{ { name: "empty", - s: "", + sc: "", want: nil, - wantErr: errors.New("unexpected end of JSON input"), + wantErr: true, }, { name: "success1", - s: `{"childPolicy":[{"pick_first":{}}]}`, + sc: ` +{ + "childPolicy": [ + {"pick_first":{}} + ], + "targetName": "foo-service" +}`, want: &grpclbServiceConfig{ ChildPolicy: &[]map[string]json.RawMessage{ {"pick_first": json.RawMessage("{}")}, }, + TargetName: "foo-service", }, }, { name: "success2", - s: `{"childPolicy":[{"round_robin":{}},{"pick_first":{}}]}`, + sc: ` +{ + "childPolicy": [ + {"round_robin":{}}, + {"pick_first":{}} + ], + "targetName": "foo-service" +}`, want: &grpclbServiceConfig{ ChildPolicy: &[]map[string]json.RawMessage{ {"round_robin": json.RawMessage("{}")}, {"pick_first": json.RawMessage("{}")}, }, + TargetName: "foo-service", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got, err := (&lbBuilder{}).ParseConfig(json.RawMessage(tt.s)); !reflect.DeepEqual(got, tt.want) || !strings.Contains(fmt.Sprint(err), fmt.Sprint(tt.wantErr)) { - t.Errorf("parseFullServiceConfig() = %+v, %+v, want %+v, ", got, err, tt.want, tt.wantErr) + got, err := (&lbBuilder{}).ParseConfig(json.RawMessage(tt.sc)) + if (err != nil) != (tt.wantErr) { + t.Fatalf("ParseConfig(%q) returned error: %v, wantErr: %v", tt.sc, err, tt.wantErr) + } + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatalf("ParseConfig(%q) returned unexpected difference (-want +got):\n%s", tt.sc, diff) } }) } diff --git a/balancer/grpclb/grpclb_remote_balancer.go b/balancer/grpclb/grpclb_remote_balancer.go index 5ac8d86bd57..0210c012d7b 100644 --- a/balancer/grpclb/grpclb_remote_balancer.go +++ b/balancer/grpclb/grpclb_remote_balancer.go @@ -206,6 +206,9 @@ type remoteBalancerCCWrapper struct { backoff backoff.Strategy done chan struct{} + streamMu sync.Mutex + streamCancel func() + // waitgroup to wait for all goroutines to exit. wg sync.WaitGroup } @@ -319,10 +322,8 @@ func (ccw *remoteBalancerCCWrapper) sendLoadReport(s *balanceLoadClientStream, i } } -func (ccw *remoteBalancerCCWrapper) callRemoteBalancer() (backoff bool, _ error) { +func (ccw *remoteBalancerCCWrapper) callRemoteBalancer(ctx context.Context) (backoff bool, _ error) { lbClient := &loadBalancerClient{cc: ccw.cc} - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() stream, err := lbClient.BalanceLoad(ctx, grpc.WaitForReady(true)) if err != nil { return true, fmt.Errorf("grpclb: failed to perform RPC to the remote balancer %v", err) @@ -362,11 +363,43 @@ func (ccw *remoteBalancerCCWrapper) callRemoteBalancer() (backoff bool, _ error) return false, ccw.readServerList(stream) } +// cancelRemoteBalancerCall cancels the context used by the stream to the remote +// balancer. watchRemoteBalancer() takes care of restarting this call after the +// stream fails. +func (ccw *remoteBalancerCCWrapper) cancelRemoteBalancerCall() { + ccw.streamMu.Lock() + if ccw.streamCancel != nil { + ccw.streamCancel() + ccw.streamCancel = nil + } + ccw.streamMu.Unlock() +} + func (ccw *remoteBalancerCCWrapper) watchRemoteBalancer() { - defer ccw.wg.Done() + defer func() { + ccw.wg.Done() + ccw.streamMu.Lock() + if ccw.streamCancel != nil { + // This is to make sure that we don't leak the context when we are + // directly returning from inside of the below `for` loop. + ccw.streamCancel() + ccw.streamCancel = nil + } + ccw.streamMu.Unlock() + }() + var retryCount int + var ctx context.Context for { - doBackoff, err := ccw.callRemoteBalancer() + ccw.streamMu.Lock() + if ccw.streamCancel != nil { + ccw.streamCancel() + ccw.streamCancel = nil + } + ctx, ccw.streamCancel = context.WithCancel(context.Background()) + ccw.streamMu.Unlock() + + doBackoff, err := ccw.callRemoteBalancer(ctx) select { case <-ccw.done: return diff --git a/balancer/grpclb/grpclb_test.go b/balancer/grpclb/grpclb_test.go index d6275b657f9..3b666764728 100644 --- a/balancer/grpclb/grpclb_test.go +++ b/balancer/grpclb/grpclb_test.go @@ -31,12 +31,16 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc" "google.golang.org/grpc/balancer" grpclbstate "google.golang.org/grpc/balancer/grpclb/state" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/resolver" @@ -60,6 +64,13 @@ var ( fakeName = "fake.Name" ) +const ( + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond + testUserAgent = "test-user-agent" + grpclbConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}` +) + type s struct { grpctest.Tester } @@ -136,18 +147,6 @@ func (s *rpcStats) merge(cs *lbpb.ClientStats) { s.mu.Unlock() } -func mapsEqual(a, b map[string]int64) bool { - if len(a) != len(b) { - return false - } - for k, v1 := range a { - if v2, ok := b[k]; !ok || v1 != v2 { - return false - } - } - return true -} - func atomicEqual(a, b *int64) bool { return atomic.LoadInt64(a) == atomic.LoadInt64(b) } @@ -172,7 +171,7 @@ func (s *rpcStats) equal(o *rpcStats) bool { defer s.mu.Unlock() o.mu.Lock() defer o.mu.Unlock() - return mapsEqual(s.numCallsDropped, o.numCallsDropped) + return cmp.Equal(s.numCallsDropped, o.numCallsDropped, cmpopts.EquateEmpty()) } func (s *rpcStats) String() string { @@ -188,24 +187,28 @@ func (s *rpcStats) String() string { type remoteBalancer struct { lbgrpc.UnimplementedLoadBalancerServer - sls chan *lbpb.ServerList - statsDura time.Duration - done chan struct{} - stats *rpcStats - statsChan chan *lbpb.ClientStats - fbChan chan struct{} - - customUserAgent string + sls chan *lbpb.ServerList + statsDura time.Duration + done chan struct{} + stats *rpcStats + statsChan chan *lbpb.ClientStats + fbChan chan struct{} + balanceLoadCh chan struct{} // notify successful invocation of BalanceLoad + + wantUserAgent string // expected user-agent in metadata of BalancerLoad + wantServerName string // expected server name in InitialLoadBalanceRequest } -func newRemoteBalancer(customUserAgent string, statsChan chan *lbpb.ClientStats) *remoteBalancer { +func newRemoteBalancer(wantUserAgent, wantServerName string, statsChan chan *lbpb.ClientStats) *remoteBalancer { return &remoteBalancer{ - sls: make(chan *lbpb.ServerList, 1), - done: make(chan struct{}), - stats: newRPCStats(), - statsChan: statsChan, - fbChan: make(chan struct{}), - customUserAgent: customUserAgent, + sls: make(chan *lbpb.ServerList, 1), + done: make(chan struct{}), + stats: newRPCStats(), + statsChan: statsChan, + fbChan: make(chan struct{}), + balanceLoadCh: make(chan struct{}, 1), + wantUserAgent: wantUserAgent, + wantServerName: wantServerName, } } @@ -218,15 +221,18 @@ func (b *remoteBalancer) fallbackNow() { b.fbChan <- struct{}{} } +func (b *remoteBalancer) updateServerName(name string) { + b.wantServerName = name +} + func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error { md, ok := metadata.FromIncomingContext(stream.Context()) if !ok { return status.Error(codes.Internal, "failed to receive metadata") } - if b.customUserAgent != "" { - ua := md["user-agent"] - if len(ua) == 0 || !strings.HasPrefix(ua[0], b.customUserAgent) { - return status.Errorf(codes.InvalidArgument, "received unexpected user-agent: %v, want prefix %q", ua, b.customUserAgent) + if b.wantUserAgent != "" { + if ua := md["user-agent"]; len(ua) == 0 || !strings.HasPrefix(ua[0], b.wantUserAgent) { + return status.Errorf(codes.InvalidArgument, "received unexpected user-agent: %v, want prefix %q", ua, b.wantUserAgent) } } @@ -235,9 +241,10 @@ func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServe return err } initReq := req.GetInitialRequest() - if initReq.Name != beServerName { - return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name) + if initReq.Name != b.wantServerName { + return status.Errorf(codes.InvalidArgument, "invalid service name: %q, want: %q", initReq.Name, b.wantServerName) } + b.balanceLoadCh <- struct{}{} resp := &lbpb.LoadBalanceResponse{ LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{ InitialResponse: &lbpb.InitialLoadBalanceResponse{ @@ -253,11 +260,8 @@ func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServe } go func() { for { - var ( - req *lbpb.LoadBalanceRequest - err error - ) - if req, err = stream.Recv(); err != nil { + req, err := stream.Recv() + if err != nil { return } b.stats.merge(req.GetClientStats()) @@ -347,7 +351,7 @@ type testServers struct { beListeners []net.Listener } -func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) { +func startBackendsAndRemoteLoadBalancer(numberOfBackends int, customUserAgent string, statsChan chan *lbpb.ClientStats) (tss *testServers, cleanup func(), err error) { var ( beListeners []net.Listener ls *remoteBalancer @@ -380,7 +384,7 @@ func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan cha sn: lbServerName, } lb = grpc.NewServer(grpc.Creds(lbCreds)) - ls = newRemoteBalancer(customUserAgent, statsChan) + ls = newRemoteBalancer(customUserAgent, beServerName, statsChan) lbgrpc.RegisterLoadBalancerServer(lb, ls) go func() { lb.Serve(lbLis) @@ -407,34 +411,29 @@ func newLoadBalancer(numberOfBackends int, customUserAgent string, statsChan cha return } -var grpclbConfig = `{"loadBalancingConfig": [{"grpclb": {}}]}` - func (s) TestGRPCLB(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - const testUserAgent = "test-user-agent" - tss, cleanup, err := newLoadBalancer(1, testUserAgent, nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, testUserAgent, nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, + tss.ls.sls <- &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, } - tss.ls.sls <- sl - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer), + + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer), grpc.WithUserAgent(testUserAgent)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) @@ -445,12 +444,11 @@ func (s) TestGRPCLB(t *testing.T) { rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, &grpclbstate.State{BalancerAddresses: []resolver.Address{{ Addr: tss.lbAddr, - Type: resolver.Backend, ServerName: lbServerName, }}}) r.UpdateState(rs) - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) @@ -461,7 +459,7 @@ func (s) TestGRPCLB(t *testing.T) { func (s) TestGRPCLBWeighted(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(2, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(2, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -481,23 +479,25 @@ func (s) TestGRPCLBWeighted(t *testing.T) { portsToIndex[tss.bePorts[i]] = i } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}}) + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() sequences := []string{"00101", "00011"} for _, seq := range sequences { var ( @@ -526,7 +526,7 @@ func (s) TestGRPCLBWeighted(t *testing.T) { func (s) TestDropRequest(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(2, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(2, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -546,22 +546,23 @@ func (s) TestDropRequest(t *testing.T) { Drop: true, }}, } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}}) + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) var ( i int @@ -573,6 +574,8 @@ func (s) TestDropRequest(t *testing.T) { sleepEachLoop = time.Millisecond loopCount = int(time.Second / sleepEachLoop) ) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() // Make a non-fail-fast RPC and wait for it to succeed. for i = 0; i < loopCount; i++ { if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err == nil { @@ -681,49 +684,51 @@ func (s) TestBalancerDisconnects(t *testing.T) { lbs []*grpc.Server ) for i := 0; i < 2; i++ { - tss, cleanup, err := newLoadBalancer(1, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, + tss.ls.sls <- &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, } - tss.ls.sls <- sl tests = append(tests, tss) lbs = append(lbs, tss.lb) } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tests[0].lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: tests[1].lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}}) + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{ + { + Addr: tests[0].lbAddr, + ServerName: lbServerName, + }, + { + Addr: tests[1].lbAddr, + ServerName: lbServerName, + }, + }}) + r.UpdateState(rs) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() var p peer.Peer if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) @@ -750,16 +755,27 @@ func (s) TestBalancerDisconnects(t *testing.T) { func (s) TestFallback(t *testing.T) { balancer.Register(newLBBuilderWithFallbackTimeout(100 * time.Millisecond)) defer balancer.Register(newLBBuilder()) - r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, "", nil) + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } + tss.ls.sls <- sl - // Start a standalone backend. + // Start a standalone backend for fallback. beLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen %v", err) @@ -768,37 +784,29 @@ func (s) TestFallback(t *testing.T) { standaloneBEs := startBackends(beServerName, true, beLis) defer stopBackends(standaloneBEs) - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, - } - tss.ls.sls <- sl - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: "invalid.address", - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}}) + // Push an update to the resolver with fallback backend address stored in + // the `Addresses` field and an invalid remote balancer address stored in + // attributes, which will cause fallback behavior to be invoked. + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: "invalid.address", ServerName: lbServerName}}}) + r.UpdateState(rs) + // Make an RPC and verify that it got routed to the fallback backend. + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() var p peer.Peer if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, ", err) @@ -807,15 +815,21 @@ func (s) TestFallback(t *testing.T) { t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr()) } - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}}) + // Push another update to the resolver, this time with a valid balancer + // address in the attributes field. + rs = resolver.State{ + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } + // Wait for RPCs to get routed to the backend behind the remote balancer. var backendUsed bool for i := 0; i < 1000; i++ { if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { @@ -856,7 +870,7 @@ func (s) TestFallback(t *testing.T) { t.Fatalf("No RPC sent to fallback after 2 seconds") } - // Restart backend and remote balancer, should not use backends. + // Restart backend and remote balancer, should not use fallback backend. tss.beListeners[0].(*restartableListener).restart() tss.lbListener.(*restartableListener).restart() tss.ls.sls <- sl @@ -880,13 +894,25 @@ func (s) TestFallback(t *testing.T) { func (s) TestExplicitFallback(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, "", nil) + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } + tss.ls.sls <- sl - // Start a standalone backend. + // Start a standalone backend for fallback. beLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen %v", err) @@ -895,37 +921,25 @@ func (s) TestExplicitFallback(t *testing.T) { standaloneBEs := startBackends(beServerName, true, beLis) defer stopBackends(standaloneBEs) - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, - } - tss.ls.sls <- sl - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}}) + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() var p peer.Peer var backendUsed bool for i := 0; i < 2000; i++ { @@ -980,23 +994,34 @@ func (s) TestExplicitFallback(t *testing.T) { } func (s) TestFallBackWithNoServerAddress(t *testing.T) { - resolveNowCh := make(chan struct{}, 1) + resolveNowCh := testutils.NewChannel() r := manual.NewBuilderWithScheme("whatever") r.ResolveNowCallback = func(resolver.ResolveNowOptions) { - select { - case <-resolveNowCh: - default: + ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer cancel() + if err := resolveNowCh.SendContext(ctx, nil); err != nil { + t.Error("timeout when attemping to send on resolverNowCh") } - resolveNowCh <- struct{}{} } - tss, cleanup, err := newLoadBalancer(1, "", nil) + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer yet. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } - // Start a standalone backend. + // Start a standalone backend for fallback. beLis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to listen %v", err) @@ -1005,81 +1030,61 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) { standaloneBEs := startBackends(beServerName, true, beLis) defer stopBackends(standaloneBEs) - be := &lbpb.Server{ - IpAddress: tss.beIPs[0], - Port: int32(tss.bePorts[0]), - LoadBalanceToken: lbToken, - } - var bes []*lbpb.Server - bes = append(bes, be) - sl := &lbpb.ServerList{ - Servers: bes, - } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - // Select grpclb with service config. - const pfc = `{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"round_robin":{}}]}}]}` - scpr := r.CC.ParseServiceConfig(pfc) - if scpr.Err != nil { - t.Fatalf("Error parsing config %q: %v", pfc, scpr.Err) - } - + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() for i := 0; i < 2; i++ { - // Send an update with only backend address. grpclb should enter fallback - // and use the fallback backend. + // Send an update with only backend address. grpclb should enter + // fallback and use the fallback backend. r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}, - ServiceConfig: scpr, + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), }) - select { - case <-resolveNowCh: - t.Errorf("unexpected resolveNow when grpclb gets no balancer address 1111, %d", i) - case <-time.After(time.Second): + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := resolveNowCh.Receive(sCtx); err != context.DeadlineExceeded { + t.Fatalf("unexpected resolveNow when grpclb gets no balancer address 1111, %d", i) } var p peer.Peer - rpcCtx, rpcCancel := context.WithTimeout(context.Background(), time.Second) - defer rpcCancel() - if _, err := testC.EmptyCall(rpcCtx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { t.Fatalf("_.EmptyCall(_, _) = _, %v, want _, ", err) } if p.Addr.String() != beLis.Addr().String() { t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr()) } - select { - case <-resolveNowCh: + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := resolveNowCh.Receive(sCtx); err != context.DeadlineExceeded { t.Errorf("unexpected resolveNow when grpclb gets no balancer address 2222, %d", i) - case <-time.After(time.Second): } tss.ls.sls <- sl // Send an update with balancer address. The backends behind grpclb should // be used. - r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}, - ServiceConfig: scpr, - }) + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) + + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } var backendUsed bool for i := 0; i < 1000; i++ { @@ -1101,7 +1106,7 @@ func (s) TestFallBackWithNoServerAddress(t *testing.T) { func (s) TestGRPCLBPickFirst(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(3, "", nil) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(3, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -1125,11 +1130,10 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { portsToIndex[tss.bePorts[i]] = i } - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } @@ -1143,21 +1147,11 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { tss.ls.sls <- &lbpb.ServerList{Servers: beServers[0:3]} // Start with sub policy pick_first. - const pfc = `{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"pick_first":{}}]}}]}` - scpr := r.CC.ParseServiceConfig(pfc) - if scpr.Err != nil { - t.Fatalf("Error parsing config %q: %v", pfc, scpr.Err) - } - - r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }}, - ServiceConfig: scpr, - }) + rs := resolver.State{ServiceConfig: r.CC.ParseServiceConfig(`{"loadBalancingConfig":[{"grpclb":{"childPolicy":[{"pick_first":{}}]}}]}`)} + r.UpdateState(grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}})) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() result = "" for i := 0; i < 1000; i++ { if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true), grpc.Peer(&p)); err != nil { @@ -1194,19 +1188,12 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { } // Switch sub policy to roundrobin. - grpclbServiceConfigEmpty := r.CC.ParseServiceConfig(`{}`) - if grpclbServiceConfigEmpty.Err != nil { - t.Fatalf("Error parsing config %q: %v", `{}`, grpclbServiceConfigEmpty.Err) - } - - r.UpdateState(resolver.State{ - Addresses: []resolver.Address{{ + rs = grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ Addr: tss.lbAddr, - Type: resolver.GRPCLB, ServerName: lbServerName, - }}, - ServiceConfig: grpclbServiceConfigEmpty, - }) + }}}) + r.UpdateState(rs) result = "" for i := 0; i < 1000; i++ { @@ -1235,9 +1222,8 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { func (s) TestGRPCLBBackendConnectionErrorPropagation(t *testing.T) { r := manual.NewBuilderWithScheme("whatever") - // Start up an LB which will tell the client to fall back - // right away. - tss, cleanup, err := newLoadBalancer(0, "", nil) + // Start up an LB which will tells the client to fall back right away. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(0, "", nil) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } @@ -1254,30 +1240,27 @@ func (s) TestGRPCLBBackendConnectionErrorPropagation(t *testing.T) { standaloneBEs := startBackends("arbitrary.invalid.name", true, beLis) defer stopBackends(standaloneBEs) - creds := serverNameCheckCreds{} - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), - grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer)) if err != nil { t.Fatalf("Failed to dial to the backend %v", err) } defer cc.Close() testC := testpb.NewTestServiceClient(cc) - r.UpdateState(resolver.State{Addresses: []resolver.Address{{ - Addr: tss.lbAddr, - Type: resolver.GRPCLB, - ServerName: lbServerName, - }, { - Addr: beLis.Addr().String(), - Type: resolver.Backend, - }}}) + rs := resolver.State{ + Addresses: []resolver.Address{{Addr: beLis.Addr().String()}}, + ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig), + } + rs = grpclbstate.Set(rs, &grpclbstate.State{BalancerAddresses: []resolver.Address{{Addr: tss.lbAddr, ServerName: lbServerName}}}) + r.UpdateState(rs) // If https://github.com/grpc/grpc-go/blob/65cabd74d8e18d7347fecd414fa8d83a00035f5f/balancer/grpclb/grpclb_test.go#L103 // changes, then expectedErrMsg may need to be updated. const expectedErrMsg = "received unexpected server name" - ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() var wg sync.WaitGroup wg.Add(1) @@ -1291,6 +1274,87 @@ func (s) TestGRPCLBBackendConnectionErrorPropagation(t *testing.T) { wg.Wait() } +func (s) TestGRPCLBWithTargetNameFieldInConfig(t *testing.T) { + r := manual.NewBuilderWithScheme("whatever") + + // Start a remote balancer and a backend. Push the backend address to the + // remote balancer. + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", nil) + if err != nil { + t.Fatalf("failed to create new load balancer: %v", err) + } + defer cleanup() + sl := &lbpb.ServerList{ + Servers: []*lbpb.Server{ + { + IpAddress: tss.beIPs[0], + Port: int32(tss.bePorts[0]), + LoadBalanceToken: lbToken, + }, + }, + } + tss.ls.sls <- sl + + cc, err := grpc.Dial(r.Scheme()+":///"+beServerName, + grpc.WithResolvers(r), + grpc.WithTransportCredentials(&serverNameCheckCreds{}), + grpc.WithContextDialer(fakeNameDialer), + grpc.WithUserAgent(testUserAgent)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + defer cc.Close() + testC := testpb.NewTestServiceClient(cc) + + // Push a resolver update with grpclb configuration which does not contain the + // target_name field. Our fake remote balancer is configured to always + // expect `beServerName` as the server name in the initial request. + rs := grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(grpclbConfig)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) + } + + // When the value of target_field changes, grpclb will recreate the stream + // to the remote balancer. So, we need to update the fake remote balancer to + // expect a new server name in the initial request. + const newServerName = "new-server-name" + tss.ls.updateServerName(newServerName) + tss.ls.sls <- sl + + // Push the resolver update with target_field changed. + // Push a resolver update with grpclb configuration containing the + // target_name field. Our fake remote balancer has been updated above to expect the newServerName in the initial request. + lbCfg := fmt.Sprintf(`{"loadBalancingConfig": [{"grpclb": {"targetName": "%s"}}]}`, newServerName) + rs = grpclbstate.Set(resolver.State{ServiceConfig: r.CC.ParseServiceConfig(lbCfg)}, + &grpclbstate.State{BalancerAddresses: []resolver.Address{{ + Addr: tss.lbAddr, + ServerName: lbServerName, + }}}) + r.UpdateState(rs) + select { + case <-ctx.Done(): + t.Fatalf("timeout when waiting for BalanceLoad RPC to be called on the remote balancer") + case <-tss.ls.balanceLoadCh: + } + + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, ", testC, err) + } +} + type failPreRPCCred struct{} func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { @@ -1314,7 +1378,7 @@ func checkStats(stats, expected *rpcStats) error { func runAndCheckStats(t *testing.T, drop bool, statsChan chan *lbpb.ClientStats, runRPCs func(*grpc.ClientConn), statsWant *rpcStats) error { r := manual.NewBuilderWithScheme("whatever") - tss, cleanup, err := newLoadBalancer(1, "", statsChan) + tss, cleanup, err := startBackendsAndRemoteLoadBalancer(1, "", statsChan) if err != nil { t.Fatalf("failed to create new load balancer: %v", err) } diff --git a/balancer/grpclb/grpclb_test_util_test.go b/balancer/grpclb/grpclb_test_util_test.go index 5d3e6ba7fed..c143e961754 100644 --- a/balancer/grpclb/grpclb_test_util_test.go +++ b/balancer/grpclb/grpclb_test_util_test.go @@ -48,19 +48,20 @@ func newRestartableListener(l net.Listener) *restartableListener { } } -func (l *restartableListener) Accept() (conn net.Conn, err error) { - conn, err = l.Listener.Accept() - if err == nil { - l.mu.Lock() - if l.closed { - conn.Close() - l.mu.Unlock() - return nil, &tempError{} - } - l.conns = append(l.conns, conn) - l.mu.Unlock() +func (l *restartableListener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err } - return + + l.mu.Lock() + defer l.mu.Unlock() + if l.closed { + conn.Close() + return nil, &tempError{} + } + l.conns = append(l.conns, conn) + return conn, nil } func (l *restartableListener) Close() error {