diff --git a/client.go b/client.go index 5eb8539..9c992e8 100644 --- a/client.go +++ b/client.go @@ -13,6 +13,7 @@ import ( "time" "github.com/miekg/dns" + "github.com/projectdiscovery/retryabledns/doh" ) func init() { @@ -21,19 +22,27 @@ func init() { // Client is a DNS resolver client to resolve hostnames. type Client struct { - resolvers []string - maxRetries int + resolvers []Resolver + options Options serversIndex uint32 TCPFallback bool - Timeout time.Duration + tcpClient *dns.Client + dohClient *doh.Client } // New creates a new dns client func New(baseResolvers []string, maxRetries int) *Client { - baseResolvers = deduplicate(baseResolvers) + return NewWithOptions(Options{BaseResolvers: baseResolvers, MaxRetries: maxRetries}) +} + +// New creates a new dns client with options +func NewWithOptions(options Options) *Client { + parsedBaseResolvers := parseResolvers(deduplicate(options.BaseResolvers)) client := Client{ - maxRetries: maxRetries, - resolvers: baseResolvers, + options: options, + resolvers: parsedBaseResolvers, + tcpClient: &dns.Client{Net: TCP.String(), Timeout: options.Timeout}, + dohClient: doh.New(), } return &client } @@ -67,11 +76,22 @@ func (c *Client) Resolve(host string) (*DNSData, error) { func (c *Client) Do(msg *dns.Msg) (*dns.Msg, error) { var resp *dns.Msg var err error - for i := 0; i < c.maxRetries; i++ { + for i := 0; i < c.options.MaxRetries; i++ { index := atomic.AddUint32(&c.serversIndex, 1) resolver := c.resolvers[index%uint32(len(c.resolvers))] - resp, err = dns.Exchange(msg, resolver) + switch r := resolver.(type) { + case *NetworkResolver: + switch r.Protocol { + case TCP: + resp, _, err = c.tcpClient.Exchange(msg, resolver.String()) + case UDP: + resp, err = dns.Exchange(msg, resolver.String()) + } + case *DohResolver: + resp, err = c.dohClient.QueryWithDOHMsg(doh.Method(r.Method()), doh.Resolver{URL: r.URL}, msg) + } + if err != nil || resp == nil { continue } @@ -138,7 +158,7 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er err error ) - msg := dns.Msg{} + msg := &dns.Msg{} msg.Id = dns.Id() msg.RecursionDesired = true msg.Question = make([]dns.Question, 1) @@ -168,19 +188,32 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er msg.SetEdns0(4096, false) var resp *dns.Msg - for i := 0; i < c.maxRetries; i++ { + for i := 0; i < c.options.MaxRetries; i++ { index := atomic.AddUint32(&c.serversIndex, 1) resolver := c.resolvers[index%uint32(len(c.resolvers))] - resp, err = dns.Exchange(&msg, resolver) + switch r := resolver.(type) { + case *NetworkResolver: + switch r.Protocol { + case TCP: + resp, _, err = c.tcpClient.Exchange(msg, resolver.String()) + case UDP: + resp, err = dns.Exchange(msg, resolver.String()) + } + case *DohResolver: + method := doh.MethodPost + if r.Protocol == GET { + method = doh.MethodGet + } + resp, err = c.dohClient.QueryWithDOHMsg(method, doh.Resolver{URL: r.URL}, msg) + } if err != nil || resp == nil { continue } // https://github.com/projectdiscovery/retryabledns/issues/25 if resp.Truncated && c.TCPFallback { - tcpClient := dns.Client{Net: "tcp", Timeout: c.Timeout} - resp, _, _ = tcpClient.Exchange(&msg, resolver) + resp, _, _ = c.tcpClient.Exchange(msg, resolver.String()) } err = dnsdata.ParseFromMsg(resp) @@ -191,7 +224,7 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er dnsdata.StatusCodeRaw = resp.Rcode dnsdata.Timestamp = time.Now() dnsdata.Raw += resp.String() - dnsdata.Resolver = append(dnsdata.Resolver, resolver) + dnsdata.Resolver = append(dnsdata.Resolver, resolver.String()) if err != nil || !dnsdata.contains() { continue diff --git a/client_test.go b/client_test.go index 152182f..97e6930 100644 --- a/client_test.go +++ b/client_test.go @@ -23,10 +23,21 @@ func TestConsistentResolve(t *testing.T) { } } +func TestDOH(t *testing.T) { + client := New([]string{"doh:https://doh.opendns.com/dns-query:post", "doh:https://doh.opendns.com/dns-query:get"}, 5) + + d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA}) + require.Nil(t, err) + + // From current dig result + require.True(t, len(d.A) > 0) +} + func TestQueryMultiple(t *testing.T) { client := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5) - d, _ := client.QueryMultiple("example.com", []uint16{dns.TypeA, dns.TypeAAAA}) + d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA, dns.TypeAAAA}) + require.Nil(t, err) // From current dig result require.True(t, len(d.A) > 0) diff --git a/go.mod b/go.mod index 770eebe..7e02482 100644 --- a/go.mod +++ b/go.mod @@ -6,5 +6,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/miekg/dns v1.1.29 github.com/projectdiscovery/retryablehttp-go v1.0.2 + github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d github.com/stretchr/testify v1.7.0 ) diff --git a/go.sum b/go.sum index a857c2e..5d70fd3 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/projectdiscovery/retryablehttp-go v1.0.2 h1:LV1/KAQU+yeWhNVlvveaYFsjBYRwXlNEq0PvrezMV0U= github.com/projectdiscovery/retryablehttp-go v1.0.2/go.mod h1:dx//aY9V247qHdsRf0vdWHTBZuBQ2vm6Dq5dagxrDYI= +github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d h1:lrdpJCBOvRrTnm44Ov7O3tLd3oOWhCvVUhTKkWwibq4= +github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d/go.mod h1:oTRc18WBv9t6BpaN9XBY+QmG28PUpsyDzRht56Qf49I= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/options.go b/options.go new file mode 100644 index 0000000..2eabaaa --- /dev/null +++ b/options.go @@ -0,0 +1,9 @@ +package retryabledns + +import "time" + +type Options struct { + BaseResolvers []string + MaxRetries int + Timeout time.Duration +} diff --git a/resolver.go b/resolver.go new file mode 100644 index 0000000..4eb52c2 --- /dev/null +++ b/resolver.go @@ -0,0 +1,152 @@ +package retryabledns + +import ( + "net" + "strings" + + "github.com/projectdiscovery/stringsutil" +) + +type Protocol int + +const ( + UDP Protocol = iota + TCP + DOH +) + +func (p Protocol) String() string { + switch p { + case DOH: + return "doh" + case UDP: + return "udp" + case TCP: + return "tcp" + } + + return "" +} + +func (p Protocol) StringWithSemicolon() string { + return p.String() + ":" +} + +type DohProtocol int + +const ( + JsonAPI DohProtocol = iota + GET + POST +) + +func (p DohProtocol) String() string { + switch p { + case JsonAPI: + return "jsonapi" + case GET: + return "get" + case POST: + return "post" + } + + return "" +} + +func (p DohProtocol) StringWithSemicolon() string { + return ":" + p.String() +} + +type Resolver interface { + String() string +} + +type NetworkResolver struct { + Protocol Protocol + Host string + Port string +} + +func (r NetworkResolver) String() string { + return net.JoinHostPort(r.Host, r.Port) +} + +type DohResolver struct { + Protocol DohProtocol + URL string +} + +func (r DohResolver) Method() string { + if r.Protocol == POST { + return POST.String() + } + + return GET.String() +} + +func (r DohResolver) String() string { + return r.URL +} + +func parseResolver(r string) (resolver Resolver) { + isTcp, isUDP, isDoh := hasProtocol(r, TCP.StringWithSemicolon()), hasProtocol(r, UDP.StringWithSemicolon()), hasProtocol(r, DOH.StringWithSemicolon()) + rNetworkTokens := trimProtocol(r) + if isTcp || isUDP { + networkResolver := &NetworkResolver{Protocol: UDP} + if isTcp { + networkResolver.Protocol = TCP + } + parseHostPort(networkResolver, rNetworkTokens) + resolver = networkResolver + } else if isDoh { + isJsonApi, isGet := hasDohProtocol(r, JsonAPI.StringWithSemicolon()), hasDohProtocol(r, GET.StringWithSemicolon()) + URL := trimDohProtocol(rNetworkTokens) + dohResolver := &DohResolver{URL: URL, Protocol: POST} + if isJsonApi { + dohResolver.Protocol = JsonAPI + } else if isGet { + dohResolver.Protocol = GET + } + resolver = dohResolver + } else { + networkResolver := &NetworkResolver{Protocol: UDP} + parseHostPort(networkResolver, rNetworkTokens) + resolver = networkResolver + } + + return +} + +func parseHostPort(networkResolver *NetworkResolver, r string) { + if host, port, err := net.SplitHostPort(r); err == nil { + networkResolver.Host = host + networkResolver.Port = port + } else { + networkResolver.Host = r + networkResolver.Port = "53" + } +} + +func hasProtocol(resolver, protocol string) bool { + return strings.HasPrefix(resolver, protocol) +} + +func hasDohProtocol(resolver, protocol string) bool { + return strings.HasSuffix(resolver, protocol) +} + +func trimProtocol(resolver string) string { + return stringsutil.TrimPrefixAny(resolver, TCP.StringWithSemicolon(), UDP.StringWithSemicolon(), DOH.StringWithSemicolon()) +} + +func trimDohProtocol(resolver string) string { + return stringsutil.TrimSuffixAny(resolver, GET.StringWithSemicolon(), POST.StringWithSemicolon(), JsonAPI.StringWithSemicolon()) +} + +func parseResolvers(resolvers []string) []Resolver { + var parsedResolvers []Resolver + for _, resolver := range resolvers { + parsedResolvers = append(parsedResolvers, parseResolver(resolver)) + } + return parsedResolvers +}