Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for AXFR zone transfer #51

Merged
merged 6 commits into from May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
233 changes: 164 additions & 69 deletions client.go
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/projectdiscovery/retryabledns/doh"
"github.com/projectdiscovery/retryabledns/hostsfile"
"github.com/projectdiscovery/retryablehttp-go"
"github.com/projectdiscovery/sliceutil"
)

var internalRangeCheckerInstance *internalRangeChecker
Expand Down Expand Up @@ -52,7 +53,7 @@ func New(baseResolvers []string, maxRetries int) *Client {

// New creates a new dns client with options
func NewWithOptions(options Options) *Client {
parsedBaseResolvers := parseResolvers(deduplicate(options.BaseResolvers))
parsedBaseResolvers := parseResolvers(sliceutil.Dedupe(options.BaseResolvers))
var knownHosts map[string][]string
if options.Hostsfile {
knownHosts, _ = hostsfile.ParseDefault()
Expand Down Expand Up @@ -185,13 +186,27 @@ func (c *Client) NS(host string) (*DNSData, error) {
return c.QueryMultiple(host, []uint16{dns.TypeNS})
}

func (c *Client) AXFR(host string) (*AXFRData, error) {
return c.axfr(host)
}

// QueryMultiple sends a provided dns request and return the data with a specific resolver
func (c *Client) QueryMultipleWithResolver(host string, requestTypes []uint16, resolver Resolver) (*DNSData, error) {
return c.queryMultiple(host, requestTypes, resolver)
}

// CAA helper function
func (c *Client) CAA(host string) (*DNSData, error) {
return c.QueryMultiple(host, []uint16{dns.TypeCAA})
}

// QueryMultiple sends a provided dns request and return the data
func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, error) {
return c.queryMultiple(host, requestTypes, nil)
}

// QueryMultiple sends a provided dns request and return the data
func (c *Client) queryMultiple(host string, requestTypes []uint16, resolver Resolver) (*DNSData, error) {
var (
dnsdata DNSData
err error
Expand All @@ -212,47 +227,73 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er

msg := &dns.Msg{}
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.Question = make([]dns.Question, 1)
msg.SetEdns0(4096, false)

for _, requestType := range requestTypes {
name := dns.Fqdn(host)
msg.Question = make([]dns.Question, 1)

// In case of PTR adjust the domain name
if requestType == dns.TypePTR {
switch requestType {
case dns.TypeAXFR:
msg.SetAxfr(name)
case dns.TypePTR: // In case of PTR adjust the domain name
var err error
if net.ParseIP(host) != nil {
name, err = dns.ReverseAddr(host)
if err != nil {
return nil, err
}
}
fallthrough
default:
// Enable Extension Mechanisms for DNS for all messages
msg.RecursionDesired = true
question := dns.Question{
Name: name,
Qtype: requestType,
Qclass: dns.ClassINET,
}
msg.Question[0] = question
}

question := dns.Question{
Name: name,
Qtype: requestType,
Qclass: dns.ClassINET,
}
msg.Question[0] = question

// Enable Extension Mechanisms for DNS for all messages
msg.SetEdns0(4096, false)

var resp *dns.Msg
var (
resp *dns.Msg
trResp chan *dns.Envelope
)
for i := 0; i < c.options.MaxRetries; i++ {
index := atomic.AddUint32(&c.serversIndex, 1)
resolver := c.resolvers[index%uint32(len(c.resolvers))]

if resolver == nil {
resolver = c.resolvers[index%uint32(len(c.resolvers))]
}
switch r := resolver.(type) {
case *NetworkResolver:
switch r.Protocol {
case TCP:
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
case UDP:
resp, _, err = c.udpClient.Exchange(msg, resolver.String())
case DOT:
resp, _, err = c.dotClient.Exchange(msg, resolver.String())
if requestType == dns.TypeAXFR {
var dnsconn *dns.Conn
switch r.Protocol {
case TCP:
dnsconn, err = c.tcpClient.Dial(resolver.String())
case UDP:
dnsconn, err = c.udpClient.Dial(resolver.String())
case DOT:
dnsconn, err = c.dotClient.Dial(resolver.String())
default:
dnsconn, err = c.tcpClient.Dial(resolver.String())
}
if err != nil {
break
}
defer dnsconn.Close()
dnsTransfer := &dns.Transfer{Conn: dnsconn}
trResp, err = dnsTransfer.In(msg, resolver.String())
} else {
switch r.Protocol {
case TCP:
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
case UDP:
resp, _, err = c.udpClient.Exchange(msg, resolver.String())
case DOT:
resp, _, err = c.dotClient.Exchange(msg, resolver.String())
}
}
case *DohResolver:
method := doh.MethodPost
Expand All @@ -261,26 +302,37 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
}
resp, err = c.dohClient.QueryWithDOHMsg(method, doh.Resolver{URL: r.URL}, msg)
}
if err != nil || resp == nil {

if err != nil || (trResp == nil && resp == nil) {
continue
}

// https://github.com/projectdiscovery/retryabledns/issues/25
if resp.Truncated && c.TCPFallback {
if resp != nil && resp.Truncated && c.TCPFallback {
resp, _, err = c.tcpClient.Exchange(msg, resolver.String())
if err != nil || resp == nil {
continue
}
}

err = dnsdata.ParseFromMsg(resp)
switch requestType {
case dns.TypeAXFR:
err = dnsdata.ParseFromEnvelopeChan(trResp)
default:
err = dnsdata.ParseFromMsg(resp)
}

// populate anyway basic info
dnsdata.Host = host
dnsdata.StatusCode = dns.RcodeToString[resp.Rcode]
dnsdata.StatusCodeRaw = resp.Rcode
switch {
case resp != nil:
dnsdata.StatusCode = dns.RcodeToString[resp.Rcode]
dnsdata.StatusCodeRaw = resp.Rcode
dnsdata.Raw += resp.String()
case trResp != nil:
// pass
}
dnsdata.Timestamp = time.Now()
dnsdata.Raw += resp.String()
dnsdata.Resolver = append(dnsdata.Resolver, resolver.String())

if err != nil || !dnsdata.contains() {
Expand All @@ -289,7 +341,10 @@ func (c *Client) QueryMultiple(host string, requestTypes []uint16) (*DNSData, er
dnsdata.dedupe()

// stop on success
if resp.Rcode == dns.RcodeSuccess {
if resp != nil && resp.Rcode == dns.RcodeSuccess {
break
}
if trResp != nil {
break
}
}
Expand Down Expand Up @@ -336,7 +391,7 @@ func (c *Client) QueryParallel(host string, requestType uint16, resolvers []stri
return dnsdatas, nil
}

// QueryMultiple sends a provided dns request and return the data
// Trace the requested domain with the provided query type
func (c *Client) Trace(host string, requestType uint16, maxrecursion int) (*TraceData, error) {
var tracedata TraceData
host = dns.CanonicalName(host)
Expand Down Expand Up @@ -388,7 +443,7 @@ func (c *Client) Trace(host string, requestType uint16, maxrecursion int) (*Trac
}
}
}
newNSResolvers = deduplicate(newNSResolvers)
newNSResolvers = sliceutil.Dedupe(newNSResolvers)

// if we have no new resolvers => return
if len(newNSResolvers) == 0 {
Expand All @@ -413,6 +468,40 @@ func (c *Client) Trace(host string, requestType uint16, maxrecursion int) (*Trac
return &tracedata, nil
}

func (c *Client) axfr(host string) (*AXFRData, error) {
// obtain ns servers
dnsData, err := c.NS(host)
if err != nil {
return nil, err
}
// resolve ns servers to ips
var resolvers []Resolver

for _, ns := range dnsData.NS {
nsData, err := c.A(ns)
if err != nil {
continue
}
for _, a := range nsData.A {
resolvers = append(resolvers, &NetworkResolver{Protocol: TCP, Host: a, Port: "53"})
}
}

resolvers = append(resolvers, c.resolvers...)

var data []*DNSData
// perform zone transfer for each ns
for _, resolver := range resolvers {
nsData, err := c.QueryMultipleWithResolver(host, []uint16{dns.TypeAXFR}, resolver)
if err != nil {
continue
}
data = append(data, nsData)
}

return &AXFRData{Host: host, DNSData: data}, nil
}

// DNSData is the data for a DNS request response
type DNSData struct {
Host string `json:"host,omitempty"`
Expand All @@ -426,27 +515,25 @@ type DNSData struct {
SOA []string `json:"soa,omitempty"`
NS []string `json:"ns,omitempty"`
TXT []string `json:"txt,omitempty"`
CAA []string `json:"caa,omitempty"`
AllRecords []string `json:"all,omitempty"`
Raw string `json:"raw,omitempty"`
HasInternalIPs bool `json:"has_internal_ips"`
HasInternalIPs bool `json:"has_internal_ips,omitempty"`
InternalIPs []string `json:"internal_ips,omitempty"`
StatusCode string `json:"status_code,omitempty"`
StatusCodeRaw int `json:"status_code_raw,omitempty"`
TraceData *TraceData `json:"trace,omitempty"`
AXFRData *AXFRData `json:"axfr,omitempty"`
RawResp *dns.Msg `json:"raw_resp,omitempty"`
Timestamp time.Time `json:"timestamp,omitempty"`
CAA []string `json:"caa,omitempty"`
}

// CheckInternalIPs when set to true returns if DNS response IPs
// belong to internal IP ranges.
var CheckInternalIPs = false

// ParseFromMsg and enrich data
func (d *DNSData) ParseFromMsg(msg *dns.Msg) error {
allRecords := append(msg.Answer, msg.Extra...)
allRecords = append(allRecords, msg.Ns...)

for _, record := range allRecords {
func (d *DNSData) ParseFromRR(rrs []dns.RR) error {
for _, record := range rrs {
switch recordType := record.(type) {
case *dns.A:
if CheckInternalIPs && internalRangeCheckerInstance != nil && internalRangeCheckerInstance.ContainsIPv4(recordType.A) {
Expand Down Expand Up @@ -478,11 +565,29 @@ func (d *DNSData) ParseFromMsg(msg *dns.Msg) error {
}
d.AAAA = append(d.AAAA, trimChars(recordType.AAAA.String()))
}
d.AllRecords = append(d.AllRecords, record.String())
}

return nil
}

// ParseFromMsg and enrich data
func (d *DNSData) ParseFromMsg(msg *dns.Msg) error {
allRecords := append(msg.Answer, msg.Extra...)
allRecords = append(allRecords, msg.Ns...)
return d.ParseFromRR(allRecords)
}

func (d *DNSData) ParseFromEnvelopeChan(envChan chan *dns.Envelope) error {
var allRecords []dns.RR
for env := range envChan {
if env.Error != nil {
return env.Error
}
allRecords = append(allRecords, env.RR...)
}
return d.ParseFromRR(allRecords)
}

func (d *DNSData) contains() bool {
return len(d.A) > 0 || len(d.AAAA) > 0 || len(d.CNAME) > 0 || len(d.MX) > 0 || len(d.NS) > 0 || len(d.PTR) > 0 || len(d.TXT) > 0 || len(d.SOA) > 0 || len(d.CAA) > 0
}
Expand All @@ -498,16 +603,17 @@ func trimChars(s string) string {
}

func (d *DNSData) dedupe() {
d.Resolver = deduplicate(d.Resolver)
d.A = deduplicate(d.A)
d.AAAA = deduplicate(d.AAAA)
d.CNAME = deduplicate(d.CNAME)
d.MX = deduplicate(d.MX)
d.PTR = deduplicate(d.PTR)
d.SOA = deduplicate(d.SOA)
d.NS = deduplicate(d.NS)
d.TXT = deduplicate(d.TXT)
d.CAA = deduplicate(d.CAA)
d.Resolver = sliceutil.Dedupe(d.Resolver)
d.A = sliceutil.Dedupe(d.A)
d.AAAA = sliceutil.Dedupe(d.AAAA)
d.CNAME = sliceutil.Dedupe(d.CNAME)
d.MX = sliceutil.Dedupe(d.MX)
d.PTR = sliceutil.Dedupe(d.PTR)
d.SOA = sliceutil.Dedupe(d.SOA)
d.NS = sliceutil.Dedupe(d.NS)
d.TXT = sliceutil.Dedupe(d.TXT)
d.CAA = sliceutil.Dedupe(d.CAA)
d.AllRecords = sliceutil.Dedupe(d.AllRecords)
}

// Marshal encodes the dnsdata to a binary representation
Expand All @@ -527,24 +633,13 @@ func (d *DNSData) Unmarshal(b []byte) error {
return dec.Decode(&d)
}

// deduplicate returns a new slice with duplicates values removed.
func deduplicate(s []string) []string {
if len(s) < 2 {
return s
}
var results []string
seen := make(map[string]struct{})
for _, val := range s {
if _, ok := seen[val]; !ok {
results = append(results, val)
seen[val] = struct{}{}
}
}
return results
}

// TraceData contains the trace information for a dns query
type TraceData struct {
Host string `json:"host,omitempty"`
DNSData []*DNSData `json:"chain,omitempty"`
}

type AXFRData struct {
Host string `json:"host,omitempty"`
DNSData []*DNSData `json:"chain,omitempty"`
}
1 change: 1 addition & 0 deletions go.mod
Expand Up @@ -7,6 +7,7 @@ require (
github.com/projectdiscovery/fileutil v0.0.0-20210926202739-6050d0acf73c
github.com/projectdiscovery/iputil v0.0.0-20210804143329-3a30fcde43f3
github.com/projectdiscovery/retryablehttp-go v1.0.2
github.com/projectdiscovery/sliceutil v0.0.0-20220225084130-8392ac12fa6d
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d
github.com/stretchr/testify v1.7.1
)
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -46,6 +46,8 @@ github.com/projectdiscovery/mapcidr v0.0.7 h1:WK6WFimbWjUxfvcHEgofYNqIyqQh0vTDKz
github.com/projectdiscovery/mapcidr v0.0.7/go.mod h1:7CzdUdjuLVI0s33dQ33lWgjg3vPuLFw2rQzZ0RxkT00=
github.com/projectdiscovery/retryablehttp-go v1.0.2 h1:LV1/KAQU+yeWhNVlvveaYFsjBYRwXlNEq0PvrezMV0U=
github.com/projectdiscovery/retryablehttp-go v1.0.2/go.mod h1:dx//aY9V247qHdsRf0vdWHTBZuBQ2vm6Dq5dagxrDYI=
github.com/projectdiscovery/sliceutil v0.0.0-20220225084130-8392ac12fa6d h1:wIQPYRZEwTeJuoZLv3NT9r+il2fAv1ObRzTdHkNgOxk=
github.com/projectdiscovery/sliceutil v0.0.0-20220225084130-8392ac12fa6d/go.mod h1:QHXvznfPfA5f0AZUIBkbLapoUJJlsIDgUlkKva6dOr4=
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d h1:lrdpJCBOvRrTnm44Ov7O3tLd3oOWhCvVUhTKkWwibq4=
github.com/projectdiscovery/stringsutil v0.0.0-20210823090203-2f5f137e8e1d/go.mod h1:oTRc18WBv9t6BpaN9XBY+QmG28PUpsyDzRht56Qf49I=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand Down