Skip to content

Commit

Permalink
Merge pull request #41 from yabeow/master
Browse files Browse the repository at this point in the history
Adding support for DNS over TLS (DOT) & Some bug fixes
  • Loading branch information
ehsandeep committed Jan 19, 2022
2 parents dc3d71d + 2a3c736 commit f7d02eb
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 63 deletions.
33 changes: 26 additions & 7 deletions client.go
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/projectdiscovery/iputil"
"github.com/projectdiscovery/retryabledns/doh"
"github.com/projectdiscovery/retryabledns/hostsfile"
"github.com/projectdiscovery/retryablehttp-go"
)

func init() {
Expand All @@ -28,8 +29,10 @@ type Client struct {
options Options
serversIndex uint32
TCPFallback bool
udpClient *dns.Client
tcpClient *dns.Client
dohClient *doh.Client
dotClient *dns.Client
knownHosts map[string][]string
}

Expand All @@ -45,11 +48,19 @@ func NewWithOptions(options Options) *Client {
if options.Hostsfile {
knownHosts, _ = hostsfile.ParseDefault()
}
httpOptions := retryablehttp.DefaultOptionsSingle
httpOptions.Timeout = options.Timeout
client := Client{
options: options,
resolvers: parsedBaseResolvers,
tcpClient: &dns.Client{Net: TCP.String(), Timeout: options.Timeout},
dohClient: doh.New(),
options: options,
resolvers: parsedBaseResolvers,
udpClient: &dns.Client{Net: "", Timeout: options.Timeout},
tcpClient: &dns.Client{Net: TCP.String(), Timeout: options.Timeout},
dohClient: doh.NewWithOptions(
doh.Options{
HttpClient: retryablehttp.NewClient(httpOptions),
},
),
dotClient: &dns.Client{Net: "tcp-tls", Timeout: options.Timeout},
knownHosts: knownHosts,
}
return &client
Expand Down Expand Up @@ -94,10 +105,16 @@ func (c *Client) Do(msg *dns.Msg) (*dns.Msg, error) {
case TCP:
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
case UDP:
resp, err = dns.Exchange(msg, resolver.String())
resp, _, err = c.udpClient.Exchange(msg, resolver.String())
case DOT:
resp, _, err = c.dotClient.Exchange(msg, resolver.String())
}
case *DohResolver:
resp, err = c.dohClient.QueryWithDOHMsg(doh.Method(r.Method()), doh.Resolver{URL: r.URL}, msg)
method := doh.MethodPost
if r.Protocol == GET {
method = doh.MethodGet
}
resp, err = c.dohClient.QueryWithDOHMsg(method, doh.Resolver{URL: r.URL}, msg)
}

if err != nil || resp == nil {
Expand Down Expand Up @@ -219,7 +236,9 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
case TCP:
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
case UDP:
resp, err = dns.Exchange(msg, resolver.String())
resp, _, err = c.udpClient.Exchange(msg, resolver.String())
case DOT:
resp, _, err = c.dotClient.Exchange(msg, resolver.String())
}
case *DohResolver:
method := doh.MethodPost
Expand Down
30 changes: 30 additions & 0 deletions client_test.go
Expand Up @@ -23,6 +23,26 @@ func TestConsistentResolve(t *testing.T) {
}
}

func TestUDP(t *testing.T) {
client := New([]string{"1.1.1.1:53", "udp:8.8.8.8"}, 5)

d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA})
require.Nil(t, err)

// From current dig result
require.True(t, len(d.A) > 0)
}

func TestTCP(t *testing.T) {
client := New([]string{"tcp:1.1.1.1:53", "tcp:8.8.8.8"}, 5)

d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA})
require.Nil(t, err)

// From current dig result
require.True(t, len(d.A) > 0)
}

func TestDOH(t *testing.T) {
client := New([]string{"doh:https://doh.opendns.com/dns-query:post", "doh:https://doh.opendns.com/dns-query:get"}, 5)

Expand All @@ -33,6 +53,16 @@ func TestDOH(t *testing.T) {
require.True(t, len(d.A) > 0)
}

func TestDOT(t *testing.T) {
client := New([]string{"dot:dns.google:853", "dot:1dot1dot1dot1.cloudflare-dns.com"}, 5)

d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA})
require.Nil(t, err)

// From current dig result
require.True(t, len(d.A) > 0)
}

