From 2a3c7367f641a68f6c8237592e1df6401f8e73f3 Mon Sep 17 00:00:00 2001 From: "Gia. Bui Dai" Date: Wed, 19 Jan 2022 15:01:13 +0700 Subject: [PATCH] Improve performance in hot parts --- client_test.go | 20 +++++++++++ resolver.go | 93 ++++++++++++++++++++------------------------------ 2 files changed, 57 insertions(+), 56 deletions(-) diff --git a/client_test.go b/client_test.go index 7a5ff95..6c5ccac 100644 --- a/client_test.go +++ b/client_test.go @@ -23,6 +23,26 @@ func TestConsistentResolve(t *testing.T) { } } +func TestUDP(t *testing.T) { + client := New([]string{"1.1.1.1:53", "udp:8.8.8.8"}, 5) + + d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA}) + require.Nil(t, err) + + // From current dig result + require.True(t, len(d.A) > 0) +} + +func TestTCP(t *testing.T) { + client := New([]string{"tcp:1.1.1.1:53", "tcp:8.8.8.8"}, 5) + + d, err := client.QueryMultiple("example.com", []uint16{dns.TypeA}) + require.Nil(t, err) + + // From current dig result + require.True(t, len(d.A) > 0) +} + func TestDOH(t *testing.T) { client := New([]string{"doh:https://doh.opendns.com/dns-query:post", "doh:https://doh.opendns.com/dns-query:get"}, 5) diff --git a/resolver.go b/resolver.go index 5e30d84..a4a5228 100644 --- a/resolver.go +++ b/resolver.go @@ -7,53 +7,33 @@ import ( "github.com/projectdiscovery/stringsutil" ) -type Protocol int +type Protocol string const ( - UDP Protocol = iota - TCP - DOH - DOT + UDP Protocol = "udp" + TCP Protocol = "tcp" + DOH Protocol = "doh" + DOT Protocol = "dot" ) func (p Protocol) String() string { - switch p { - case DOH: - return "doh" - case UDP: - return "udp" - case TCP: - return "tcp" - case DOT: - return "dot" - } - - return "" + return string(p) } func (p Protocol) StringWithSemicolon() string { return p.String() + ":" } -type DohProtocol int +type DohProtocol string const ( - JsonAPI DohProtocol = iota - GET - POST + JsonAPI DohProtocol = "jsonapi" + GET DohProtocol = "get" + POST DohProtocol = "post" ) func (p DohProtocol) String() string { - switch p { - case JsonAPI: - return "jsonapi" - case GET: - return "get" - case POST: - return "post" - } - - return "" + return string(p) } func (p DohProtocol) StringWithSemicolon() string { @@ -92,29 +72,34 @@ func (r DohResolver) String() string { } func parseResolver(r string) (resolver Resolver) { - isTcp, isUDP, isDoh, isDot := hasProtocol(r, TCP.StringWithSemicolon()), hasProtocol(r, UDP.StringWithSemicolon()), hasProtocol(r, DOH.StringWithSemicolon()), hasProtocol(r, DOT.StringWithSemicolon()) rNetworkTokens := trimProtocol(r) - if isTcp || isUDP || isDot { - networkResolver := &NetworkResolver{Protocol: UDP} - if isTcp { - networkResolver.Protocol = TCP - } else if isDot { - networkResolver.Protocol = DOT - } - parseHostPort(networkResolver, rNetworkTokens) - resolver = networkResolver - } else if isDoh { - isJsonApi, isGet := hasDohProtocol(r, JsonAPI.StringWithSemicolon()), hasDohProtocol(r, GET.StringWithSemicolon()) - URL := trimDohProtocol(rNetworkTokens) - dohResolver := &DohResolver{URL: URL, Protocol: POST} - if isJsonApi { - dohResolver.Protocol = JsonAPI - } else if isGet { - dohResolver.Protocol = GET + protocol := UDP + + if len(r) >= 4 && r[3] == 58 { // 58 is ":" + switch r[0:3] { + case "udp": + case "tcp": + protocol = TCP + case "dot": + protocol = DOT + case "doh": + protocol = DOH + isJsonApi, isGet := hasDohProtocol(r, JsonAPI.StringWithSemicolon()), hasDohProtocol(r, GET.StringWithSemicolon()) + URL := trimDohProtocol(rNetworkTokens) + dohResolver := &DohResolver{URL: URL, Protocol: POST} + if isJsonApi { + dohResolver.Protocol = JsonAPI + } else if isGet { + dohResolver.Protocol = GET + } + resolver = dohResolver + default: + // unsupported protocol? } - resolver = dohResolver - } else { - networkResolver := &NetworkResolver{Protocol: UDP} + } + + if protocol != DOH { + networkResolver := &NetworkResolver{Protocol: protocol} parseHostPort(networkResolver, rNetworkTokens) resolver = networkResolver } @@ -136,10 +121,6 @@ func parseHostPort(networkResolver *NetworkResolver, r string) { } } -func hasProtocol(resolver, protocol string) bool { - return strings.HasPrefix(resolver, protocol) -} - func hasDohProtocol(resolver, protocol string) bool { return strings.HasSuffix(resolver, protocol) }