Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TsigProvider for Server and Transfer #1331

Merged
merged 4 commits into from Feb 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 12 additions & 20 deletions client.go
Expand Up @@ -39,6 +39,14 @@ type Conn struct {
tsigRequestMAC string
}

func (co *Conn) tsigProvider() TsigProvider {
if co.TsigProvider != nil {
return co.TsigProvider
}
// tsigSecretProvider will return ErrSecret if co.TsigSecret is nil.
return tsigSecretProvider(co.TsigSecret)
}

// A Client defines parameters for a DNS client.
type Client struct {
Net string // if "tcp" or "tcp-tls" (DNS over TLS) a TCP query will be initiated, otherwise an UDP one (default is "" for UDP)
Expand Down Expand Up @@ -271,15 +279,8 @@ func (co *Conn) ReadMsg() (*Msg, error) {
return m, err
}
if t := m.IsTsig(); t != nil {
if co.TsigProvider != nil {
err = tsigVerifyProvider(p, co.TsigProvider, co.tsigRequestMAC, false)
} else {
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
return m, ErrSecret
}
// Need to work on the original message p, as that was used to calculate the tsig.
err = TsigVerify(p, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
}
// Need to work on the original message p, as that was used to calculate the tsig.
err = tsigVerifyProvider(p, co.tsigProvider(), co.tsigRequestMAC, false)
}
return m, err
}
Expand Down Expand Up @@ -356,17 +357,8 @@ func (co *Conn) Read(p []byte) (n int, err error) {
func (co *Conn) WriteMsg(m *Msg) (err error) {
var out []byte
if t := m.IsTsig(); t != nil {
mac := ""
if co.TsigProvider != nil {
out, mac, err = tsigGenerateProvider(m, co.TsigProvider, co.tsigRequestMAC, false)
} else {
if _, ok := co.TsigSecret[t.Hdr.Name]; !ok {
return ErrSecret
}
out, mac, err = TsigGenerate(m, co.TsigSecret[t.Hdr.Name], co.tsigRequestMAC, false)
}
// Set for the next read, although only used in zone transfers
co.tsigRequestMAC = mac
// Set tsigRequestMAC for the next read, although only used in zone transfers.
out, co.tsigRequestMAC, err = tsigGenerateProvider(m, co.tsigProvider(), co.tsigRequestMAC, false)
} else {
out, err = m.Pack()
}
Expand Down
42 changes: 25 additions & 17 deletions server.go
Expand Up @@ -71,12 +71,12 @@ type response struct {
tsigTimersOnly bool
tsigStatus error
tsigRequestMAC string
tsigSecret map[string]string // the tsig secrets
udp net.PacketConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
pcSession net.Addr // address to use when writing to a generic net.PacketConn
writer Writer // writer to output the raw DNS bits
tsigProvider TsigProvider
udp net.PacketConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right
pcSession net.Addr // address to use when writing to a generic net.PacketConn
writer Writer // writer to output the raw DNS bits
}

// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
Expand Down Expand Up @@ -211,6 +211,8 @@ type Server struct {
WriteTimeout time.Duration
// TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
IdleTimeout func() time.Duration
// An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
TsigProvider TsigProvider
// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
TsigSecret map[string]string
// If NotifyStartedFunc is set it is called once the server has started listening.
Expand Down Expand Up @@ -238,6 +240,16 @@ type Server struct {
udpPool sync.Pool
}

func (srv *Server) tsigProvider() TsigProvider {
tmthrgd marked this conversation as resolved.
Show resolved Hide resolved
if srv.TsigProvider != nil {
return srv.TsigProvider
}
if srv.TsigSecret != nil {
return tsigSecretProvider(srv.TsigSecret)
}
return nil
}

func (srv *Server) isStarted() bool {
srv.lock.RLock()
started := srv.started
Expand Down Expand Up @@ -526,7 +538,7 @@ func (srv *Server) serveUDP(l net.PacketConn) error {

// Serve a new TCP connection.
func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
w := &response{tsigSecret: srv.TsigSecret, tcp: rw}
w := &response{tsigProvider: srv.tsigProvider(), tcp: rw}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
Expand Down Expand Up @@ -581,7 +593,7 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {

// Serve a new UDP request.
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
w := &response{tsigSecret: srv.TsigSecret, udp: u, udpSession: udpSession, pcSession: pcSession}
w := &response{tsigProvider: srv.tsigProvider(), udp: u, udpSession: udpSession, pcSession: pcSession}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
Expand Down Expand Up @@ -632,15 +644,11 @@ func (srv *Server) serveDNS(m []byte, w *response) {
}

w.tsigStatus = nil
if w.tsigSecret != nil {
if w.tsigProvider != nil {
if t := req.IsTsig(); t != nil {
if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
w.tsigStatus = TsigVerify(m, secret, "", false)
} else {
w.tsigStatus = ErrSecret
}
w.tsigStatus = tsigVerifyProvider(m, w.tsigProvider, "", false)
w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
w.tsigRequestMAC = t.MAC
}
}

Expand Down Expand Up @@ -718,9 +726,9 @@ func (w *response) WriteMsg(m *Msg) (err error) {
}

var data []byte
if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
if w.tsigProvider != nil { // if no provider, dont check for the tsig (which is a longer check)
if t := m.IsTsig(); t != nil {
data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly)
data, w.tsigRequestMAC, err = tsigGenerateProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
if err != nil {
return err
}
Expand Down
18 changes: 18 additions & 0 deletions tsig.go
Expand Up @@ -74,6 +74,24 @@ func (key tsigHMACProvider) Verify(msg []byte, t *TSIG) error {
return nil
}

type tsigSecretProvider map[string]string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good for a default TsigProvider, however it still doesn't solve the problem described in

#1325 (comment)

We need some way for Generate to be able to unwrap the Message so that we can see the query name. This is specifically useful in situations where multiple tsigs have the same name, but exist under a different zone.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can be handled either in a separate PR or externally, but yes we do need to do something.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How common a use case is that do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say it's fairly common. I ran into this problem early on with this library when trying to serve multiple zones that had different primaries, therefore different TSIGs on each primary meaning they could not use the TSIG name as the zone Name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @tmthrgd have you thought any more about how we can best obtain the query name from the Generate function. I think having the opposite of TsigBuffer available in a function would make this possible. Alternatively we could consider passing in a context? It would be useful to have contexts everywhere anyway so that we could pass data around without needing to alter signatures.


func (ts tsigSecretProvider) Generate(msg []byte, t *TSIG) ([]byte, error) {
key, ok := ts[t.Hdr.Name]
if !ok {
return nil, ErrSecret
}
return tsigHMACProvider(key).Generate(msg, t)
}

func (ts tsigSecretProvider) Verify(msg []byte, t *TSIG) error {
key, ok := ts[t.Hdr.Name]
if !ok {
return ErrSecret
}
return tsigHMACProvider(key).Verify(msg, t)
}

// TSIG is the RR the holds the transaction signature of a message.
// See RFC 2845 and RFC 4635.
type TSIG struct {
Expand Down
27 changes: 16 additions & 11 deletions xfr.go
Expand Up @@ -17,11 +17,22 @@ type Transfer struct {
DialTimeout time.Duration // net.DialTimeout, defaults to 2 seconds
ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds
WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds
TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations.
TsigSecret map[string]string // Secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
tsigTimersOnly bool
}

// Think we need to away to stop the transfer
func (t *Transfer) tsigProvider() TsigProvider {
if t.TsigProvider != nil {
return t.TsigProvider
}
if t.TsigSecret != nil {
return tsigSecretProvider(t.TsigSecret)
}
return nil
}

// TODO: Think we need to away to stop the transfer

// In performs an incoming transfer with the server in a.
// If you would like to set the source IP, or some other attribute
Expand Down Expand Up @@ -224,12 +235,9 @@ func (t *Transfer) ReadMsg() (*Msg, error) {
if err := m.Unpack(p); err != nil {
return nil, err
}
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
return m, ErrSecret
}
if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
// Need to work on the original message p, as that was used to calculate the tsig.
err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
err = tsigVerifyProvider(p, tp, t.tsigRequestMAC, t.tsigTimersOnly)
t.tsigRequestMAC = ts.MAC
}
return m, err
Expand All @@ -238,11 +246,8 @@ func (t *Transfer) ReadMsg() (*Msg, error) {
// WriteMsg writes a message through the transfer connection t.
func (t *Transfer) WriteMsg(m *Msg) (err error) {
var out []byte
if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil {
if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok {
return ErrSecret
}
out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly)
if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil {
out, t.tsigRequestMAC, err = tsigGenerateProvider(m, tp, t.tsigRequestMAC, t.tsigTimersOnly)
} else {
out, err = m.Pack()
}
Expand Down
56 changes: 55 additions & 1 deletion xfr_test.go
@@ -1,6 +1,9 @@
package dns

import "testing"
import (
"testing"
"time"
)

var (
tsigSecret = map[string]string{"axfr.": "so6ZGir4GPAqINNh9U5c3A=="}
Expand Down Expand Up @@ -127,3 +130,54 @@ func axfrTestingSuite(t *testing.T, addrstr string) {
}
}
}

func axfrTestingSuiteWithCustomTsig(t *testing.T, addrstr string, provider TsigProvider) {
tr := new(Transfer)
m := new(Msg)
var err error
tr.Conn, err = Dial("tcp", addrstr)
if err != nil {
t.Fatal("failed to dial", err)
}
tr.TsigProvider = provider
m.SetAxfr("miek.nl.")
m.SetTsig("axfr.", HmacSHA256, 300, time.Now().Unix())

c, err := tr.In(m, addrstr)
if err != nil {
t.Fatal("failed to zone transfer in", err)
}

var records []RR
for msg := range c {
if msg.Error != nil {
t.Fatal(msg.Error)
}
records = append(records, msg.RR...)
}

if len(records) != len(xfrTestData) {
t.Fatalf("bad axfr: expected %v, got %v", records, xfrTestData)
}

for i, rr := range records {
if !IsDuplicate(rr, xfrTestData[i]) {
t.Errorf("bad axfr: expected %v, got %v", records, xfrTestData)
}
}
}

func TestCustomTsigProvider(t *testing.T) {
HandleFunc("miek.nl.", SingleEnvelopeXfrServer)
defer HandleRemove("miek.nl.")

s, addrstr, _, err := RunLocalTCPServer(":0", func(srv *Server) {
srv.TsigProvider = tsigSecretProvider(tsigSecret)
})
if err != nil {
t.Fatalf("unable to run test server: %s", err)
}
defer s.Shutdown()

axfrTestingSuiteWithCustomTsig(t, addrstr, tsigSecretProvider(tsigSecret))
}