Skip to content

Commit

Permalink
Fix edns keepalive (#1317)
Browse files Browse the repository at this point in the history
* Fix un/packing of EDNS0 TCP keepalive extension

The current code is assuming the OPT code and length are part of the
data passed to the unpacker and that it should be appenedded during
packing while those fields are parsed by the caller.

This change fixes the parsing and make sure a request with a keepalive
won't fail.

* added tests for EDNS0_TCP_KEEPALIVE pack/unpack

* mark Length as deprecated

* removed named returns

* restored String

Co-authored-by: Olivier Poitrey <rs@rhapsodyk.net>
  • Loading branch information
ameshkov and rs committed Dec 23, 2021
1 parent f4af582 commit ba44371
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 32 deletions.
61 changes: 29 additions & 32 deletions edns.go
Expand Up @@ -98,6 +98,8 @@ func (rr *OPT) String() string {
s += "\n; SUBNET: " + o.String()
case *EDNS0_COOKIE:
s += "\n; COOKIE: " + o.String()
case *EDNS0_TCP_KEEPALIVE:
s += "\n; KEEPALIVE: " + o.String()
case *EDNS0_UL:
s += "\n; UPDATE LEASE: " + o.String()
case *EDNS0_LLQ:
Expand Down Expand Up @@ -657,57 +659,52 @@ func (e *EDNS0_LOCAL) unpack(b []byte) error {
// EDNS0_TCP_KEEPALIVE is an EDNS0 option that instructs the server to keep
// the TCP connection alive. See RFC 7828.
type EDNS0_TCP_KEEPALIVE struct {
Code uint16 // Always EDNSTCPKEEPALIVE
Length uint16 // the value 0 if the TIMEOUT is omitted, the value 2 if it is present;
Timeout uint16 // an idle timeout value for the TCP connection, specified in units of 100 milliseconds, encoded in network byte order.
Code uint16 // Always EDNSTCPKEEPALIVE

// Timeout is an idle timeout value for the TCP connection, specified in
// units of 100 milliseconds, encoded in network byte order. If set to 0,
// pack will return a nil slice.
Timeout uint16

// Length is the option's length.
// Deprecated: this field is deprecated and is always equal to 0.
Length uint16
}

// Option implements the EDNS0 interface.
func (e *EDNS0_TCP_KEEPALIVE) Option() uint16 { return EDNS0TCPKEEPALIVE }

func (e *EDNS0_TCP_KEEPALIVE) pack() ([]byte, error) {
if e.Timeout != 0 && e.Length != 2 {
return nil, errors.New("dns: timeout specified but length is not 2")
}
if e.Timeout == 0 && e.Length != 0 {
return nil, errors.New("dns: timeout not specified but length is not 0")
if e.Timeout > 0 {
b := make([]byte, 2)
binary.BigEndian.PutUint16(b, e.Timeout)
return b, nil
}
b := make([]byte, 4+e.Length)
binary.BigEndian.PutUint16(b[0:], e.Code)
binary.BigEndian.PutUint16(b[2:], e.Length)
if e.Length == 2 {
binary.BigEndian.PutUint16(b[4:], e.Timeout)
}
return b, nil
return nil, nil
}

func (e *EDNS0_TCP_KEEPALIVE) unpack(b []byte) error {
if len(b) < 4 {
return ErrBuf
}
e.Length = binary.BigEndian.Uint16(b[2:4])
if e.Length != 0 && e.Length != 2 {
return errors.New("dns: length mismatch, want 0/2 but got " + strconv.FormatUint(uint64(e.Length), 10))
}
if e.Length == 2 {
if len(b) < 6 {
return ErrBuf
}
e.Timeout = binary.BigEndian.Uint16(b[4:6])
switch len(b) {
case 0:
case 2:
e.Timeout = binary.BigEndian.Uint16(b)
default:
return fmt.Errorf("dns: length mismatch, want 0/2 but got %d", len(b))
}
return nil
}

func (e *EDNS0_TCP_KEEPALIVE) String() (s string) {
s = "use tcp keep-alive"
if e.Length == 0 {
func (e *EDNS0_TCP_KEEPALIVE) String() string {
s := "use tcp keep-alive"
if e.Timeout == 0 {
s += ", timeout omitted"
} else {
s += fmt.Sprintf(", timeout %dms", e.Timeout*100)
}
return
return s
}
func (e *EDNS0_TCP_KEEPALIVE) copy() EDNS0 { return &EDNS0_TCP_KEEPALIVE{e.Code, e.Length, e.Timeout} }

func (e *EDNS0_TCP_KEEPALIVE) copy() EDNS0 { return &EDNS0_TCP_KEEPALIVE{e.Code, e.Timeout, e.Length} }

// EDNS0_PADDING option is used to add padding to a request/response. The default
// value of padding SHOULD be 0x0 but other values MAY be used, for instance if
Expand Down
85 changes: 85 additions & 0 deletions edns_test.go
@@ -1,6 +1,7 @@
package dns

import (
"bytes"
"net"
"testing"
)
Expand Down Expand Up @@ -192,3 +193,87 @@ func TestEDNS0_ESU(t *testing.T) {
t.Errorf("unpacked option is different; expected %v, got %v", expect, esu.Uri)
}
}

func TestEDNS0_TCP_KEEPALIVE_unpack(t *testing.T) {
cases := []struct {
name string
b []byte
expected uint16
expectedErr bool
}{
{
name: "empty",
b: []byte{},
expected: 0,
},
{
name: "timeout 1",
b: []byte{0, 1},
expected: 1,
},
{
name: "invalid",
b: []byte{0, 1, 3},
expectedErr: true,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
e := &EDNS0_TCP_KEEPALIVE{}
err := e.unpack(tc.b)
if err != nil && !tc.expectedErr {
t.Error("failed to unpack, expected no error")
}
if err == nil && tc.expectedErr {
t.Error("unpacked, but expected an error")
}
if e.Timeout != tc.expected {
t.Errorf("invalid timeout, actual: %d, expected: %d", e.Timeout, tc.expected)
}
})
}
}

func TestEDNS0_TCP_KEEPALIVE_pack(t *testing.T) {
cases := []struct {
name string
edns *EDNS0_TCP_KEEPALIVE
expected []byte
}{
{
name: "empty",
edns: &EDNS0_TCP_KEEPALIVE{
Code: EDNS0TCPKEEPALIVE,
Timeout: 0,
},
expected: nil,
},
{
name: "timeout 1",
edns: &EDNS0_TCP_KEEPALIVE{
Code: EDNS0TCPKEEPALIVE,
Timeout: 1,
},
expected: []byte{0, 1},
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
b, err := tc.edns.pack()
if err != nil {
t.Error("expected no error")
}

if tc.expected == nil && b != nil {
t.Errorf("invalid result, expected nil")
}

res := bytes.Compare(b, tc.expected)
if res != 0 {
t.Errorf("invalid result, expected: %v, actual: %v", tc.expected, b)
}
})
}
}

0 comments on commit ba44371

Please sign in to comment.