From 84b9e429c424d0b1bc2e0a19cad64167da84e7b4 Mon Sep 17 00:00:00 2001 From: "Gia. Bui Dai" Date: Tue, 18 Jan 2022 14:02:14 +0700 Subject: [PATCH 1/3] Adding support for DNS over TLS (DOT) & some bug fixes --- client.go | 18 +++++++++++++++--- client_test.go | 10 ++++++++++ resolver.go | 17 +++++++++++++---- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index b7efc38..74e60bb 100644 --- a/client.go +++ b/client.go @@ -28,8 +28,10 @@ type Client struct { options Options serversIndex uint32 TCPFallback bool + udpClient *dns.Client tcpClient *dns.Client dohClient *doh.Client + dotClient *dns.Client knownHosts map[string][]string } @@ -48,8 +50,10 @@ func NewWithOptions(options Options) *Client { client := Client{ options: options, resolvers: parsedBaseResolvers, + udpClient: &dns.Client{Net: "", Timeout: options.Timeout}, tcpClient: &dns.Client{Net: TCP.String(), Timeout: options.Timeout}, dohClient: doh.New(), + dotClient: &dns.Client{Net: "tcp-tls", Timeout: options.Timeout}, knownHosts: knownHosts, } return &client @@ -94,10 +98,16 @@ func (c *Client) Do(msg *dns.Msg) (*dns.Msg, error) { case TCP: resp, _, err = c.tcpClient.Exchange(msg, resolver.String()) case UDP: - resp, err = dns.Exchange(msg, resolver.String()) + resp, _, err = c.udpClient.Exchange(msg, resolver.String()) + case DOT: + resp, _, err = c.dotClient.Exchange(msg, resolver.String()) } case *DohResolver: - resp, err = c.dohClient.QueryWithDOHMsg(doh.Method(r.Method()), doh.Resolver{URL: r.URL}, msg) + method := doh.MethodPost + if r.Protocol == GET { + method = doh.MethodGet + } + resp, err = c.dohClient.QueryWithDOHMsg(method, doh.Resolver{URL: r.URL}, msg) } if err != nil || resp == nil { @@ -219,7 +229,9 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er case TCP: resp, _, err = c.tcpClient.Exchange(msg, resolver.String()) case UDP: - resp, err = dns.Exchange(msg, resolver.String()) + resp, _, err = c.udpClient.Exchange(msg, resolver.String()) + case DOT: + resp, _, err = c.dotClient.Exchange(msg, resolver.String()) } case *DohResolver: method := doh.MethodPost diff --git a/client_test.go b/client_test.go index 97e6930..7a5ff95 100644 --- a/client_test.go +++ b/client_test.go @@ -33,6 +33,16 @@ func TestDOH(t *testing.T) { require.True(t, len(d.A) > 0) } +func TestDOT(t *testing.T) { + client := New([]string{"dot:dns.google:853", "dot:1dot1dot1dot1.cloudflare-dns.com"}, 5) + + d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA}) + require.Nil(t, err) + + // From current dig result + require.True(t, len(d.A) > 0) +} + func TestQueryMultiple(t *testing.T) { client := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5) diff --git a/resolver.go b/resolver.go index 4eb52c2..5e30d84 100644 --- a/resolver.go +++ b/resolver.go @@ -13,6 +13,7 @@ const ( UDP Protocol = iota TCP DOH + DOT ) func (p Protocol) String() string { @@ -23,6 +24,8 @@ func (p Protocol) String() string { return "udp" case TCP: return "tcp" + case DOT: + return "dot" } return "" @@ -89,12 +92,14 @@ func (r DohResolver) String() string { } func parseResolver(r string) (resolver Resolver) { - isTcp, isUDP, isDoh := hasProtocol(r, TCP.StringWithSemicolon()), hasProtocol(r, UDP.StringWithSemicolon()), hasProtocol(r, DOH.StringWithSemicolon()) + isTcp, isUDP, isDoh, isDot := hasProtocol(r, TCP.StringWithSemicolon()), hasProtocol(r, UDP.StringWithSemicolon()), hasProtocol(r, DOH.StringWithSemicolon()), hasProtocol(r, DOT.StringWithSemicolon()) rNetworkTokens := trimProtocol(r) - if isTcp || isUDP { + if isTcp || isUDP || isDot { networkResolver := &NetworkResolver{Protocol: UDP} if isTcp { networkResolver.Protocol = TCP + } else if isDot { + networkResolver.Protocol = DOT } parseHostPort(networkResolver, rNetworkTokens) resolver = networkResolver @@ -123,7 +128,11 @@ func parseHostPort(networkResolver *NetworkResolver, r string) { networkResolver.Port = port } else { networkResolver.Host = r - networkResolver.Port = "53" + if networkResolver.Protocol == DOT { + networkResolver.Port = "853" + } else { + networkResolver.Port = "53" + } } } @@ -136,7 +145,7 @@ func hasDohProtocol(resolver, protocol string) bool { } func trimProtocol(resolver string) string { - return stringsutil.TrimPrefixAny(resolver, TCP.StringWithSemicolon(), UDP.StringWithSemicolon(), DOH.StringWithSemicolon()) + return stringsutil.TrimPrefixAny(resolver, TCP.StringWithSemicolon(), UDP.StringWithSemicolon(), DOH.StringWithSemicolon(), DOT.StringWithSemicolon()) } func trimDohProtocol(resolver string) string { From 177540bd16bc22bafcb4874a42eae84dacbac2d6 Mon Sep 17 00:00:00 2001 From: "Gia. Bui Dai" Date: Tue, 18 Jan 2022 15:12:05 +0700 Subject: [PATCH 2/3] Respect timeout option when using DNS over HTTPS (DOH) --- client.go | 17 ++++++++++++----- doh/doh_client.go | 4 ++-- doh/options.go | 2 +- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index 74e60bb..3180238 100644 --- a/client.go +++ b/client.go @@ -16,6 +16,7 @@ import ( "github.com/projectdiscovery/iputil" "github.com/projectdiscovery/retryabledns/doh" "github.com/projectdiscovery/retryabledns/hostsfile" + "github.com/projectdiscovery/retryablehttp-go" ) func init() { @@ -47,12 +48,18 @@ func NewWithOptions(options Options) *Client { if options.Hostsfile { knownHosts, _ = hostsfile.ParseDefault() } + httpOptions := retryablehttp.DefaultOptionsSingle + httpOptions.Timeout = options.Timeout client := Client{ - options: options, - resolvers: parsedBaseResolvers, - udpClient: &dns.Client{Net: "", Timeout: options.Timeout}, - tcpClient: &dns.Client{Net: TCP.String(), Timeout: options.Timeout}, - dohClient: doh.New(), + options: options, + resolvers: parsedBaseResolvers, + udpClient: &dns.Client{Net: "", Timeout: options.Timeout}, + tcpClient: &dns.Client{Net: TCP.String(), Timeout: options.Timeout}, + dohClient: doh.NewWithOptions( + doh.Options{ + HttpClient: retryablehttp.NewClient(httpOptions), + }, + ), dotClient: &dns.Client{Net: "tcp-tls", Timeout: options.Timeout}, knownHosts: knownHosts, } diff --git a/doh/doh_client.go b/doh/doh_client.go index 081a93a..2b8db75 100644 --- a/doh/doh_client.go +++ b/doh/doh_client.go @@ -18,11 +18,11 @@ type Client struct { } func NewWithOptions(options Options) *Client { - return &Client{DefaultResolver: options.DefaultResolver, httpClient: options.httpClient} + return &Client{DefaultResolver: options.DefaultResolver, httpClient: options.HttpClient} } func New() *Client { - return NewWithOptions(Options{DefaultResolver: Cloudflare, httpClient: retryablehttp.NewClient(retryablehttp.DefaultOptionsSingle)}) + return NewWithOptions(Options{DefaultResolver: Cloudflare, HttpClient: retryablehttp.NewClient(retryablehttp.DefaultOptionsSingle)}) } func (c *Client) Query(name string, question QuestionType) (*Response, error) { diff --git a/doh/options.go b/doh/options.go index b740d3f..0ea9ba7 100644 --- a/doh/options.go +++ b/doh/options.go @@ -9,7 +9,7 @@ import ( type Options struct { DefaultResolver Resolver - httpClient *retryablehttp.Client + HttpClient *retryablehttp.Client } type Resolver struct { From 2a3c7367f641a68f6c8237592e1df6401f8e73f3 Mon Sep 17 00:00:00 2001 From: "Gia. Bui Dai" Date: Wed, 19 Jan 2022 15:01:13 +0700 Subject: [PATCH 3/3] Improve performance in hot parts --- client_test.go | 20 +++++++++++ resolver.go | 93 ++++++++++++++++++++------------------------------ 2 files changed, 57 insertions(+), 56 deletions(-) diff --git a/client_test.go b/client_test.go index 7a5ff95..6c5ccac 100644 --- a/client_test.go +++ b/client_test.go @@ -23,6 +23,26 @@ func TestConsistentResolve(t *testing.T) { } } +func TestUDP(t *testing.T) { + client := New([]string{"1.1.1.1:53", "udp:8.8.8.8"}, 5) + + d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA}) + require.Nil(t, err) + + // From current dig result + require.True(t, len(d.A) > 0) +} + +func TestTCP(t *testing.T) { + client := New([]string{"tcp:1.1.1.1:53", "tcp:8.8.8.8"}, 5) + + d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA}) + require.Nil(t, err) + + // From current dig result + require.True(t, len(d.A) > 0) +} + func TestDOH(t *testing.T) { client := New([]string{"doh:https://doh.opendns.com/dns-query:post", "doh:https://doh.opendns.com/dns-query:get"}, 5) diff --git a/resolver.go b/resolver.go index 5e30d84..a4a5228 100644 --- a/resolver.go +++ b/resolver.go @@ -7,53 +7,33 @@ import ( "github.com/projectdiscovery/stringsutil" ) -type Protocol int +type Protocol string const ( - UDP Protocol = iota - TCP - DOH - DOT + UDP Protocol = "udp" + TCP Protocol = "tcp" + DOH Protocol = "doh" + DOT Protocol = "dot" ) func (p Protocol) String() string { - switch p { - case DOH: - return "doh" - case UDP: - return "udp" - case TCP: - return "tcp" - case DOT: - return "dot" - } - - return "" + return string(p) } func (p Protocol) StringWithSemicolon() string { return p.String() + ":" } -type DohProtocol int +type DohProtocol string const ( - JsonAPI DohProtocol = iota - GET - POST + JsonAPI DohProtocol = "jsonapi" + GET DohProtocol = "get" + POST DohProtocol = "post" ) func (p DohProtocol) String() string { - switch p { - case JsonAPI: - return "jsonapi" - case GET: - return "get" - case POST: - return "post" - } - - return "" + return string(p) } func (p DohProtocol) StringWithSemicolon() string { @@ -92,29 +72,34 @@ func (r DohResolver) String() string { } func parseResolver(r string) (resolver Resolver) { - isTcp, isUDP, isDoh, isDot := hasProtocol(r, TCP.StringWithSemicolon()), hasProtocol(r, UDP.StringWithSemicolon()), hasProtocol(r, DOH.StringWithSemicolon()), hasProtocol(r, DOT.StringWithSemicolon()) rNetworkTokens := trimProtocol(r) - if isTcp || isUDP || isDot { - networkResolver := &NetworkResolver{Protocol: UDP} - if isTcp { - networkResolver.Protocol = TCP - } else if isDot { - networkResolver.Protocol = DOT - } - parseHostPort(networkResolver, rNetworkTokens) - resolver = networkResolver - } else if isDoh { - isJsonApi, isGet := hasDohProtocol(r, JsonAPI.StringWithSemicolon()), hasDohProtocol(r, GET.StringWithSemicolon()) - URL := trimDohProtocol(rNetworkTokens) - dohResolver := &DohResolver{URL: URL, Protocol: POST} - if isJsonApi { - dohResolver.Protocol = JsonAPI - } else if isGet { - dohResolver.Protocol = GET + protocol := UDP + + if len(r) >= 4 && r[3] == 58 { // 58 is ":" + switch r[0:3] { + case "udp": + case "tcp": + protocol = TCP + case "dot": + protocol = DOT + case "doh": + protocol = DOH + isJsonApi, isGet := hasDohProtocol(r, JsonAPI.StringWithSemicolon()), hasDohProtocol(r, GET.StringWithSemicolon()) + URL := trimDohProtocol(rNetworkTokens) + dohResolver := &DohResolver{URL: URL, Protocol: POST} + if isJsonApi { + dohResolver.Protocol = JsonAPI + } else if isGet { + dohResolver.Protocol = GET + } + resolver = dohResolver + default: + // unsupported protocol? } - resolver = dohResolver - } else { - networkResolver := &NetworkResolver{Protocol: UDP} + } + + if protocol != DOH { + networkResolver := &NetworkResolver{Protocol: protocol} parseHostPort(networkResolver, rNetworkTokens) resolver = networkResolver } @@ -136,10 +121,6 @@ func parseHostPort(networkResolver *NetworkResolver, r string) { } } -func hasProtocol(resolver, protocol string) bool { - return strings.HasPrefix(resolver, protocol) -} - func hasDohProtocol(resolver, protocol string) bool { return strings.HasSuffix(resolver, protocol) }