From cb3eec4f6e4ba41e5e38382c065d329929d3f3f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jerry=20Lundstr=C3=B6m?= Date: Thu, 9 Jun 2022 11:29:43 +0200 Subject: [PATCH] TSIG Verify/Generate using TsigProvider - `tsig`: Expose `TsigVerifyProvider` and `TsigGenerateProvider` so that others can use these TSIG functions using a `TsigProvider` --- client.go | 4 ++-- server.go | 4 ++-- tsig.go | 10 +++++++--- tsig_test.go | 4 ++-- xfr.go | 4 ++-- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/client.go b/client.go index fde5b5e31..7cf40f837 100644 --- a/client.go +++ b/client.go @@ -280,7 +280,7 @@ func (co *Conn) ReadMsg() (*Msg, error) { } if t := m.IsTsig(); t != nil { // Need to work on the original message p, as that was used to calculate the tsig. - err = tsigVerifyProvider(p, co.tsigProvider(), co.tsigRequestMAC, false) + err = TsigVerifyProvider(p, co.tsigProvider(), co.tsigRequestMAC, false) } return m, err } @@ -358,7 +358,7 @@ func (co *Conn) WriteMsg(m *Msg) (err error) { var out []byte if t := m.IsTsig(); t != nil { // Set tsigRequestMAC for the next read, although only used in zone transfers. - out, co.tsigRequestMAC, err = tsigGenerateProvider(m, co.tsigProvider(), co.tsigRequestMAC, false) + out, co.tsigRequestMAC, err = TsigGenerateProvider(m, co.tsigProvider(), co.tsigRequestMAC, false) } else { out, err = m.Pack() } diff --git a/server.go b/server.go index b962e6f35..050e1b60e 100644 --- a/server.go +++ b/server.go @@ -646,7 +646,7 @@ func (srv *Server) serveDNS(m []byte, w *response) { w.tsigStatus = nil if w.tsigProvider != nil { if t := req.IsTsig(); t != nil { - w.tsigStatus = tsigVerifyProvider(m, w.tsigProvider, "", false) + w.tsigStatus = TsigVerifyProvider(m, w.tsigProvider, "", false) w.tsigTimersOnly = false w.tsigRequestMAC = t.MAC } @@ -728,7 +728,7 @@ func (w *response) WriteMsg(m *Msg) (err error) { var data []byte if w.tsigProvider != nil { // if no provider, dont check for the tsig (which is a longer check) if t := m.IsTsig(); t != nil { - data, w.tsigRequestMAC, err = tsigGenerateProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly) + data, w.tsigRequestMAC, err = TsigGenerateProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly) if err != nil { return err } diff --git a/tsig.go b/tsig.go index 8b37cc841..5842e82d3 100644 --- a/tsig.go +++ b/tsig.go @@ -166,10 +166,12 @@ type timerWireFmt struct { // timersOnly is false. // If something goes wrong an error is returned, otherwise it is nil. func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) { - return tsigGenerateProvider(m, tsigHMACProvider(secret), requestMAC, timersOnly) + return TsigGenerateProvider(m, tsigHMACProvider(secret), requestMAC, timersOnly) } -func tsigGenerateProvider(m *Msg, provider TsigProvider, requestMAC string, timersOnly bool) ([]byte, string, error) { +// TsigGenerate fills out the TSIG record attached to the message using +// a TsigProvider, for more details and return see TsigGenerate. +func TsigGenerateProvider(m *Msg, provider TsigProvider, requestMAC string, timersOnly bool) ([]byte, string, error) { if m.IsTsig() == nil { panic("dns: TSIG not last RR in additional") } @@ -223,7 +225,9 @@ func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error { return tsigVerify(msg, tsigHMACProvider(secret), requestMAC, timersOnly, uint64(time.Now().Unix())) } -func tsigVerifyProvider(msg []byte, provider TsigProvider, requestMAC string, timersOnly bool) error { +// TsigVerify verifies the TSIG on a message using a TsigProvider, for +// more details and return see TsigVerify. +func TsigVerifyProvider(msg []byte, provider TsigProvider, requestMAC string, timersOnly bool) error { return tsigVerify(msg, provider, requestMAC, timersOnly, uint64(time.Now().Unix())) } diff --git a/tsig_test.go b/tsig_test.go index 25a3127b2..da6cb38dc 100644 --- a/tsig_test.go +++ b/tsig_test.go @@ -354,7 +354,7 @@ func TestTsigGenerateProvider(t *testing.T) { Extra: []RR{&tsig}, } - _, mac, err := tsigGenerateProvider(req, new(testProvider), "", false) + _, mac, err := TsigGenerateProvider(req, new(testProvider), "", false) if err != table.err { t.Fatalf("error doesn't match: expected '%s' but got '%s'", table.err, err) } @@ -397,7 +397,7 @@ func TestTsigVerifyProvider(t *testing.T) { } provider := &testProvider{true} - msgData, _, err := tsigGenerateProvider(req, provider, "", false) + msgData, _, err := TsigGenerateProvider(req, provider, "", false) if err != nil { t.Error(err) } diff --git a/xfr.go b/xfr.go index f0dcf61d4..5a71c87dc 100644 --- a/xfr.go +++ b/xfr.go @@ -237,7 +237,7 @@ func (t *Transfer) ReadMsg() (*Msg, error) { } if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil { // Need to work on the original message p, as that was used to calculate the tsig. - err = tsigVerifyProvider(p, tp, t.tsigRequestMAC, t.tsigTimersOnly) + err = TsigVerifyProvider(p, tp, t.tsigRequestMAC, t.tsigTimersOnly) t.tsigRequestMAC = ts.MAC } return m, err @@ -247,7 +247,7 @@ func (t *Transfer) ReadMsg() (*Msg, error) { func (t *Transfer) WriteMsg(m *Msg) (err error) { var out []byte if ts, tp := m.IsTsig(), t.tsigProvider(); ts != nil && tp != nil { - out, t.tsigRequestMAC, err = tsigGenerateProvider(m, tp, t.tsigRequestMAC, t.tsigTimersOnly) + out, t.tsigRequestMAC, err = TsigGenerateProvider(m, tp, t.tsigRequestMAC, t.tsigTimersOnly) } else { out, err = m.Pack() }