diff --git a/internal/balancergroup/balancergroup.go b/internal/balancergroup/balancergroup.go index 9776158dd986..45a15af10935 100644 --- a/internal/balancergroup/balancergroup.go +++ b/internal/balancergroup/balancergroup.go @@ -115,6 +115,7 @@ func (sbc *subBalancerWrapper) exitIdle() { sc.Connect() } } + // sbc.group.connect(sbc) } func (sbc *subBalancerWrapper) updateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { @@ -383,6 +384,17 @@ func (bg *BalancerGroup) cleanupSubConns(config *subBalancerWrapper) { bg.incomingMu.Unlock() } +// connect attempts to connect to all subConns belonging to sb. +func (bg *BalancerGroup) connect(sb *subBalancerWrapper) { + bg.incomingMu.Lock() + for sc, b := range bg.scToSubBalancer { + if b == sb { + sc.Connect() + } + } + bg.incomingMu.Unlock() +} + // Following are actions from the parent grpc.ClientConn, forward to sub-balancers. // UpdateSubConnState handles the state for the subconn. It finds the diff --git a/internal/balancergroup/balancergroup_test.go b/internal/balancergroup/balancergroup_test.go index 4942f8a7da87..32d271a8469c 100644 --- a/internal/balancergroup/balancergroup_test.go +++ b/internal/balancergroup/balancergroup_test.go @@ -18,6 +18,7 @@ package balancergroup import ( "fmt" + "sync" "testing" "time" @@ -535,3 +536,104 @@ func (s) TestBalancerExitIdleOne(t *testing.T) { case <-exitIdleCh: } } + +type nonExitIdlerBalancerBuilder struct { + name string +} + +func (bb *nonExitIdlerBalancerBuilder) Name() string { + return bb.name +} + +func (bb *nonExitIdlerBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + rr := balancer.Get(roundrobin.Name) + return &nonExitIdlerBalancer{rrBalancer: rr.Build(cc, opts)} +} + +type nonExitIdlerBalancer struct { + rrBalancer balancer.Balancer +} + +func (b *nonExitIdlerBalancer) UpdateClientConnState(state balancer.ClientConnState) error { + err := b.rrBalancer.UpdateClientConnState(state) + fmt.Println("error from pickfirst balancer is:", err) + return err +} + +func (b *nonExitIdlerBalancer) ResolverError(err error) { + b.rrBalancer.ResolverError(err) +} + +func (b *nonExitIdlerBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + b.rrBalancer.UpdateSubConnState(sc, state) +} + +func (b *nonExitIdlerBalancer) Close() { + b.rrBalancer.Close() +} + +func (s) TestBalancerExitIdleRace(t *testing.T) { + bBuilder := &nonExitIdlerBalancerBuilder{name: t.Name()} + balancer.Register(bBuilder) + + // Create a balancer group with a weighted target state aggregator. + cc := testutils.NewTestClientConn(t) + gator := weightedaggregator.New(cc, nil, testutils.NewTestWRR) + gator.Start() + bg := New(cc, balancer.BuildOptions{}, gator, nil) + defer func() { + gator.Stop() + bg.Close() + }() + + // Add a balancer which does not implement the ExitIdler interface to the + // group, and add backends to the balancer. + gator.Add(testBalancerIDs[0], 1) + bg.Add(testBalancerIDs[0], bBuilder) + bg.UpdateClientConnState(testBalancerIDs[0], balancer.ClientConnState{ResolverState: resolver.State{Addresses: testBackendAddrs[0:2]}}) + bg.Start() + + // Move both backends to READY. + addrToSC := make(map[resolver.Address]balancer.SubConn) + for i := 0; i < 2; i++ { + addrs := <-cc.NewSubConnAddrsCh + sc := <-cc.NewSubConnCh + addrToSC[addrs[0]] = sc + bg.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + bg.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + } + + // Test roundrobin on the last picker. + p := <-cc.NewPickerCh + want := []balancer.SubConn{addrToSC[testBackendAddrs[0]], addrToSC[testBackendAddrs[1]]} + if err := testutils.IsRoundRobin(want, subConnFromPicker(p)); err != nil { + t.Fatalf("want %v, got %v", want, err) + } + + var wg sync.WaitGroup + wg.Add(2) + startCh := make(chan struct{}, 1) + go func() { + defer wg.Done() + + <-startCh + for i := 0; i < 100; i++ { + for j := 0; j < 2; j++ { + bg.UpdateSubConnState(addrToSC[testBackendAddrs[j]], balancer.SubConnState{ConnectivityState: connectivity.Idle}) + } + time.Sleep(10 * time.Millisecond) + } + }() + + go func() { + defer wg.Done() + + close(startCh) + for i := 0; i < 100; i++ { + bg.ExitIdle() + time.Sleep(10 * time.Millisecond) + } + }() + + wg.Wait() +}