func TestQueryMultiple(t *testing.T) {
client := New([]string{"8.8.8.8:53", "1.1.1.1:53"}, 5)

Expand Down
4 changes: 2 additions & 2 deletions doh/doh_client.go
Expand Up @@ -18,11 +18,11 @@ type Client struct {
}

func NewWithOptions(options Options) *Client {
return &Client{DefaultResolver: options.DefaultResolver, httpClient: options.httpClient}
return &Client{DefaultResolver: options.DefaultResolver, httpClient: options.HttpClient}
}

func New() *Client {
return NewWithOptions(Options{DefaultResolver: Cloudflare, httpClient: retryablehttp.NewClient(retryablehttp.DefaultOptionsSingle)})
return NewWithOptions(Options{DefaultResolver: Cloudflare, HttpClient: retryablehttp.NewClient(retryablehttp.DefaultOptionsSingle)})
}

func (c *Client) Query(name string, question QuestionType) (*Response, error) {
Expand Down
2 changes: 1 addition & 1 deletion doh/options.go
Expand Up @@ -9,7 +9,7 @@ import (

type Options struct {
DefaultResolver Resolver
httpClient *retryablehttp.Client
HttpClient *retryablehttp.Client
}

type Resolver struct {
Expand Down
96 changes: 43 additions & 53 deletions resolver.go
Expand Up @@ -7,50 +7,33 @@ import (
"github.com/projectdiscovery/stringsutil"
)

type Protocol int
type Protocol string

const (
UDP Protocol = iota
TCP
DOH
UDP Protocol = "udp"
TCP Protocol = "tcp"
DOH Protocol = "doh"
DOT Protocol = "dot"
)

func (p Protocol) String() string {
switch p {
case DOH:
return "doh"
case UDP:
return "udp"
case TCP:
return "tcp"
}

return ""
return string(p)
}

func (p Protocol) StringWithSemicolon() string {
return p.String() + ":"
}

type DohProtocol int
type DohProtocol string

const (
JsonAPI DohProtocol = iota
GET
POST
JsonAPI DohProtocol = "jsonapi"
GET DohProtocol = "get"
POST DohProtocol = "post"
)

func (p DohProtocol) String() string {
switch p {
case JsonAPI:
return "jsonapi"
case GET:
return "get"
case POST:
return "post"
}

return ""
return string(p)
}

func (p DohProtocol) StringWithSemicolon() string {
Expand Down Expand Up @@ -89,27 +72,34 @@ func (r DohResolver) String() string {
}

func parseResolver(r string) (resolver Resolver) {
isTcp, isUDP, isDoh := hasProtocol(r, TCP.StringWithSemicolon()), hasProtocol(r, UDP.StringWithSemicolon()), hasProtocol(r, DOH.StringWithSemicolon())
rNetworkTokens := trimProtocol(r)
if isTcp || isUDP {
networkResolver := &NetworkResolver{Protocol: UDP}
if isTcp {
networkResolver.Protocol = TCP
}
parseHostPort(networkResolver, rNetworkTokens)
resolver = networkResolver
} else if isDoh {
isJsonApi, isGet := hasDohProtocol(r, JsonAPI.StringWithSemicolon()), hasDohProtocol(r, GET.StringWithSemicolon())
URL := trimDohProtocol(rNetworkTokens)
dohResolver := &DohResolver{URL: URL, Protocol: POST}
if isJsonApi {
dohResolver.Protocol = JsonAPI
} else if isGet {
dohResolver.Protocol = GET
protocol := UDP

if len(r) >= 4 && r[3] == 58 { // 58 is ":"
switch r[0:3] {
case "udp":
case "tcp":
protocol = TCP
case "dot":
protocol = DOT
case "doh":
protocol = DOH
isJsonApi, isGet := hasDohProtocol(r, JsonAPI.StringWithSemicolon()), hasDohProtocol(r, GET.StringWithSemicolon())
URL := trimDohProtocol(rNetworkTokens)
dohResolver := &DohResolver{URL: URL, Protocol: POST}
if isJsonApi {
dohResolver.Protocol = JsonAPI
} else if isGet {
dohResolver.Protocol = GET
}
resolver = dohResolver
default:
// unsupported protocol?
}
resolver = dohResolver
} else {
networkResolver := &NetworkResolver{Protocol: UDP}
}

if protocol != DOH {
networkResolver := &NetworkResolver{Protocol: protocol}
parseHostPort(networkResolver, rNetworkTokens)
resolver = networkResolver
}
Expand All @@ -123,20 +113,20 @@ func parseHostPort(networkResolver *NetworkResolver, r string) {
networkResolver.Port = port
} else {
networkResolver.Host = r
networkResolver.Port = "53"
if networkResolver.Protocol == DOT {
networkResolver.Port = "853"
} else {
networkResolver.Port = "53"
}
}
}

func hasProtocol(resolver, protocol string) bool {
return strings.HasPrefix(resolver, protocol)
}

func hasDohProtocol(resolver, protocol string) bool {
return strings.HasSuffix(resolver, protocol)
}

func trimProtocol(resolver string) string {
return stringsutil.TrimPrefixAny(resolver, TCP.StringWithSemicolon(), UDP.StringWithSemicolon(), DOH.StringWithSemicolon())
return stringsutil.TrimPrefixAny(resolver, TCP.StringWithSemicolon(), UDP.StringWithSemicolon(), DOH.StringWithSemicolon(), DOT.StringWithSemicolon())
}

func trimDohProtocol(resolver string) string {
Expand Down

0 comments on commit f7d02eb

Please sign in to comment.