diff --git a/client.go b/client.go index d25f6d5..7d53240 100644 --- a/client.go +++ b/client.go @@ -18,6 +18,7 @@ import ( "github.com/projectdiscovery/retryabledns/doh" "github.com/projectdiscovery/retryabledns/hostsfile" "github.com/projectdiscovery/retryablehttp-go" + "github.com/projectdiscovery/sliceutil" ) var internalRangeCheckerInstance *internalRangeChecker @@ -52,7 +53,7 @@ func New(baseResolvers []string, maxRetries int) *Client { // New creates a new dns client with options func NewWithOptions(options Options) *Client { - parsedBaseResolvers := parseResolvers(deduplicate(options.BaseResolvers)) + parsedBaseResolvers := parseResolvers(sliceutil.Dedupe(options.BaseResolvers)) var knownHosts map[string][]string if options.Hostsfile { knownHosts, _ = hostsfile.ParseDefault() @@ -185,6 +186,15 @@ func (c *Client) NS(host string) (*DNSData, error) { return c.QueryMultiple(host, []uint16{dns.TypeNS}) } +func (c *Client) AXFR(host string) (*AXFRData, error) { + return c.axfr(host) +} + +// QueryMultiple sends a provided dns request and return the data with a specific resolver +func (c *Client) QueryMultipleWithResolver(host string, requestTypes []uint16, resolver Resolver) (*DNSData, error) { + return c.queryMultiple(host, requestTypes, resolver) +} + // CAA helper function func (c *Client) CAA(host string) (*DNSData, error) { return c.QueryMultiple(host, []uint16{dns.TypeCAA}) @@ -192,6 +202,11 @@ func (c *Client) CAA(host string) (*DNSData, error) { // QueryMultiple sends a provided dns request and return the data func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, error) { + return c.queryMultiple(host, requestTypes, nil) +} + +// QueryMultiple sends a provided dns request and return the data +func (c *Client) queryMultiple(host string, requestTypes []uint16, resolver Resolver) (*DNSData, error) { var ( dnsdata DNSData err error @@ -212,14 +227,16 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er msg := &dns.Msg{} msg.Id = dns.Id() - msg.RecursionDesired = true - msg.Question = make([]dns.Question, 1) + msg.SetEdns0(4096, false) for _, requestType := range requestTypes { name := dns.Fqdn(host) + msg.Question = make([]dns.Question, 1) - // In case of PTR adjust the domain name - if requestType == dns.TypePTR { + switch requestType { + case dns.TypeAXFR: + msg.SetAxfr(name) + case dns.TypePTR: // In case of PTR adjust the domain name var err error if net.ParseIP(host) != nil { name, err = dns.ReverseAddr(host) @@ -227,32 +244,56 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er return nil, err } } + fallthrough + default: + // Enable Extension Mechanisms for DNS for all messages + msg.RecursionDesired = true + question := dns.Question{ + Name: name, + Qtype: requestType, + Qclass: dns.ClassINET, + } + msg.Question[0] = question } - question := dns.Question{ - Name: name, - Qtype: requestType, - Qclass: dns.ClassINET, - } - msg.Question[0] = question - - // Enable Extension Mechanisms for DNS for all messages - msg.SetEdns0(4096, false) - - var resp *dns.Msg + var ( + resp *dns.Msg + trResp chan *dns.Envelope + ) for i := 0; i < c.options.MaxRetries; i++ { index := atomic.AddUint32(&c.serversIndex, 1) - resolver := c.resolvers[index%uint32(len(c.resolvers))] - + if resolver == nil { + resolver = c.resolvers[index%uint32(len(c.resolvers))] + } switch r := resolver.(type) { case *NetworkResolver: - switch r.Protocol { - case TCP: - resp, _, err = c.tcpClient.Exchange(msg, resolver.String()) - case UDP: - resp, _, err = c.udpClient.Exchange(msg, resolver.String()) - case DOT: - resp, _, err = c.dotClient.Exchange(msg, resolver.String()) + if requestType == dns.TypeAXFR { + var dnsconn *dns.Conn + switch r.Protocol { + case TCP: + dnsconn, err = c.tcpClient.Dial(resolver.String()) + case UDP: + dnsconn, err = c.udpClient.Dial(resolver.String()) + case DOT: + dnsconn, err = c.dotClient.Dial(resolver.String()) + default: + dnsconn, err = c.tcpClient.Dial(resolver.String()) + } + if err != nil { + break + } + defer dnsconn.Close() + dnsTransfer := &dns.Transfer{Conn: dnsconn} + trResp, err = dnsTransfer.In(msg, resolver.String()) + } else { + switch r.Protocol { + case TCP: + resp, _, err = c.tcpClient.Exchange(msg, resolver.String()) + case UDP: + resp, _, err = c.udpClient.Exchange(msg, resolver.String()) + case DOT: + resp, _, err = c.dotClient.Exchange(msg, resolver.String()) + } } case *DohResolver: method := doh.MethodPost @@ -261,26 +302,37 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er } resp, err = c.dohClient.QueryWithDOHMsg(method, doh.Resolver{URL: r.URL}, msg) } - if err != nil || resp == nil { + + if err != nil || (trResp == nil && resp == nil) { continue } // https://github.com/projectdiscovery/retryabledns/issues/25 - if resp.Truncated && c.TCPFallback { + if resp != nil && resp.Truncated && c.TCPFallback { resp, _, err = c.tcpClient.Exchange(msg, resolver.String()) if err != nil || resp == nil { continue } } - err = dnsdata.ParseFromMsg(resp) + switch requestType { + case dns.TypeAXFR: + err = dnsdata.ParseFromEnvelopeChan(trResp) + default: + err = dnsdata.ParseFromMsg(resp) + } // populate anyway basic info dnsdata.Host = host - dnsdata.StatusCode = dns.RcodeToString[resp.Rcode] - dnsdata.StatusCodeRaw = resp.Rcode + switch { + case resp != nil: + dnsdata.StatusCode = dns.RcodeToString[resp.Rcode] + dnsdata.StatusCodeRaw = resp.Rcode + dnsdata.Raw += resp.String() + case trResp != nil: + // pass + } dnsdata.Timestamp = time.Now() - dnsdata.Raw += resp.String() dnsdata.Resolver = append(dnsdata.Resolver, resolver.String()) if err != nil || !dnsdata.contains() { @@ -289,7 +341,10 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er dnsdata.dedupe() // stop on success - if resp.Rcode == dns.RcodeSuccess { + if resp != nil && resp.Rcode == dns.RcodeSuccess { + break + } + if trResp != nil { break } } @@ -336,7 +391,7 @@ func (c *Client) QueryParallel(host string, requestType uint16, resolvers []stri return dnsdatas, nil } -// QueryMultiple sends a provided dns request and return the data +// Trace the requested domain with the provided query type func (c *Client) Trace(host string, requestType uint16, maxrecursion int) (*TraceData, error) { var tracedata TraceData host = dns.CanonicalName(host) @@ -388,7 +443,7 @@ func (c *Client) Trace(host string, requestType uint16, maxrecursion int) (*Trac } } } - newNSResolvers = deduplicate(newNSResolvers) + newNSResolvers = sliceutil.Dedupe(newNSResolvers) // if we have no new resolvers => return if len(newNSResolvers) == 0 { @@ -413,6 +468,40 @@ func (c *Client) Trace(host string, requestType uint16, maxrecursion int) (*Trac return &tracedata, nil } +func (c *Client) axfr(host string) (*AXFRData, error) { + // obtain ns servers + dnsData, err := c.NS(host) + if err != nil { + return nil, err + } + // resolve ns servers to ips + var resolvers []Resolver + + for _, ns := range dnsData.NS { + nsData, err := c.A(ns) + if err != nil { + continue + } + for _, a := range nsData.A { + resolvers = append(resolvers, &NetworkResolver{Protocol: TCP, Host: a, Port: "53"}) + } + } + + resolvers = append(resolvers, c.resolvers...) + + var data []*DNSData + // perform zone transfer for each ns + for _, resolver := range resolvers { + nsData, err := c.QueryMultipleWithResolver(host, []uint16{dns.TypeAXFR}, resolver) + if err != nil { + continue + } + data = append(data, nsData) + } + + return &AXFRData{Host: host, DNSData: data}, nil +} + // DNSData is the data for a DNS request response type DNSData struct { Host string `json:"host,omitempty"` @@ -426,27 +515,25 @@ type DNSData struct { SOA []string `json:"soa,omitempty"` NS []string `json:"ns,omitempty"` TXT []string `json:"txt,omitempty"` + CAA []string `json:"caa,omitempty"` + AllRecords []string `json:"all,omitempty"` Raw string `json:"raw,omitempty"` - HasInternalIPs bool `json:"has_internal_ips"` + HasInternalIPs bool `json:"has_internal_ips,omitempty"` InternalIPs []string `json:"internal_ips,omitempty"` StatusCode string `json:"status_code,omitempty"` StatusCodeRaw int `json:"status_code_raw,omitempty"` TraceData *TraceData `json:"trace,omitempty"` + AXFRData *AXFRData `json:"axfr,omitempty"` RawResp *dns.Msg `json:"raw_resp,omitempty"` Timestamp time.Time `json:"timestamp,omitempty"` - CAA []string `json:"caa,omitempty"` } // CheckInternalIPs when set to true returns if DNS response IPs // belong to internal IP ranges. var CheckInternalIPs = false -// ParseFromMsg and enrich data -func (d *DNSData) ParseFromMsg(msg *dns.Msg) error { - allRecords := append(msg.Answer, msg.Extra...) - allRecords = append(allRecords, msg.Ns...) - - for _, record := range allRecords { +func (d *DNSData) ParseFromRR(rrs []dns.RR) error { + for _, record := range rrs { switch recordType := record.(type) { case *dns.A: if CheckInternalIPs && internalRangeCheckerInstance != nil && internalRangeCheckerInstance.ContainsIPv4(recordType.A) { @@ -478,11 +565,29 @@ func (d *DNSData) ParseFromMsg(msg *dns.Msg) error { } d.AAAA = append(d.AAAA, trimChars(recordType.AAAA.String())) } + d.AllRecords = append(d.AllRecords, record.String()) } - return nil } +// ParseFromMsg and enrich data +func (d *DNSData) ParseFromMsg(msg *dns.Msg) error { + allRecords := append(msg.Answer, msg.Extra...) + allRecords = append(allRecords, msg.Ns...) + return d.ParseFromRR(allRecords) +} + +func (d *DNSData) ParseFromEnvelopeChan(envChan chan *dns.Envelope) error { + var allRecords []dns.RR + for env := range envChan { + if env.Error != nil { + return env.Error + } + allRecords = append(allRecords, env.RR...) + } + return d.ParseFromRR(allRecords) +} + func (d *DNSData) contains() bool { return len(d.A) > 0 || len(d.AAAA) > 0 || len(d.CNAME) > 0 || len(d.MX) > 0 || len(d.NS) > 0 || len(d.PTR) > 0 || len(d.TXT) > 0 || len(d.SOA) > 0 || len(d.CAA) > 0 } @@ -498,16 +603,17 @@ func trimChars(s string) string { } func (d *DNSData) dedupe() { - d.Resolver = deduplicate(d.Resolver) - d.A = deduplicate(d.A) - d.AAAA = deduplicate(d.AAAA) - d.CNAME = deduplicate(d.CNAME) - d.MX = deduplicate(d.MX) - d.PTR = deduplicate(d.PTR) - d.SOA = deduplicate(d.SOA) - d.NS = deduplicate(d.NS) - d.TXT = deduplicate(d.TXT) - d.CAA = deduplicate(d.CAA) + d.Resolver = sliceutil.Dedupe(d.Resolver) + d.A = sliceutil.Dedupe(d.A) + d.AAAA = sliceutil.Dedupe(d.AAAA) + d.CNAME = sliceutil.Dedupe(d.CNAME) + d.MX = sliceutil.Dedupe(d.MX) + d.PTR = sliceutil.Dedupe(d.PTR) + d.SOA = sliceutil.Dedupe(d.SOA) + d.NS = sliceutil.Dedupe(d.NS) + d.TXT = sliceutil.Dedupe(d.TXT) + d.CAA = sliceutil.Dedupe(d.CAA) + d.AllRecords = sliceutil.Dedupe(d.AllRecords) } // Marshal encodes the dnsdata to a binary representation @@ -527,24 +633,13 @@ func (d *DNSData) Unmarshal(b []byte) error { return dec.Decode(&d) } -// deduplicate returns a new slice with duplicates values removed. -func deduplicate(s []string) []string { - if len(s) < 2 { - return s - } - var results []string - seen := make(map[string]struct{}) - for _, val := range s { - if _, ok := seen[val]; !ok { - results = append(results, val) - seen[val] = struct{}{} - } - } - return results -} - // TraceData contains the trace information for a dns query type TraceData struct { Host string `json:"host,omitempty"` DNSData []*DNSData `json:"chain,omitempty"` } + +type AXFRData struct { + Host string `json:"host,omitempty"` + DNSData []*DNSData `json:"chain,omitempty"` +} diff --git a/go.mod b/go.mod index 4a0a18e..6223cf4 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/projectdiscovery/fileutil v0.0.0-20210926202739-6050d0acf73c github.com/projectdiscovery/iputil v0.0.0-20210804143329-3a30fcde43f3 github.com/projectdiscovery/retryablehttp-go v1.0.2 + github.com/projectdiscovery/sliceutil v0.0.0-20220225084130-8392ac12fa6d github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d github.com/stretchr/testify v1.7.1 ) diff --git a/go.sum b/go.sum index 98fa556..356e48d 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/projectdiscovery/mapcidr v0.0.7 h1:WK6WFimbWjUxfvcHEgofYNqIyqQh0vTDKz github.com/projectdiscovery/mapcidr v0.0.7/go.mod h1:7CzdUdjuLVI0s33dQ33lWgjg3vPuLFw2rQzZ0RxkT00= github.com/projectdiscovery/retryablehttp-go v1.0.2 h1:LV1/KAQU+yeWhNVlvveaYFsjBYRwXlNEq0PvrezMV0U= github.com/projectdiscovery/retryablehttp-go v1.0.2/go.mod h1:dx//aY9V247qHdsRf0vdWHTBZuBQ2vm6Dq5dagxrDYI= +github.com/projectdiscovery/sliceutil v0.0.0-20220225084130-8392ac12fa6d h1:wIQPYRZEwTeJuoZLv3NT9r+il2fAv1ObRzTdHkNgOxk= +github.com/projectdiscovery/sliceutil v0.0.0-20220225084130-8392ac12fa6d/go.mod h1:QHXvznfPfA5f0AZUIBkbLapoUJJlsIDgUlkKva6dOr4= github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d h1:lrdpJCBOvRrTnm44Ov7O3tLd3oOWhCvVUhTKkWwibq4= github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d/go.mod h1:oTRc18WBv9t6BpaN9XBY+QmG28PUpsyDzRht56Qf49I= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=