Skip to content

Commit

Permalink
Merge pull request #236 from hashicorp/dnephin/fix-some-errors
Browse files Browse the repository at this point in the history
Prevent a couple panics on malformed input
  • Loading branch information
dnephin committed Apr 30, 2021
2 parents 838073f + 45d05f1 commit 619135c
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
4 changes: 4 additions & 0 deletions net.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,10 @@ func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time
}

func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Time) {
if len(buf) < 1 {
m.logger.Printf("[ERR] memberlist: missing message type byte %s", LogAddress(from))
return
}
// Decode the message type
msgType := messageType(buf[0])
buf = buf[1:]
Expand Down
9 changes: 9 additions & 0 deletions net_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -866,3 +866,12 @@ func listenUDP(t *testing.T) *net.UDPConn {
}
return udp
}

func TestHandleCommand(t *testing.T) {
var buf bytes.Buffer
m := Memberlist{
logger: log.New(&buf, "", 0),
}
m.handleCommand(nil, &net.TCPAddr{Port: 12345}, time.Now())
require.Contains(t, buf.String(), "missing message type byte")
}
8 changes: 4 additions & 4 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,26 +185,26 @@ func decodeCompoundMessage(buf []byte) (trunc int, parts [][]byte, err error) {
err = fmt.Errorf("missing compound length byte")
return
}
numParts := uint8(buf[0])
numParts := int(buf[0])
buf = buf[1:]

// Check we have enough bytes
if len(buf) < int(numParts*2) {
if len(buf) < numParts*2 {
err = fmt.Errorf("truncated len slice")
return
}

// Decode the lengths
lengths := make([]uint16, numParts)
for i := 0; i < int(numParts); i++ {
for i := 0; i < numParts; i++ {
lengths[i] = binary.BigEndian.Uint16(buf[i*2 : i*2+2])
}
buf = buf[numParts*2:]

// Split each message
for idx, msgLen := range lengths {
if len(buf) < int(msgLen) {
trunc = int(numParts) - idx
trunc = numParts - idx
return
}

Expand Down
9 changes: 9 additions & 0 deletions util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"reflect"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestUtil_PortFunctions(t *testing.T) {
Expand Down Expand Up @@ -314,6 +316,13 @@ func TestDecodeCompoundMessage(t *testing.T) {
}
}

func TestDecodeCompoundMessage_NumberOfPartsOverflow(t *testing.T) {
buf := []byte{0x80}
_, _, err := decodeCompoundMessage(buf)
require.Error(t, err)
require.Equal(t, err.Error(), "truncated len slice")
}

func TestDecodeCompoundMessage_Trunc(t *testing.T) {
msg := &ping{SeqNo: 100}
buf, err := encode(pingMsg, msg)
Expand Down

0 comments on commit 619135c

Please sign in to comment.