Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding protocol support for resolvers #35

Merged
merged 6 commits into from Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 47 additions & 14 deletions client.go
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/miekg/dns"
"github.com/projectdiscovery/retryabledns/doh"
)

func init() {
Expand All @@ -21,19 +22,27 @@ func init() {

// Client is a DNS resolver client to resolve hostnames.
type Client struct {
resolvers []string
maxRetries int
resolvers []Resolver
options Options
serversIndex uint32
TCPFallback bool
Timeout time.Duration
tcpClient *dns.Client
dohClient *doh.Client
}

// New creates a new dns client
func New(baseResolvers []string, maxRetries int) *Client {
baseResolvers = deduplicate(baseResolvers)
return NewWithOptions(Options{BaseResolvers: baseResolvers, MaxRetries: maxRetries})
}

// New creates a new dns client with options
func NewWithOptions(options Options) *Client {
parsedBaseResolvers := parseResolvers(deduplicate(options.BaseResolvers))
client := Client{
maxRetries: maxRetries,
resolvers: baseResolvers,
options: options,
resolvers: parsedBaseResolvers,
tcpClient: &dns.Client{Net: TCP.String(), Timeout: options.Timeout},
dohClient: doh.New(),
}
return &client
}
Expand Down Expand Up @@ -67,11 +76,22 @@ func (c *Client) Resolve(host string) (*DNSData, error) {
func (c *Client) Do(msg *dns.Msg) (*dns.Msg, error) {
var resp *dns.Msg
var err error
for i := 0; i < c.maxRetries; i++ {
for i := 0; i < c.options.MaxRetries; i++ {
index := atomic.AddUint32(&c.serversIndex, 1)
resolver := c.resolvers[index%uint32(len(c.resolvers))]

resp, err = dns.Exchange(msg, resolver)
switch r := resolver.(type) {
case *NetworkResolver:
switch r.Protocol {
case TCP:
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
case UDP:
resp, err = dns.Exchange(msg, resolver.String())
}
case *DohResolver:
resp, err = c.dohClient.QueryWithDOHMsg(doh.Method(r.Method()), doh.Resolver{URL: r.URL}, msg)
}

if err != nil || resp == nil {
continue
}
Expand Down Expand Up @@ -138,7 +158,7 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
err error
)

msg := dns.Msg{}
msg := &dns.Msg{}
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.Question = make([]dns.Question, 1)
Expand Down Expand Up @@ -168,19 +188,32 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
msg.SetEdns0(4096, false)

var resp *dns.Msg
for i := 0; i < c.maxRetries; i++ {
for i := 0; i < c.options.MaxRetries; i++ {
index := atomic.AddUint32(&c.serversIndex, 1)
resolver := c.resolvers[index%uint32(len(c.resolvers))]

resp, err = dns.Exchange(&msg, resolver)
switch r := resolver.(type) {
case *NetworkResolver:
switch r.Protocol {
case TCP:
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
case UDP:
resp, err = dns.Exchange(msg, resolver.String())
}
case *DohResolver:
method := doh.MethodPost
if r.Protocol == GET {
method = doh.MethodGet
}
resp, err = c.dohClient.QueryWithDOHMsg(method, doh.Resolver{URL: r.URL}, msg)
}
if err != nil || resp == nil {
continue
}

// https://github.com/projectdiscovery/retryabledns/issues/25
if resp.Truncated && c.TCPFallback {
tcpClient := dns.Client{Net: "tcp", Timeout: c.Timeout}
resp, _, _ = tcpClient.Exchange(&msg, resolver)
resp, _, _ = c.tcpClient.Exchange(msg, resolver.String())
}

err = dnsdata.ParseFromMsg(resp)
Expand All @@ -191,7 +224,7 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
dnsdata.StatusCodeRaw = resp.Rcode
dnsdata.Timestamp = time.Now()
dnsdata.Raw += resp.String()
dnsdata.Resolver = append(dnsdata.Resolver, resolver)
dnsdata.Resolver = append(dnsdata.Resolver, resolver.String())

if err != nil || !dnsdata.contains() {
continue
Expand Down
13 changes: 12 additions & 1 deletion client_test.go
Expand Up @@ -23,10 +23,21 @@ func TestConsistentResolve(t *testing.T) {
}
}

func TestDOH(t *testing.T) {
client := New([]string{"doh:https://doh.opendns.com/dns-query:post", "doh:https://doh.opendns.com/dns-query:get"}, 5)

d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA})
require.Nil(t, err)

// From current dig result
require.True(t, len(d.A) > 0)
}

func TestQueryMultiple(t *testing.T) {
client := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5)

d, _ := client.QueryMultiple("example.com", []uint16{dns.TypeA, dns.TypeAAAA})
d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA, dns.TypeAAAA})
require.Nil(t, err)

// From current dig result
require.True(t, len(d.A) > 0)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Expand Up @@ -6,5 +6,6 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/miekg/dns v1.1.29
github.com/projectdiscovery/retryablehttp-go v1.0.2
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d
github.com/stretchr/testify v1.7.0
)
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -7,6 +7,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/projectdiscovery/retryablehttp-go v1.0.2 h1:LV1/KAQU+yeWhNVlvveaYFsjBYRwXlNEq0PvrezMV0U=
github.com/projectdiscovery/retryablehttp-go v1.0.2/go.mod h1:dx//aY9V247qHdsRf0vdWHTBZuBQ2vm6Dq5dagxrDYI=
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d h1:lrdpJCBOvRrTnm44Ov7O3tLd3oOWhCvVUhTKkWwibq4=
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d/go.mod h1:oTRc18WBv9t6BpaN9XBY+QmG28PUpsyDzRht56Qf49I=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
Expand Down
9 changes: 9 additions & 0 deletions options.go
@@ -0,0 +1,9 @@
package retryabledns

import "time"

type Options struct {
BaseResolvers []string
MaxRetries int
Timeout time.Duration
}
152 changes: 152 additions & 0 deletions resolver.go
@@ -0,0 +1,152 @@
package retryabledns

import (
"net"
"strings"

"github.com/projectdiscovery/stringsutil"
)

type Protocol int

const (
UDP Protocol = iota
TCP
DOH
)

func (p Protocol) String() string {
switch p {
case DOH:
return "doh"
case UDP:
return "udp"
case TCP:
return "tcp"
}

return ""
}

func (p Protocol) StringWithSemicolon() string {
return p.String() + ":"
}

type DohProtocol int

const (
JsonAPI DohProtocol = iota
GET
POST
)

func (p DohProtocol) String() string {
switch p {
case JsonAPI:
return "jsonapi"
case GET:
return "get"
case POST:
return "post"
}

return ""
}

func (p DohProtocol) StringWithSemicolon() string {
return ":" + p.String()
}

type Resolver interface {
String() string
}

type NetworkResolver struct {
Protocol Protocol
Host string
Port string
}

func (r NetworkResolver) String() string {
return net.JoinHostPort(r.Host, r.Port)
}

type DohResolver struct {
Protocol DohProtocol
URL string
}

func (r DohResolver) Method() string {
if r.Protocol == POST {
return POST.String()
}

return GET.String()
}

func (r DohResolver) String() string {
return r.URL
}

func parseResolver(r string) (resolver Resolver) {
isTcp, isUDP, isDoh := hasProtocol(r, TCP.StringWithSemicolon()), hasProtocol(r, UDP.StringWithSemicolon()), hasProtocol(r, DOH.StringWithSemicolon())
rNetworkTokens := trimProtocol(r)
if isTcp || isUDP {
networkResolver := &NetworkResolver{Protocol: UDP}
if isTcp {
networkResolver.Protocol = TCP
}
parseHostPort(networkResolver, rNetworkTokens)
resolver = networkResolver
} else if isDoh {
isJsonApi, isGet := hasDohProtocol(r, JsonAPI.StringWithSemicolon()), hasDohProtocol(r, GET.StringWithSemicolon())
URL := trimDohProtocol(rNetworkTokens)
dohResolver := &DohResolver{URL: URL, Protocol: POST}
if isJsonApi {
dohResolver.Protocol = JsonAPI
} else if isGet {
dohResolver.Protocol = GET
}
resolver = dohResolver
} else {
networkResolver := &NetworkResolver{Protocol: UDP}
parseHostPort(networkResolver, rNetworkTokens)
resolver = networkResolver
}

return
}

func parseHostPort(networkResolver *NetworkResolver, r string) {
if host, port, err := net.SplitHostPort(r); err == nil {
networkResolver.Host = host
networkResolver.Port = port
} else {
networkResolver.Host = r
networkResolver.Port = "53"
}
}

func hasProtocol(resolver, protocol string) bool {
return strings.HasPrefix(resolver, protocol)
}

func hasDohProtocol(resolver, protocol string) bool {
return strings.HasSuffix(resolver, protocol)
}

func trimProtocol(resolver string) string {
return stringsutil.TrimPrefixAny(resolver, TCP.StringWithSemicolon(), UDP.StringWithSemicolon(), DOH.StringWithSemicolon())
}

func trimDohProtocol(resolver string) string {
return stringsutil.TrimSuffixAny(resolver, GET.StringWithSemicolon(), POST.StringWithSemicolon(), JsonAPI.StringWithSemicolon())
}

func parseResolvers(resolvers []string) []Resolver {
var parsedResolvers []Resolver
for _, resolver := range resolvers {
parsedResolvers = append(parsedResolvers, parseResolver(resolver))
}
return parsedResolvers
}