From 796a6b0e17dd7da9bac6581efae68f1147a9374d Mon Sep 17 00:00:00 2001 From: mzack Date: Sun, 10 Jul 2022 22:35:43 +0200 Subject: [PATCH] Adding retry positive validation --- client.go | 9 ++++++--- client_test.go | 14 +++++++------- options.go | 12 +++++++++++- options_test.go | 13 +++++++++++++ 4 files changed, 37 insertions(+), 11 deletions(-) create mode 100644 options_test.go diff --git a/client.go b/client.go index 7d53240..ec96309 100644 --- a/client.go +++ b/client.go @@ -47,12 +47,15 @@ type Client struct { } // New creates a new dns client -func New(baseResolvers []string, maxRetries int) *Client { +func New(baseResolvers []string, maxRetries int) (*Client, error) { return NewWithOptions(Options{BaseResolvers: baseResolvers, MaxRetries: maxRetries}) } // New creates a new dns client with options -func NewWithOptions(options Options) *Client { +func NewWithOptions(options Options) (*Client, error) { + if err := options.Validate(); err != nil { + return nil, err + } parsedBaseResolvers := parseResolvers(sliceutil.Dedupe(options.BaseResolvers)) var knownHosts map[string][]string if options.Hostsfile { @@ -73,7 +76,7 @@ func NewWithOptions(options Options) *Client { dotClient: &dns.Client{Net: "tcp-tls", Timeout: options.Timeout}, knownHosts: knownHosts, } - return &client + return &client, nil } // ResolveWithSyscall attempts to resolve the host through system calls diff --git a/client_test.go b/client_test.go index 6c5ccac..ee40b7b 100644 --- a/client_test.go +++ b/client_test.go @@ -8,7 +8,7 @@ import ( ) func TestConsistentResolve(t *testing.T) { - client := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5) + client, _ := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5) var last string for i := 0; i < 10; i++ { @@ -24,7 +24,7 @@ func TestConsistentResolve(t *testing.T) { } func TestUDP(t *testing.T) { - client := New([]string{"1.1.1.1:53", "udp:8.8.8.8"}, 5) + client, _ := New([]string{"1.1.1.1:53", "udp:8.8.8.8"}, 5) d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA}) require.Nil(t, err) @@ -34,7 +34,7 @@ func TestUDP(t *testing.T) { } func TestTCP(t *testing.T) { - client := New([]string{"tcp:1.1.1.1:53", "tcp:8.8.8.8"}, 5) + client, _ := New([]string{"tcp:1.1.1.1:53", "tcp:8.8.8.8"}, 5) d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA}) require.Nil(t, err) @@ -44,7 +44,7 @@ func TestTCP(t *testing.T) { } func TestDOH(t *testing.T) { - client := New([]string{"doh:https://doh.opendns.com/dns-query:post", "doh:https://doh.opendns.com/dns-query:get"}, 5) + client, _ := New([]string{"doh:https://doh.opendns.com/dns-query:post", "doh:https://doh.opendns.com/dns-query:get"}, 5) d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA}) require.Nil(t, err) @@ -54,7 +54,7 @@ func TestDOH(t *testing.T) { } func TestDOT(t *testing.T) { - client := New([]string{"dot:dns.google:853", "dot:1dot1dot1dot1.cloudflare-dns.com"}, 5) + client, _ := New([]string{"dot:dns.google:853", "dot:1dot1dot1dot1.cloudflare-dns.com"}, 5) d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA}) require.Nil(t, err) @@ -64,7 +64,7 @@ func TestDOT(t *testing.T) { } func TestQueryMultiple(t *testing.T) { - client := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5) + client, _ := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5) d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA, dns.TypeAAAA}) require.Nil(t, err) @@ -75,7 +75,7 @@ func TestQueryMultiple(t *testing.T) { } func TestTrace(t *testing.T) { - client := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5) + client, _ := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5) _, err := client.Trace("www.projectdiscovery.io", dns.TypeA, 100) require.Nil(t, err, "could not resolve dns") diff --git a/options.go b/options.go index e2df88b..be28988 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,9 @@ package retryabledns -import "time" +import ( + "errors" + "time" +) type Options struct { BaseResolvers []string @@ -8,3 +11,10 @@ type Options struct { Timeout time.Duration Hostsfile bool } + +func (options *Options) Validate() error { + if options.MaxRetries == 0 { + return errors.New("retries must be at least 1") + } + return nil +} diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..5241869 --- /dev/null +++ b/options_test.go @@ -0,0 +1,13 @@ +package retryabledns + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidateOptions(t *testing.T) { + options := Options{} + err := options.Validate() + require.NotNil(t, err) +}