diff --git a/internal/resolver/dns/dns_resolver.go b/internal/resolver/dns/dns_resolver.go index 9d86460ab6f..03825bbe7b5 100644 --- a/internal/resolver/dns/dns_resolver.go +++ b/internal/resolver/dns/dns_resolver.go @@ -47,8 +47,12 @@ var EnableSRVLookups = false var logger = grpclog.Component("dns") -// A global to stub out in tests. -var newTimer = time.NewTimer +// Globals to stub out in tests. TODO: Perhaps these two can be combined into a +// single variable for testing the resolver? +var ( + newTimer = time.NewTimer + newTimerDNSResRate = time.NewTimer +) func init() { resolver.Register(NewBuilder()) @@ -219,7 +223,7 @@ func (d *dnsResolver) watcher() { // Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least // to prevent constantly re-resolving. backoffIndex = 1 - timer = time.NewTimer(minDNSResRate) + timer = newTimerDNSResRate(minDNSResRate) select { case <-d.ctx.Done(): timer.Stop() diff --git a/internal/resolver/dns/dns_resolver_test.go b/internal/resolver/dns/dns_resolver_test.go index 52067e39cc6..73749eca44b 100644 --- a/internal/resolver/dns/dns_resolver_test.go +++ b/internal/resolver/dns/dns_resolver_test.go @@ -43,14 +43,15 @@ func TestMain(m *testing.M) { // Set a non-zero duration only for tests which are actually testing that // feature. replaceDNSResRate(time.Duration(0)) // No nead to clean up since we os.Exit - replaceNetFunc(nil) // No nead to clean up since we os.Exit + overrideDefaultResolver(false) // No nead to clean up since we os.Exit code := m.Run() os.Exit(code) } const ( - txtBytesLimit = 255 - defaultTestTimeout = 10 * time.Second + txtBytesLimit = 255 + defaultTestTimeout = 10 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond ) type testClientConn struct { @@ -106,12 +107,12 @@ type testResolver struct { // A write to this channel is made when this resolver receives a resolution // request. Tests can rely on reading from this channel to be notified about // resolution requests instead of sleeping for a predefined period of time. - ch chan struct{} + lookupHostCh *testutils.Channel } func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) { - if tr.ch != nil { - tr.ch <- struct{}{} + if tr.lookupHostCh != nil { + tr.lookupHostCh.Send(nil) } return hostLookup(host) } @@ -124,9 +125,17 @@ func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, erro return txtLookup(host) } -func replaceNetFunc(ch chan struct{}) func() { +// overrideDefaultResolver overrides the defaultResolver used by the code with +// an instance of the testResolver. pushOnLookup controls whether the +// testResolver created here pushes lookupHost events on its channel. +func overrideDefaultResolver(pushOnLookup bool) func() { oldResolver := defaultResolver - defaultResolver = &testResolver{ch: ch} + + var lookupHostCh *testutils.Channel + if pushOnLookup { + lookupHostCh = testutils.NewChannel() + } + defaultResolver = &testResolver{lookupHostCh: lookupHostCh} return func() { defaultResolver = oldResolver @@ -1451,23 +1460,33 @@ func TestCustomAuthority(t *testing.T) { // requests. It sets the re-resolution rate to a small value and repeatedly // calls ResolveNow() and ensures only the expected number of resolution // requests are made. + func TestRateLimitedResolve(t *testing.T) { defer leakcheck.Check(t) defer func(nt func(d time.Duration) *time.Timer) { newTimer = nt }(newTimer) newTimer = func(d time.Duration) *time.Timer { - // Will never fire on its own, will protect from triggering exponential backoff. + // Will never fire on its own, will protect from triggering exponential + // backoff. return time.NewTimer(time.Hour) } + defer func(nt func(d time.Duration) *time.Timer) { + newTimerDNSResRate = nt + }(newTimerDNSResRate) - const dnsResRate = 10 * time.Millisecond - dc := replaceDNSResRate(dnsResRate) - defer dc() + timerChan := testutils.NewChannel() + newTimerDNSResRate = func(d time.Duration) *time.Timer { + // Will never fire on its own, allows this test to call timer + // immediately. + t := time.NewTimer(time.Hour) + timerChan.Send(t) + return t + } // Create a new testResolver{} for this test because we want the exact count // of the number of times the resolver was invoked. - nc := replaceNetFunc(make(chan struct{})) + nc := overrideDefaultResolver(true) defer nc() target := "foo.bar.com" @@ -1490,55 +1509,65 @@ func TestRateLimitedResolve(t *testing.T) { t.Fatalf("delegate resolver returned unexpected type: %T\n", tr) } - // Observe the time before unblocking the lookupHost call. The 100ms rate - // limiting timer will begin immediately after that. This means the next - // resolution could happen less than 100ms if we read the time *after* - // receiving from tr.ch - start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() // Wait for the first resolution request to be done. This happens as part - // of the first iteration of the for loop in watcher() because we call - // ResolveNow in Build. - <-tr.ch - - // Here we start a couple of goroutines. One repeatedly calls ResolveNow() - // until asked to stop, and the other waits for two resolution requests to be - // made to our testResolver and stops the former. We measure the start and - // end times, and expect the duration elapsed to be in the interval - // {wantCalls*dnsResRate, wantCalls*dnsResRate} - done := make(chan struct{}) - go func() { - for { - select { - case <-done: - return - default: - r.ResolveNow(resolver.ResolveNowOptions{}) - time.Sleep(1 * time.Millisecond) - } - } - }() + // of the first iteration of the for loop in watcher(). + if _, err := tr.lookupHostCh.Receive(ctx); err != nil { + t.Fatalf("Timed out waiting for lookup() call.") + } - gotCalls := 0 - const wantCalls = 3 - min, max := wantCalls*dnsResRate, (wantCalls+1)*dnsResRate - tMax := time.NewTimer(max) - for gotCalls != wantCalls { - select { - case <-tr.ch: - gotCalls++ - case <-tMax.C: - t.Fatalf("Timed out waiting for %v calls after %v; got %v", wantCalls, max, gotCalls) - } + // Call Resolve Now 100 times, shouldn't continue onto next iteration of + // watcher, thus shouldn't lookup again. + for i := 0; i <= 100; i++ { + r.ResolveNow(resolver.ResolveNowOptions{}) + } + + continueCtx, continueCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer continueCancel() + + if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil { + t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") + } + + // Make the DNSMinResRate timer fire immediately (by receiving it, then + // resetting to 0), this will unblock the resolver which is currently + // blocked on the DNS Min Res Rate timer going off, which will allow it to + // continue to the next iteration of the watcher loop. + timer, err := timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) } - close(done) - elapsed := time.Since(start) + timerPointer := timer.(*time.Timer) + timerPointer.Reset(0) - if gotCalls != wantCalls { - t.Fatalf("resolve count mismatch for target: %q = %+v, want %+v\n", target, gotCalls, wantCalls) + // Now that DNS Min Res Rate timer has gone off, it should lookup again. + if _, err := tr.lookupHostCh.Receive(ctx); err != nil { + t.Fatalf("Timed out waiting for lookup() call.") } - if elapsed < min { - t.Fatalf("elapsed time: %v, wanted it to be between {%v and %v}", elapsed, min, max) + + // Resolve Now 1000 more times, shouldn't lookup again as DNS Min Res Rate + // timer has not gone off. + for i := 0; i < 1000; i++ { + r.ResolveNow(resolver.ResolveNowOptions{}) + } + + if _, err = tr.lookupHostCh.Receive(continueCtx); err == nil { + t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.") + } + + // Make the DNSMinResRate timer fire immediately again. + timer, err = timerChan.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving timer from mock NewTimer call: %v", err) + } + timerPointer = timer.(*time.Timer) + timerPointer.Reset(0) + + // Now that DNS Min Res Rate timer has gone off, it should lookup again. + if _, err = tr.lookupHostCh.Receive(ctx); err != nil { + t.Fatalf("Timed out waiting for lookup() call.") } wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}