diff --git a/internal/resolver/dns/dns_resolver_test.go b/internal/resolver/dns/dns_resolver_test.go index 21f87a84e09..c76700a51ab 100644 --- a/internal/resolver/dns/dns_resolver_test.go +++ b/internal/resolver/dns/dns_resolver_test.go @@ -22,6 +22,7 @@ import ( "context" "errors" "fmt" + "google.golang.org/grpc/internal/testutils" "net" "os" "reflect" @@ -756,11 +757,11 @@ func TestDNSResolverExponentialBackoff(t *testing.T) { defer func(nt func(d time.Duration) *time.Timer) { newTimer = nt }(newTimer) - timerChan := make(chan *time.Timer, 1) + timerChan := testutils.NewChannel() newTimer = func(d time.Duration) *time.Timer { // Will never fire on its own, allows this test to call timer immediately. t := time.NewTimer(time.Hour) - timerChan <- t + timerChan.SendOrFail(t) return t } tests := []struct { @@ -770,19 +771,19 @@ func TestDNSResolverExponentialBackoff(t *testing.T) { scWant string }{ { - "happy-case-default-port", + "happy case default port", "foo.bar.com", []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}, generateSC("foo.bar.com"), }, { - "happy-case-specified-port", + "happy case specified port", "foo.bar.com:1234", []resolver.Address{{Addr: "1.2.3.4:1234"}, {Addr: "5.6.7.8:1234"}}, generateSC("foo.bar.com"), }, { - "happy-case-another-default-port", + "happy case another default port", "srv.ipv4.single.fake", []resolver.Address{{Addr: "2.4.6.8" + colonDefaultPort}}, generateSC("srv.ipv4.single.fake"), @@ -796,7 +797,7 @@ func TestDNSResolverExponentialBackoff(t *testing.T) { cc.updateStateErr = balancer.ErrBadResolverState r, err := b.Build(resolver.Target{Endpoint: test.target}, cc, resolver.BuildOptions{}) if err != nil { - t.Fatalf("%v\n", err) + t.Fatalf("Error building resolver for target %v: %v", test.target, err) } var state resolver.State var cnt int @@ -817,10 +818,16 @@ func TestDNSResolverExponentialBackoff(t *testing.T) { if test.scWant != sc { t.Errorf("Resolved service config of target: %q = %+v, want %+v", test.target, sc, test.scWant) } + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() // Cause timer to go off 10 times, and see if it calls updateState() correctly. for i := 0; i < 10; i++ { - timer := <-timerChan - timer.Reset(0) + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) } // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call // ClientConn update state. @@ -828,11 +835,10 @@ func TestDNSResolverExponentialBackoff(t *testing.T) { for { cc.m1.Lock() got := cc.updateStateCalls + cc.m1.Unlock() if got == 11 { - cc.m1.Unlock() break } - cc.m1.Unlock() if time.Now().After(deadline) { t.Fatalf("Exponential backoff is not working as expected - should update state 11 times instead of %d", got) @@ -843,28 +849,30 @@ func TestDNSResolverExponentialBackoff(t *testing.T) { // Update resolver.ClientConn to not return an error anymore - this should stop it from backing off. cc.updateStateErr = nil - timer := <-timerChan - timer.Reset(0) + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) // Poll to see if DNS Resolver updated state the correct number of times, which allows time for the DNS Resolver to call // ClientConn update state the final time. The DNS Resolver should then stop polling. deadline = time.Now().Add(defaultTestTimeout) for { cc.m1.Lock() got := cc.updateStateCalls + cc.m1.Unlock() if got == 12 { - cc.m1.Unlock() break } - cc.m1.Unlock() if time.Now().After(deadline) { t.Fatalf("Exponential backoff is not working as expected - should stop backing off at 12 total UpdateState calls instead of %d", got) } - select { - case <-timerChan: - t.Fatalf("Should not poll again after no more error") - default: + _, err := timerChan.ReceiveOrFail() + if err { + t.Fatalf("Should not poll again after Client Conn stops returning error.") } time.Sleep(time.Millisecond) @@ -1555,11 +1563,11 @@ func TestReportError(t *testing.T) { defer func(nt func(d time.Duration) *time.Timer) { newTimer = nt }(newTimer) - timerChan := make(chan *time.Timer, 1) + timerChan := testutils.NewChannel() newTimer = func(d time.Duration) *time.Timer { - // Will never fire on its own, allowing us to control it. + // Will never fire on its own, allows this test to call timer immediately. t := time.NewTimer(time.Hour) - timerChan <- t + timerChan.SendOrFail(t) return t } cc := &testClientConn{target: target, errChan: make(chan error)} @@ -1567,7 +1575,7 @@ func TestReportError(t *testing.T) { b := NewBuilder() r, err := b.Build(resolver.Target{Endpoint: target}, cc, resolver.BuildOptions{}) if err != nil { - t.Fatalf("%v\n", err) + t.Fatalf("Error building resolver for target %v: %v", target, err) } // Should receive first error. err = <-cc.errChan @@ -1575,8 +1583,14 @@ func TestReportError(t *testing.T) { t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err) } totalTimesCalledError++ - timer := <-timerChan - timer.Reset(0) + ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer ctxCancel() + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) defer r.Close() // Cause timer to go off 10 times, and see if it matches DNS Resolver updating Error. @@ -1587,8 +1601,12 @@ func TestReportError(t *testing.T) { t.Fatalf(`ReportError(err=%v) called; want err contains "hostLookupError"`, err) } totalTimesCalledError++ - timer = <-timerChan - timer.Reset(0) + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) } if totalTimesCalledError != 11 { @@ -1596,5 +1614,8 @@ func TestReportError(t *testing.T) { } // Clean up final watcher iteration. <-cc.errChan - <-timerChan + _, err = timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } }