Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass DNS Msg via context into TSIG Verify and Generate functions #1348

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion acceptfunc_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dns

import (
"context"
"testing"
)

Expand Down Expand Up @@ -28,7 +29,7 @@ func TestAcceptNotify(t *testing.T) {
}
}

func handleNotify(w ResponseWriter, req *Msg) {
func handleNotify(ctx context.Context, w ResponseWriter, req *Msg) {
m := new(Msg)
m.SetReply(req)
w.WriteMsg(m)
Expand Down
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func (co *Conn) WriteMsg(m *Msg) (err error) {
var out []byte
if t := m.IsTsig(); t != nil {
// Set tsigRequestMAC for the next read, although only used in zone transfers.
out, co.tsigRequestMAC, err = tsigGenerateProvider(m, co.tsigProvider(), co.tsigRequestMAC, false)
out, co.tsigRequestMAC, err = tsigGenerateProvider(context.Background(), m, co.tsigProvider(), co.tsigRequestMAC, false)
} else {
out, err = m.Pack()
}
Expand Down
4 changes: 2 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ func TestClientEDNS0Local(t *testing.T) {
optStr1 := "1979:0x0707"
optStr2 := strconv.Itoa(EDNS0LOCALSTART) + ":0x0601"

handler := func(w ResponseWriter, req *Msg) {
handler := func(ctx context.Context, w ResponseWriter, req *Msg) {
m := new(Msg)
m.SetReply(req)

Expand Down Expand Up @@ -667,7 +667,7 @@ func TestConcurrentExchanges(t *testing.T) {

for _, m := range cases {
mm := m // redeclare m so as not to trip the race detector
handler := func(w ResponseWriter, req *Msg) {
handler := func(ctx context.Context, w ResponseWriter, req *Msg) {
r := mm.Copy()
r.SetReply(req)

Expand Down
11 changes: 6 additions & 5 deletions serve_mux.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dns

import (
"context"
"sync"
)

Expand Down Expand Up @@ -70,7 +71,7 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) {
}

// HandleFunc adds a handler function to the ServeMux for pattern.
func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
func (mux *ServeMux) HandleFunc(pattern string, handler func(context.Context, ResponseWriter, *Msg)) {
mux.Handle(pattern, HandlerFunc(handler))
}

Expand All @@ -93,16 +94,16 @@ func (mux *ServeMux) HandleRemove(pattern string) {
//
// If no handler is found, or there is no question, a standard REFUSED
// message is returned
func (mux *ServeMux) ServeDNS(w ResponseWriter, req *Msg) {
func (mux *ServeMux) ServeDNS(ctx context.Context, w ResponseWriter, req *Msg) {
var h Handler
if len(req.Question) >= 1 { // allow more than one question
h = mux.match(req.Question[0].Name, req.Question[0].Qtype)
}

if h != nil {
h.ServeDNS(w, req)
h.ServeDNS(ctx, w, req)
} else {
handleRefused(w, req)
handleRefused(ctx, w, req)
}
}

Expand All @@ -117,6 +118,6 @@ func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }

// HandleFunc registers the handler function with the given pattern
// in the DefaultServeMux.
func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
func HandleFunc(pattern string, handler func(context.Context, ResponseWriter, *Msg)) {
DefaultServeMux.HandleFunc(pattern, handler)
}
71 changes: 36 additions & 35 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ var aLongTimeAgo = time.Unix(1, 0)

// Handler is implemented by any value that implements ServeDNS.
type Handler interface {
ServeDNS(w ResponseWriter, r *Msg)
ServeDNS(ctx context.Context, w ResponseWriter, r *Msg)
}

// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as DNS handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
// Handler object that calls f.
type HandlerFunc func(ResponseWriter, *Msg)
type HandlerFunc func(context.Context, ResponseWriter, *Msg)

// ServeDNS calls f(w, r).
func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
f(w, r)
func (f HandlerFunc) ServeDNS(ctx context.Context, w ResponseWriter, r *Msg) {
f(ctx, w, r)
}

// A ResponseWriter interface is used by an DNS handler to
Expand Down Expand Up @@ -80,15 +80,15 @@ type response struct {
}

// handleRefused returns a HandlerFunc that returns REFUSED for every request it gets.
func handleRefused(w ResponseWriter, r *Msg) {
func handleRefused(ctx context.Context, w ResponseWriter, r *Msg) {
m := new(Msg)
m.SetRcode(r, RcodeRefused)
w.WriteMsg(m)
}

// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
// Deprecated: This function is going away.
func HandleFailed(w ResponseWriter, r *Msg) {
func HandleFailed(ctx context.Context, w ResponseWriter, r *Msg) {
m := new(Msg)
m.SetRcode(r, RcodeServerFailure)
// does not matter if this write fails
Expand Down Expand Up @@ -142,10 +142,10 @@ type Writer interface {
type Reader interface {
// ReadTCP reads a raw message from a TCP connection. Implementations may alter
// connection properties, for example the read-deadline.
ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
ReadTCP(ctx context.Context, conn net.Conn, timeout time.Duration) ([]byte, context.Context, error)
// ReadUDP reads a raw message from a UDP connection. Implementations may alter
// connection properties, for example the read-deadline.
ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
ReadUDP(ctx context.Context, conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, context.Context, error)
}

// PacketConnReader is an optional interface that Readers can implement to support using generic net.PacketConns.
Expand All @@ -154,7 +154,7 @@ type PacketConnReader interface {

// ReadPacketConn reads a raw message from a generic net.PacketConn UDP connection. Implementations may
// alter connection properties, for example the read-deadline.
ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error)
ReadPacketConn(ctx context.Context, conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, context.Context, error)
}

// defaultReader is an adapter for the Server struct that implements the Reader and
Expand All @@ -166,16 +166,16 @@ type defaultReader struct {

var _ PacketConnReader = defaultReader{}

func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
return dr.readTCP(conn, timeout)
func (dr defaultReader) ReadTCP(ctx context.Context, conn net.Conn, timeout time.Duration) ([]byte, context.Context, error) {
return dr.readTCP(ctx, conn, timeout)
}

func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
return dr.readUDP(conn, timeout)
func (dr defaultReader) ReadUDP(ctx context.Context, conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, context.Context, error) {
return dr.readUDP(ctx, conn, timeout)
}

func (dr defaultReader) ReadPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
return dr.readPacketConn(conn, timeout)
func (dr defaultReader) ReadPacketConn(ctx context.Context, conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, context.Context, error) {
return dr.readPacketConn(ctx, conn, timeout)
}

// DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
Expand Down Expand Up @@ -509,10 +509,11 @@ func (srv *Server) serveUDP(l net.PacketConn) error {
sUDP *SessionUDP
err error
)
ctx := context.Background()
if isUDP {
m, sUDP, err = reader.ReadUDP(lUDP, rtimeout)
m, sUDP, ctx, err = reader.ReadUDP(ctx, lUDP, rtimeout)
} else {
m, sPC, err = readerPC.ReadPacketConn(l, rtimeout)
m, sPC, ctx, err = readerPC.ReadPacketConn(ctx, l, rtimeout)
}
if err != nil {
if !srv.isStarted() {
Expand All @@ -530,7 +531,7 @@ func (srv *Server) serveUDP(l net.PacketConn) error {
continue
}
wg.Add(1)
go srv.serveUDPPacket(&wg, m, l, sUDP, sPC)
go srv.serveUDPPacket(ctx, &wg, m, l, sUDP, sPC)
}

return nil
Expand Down Expand Up @@ -563,12 +564,12 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
}

for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ {
m, err := reader.ReadTCP(w.tcp, timeout)
m, ctx, err := reader.ReadTCP(context.Background(), w.tcp, timeout)
if err != nil {
// TODO(tmthrgd): handle error
break
}
srv.serveDNS(m, w)
srv.serveDNS(ctx, m, w)
if w.closed {
break // Close() was called
}
Expand All @@ -592,19 +593,19 @@ func (srv *Server) serveTCPConn(wg *sync.WaitGroup, rw net.Conn) {
}

// Serve a new UDP request.
func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
func (srv *Server) serveUDPPacket(ctx context.Context, wg *sync.WaitGroup, m []byte, u net.PacketConn, udpSession *SessionUDP, pcSession net.Addr) {
w := &response{tsigProvider: srv.tsigProvider(), udp: u, udpSession: udpSession, pcSession: pcSession}
if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w)
} else {
w.writer = w
}

srv.serveDNS(m, w)
srv.serveDNS(ctx, m, w)
wg.Done()
}

func (srv *Server) serveDNS(m []byte, w *response) {
func (srv *Server) serveDNS(ctx context.Context, m []byte, w *response) {
dh, off, err := unpackMsgHdr(m, 0)
if err != nil {
// Let client hang, they are sending crap; any reply can be used to amplify.
Expand Down Expand Up @@ -656,10 +657,10 @@ func (srv *Server) serveDNS(m []byte, w *response) {
srv.udpPool.Put(m[:srv.UDPSize])
}

srv.Handler.ServeDNS(w, req) // Writes back to the client
srv.Handler.ServeDNS(ctx, w, req) // Writes back to the client
}

func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
func (srv *Server) readTCP(ctx context.Context, conn net.Conn, timeout time.Duration) ([]byte, context.Context, error) {
// If we race with ShutdownContext, the read deadline may
// have been set in the distant past to unblock the read
// below. We must not override it, otherwise we may block
Expand All @@ -672,18 +673,18 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error)

var length uint16
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
return nil, err
return nil, ctx, err
}

m := make([]byte, length)
if _, err := io.ReadFull(conn, m); err != nil {
return nil, err
return nil, ctx, err
}

return m, nil
return m, ctx, nil
}

func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
func (srv *Server) readUDP(ctx context.Context, conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, context.Context, error) {
srv.lock.RLock()
if srv.started {
// See the comment in readTCP above.
Expand All @@ -695,13 +696,13 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S
n, s, err := ReadFromSessionUDP(conn, m)
if err != nil {
srv.udpPool.Put(m)
return nil, nil, err
return nil, nil, ctx, err
}
m = m[:n]
return m, s, nil
return m, s, ctx, nil
}

func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, error) {
func (srv *Server) readPacketConn(ctx context.Context, conn net.PacketConn, timeout time.Duration) ([]byte, net.Addr, context.Context, error) {
srv.lock.RLock()
if srv.started {
// See the comment in readTCP above.
Expand All @@ -713,10 +714,10 @@ func (srv *Server) readPacketConn(conn net.PacketConn, timeout time.Duration) ([
n, addr, err := conn.ReadFrom(m)
if err != nil {
srv.udpPool.Put(m)
return nil, nil, err
return nil, nil, ctx, err
}
m = m[:n]
return m, addr, nil
return m, addr, ctx, nil
}

// WriteMsg implements the ResponseWriter.WriteMsg method.
Expand All @@ -728,7 +729,7 @@ func (w *response) WriteMsg(m *Msg) (err error) {
var data []byte
if w.tsigProvider != nil { // if no provider, dont check for the tsig (which is a longer check)
if t := m.IsTsig(); t != nil {
data, w.tsigRequestMAC, err = tsigGenerateProvider(m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
data, w.tsigRequestMAC, err = tsigGenerateProvider(context.Background(), m, w.tsigProvider, w.tsigRequestMAC, w.tsigTimersOnly)
if err != nil {
return err
}
Expand Down