Skip to content

Commit

Permalink
Merge pull request #35 from projectdiscovery/feature-adding-dns-protocol
Browse files Browse the repository at this point in the history
Adding protocol support for resolvers
  • Loading branch information
Mzack9999 committed Sep 16, 2021
2 parents 74ea693 + 95b6cc0 commit 76c5b76
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 15 deletions.
61 changes: 47 additions & 14 deletions client.go
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/miekg/dns"
"github.com/projectdiscovery/retryabledns/doh"
)

func init() {
Expand All @@ -21,19 +22,27 @@ func init() {

// Client is a DNS resolver client to resolve hostnames.
type Client struct {
resolvers []string
maxRetries int
resolvers []Resolver
options Options
serversIndex uint32
TCPFallback bool
Timeout time.Duration
tcpClient *dns.Client
dohClient *doh.Client
}

// New creates a new dns client
func New(baseResolvers []string, maxRetries int) *Client {
baseResolvers = deduplicate(baseResolvers)
return NewWithOptions(Options{BaseResolvers: baseResolvers, MaxRetries: maxRetries})
}

// New creates a new dns client with options
func NewWithOptions(options Options) *Client {
parsedBaseResolvers := parseResolvers(deduplicate(options.BaseResolvers))
client := Client{
maxRetries: maxRetries,
resolvers: baseResolvers,
options: options,
resolvers: parsedBaseResolvers,
tcpClient: &dns.Client{Net: TCP.String(), Timeout: options.Timeout},
dohClient: doh.New(),
}
return &client
}
Expand Down Expand Up @@ -67,11 +76,22 @@ func (c *Client) Resolve(host string) (*DNSData, error) {
func (c *Client) Do(msg *dns.Msg) (*dns.Msg, error) {
var resp *dns.Msg
var err error
for i := 0; i < c.maxRetries; i++ {
for i := 0; i < c.options.MaxRetries; i++ {
index := atomic.AddUint32(&c.serversIndex, 1)
resolver := c.resolvers[index%uint32(len(c.resolvers))]

resp, err = dns.Exchange(msg, resolver)
switch r := resolver.(type) {
case *NetworkResolver:
switch r.Protocol {
case TCP:
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
case UDP:
resp, err = dns.Exchange(msg, resolver.String())
}
case *DohResolver:
resp, err = c.dohClient.QueryWithDOHMsg(doh.Method(r.Method()), doh.Resolver{URL: r.URL}, msg)
}

if err != nil || resp == nil {
continue
}
Expand Down Expand Up @@ -138,7 +158,7 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
err error
)

msg := dns.Msg{}
msg := &dns.Msg{}
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.Question = make([]dns.Question, 1)
Expand Down Expand Up @@ -168,19 +188,32 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
msg.SetEdns0(4096, false)

var resp *dns.Msg
for i := 0; i < c.maxRetries; i++ {
for i := 0; i < c.options.MaxRetries; i++ {
index := atomic.AddUint32(&c.serversIndex, 1)
resolver := c.resolvers[index%uint32(len(c.resolvers))]

resp, err = dns.Exchange(&msg, resolver)
switch r := resolver.(type) {
case *NetworkResolver:
switch r.Protocol {
case TCP:
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
case UDP:
resp, err = dns.Exchange(msg, resolver.String())
}
case *DohResolver:
method := doh.MethodPost
if r.Protocol == GET {
method = doh.MethodGet
}
resp, err = c.dohClient.QueryWithDOHMsg(method, doh.Resolver{URL: r.URL}, msg)
}
if err != nil || resp == nil {
continue
}

// https://github.com/projectdiscovery/retryabledns/issues/25
if resp.Truncated && c.TCPFallback {
tcpClient := dns.Client{Net: "tcp", Timeout: c.Timeout}
resp, _, _ = tcpClient.Exchange(&msg, resolver)
resp, _, _ = c.tcpClient.Exchange(msg, resolver.String())
}

err = dnsdata.ParseFromMsg(resp)
Expand All @@ -191,7 +224,7 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
dnsdata.StatusCodeRaw = resp.Rcode
dnsdata.Timestamp = time.Now()
dnsdata.Raw += resp.String()
dnsdata.Resolver = append(dnsdata.Resolver, resolver)
dnsdata.Resolver = append(dnsdata.Resolver, resolver.String())

if err != nil || !dnsdata.contains() {
continue
Expand Down
13 changes: 12 additions & 1 deletion client_test.go
Expand Up @@ -23,10 +23,21 @@ func TestConsistentResolve(t *testing.T) {
}
}

func TestDOH(t *testing.T) {
client := New([]string{"doh:https://doh.opendns.com/dns-query:post", "doh:https://doh.opendns.com/dns-query:get"}, 5)

d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA})
require.Nil(t, err)

// From current dig result
require.True(t, len(d.A) > 0)
}

func TestQueryMultiple(t *testing.T) {
client := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5)

d, _ := client.QueryMultiple("example.com", []uint16{dns.TypeA, dns.TypeAAAA})
d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA, dns.TypeAAAA})
require.Nil(t, err)

// From current dig result
require.True(t, len(d.A) > 0)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Expand Up @@ -6,5 +6,6 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/miekg/dns v1.1.29
github.com/projectdiscovery/retryablehttp-go v1.0.2
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d
github.com/stretchr/testify v1.7.0
)
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -7,6 +7,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/projectdiscovery/retryablehttp-go v1.0.2 h1:LV1/KAQU+yeWhNVlvveaYFsjBYRwXlNEq0PvrezMV0U=
github.com/projectdiscovery/retryablehttp-go v1.0.2/go.mod h1:dx//aY9V247qHdsRf0vdWHTBZuBQ2vm6Dq5dagxrDYI=
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d h1:lrdpJCBOvRrTnm44Ov7O3tLd3oOWhCvVUhTKkWwibq4=
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d/go.mod h1:oTRc18WBv9t6BpaN9XBY+QmG28PUpsyDzRht56Qf49I=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
Expand Down
9 changes: 9 additions & 0 deletions options.go
@@ -0,0 +1,9 @@
package retryabledns

import "time"

type Options struct {
BaseResolvers []string
MaxRetries int
Timeout time.Duration
}
152 changes: 152 additions & 0 deletions resolver.go
@@ -0,0 +1,152 @@
package retryabledns

import (
"net"
"strings"

"github.com/projectdiscovery/stringsutil"
)

type Protocol int

const (
UDP Protocol = iota
TCP
DOH
)

func (p Protocol) String() string {
switch p {
case DOH:
return "doh"
case UDP:
return "udp"
case TCP:
return "tcp"
}

return ""
}

func (p Protocol) StringWithSemicolon() string {
return p.String() + ":"
}

type DohProtocol int

const (
JsonAPI DohProtocol = iota
GET
POST
)

func (p DohProtocol) String() string {
switch p {
case JsonAPI:
return "jsonapi"
case GET:
return "get"
case POST:
return "post"
}

return ""
}

func (p DohProtocol) StringWithSemicolon() string {
return ":" + p.String()
}

type Resolver interface {
String() string
}

type NetworkResolver struct {
Protocol Protocol
Host string
Port string
}

func (r NetworkResolver) String() string {
return net.JoinHostPort(r.Host, r.Port)
}

type DohResolver struct {
Protocol DohProtocol
URL string
}

func (r DohResolver) Method() string {
if r.Protocol == POST {
return POST.String()
}

return GET.String()
}

func (r DohResolver) String() string {
return r.URL
}

func parseResolver(r string) (resolver Resolver) {
isTcp, isUDP, isDoh := hasProtocol(r, TCP.StringWithSemicolon()), hasProtocol(r, UDP.StringWithSemicolon()), hasProtocol(r, DOH.StringWithSemicolon())
rNetworkTokens := trimProtocol(r)
if isTcp || isUDP {
networkResolver := &NetworkResolver{Protocol: UDP}
if isTcp {
networkResolver.Protocol = TCP
}
parseHostPort(networkResolver, rNetworkTokens)
resolver = networkResolver
} else if isDoh {
isJsonApi, isGet := hasDohProtocol(r, JsonAPI.StringWithSemicolon()), hasDohProtocol(r, GET.StringWithSemicolon())
URL := trimDohProtocol(rNetworkTokens)
dohResolver := &DohResolver{URL: URL, Protocol: POST}
if isJsonApi {
dohResolver.Protocol = JsonAPI
} else if isGet {
dohResolver.Protocol = GET
}
resolver = dohResolver
} else {
networkResolver := &NetworkResolver{Protocol: UDP}
parseHostPort(networkResolver, rNetworkTokens)
resolver = networkResolver
}

return
}

func parseHostPort(networkResolver *NetworkResolver, r string) {
if host, port, err := net.SplitHostPort(r); err == nil {
networkResolver.Host = host
networkResolver.Port = port
} else {
networkResolver.Host = r
networkResolver.Port = "53"
}
}

func hasProtocol(resolver, protocol string) bool {
return strings.HasPrefix(resolver, protocol)
}

func hasDohProtocol(resolver, protocol string) bool {
return strings.HasSuffix(resolver, protocol)
}

func trimProtocol(resolver string) string {
return stringsutil.TrimPrefixAny(resolver, TCP.StringWithSemicolon(), UDP.StringWithSemicolon(), DOH.StringWithSemicolon())
}

func trimDohProtocol(resolver string) string {
return stringsutil.TrimSuffixAny(resolver, GET.StringWithSemicolon(), POST.StringWithSemicolon(), JsonAPI.StringWithSemicolon())
}

func parseResolvers(resolvers []string) []Resolver {
var parsedResolvers []Resolver
for _, resolver := range resolvers {
parsedResolvers = append(parsedResolvers, parseResolver(resolver))
}
return parsedResolvers
}

0 comments on commit 76c5b76

Please sign in to comment.