Skip to content

Commit

Permalink
Merge pull request #84 from Shopify/seb-retry-ctx
Browse files Browse the repository at this point in the history
Cancel retry resolver sleep if context is canceled
  • Loading branch information
lavoiesl committed Jan 13, 2021
2 parents f9d1e81 + b0398e5 commit 2bfa546
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 35 deletions.
28 changes: 16 additions & 12 deletions resolver/retry_resolver.go
Expand Up @@ -25,13 +25,17 @@ func NewRetryResolver(resolver Resolver, backoffs []time.Duration) Resolver {
}
}

func (r *retryResolver) retry(fn func() error) (err error) {
func (r *retryResolver) retry(ctx context.Context, fn func() error) (err error) {
var dnsError *net.DNSError
err = fn()

for i := 0; i < len(r.backoffs) && errors.As(err, &dnsError) && dnsError.Temporary(); i++ {
time.Sleep(r.backoffs[i])
err = fn()
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(r.backoffs[i]):
err = fn()
}
}

if errors.As(err, &dnsError) && dnsError.Err == "server misbehaving" {
Expand All @@ -49,71 +53,71 @@ func (r *retryResolver) retry(fn func() error) (err error) {
}

func (r *retryResolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
err = r.retry(func() (err error) {
err = r.retry(ctx, func() (err error) {
addrs, err = r.resolver.LookupHost(ctx, host)
return err
})
return addrs, err
}

func (r *retryResolver) LookupIPAddr(ctx context.Context, host string) (records []net.IPAddr, err error) {
err = r.retry(func() (err error) {
err = r.retry(ctx, func() (err error) {
records, err = r.resolver.LookupIPAddr(ctx, host)
return err
})
return records, err
}

func (r *retryResolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
err = r.retry(func() (err error) {
err = r.retry(ctx, func() (err error) {
port, err = r.resolver.LookupPort(ctx, network, service)
return err
})
return port, err
}

func (r *retryResolver) LookupCNAME(ctx context.Context, host string) (cname string, err error) {
err = r.retry(func() (err error) {
err = r.retry(ctx, func() (err error) {
cname, err = r.resolver.LookupCNAME(ctx, host)
return err
})
return cname, err
}

func (r *retryResolver) LookupSRV(ctx context.Context, service, proto, name string) (cname string, records []*net.SRV, err error) {
err = r.retry(func() (err error) {
err = r.retry(ctx, func() (err error) {
cname, records, err = r.resolver.LookupSRV(ctx, service, proto, name)
return err
})
return cname, records, err
}

func (r *retryResolver) LookupMX(ctx context.Context, name string) (records []*net.MX, err error) {
err = r.retry(func() (err error) {
err = r.retry(ctx, func() (err error) {
records, err = r.resolver.LookupMX(ctx, name)
return err
})
return records, err
}

func (r *retryResolver) LookupNS(ctx context.Context, name string) (records []*net.NS, err error) {
err = r.retry(func() (err error) {
err = r.retry(ctx, func() (err error) {
records, err = r.resolver.LookupNS(ctx, name)
return err
})
return records, err
}

func (r *retryResolver) LookupTXT(ctx context.Context, name string) (records []string, err error) {
err = r.retry(func() (err error) {
err = r.retry(ctx, func() (err error) {
records, err = r.resolver.LookupTXT(ctx, name)
return err
})
return records, err
}

func (r *retryResolver) LookupAddr(ctx context.Context, addr string) (names []string, err error) {
err = r.retry(func() (err error) {
err = r.retry(ctx, func() (err error) {
names, err = r.resolver.LookupAddr(ctx, addr)
return err
})
Expand Down
58 changes: 35 additions & 23 deletions resolver/retry_resolver_test.go
Expand Up @@ -8,6 +8,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -54,12 +55,12 @@ func TestNewRetryLookup(t *testing.T) {
tests := map[string]struct {
callArgs []interface{}
returnArgs []interface{}
call func(t *testing.T, r Resolver, success bool) error
call func(ctx context.Context, t *testing.T, r Resolver, success bool) error
}{
"LookupHost": {
callArgs: []interface{}{ctx, "foo"},
callArgs: []interface{}{mock.Anything, "foo"},
returnArgs: []interface{}{[]string{"bar"}, nil},
call: func(t *testing.T, r Resolver, success bool) error {
call: func(ctx context.Context, t *testing.T, r Resolver, success bool) error {
addrs, err := r.LookupHost(ctx, "foo")
if success {
require.Equal(t, []string{"bar"}, addrs)
Expand All @@ -68,9 +69,9 @@ func TestNewRetryLookup(t *testing.T) {
},
},
"LookupIPAddr": {
callArgs: []interface{}{ctx, "foo"},
callArgs: []interface{}{mock.Anything, "foo"},
returnArgs: []interface{}{[]net.IPAddr{{Zone: "bar"}}, nil},
call: func(t *testing.T, r Resolver, success bool) error {
call: func(ctx context.Context, t *testing.T, r Resolver, success bool) error {
records, err := r.LookupIPAddr(ctx, "foo")
if success {
require.Equal(t, []net.IPAddr{{Zone: "bar"}}, records)
Expand All @@ -79,9 +80,9 @@ func TestNewRetryLookup(t *testing.T) {
},
},
"LookupPort": {
callArgs: []interface{}{ctx, "a", "b"},
callArgs: []interface{}{mock.Anything, "a", "b"},
returnArgs: []interface{}{123, nil},
call: func(t *testing.T, r Resolver, success bool) error {
call: func(ctx context.Context, t *testing.T, r Resolver, success bool) error {
port, err := r.LookupPort(ctx, "a", "b")
if success {
require.Equal(t, 123, port)
Expand All @@ -90,9 +91,9 @@ func TestNewRetryLookup(t *testing.T) {
},
},
"LookupCNAME": {
callArgs: []interface{}{ctx, "foo"},
callArgs: []interface{}{mock.Anything, "foo"},
returnArgs: []interface{}{"bar", nil},
call: func(t *testing.T, r Resolver, success bool) error {
call: func(ctx context.Context, t *testing.T, r Resolver, success bool) error {
cname, err := r.LookupCNAME(ctx, "foo")
if success {
require.Equal(t, "bar", cname)
Expand All @@ -101,9 +102,9 @@ func TestNewRetryLookup(t *testing.T) {
},
},
"LookupSRV": {
callArgs: []interface{}{ctx, "a", "b", "c"},
callArgs: []interface{}{mock.Anything, "a", "b", "c"},
returnArgs: []interface{}{"bar", []*net.SRV{{Target: "bar"}}, nil},
call: func(t *testing.T, r Resolver, success bool) error {
call: func(ctx context.Context, t *testing.T, r Resolver, success bool) error {
cname, records, err := r.LookupSRV(ctx, "a", "b", "c")
if success {
require.Equal(t, "bar", cname)
Expand All @@ -113,9 +114,9 @@ func TestNewRetryLookup(t *testing.T) {
},
},
"LookupMX": {
callArgs: []interface{}{ctx, "foo"},
callArgs: []interface{}{mock.Anything, "foo"},
returnArgs: []interface{}{[]*net.MX{{Host: "bar"}}, nil},
call: func(t *testing.T, r Resolver, success bool) error {
call: func(ctx context.Context, t *testing.T, r Resolver, success bool) error {
records, err := r.LookupMX(ctx, "foo")
if success {
require.Equal(t, []*net.MX{{Host: "bar"}}, records)
Expand All @@ -124,9 +125,9 @@ func TestNewRetryLookup(t *testing.T) {
},
},
"LookupNS": {
callArgs: []interface{}{ctx, "foo"},
callArgs: []interface{}{mock.Anything, "foo"},
returnArgs: []interface{}{[]*net.NS{{Host: "bar"}}, nil},
call: func(t *testing.T, r Resolver, success bool) error {
call: func(ctx context.Context, t *testing.T, r Resolver, success bool) error {
records, err := r.LookupNS(ctx, "foo")
if success {
require.Equal(t, []*net.NS{{Host: "bar"}}, records)
Expand All @@ -135,9 +136,9 @@ func TestNewRetryLookup(t *testing.T) {
},
},
"LookupTXT": {
callArgs: []interface{}{ctx, "foo"},
callArgs: []interface{}{mock.Anything, "foo"},
returnArgs: []interface{}{[]string{"bar"}, nil},
call: func(t *testing.T, r Resolver, success bool) error {
call: func(ctx context.Context, t *testing.T, r Resolver, success bool) error {
records, err := r.LookupTXT(ctx, "foo")
if success {
require.Equal(t, []string{"bar"}, records)
Expand All @@ -146,9 +147,9 @@ func TestNewRetryLookup(t *testing.T) {
},
},
"LookupAddr": {
callArgs: []interface{}{ctx, "foo"},
callArgs: []interface{}{mock.Anything, "foo"},
returnArgs: []interface{}{[]string{"bar"}, nil},
call: func(t *testing.T, r Resolver, success bool) error {
call: func(ctx context.Context, t *testing.T, r Resolver, success bool) error {
names, err := r.LookupAddr(ctx, "foo")
if success {
require.Equal(t, []string{"bar"}, names)
Expand All @@ -164,15 +165,15 @@ func TestNewRetryLookup(t *testing.T) {
withRetry(t, func(m *mockResolver, r Resolver) {
m.On(method, tt.callArgs...).Return(makeErrorArgs(len(tt.returnArgs), temporaryError)...).Once()
m.On(method, tt.callArgs...).Return(makeErrorArgs(len(tt.returnArgs), permanentError)...).Once()
err := tt.call(t, r, false)
err := tt.call(ctx, t, r, false)
require.EqualError(t, err, "lookup foo: baz")
})
})

t.Run("retry servfail", func(t *testing.T) {
withRetry(t, func(m *mockResolver, r Resolver) {
m.On(method, tt.callArgs...).Return(makeErrorArgs(len(tt.returnArgs), &net.DNSError{Name: "foo", Err: "server misbehaving", IsTemporary: true})...).Times(3)
err := tt.call(t, r, false)
err := tt.call(ctx, t, r, false)

var dnsError *net.DNSError
require.True(t, errors.As(err, &dnsError))
Expand All @@ -183,16 +184,27 @@ func TestNewRetryLookup(t *testing.T) {
t.Run("retry exhaustion", func(t *testing.T) {
withRetry(t, func(m *mockResolver, r Resolver) {
m.On(method, tt.callArgs...).Return(makeErrorArgs(len(tt.returnArgs), temporaryError)...).Times(3)
err := tt.call(t, r, false)
err := tt.call(ctx, t, r, false)
require.EqualError(t, err, "lookup foo: bar")
})
})

t.Run("canceled", func(t *testing.T) {
withRetry(t, func(m *mockResolver, r Resolver) {
ctx, cancel := context.WithCancel(ctx)
cancel()

m.On(method, tt.callArgs...).Return(makeErrorArgs(len(tt.returnArgs), context.Canceled)...).Once()
err := tt.call(ctx, t, r, false)
require.EqualError(t, err, "context canceled")
})
})

t.Run("success", func(t *testing.T) {
withRetry(t, func(m *mockResolver, r Resolver) {
m.On(method, tt.callArgs...).Return(makeErrorArgs(len(tt.returnArgs), temporaryError)...).Once()
m.On(method, tt.callArgs...).Return(tt.returnArgs...).Once()
err := tt.call(t, r, true)
err := tt.call(ctx, t, r, true)
require.NoError(t, err)
})
})
Expand Down

0 comments on commit 2bfa546

Please sign in to comment.