Skip to content

Commit

Permalink
Pass DNS Msg via context into TSIG Verify and Generate functions
Browse files Browse the repository at this point in the history
This gives the server operator a lot more flexibility as to how they manage the verification and generation of tsigs.
The test TestServerRoundtripTsigProvider displays some of these possibilities.

For example:

1. Establish a sync.RWMutex on the secret map so that secrets can be dynamically updated even after the server has been started.
2. Establish more granular secret maps. I.E to ensure that many zones can share TSIG names without collisions in the map.

You could do other things like have zone specific verification or generation.
  • Loading branch information
Fattouche committed Mar 10, 2022
1 parent af1ebf5 commit f47d13e
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 33 deletions.
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func (co *Conn) WriteMsg(m *Msg) (err error) {
var out []byte
if t := m.IsTsig(); t != nil {
// Set tsigRequestMAC for the next read, although only used in zone transfers.
out, co.tsigRequestMAC, err = tsigGenerateProvider(m, co.tsigProvider(), co.tsigRequestMAC, false)
out, co.tsigRequestMAC, err = tsigGenerateProvider(context.Background(), m, co.tsigProvider(), co.tsigRequestMAC, false)
} else {
out, err = m.Pack()
}
Expand Down
2 changes: 1 addition & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ func (w *response) WriteMsg(m *Msg) (err error) {
var data []byte
if w.tsigProvider != nil { // if no provider, dont check for the tsig (which is a longer check)
if t := m.IsTsig(); t != nil {
data, w.tsigRequestMAC, err = tsigGenerateProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
data, w.tsigRequestMAC, err = tsigGenerateProvider(context.Background(), m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
if err != nil {
return err
}
Expand Down
154 changes: 154 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@ package dns

import (
"context"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"crypto/tls"
"encoding/hex"
"fmt"
"hash"
"io"
"net"
"runtime"
Expand Down Expand Up @@ -1256,3 +1263,150 @@ zDCJkckCgYEAndqM5KXGk5xYo+MAA1paZcbTUXwaWwjLU+XSRSSoyBEi5xMtfvUb
kFsxKCqxAnBVGEWAvVZAiiTOxleQFjz5RnL0BQp9Lg2cQe+dvuUmIAA=
-----END RSA PRIVATE KEY-----`)
)

type customTsigProvider struct {
secrets map[string]map[string]string
lock sync.RWMutex
}

func TestServerRoundtripTsigProvider(t *testing.T) {
// Let's use a read write lock so that we can show how to properly reload secrets async/
// We will also make the map value be another map so that we can properly
// separate tsigs based on zone ID
// map[tsigName]map[zoneName]tsigSecret
secrets := map[string]map[string]string{}
secrets["test."] = make(map[string]string)
secrets["test."]["example.com."] = "so6ZGir4GPAqINNh9U5c3A=="
secrets["test."]["otherexample.com."] = "blahblah"

secret := map[string]string{"test.": "so6ZGir4GPAqINNh9U5c3A=="}

s, addrstr, _, err := RunLocalUDPServer(":0", func(srv *Server) {
srv.TsigProvider = customTsigProvider{secrets: secrets}
srv.MsgAcceptFunc = func(dh Header) MsgAcceptAction {
// defaultMsgAcceptFunc does reject UPDATE queries
return MsgAccept
}
})
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()

handlerFired := make(chan struct{})
HandleFunc("example.com.", func(w ResponseWriter, r *Msg) {
close(handlerFired)

m := new(Msg)
m.SetReply(r)
if r.IsTsig() != nil {
status := w.TsigStatus()
if status == nil {
// *Msg r has an TSIG record and it was validated
m.SetTsig("test.", HmacSHA256, 300, time.Now().Unix())
} else {
// *Msg r has an TSIG records and it was not validated
t.Errorf("invalid TSIG: %v", status)
}
} else {
t.Error("missing TSIG")
}
if err := w.WriteMsg(m); err != nil {
t.Error("writemsg failed", err)
}
})

c := new(Client)
m := new(Msg)
m.Opcode = OpcodeUpdate
m.SetQuestion("example.com.", TypeSOA)
m.Ns = []RR{&CNAME{
Hdr: RR_Header{
Name: "foo.example.com.",
Rrtype: TypeCNAME,
Class: ClassINET,
Ttl: 300,
},
Target: "bar.example.com.",
}}
c.TsigSecret = secret
m.SetTsig("test.", HmacSHA256, 300, time.Now().Unix())
_, _, err = c.Exchange(m, addrstr)
if err != nil {
t.Fatal("failed to exchange", err)
}
select {
case <-handlerFired:
// ok, handler was actually called
default:
t.Error("handler was not called")
}
}

func (ctp customTsigProvider) Generate(ctx context.Context, msg []byte, t *TSIG) ([]byte, error) {
// Readlock
ctp.lock.RLock()
defer ctp.lock.RUnlock()

// We can put anything in here, but for this example we will put the dnsMsg in the context so that we can get any information we need
dnsMsg, ok := ctx.Value(DNSMsgKey).(*Msg)
if !ok {
return nil, fmt.Errorf("Failed to find dnsMsg in context for tsig %s", t.Hdr.Name)
}
if len(dnsMsg.Question) == 0 {
return nil, fmt.Errorf("Failed to grab zoneName from dnsMsg question for %s", t.Hdr.Name)
}
zoneName := dnsMsg.Question[0].Name

// Check if we have a secret with this name
secretsWithTsigName, ok := ctp.secrets[t.Hdr.Name]
if !ok {
return nil, fmt.Errorf("Failed to find tsig with this name %s", t.Hdr.Name)
}
// Make sure this tsig is actually for the specific zone in question
secret, ok := secretsWithTsigName[zoneName]
if !ok {
return nil, fmt.Errorf("Failed to find tsig with this name %s and zone %s", t.Hdr.Name, zoneName)
}

rawsecret, err := fromBase64([]byte(secret))
if err != nil {
return nil, err
}

var h hash.Hash
switch CanonicalName(t.Algorithm) {
case HmacSHA1:
h = hmac.New(sha1.New, rawsecret)
case HmacSHA224:
h = hmac.New(sha256.New224, rawsecret)
// Deprecated
case HmacMD5:
h = hmac.New(md5.New, rawsecret)
case HmacSHA256:
h = hmac.New(sha256.New, rawsecret)
case HmacSHA384:
h = hmac.New(sha512.New384, rawsecret)
case HmacSHA512:
h = hmac.New(sha512.New, rawsecret)
default:
return nil, ErrKeyAlg
}
h.Write(msg)
return h.Sum(nil), nil
}

func (key customTsigProvider) Verify(ctx context.Context, msg []byte, t *TSIG) error {
b, err := key.Generate(ctx, msg, t)
if err != nil {
return err
}
mac, err := hex.DecodeString(t.MAC)
if err != nil {
return err
}
if !hmac.Equal(b, mac) {
return ErrSig
}
return nil
}
53 changes: 35 additions & 18 deletions tsig.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dns

import (
"context"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
Expand All @@ -24,17 +25,21 @@ const (
HmacMD5 = "hmac-md5.sig-alg.reg.int." // Deprecated: HmacMD5 is no longer supported.
)

type GenerateContextKey string

var DNSMsgKey GenerateContextKey = "dnsMsg"

// TsigProvider provides the API to plug-in a custom TSIG implementation.
type TsigProvider interface {
// Generate is passed the DNS message to be signed and the partial TSIG RR. It returns the signature and nil, otherwise an error.
Generate(msg []byte, t *TSIG) ([]byte, error)
// Verify is passed the DNS message to be verified and the TSIG RR. If the signature is valid it will return nil, otherwise an error.
Verify(msg []byte, t *TSIG) error
// Generate is passed the context, the stripped DNS message to be signed and the partial TSIG RR. It returns the signature and nil, otherwise an error.
Generate(ctx context.Context, msg []byte, t *TSIG) ([]byte, error)
// Verify is passed the context, the DNS message to be verified and the TSIG RR. If the signature is valid it will return nil, otherwise an error.
Verify(ctx context.Context, msg []byte, t *TSIG) error
}

type tsigHMACProvider string

func (key tsigHMACProvider) Generate(msg []byte, t *TSIG) ([]byte, error) {
func (key tsigHMACProvider) Generate(ctx context.Context, msg []byte, t *TSIG) ([]byte, error) {
// If we barf here, the caller is to blame
rawsecret, err := fromBase64([]byte(key))
if err != nil {
Expand All @@ -59,8 +64,8 @@ func (key tsigHMACProvider) Generate(msg []byte, t *TSIG) ([]byte, error) {
return h.Sum(nil), nil
}

func (key tsigHMACProvider) Verify(msg []byte, t *TSIG) error {
b, err := key.Generate(msg, t)
func (key tsigHMACProvider) Verify(ctx context.Context, msg []byte, t *TSIG) error {
b, err := key.Generate(ctx, msg, t)
if err != nil {
return err
}
Expand All @@ -76,20 +81,20 @@ func (key tsigHMACProvider) Verify(msg []byte, t *TSIG) error {

type tsigSecretProvider map[string]string

func (ts tsigSecretProvider) Generate(msg []byte, t *TSIG) ([]byte, error) {
func (ts tsigSecretProvider) Generate(ctx context.Context, msg []byte, t *TSIG) ([]byte, error) {
key, ok := ts[t.Hdr.Name]
if !ok {
return nil, ErrSecret
}
return tsigHMACProvider(key).Generate(msg, t)
return tsigHMACProvider(key).Generate(ctx, msg, t)
}

func (ts tsigSecretProvider) Verify(msg []byte, t *TSIG) error {
func (ts tsigSecretProvider) Verify(ctx context.Context, msg []byte, t *TSIG) error {
key, ok := ts[t.Hdr.Name]
if !ok {
return ErrSecret
}
return tsigHMACProvider(key).Verify(msg, t)
return tsigHMACProvider(key).Verify(ctx, msg, t)
}

// TSIG is the RR the holds the transaction signature of a message.
Expand Down Expand Up @@ -166,14 +171,17 @@ type timerWireFmt struct {
// timersOnly is false.
// If something goes wrong an error is returned, otherwise it is nil.
func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, string, error) {
return tsigGenerateProvider(m, tsigHMACProvider(secret), requestMAC, timersOnly)
return tsigGenerateProvider(context.Background(), m, tsigHMACProvider(secret), requestMAC, timersOnly)
}

func tsigGenerateProvider(m *Msg, provider TsigProvider, requestMAC string, timersOnly bool) ([]byte, string, error) {
func tsigGenerateProvider(ctx context.Context, m *Msg, provider TsigProvider, requestMAC string, timersOnly bool) ([]byte, string, error) {
if m.IsTsig() == nil {
panic("dns: TSIG not last RR in additional")
}

// Add the dns message into the context so that the tsigProvider has access to it.
ctx = context.WithValue(ctx, DNSMsgKey, m)

rr := m.Extra[len(m.Extra)-1].(*TSIG)
m.Extra = m.Extra[0 : len(m.Extra)-1] // kill the TSIG from the msg
mbuf, err := m.Pack()
Expand All @@ -195,7 +203,7 @@ func tsigGenerateProvider(m *Msg, provider TsigProvider, requestMAC string, time

// Sign unless there is a key or MAC validation error (RFC 8945 5.3.2)
if rr.Error != RcodeBadKey && rr.Error != RcodeBadSig {
mac, err := provider.Generate(buf, rr)
mac, err := provider.Generate(ctx, buf, rr)
if err != nil {
return nil, "", err
}
Expand All @@ -220,15 +228,24 @@ func tsigGenerateProvider(m *Msg, provider TsigProvider, requestMAC string, time
// If the signature does not validate err contains the
// error, otherwise it is nil.
func TsigVerify(msg []byte, secret, requestMAC string, timersOnly bool) error {
return tsigVerify(msg, tsigHMACProvider(secret), requestMAC, timersOnly, uint64(time.Now().Unix()))
return tsigVerify(context.Background(), msg, tsigHMACProvider(secret), requestMAC, timersOnly, uint64(time.Now().Unix()))
}

func tsigVerifyProvider(msg []byte, provider TsigProvider, requestMAC string, timersOnly bool) error {
return tsigVerify(msg, provider, requestMAC, timersOnly, uint64(time.Now().Unix()))
return tsigVerify(context.Background(), msg, provider, requestMAC, timersOnly, uint64(time.Now().Unix()))
}

// actual implementation of TsigVerify, taking the current time ('now') as a parameter for the convenience of tests.
func tsigVerify(msg []byte, provider TsigProvider, requestMAC string, timersOnly bool, now uint64) error {
func tsigVerify(ctx context.Context, msg []byte, provider TsigProvider, requestMAC string, timersOnly bool, now uint64) error {
dnsMsg := &Msg{}
err := dnsMsg.Unpack(msg)
if err != nil {
return err
}

// Add the dns message into the context so that the tsigProvider has access to it.
ctx = context.WithValue(ctx, DNSMsgKey, dnsMsg)

// Strip the TSIG from the incoming msg
stripped, tsig, err := stripTsig(msg)
if err != nil {
Expand All @@ -240,7 +257,7 @@ func tsigVerify(msg []byte, provider TsigProvider, requestMAC string, timersOnly
return err
}

if err := provider.Verify(buf, tsig); err != nil {
if err := provider.Verify(ctx, buf, tsig); err != nil {
return err
}

Expand Down

0 comments on commit f47d13e

Please sign in to comment.