Skip to content

Commit

Permalink
dns: fix flaky TestRateLimitedResolve (#4387)
Browse files Browse the repository at this point in the history
* Rewrote TestRateLimitedResolve in dns resolver test to get rid of flakiness.
  • Loading branch information
zasweq committed May 7, 2021
1 parent cb39647 commit c7ea734
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 59 deletions.
10 changes: 7 additions & 3 deletions internal/resolver/dns/dns_resolver.go
Expand Up @@ -47,8 +47,12 @@ var EnableSRVLookups = false

var logger = grpclog.Component("dns")

// A global to stub out in tests.
var newTimer = time.NewTimer
// Globals to stub out in tests. TODO: Perhaps these two can be combined into a
// single variable for testing the resolver?
var (
newTimer = time.NewTimer
newTimerDNSResRate = time.NewTimer
)

func init() {
resolver.Register(NewBuilder())
Expand Down Expand Up @@ -219,7 +223,7 @@ func (d *dnsResolver) watcher() {
// Success resolving, wait for the next ResolveNow. However, also wait 30 seconds at the very least
// to prevent constantly re-resolving.
backoffIndex = 1
timer = time.NewTimer(minDNSResRate)
timer = newTimerDNSResRate(minDNSResRate)
select {
case <-d.ctx.Done():
timer.Stop()
Expand Down
141 changes: 85 additions & 56 deletions internal/resolver/dns/dns_resolver_test.go
Expand Up @@ -43,14 +43,15 @@ func TestMain(m *testing.M) {
// Set a non-zero duration only for tests which are actually testing that
// feature.
replaceDNSResRate(time.Duration(0)) // No nead to clean up since we os.Exit
replaceNetFunc(nil) // No nead to clean up since we os.Exit
overrideDefaultResolver(false) // No nead to clean up since we os.Exit
code := m.Run()
os.Exit(code)
}

const (
txtBytesLimit = 255
defaultTestTimeout = 10 * time.Second
txtBytesLimit = 255
defaultTestTimeout = 10 * time.Second
defaultTestShortTimeout = 10 * time.Millisecond
)

type testClientConn struct {
Expand Down Expand Up @@ -106,12 +107,12 @@ type testResolver struct {
// A write to this channel is made when this resolver receives a resolution
// request. Tests can rely on reading from this channel to be notified about
// resolution requests instead of sleeping for a predefined period of time.
ch chan struct{}
lookupHostCh *testutils.Channel
}

func (tr *testResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
if tr.ch != nil {
tr.ch <- struct{}{}
if tr.lookupHostCh != nil {
tr.lookupHostCh.Send(nil)
}
return hostLookup(host)
}
Expand All @@ -124,9 +125,17 @@ func (*testResolver) LookupTXT(ctx context.Context, host string) ([]string, erro
return txtLookup(host)
}

func replaceNetFunc(ch chan struct{}) func() {
// overrideDefaultResolver overrides the defaultResolver used by the code with
// an instance of the testResolver. pushOnLookup controls whether the
// testResolver created here pushes lookupHost events on its channel.
func overrideDefaultResolver(pushOnLookup bool) func() {
oldResolver := defaultResolver
defaultResolver = &testResolver{ch: ch}

var lookupHostCh *testutils.Channel
if pushOnLookup {
lookupHostCh = testutils.NewChannel()
}
defaultResolver = &testResolver{lookupHostCh: lookupHostCh}

return func() {
defaultResolver = oldResolver
Expand Down Expand Up @@ -1451,23 +1460,33 @@ func TestCustomAuthority(t *testing.T) {
// requests. It sets the re-resolution rate to a small value and repeatedly
// calls ResolveNow() and ensures only the expected number of resolution
// requests are made.

func TestRateLimitedResolve(t *testing.T) {
defer leakcheck.Check(t)
defer func(nt func(d time.Duration) *time.Timer) {
newTimer = nt
}(newTimer)
newTimer = func(d time.Duration) *time.Timer {
// Will never fire on its own, will protect from triggering exponential backoff.
// Will never fire on its own, will protect from triggering exponential
// backoff.
return time.NewTimer(time.Hour)
}
defer func(nt func(d time.Duration) *time.Timer) {
newTimerDNSResRate = nt
}(newTimerDNSResRate)

const dnsResRate = 10 * time.Millisecond
dc := replaceDNSResRate(dnsResRate)
defer dc()
timerChan := testutils.NewChannel()
newTimerDNSResRate = func(d time.Duration) *time.Timer {
// Will never fire on its own, allows this test to call timer
// immediately.
t := time.NewTimer(time.Hour)
timerChan.Send(t)
return t
}

// Create a new testResolver{} for this test because we want the exact count
// of the number of times the resolver was invoked.
nc := replaceNetFunc(make(chan struct{}))
nc := overrideDefaultResolver(true)
defer nc()

target := "foo.bar.com"
Expand All @@ -1490,55 +1509,65 @@ func TestRateLimitedResolve(t *testing.T) {
t.Fatalf("delegate resolver returned unexpected type: %T\n", tr)
}

// Observe the time before unblocking the lookupHost call. The 100ms rate
// limiting timer will begin immediately after that. This means the next
// resolution could happen less than 100ms if we read the time *after*
// receiving from tr.ch
start := time.Now()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

// Wait for the first resolution request to be done. This happens as part
// of the first iteration of the for loop in watcher() because we call
// ResolveNow in Build.
<-tr.ch

// Here we start a couple of goroutines. One repeatedly calls ResolveNow()
// until asked to stop, and the other waits for two resolution requests to be
// made to our testResolver and stops the former. We measure the start and
// end times, and expect the duration elapsed to be in the interval
// {wantCalls*dnsResRate, wantCalls*dnsResRate}
done := make(chan struct{})
go func() {
for {
select {
case <-done:
return
default:
r.ResolveNow(resolver.ResolveNowOptions{})
time.Sleep(1 * time.Millisecond)
}
}
}()
// of the first iteration of the for loop in watcher().
if _, err := tr.lookupHostCh.Receive(ctx); err != nil {
t.Fatalf("Timed out waiting for lookup() call.")
}

gotCalls := 0
const wantCalls = 3
min, max := wantCalls*dnsResRate, (wantCalls+1)*dnsResRate
tMax := time.NewTimer(max)
for gotCalls != wantCalls {
select {
case <-tr.ch:
gotCalls++
case <-tMax.C:
t.Fatalf("Timed out waiting for %v calls after %v; got %v", wantCalls, max, gotCalls)
}
// Call Resolve Now 100 times, shouldn't continue onto next iteration of
// watcher, thus shouldn't lookup again.
for i := 0; i <= 100; i++ {
r.ResolveNow(resolver.ResolveNowOptions{})
}

continueCtx, continueCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
defer continueCancel()

if _, err := tr.lookupHostCh.Receive(continueCtx); err == nil {
t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.")
}

// Make the DNSMinResRate timer fire immediately (by receiving it, then
// resetting to 0), this will unblock the resolver which is currently
// blocked on the DNS Min Res Rate timer going off, which will allow it to
// continue to the next iteration of the watcher loop.
timer, err := timerChan.Receive(ctx)
if err != nil {
t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
}
close(done)
elapsed := time.Since(start)
timerPointer := timer.(*time.Timer)
timerPointer.Reset(0)

if gotCalls != wantCalls {
t.Fatalf("resolve count mismatch for target: %q = %+v, want %+v\n", target, gotCalls, wantCalls)
// Now that DNS Min Res Rate timer has gone off, it should lookup again.
if _, err := tr.lookupHostCh.Receive(ctx); err != nil {
t.Fatalf("Timed out waiting for lookup() call.")
}
if elapsed < min {
t.Fatalf("elapsed time: %v, wanted it to be between {%v and %v}", elapsed, min, max)

// Resolve Now 1000 more times, shouldn't lookup again as DNS Min Res Rate
// timer has not gone off.
for i := 0; i < 1000; i++ {
r.ResolveNow(resolver.ResolveNowOptions{})
}

if _, err = tr.lookupHostCh.Receive(continueCtx); err == nil {
t.Fatalf("Should not have looked up again as DNS Min Res Rate timer has not gone off.")
}

// Make the DNSMinResRate timer fire immediately again.
timer, err = timerChan.Receive(ctx)
if err != nil {
t.Fatalf("Error receiving timer from mock NewTimer call: %v", err)
}
timerPointer = timer.(*time.Timer)
timerPointer.Reset(0)

// Now that DNS Min Res Rate timer has gone off, it should lookup again.
if _, err = tr.lookupHostCh.Receive(ctx); err != nil {
t.Fatalf("Timed out waiting for lookup() call.")
}

wantAddrs := []resolver.Address{{Addr: "1.2.3.4" + colonDefaultPort}, {Addr: "5.6.7.8" + colonDefaultPort}}
Expand Down

0 comments on commit c7ea734

Please sign in to comment.