Skip to content

Commit

Permalink
Merge pull request #1 from grafana/cherry-pick-compound-fix
Browse files Browse the repository at this point in the history
Fix for compound messages containing >255 messages or messages > 64KB
  • Loading branch information
pracucci committed Dec 1, 2021
2 parents 619135c + ffbe0d2 commit c7bc8e9
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 11 deletions.
6 changes: 3 additions & 3 deletions memberlist_test.go
Expand Up @@ -1200,9 +1200,9 @@ func TestMemberlist_UserData(t *testing.T) {

bindPort := m1.config.BindPort

bcasts := [][]byte{
[]byte("test"),
[]byte("foobar"),
bcasts := make([][]byte, 256)
for i := range bcasts {
bcasts[i] = []byte(fmt.Sprintf("%d", i))
}

// Create a second node
Expand Down
14 changes: 10 additions & 4 deletions net.go
Expand Up @@ -739,11 +739,17 @@ func (m *Memberlist) sendMsg(a Address, msg []byte) error {
msgs = append(msgs, msg)
msgs = append(msgs, extra...)

// Create a compound message
compound := makeCompoundMessage(msgs)
// Create one or more compound messages.
compounds := makeCompoundMessages(msgs)

// Send the message
return m.rawSendMsgPacket(a, nil, compound.Bytes())
// Send the messages.
for _, compound := range compounds {
if err := m.rawSendMsgPacket(a, nil, compound.Bytes()); err != nil {
return err
}
}

return nil
}

// rawSendMsgPacket is used to send message via packet to another host without
Expand Down
10 changes: 6 additions & 4 deletions state.go
Expand Up @@ -606,10 +606,12 @@ func (m *Memberlist) gossip() {
m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err)
}
} else {
// Otherwise create and send a compound message
compound := makeCompoundMessage(msgs)
if err := m.rawSendMsgPacket(node.FullAddress(), &node, compound.Bytes()); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err)
// Otherwise create and send one or more compound messages
compounds := makeCompoundMessages(msgs)
for _, compound := range compounds {
if err := m.rawSendMsgPacket(node.FullAddress(), &node, compound.Bytes()); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err)
}
}
}
}
Expand Down
43 changes: 43 additions & 0 deletions util.go
Expand Up @@ -152,6 +152,49 @@ OUTER:
return kNodes
}

// makeCompoundMessages takes a list of messages and packs
// them into one or multiple messages based on the limitations
// of compound messages (255 messages each, 64KB max message size).
//
// The input msgs can be modified in-place.
func makeCompoundMessages(msgs [][]byte) []*bytes.Buffer {
const (
maxMsgs = math.MaxUint8
maxMsgLength = math.MaxUint16
)

// Optimistically assume there will be no big message.
bufs := make([]*bytes.Buffer, 0, (len(msgs)+(maxMsgs-1))/maxMsgs)

// Do not add to a compound message any message bigger than the max message length
// we can store.
r, w := 0, 0
for r < len(msgs) {
if len(msgs[r]) <= maxMsgLength {
// Keep it.
msgs[w] = msgs[r]
r++
w++
continue
}

// This message is a large one, so we send it alone.
bufs = append(bufs, bytes.NewBuffer(msgs[r]))
r++
}
msgs = msgs[:w]

// Group remaining messages in compound message(s).
for ; len(msgs) > maxMsgs; msgs = msgs[maxMsgs:] {
bufs = append(bufs, makeCompoundMessage(msgs[:maxMsgs]))
}
if len(msgs) > 0 {
bufs = append(bufs, makeCompoundMessage(msgs))
}

return bufs
}

// makeCompoundMessage takes a list of messages and generates
// a single compound message containing all of them
func makeCompoundMessage(msgs [][]byte) *bytes.Buffer {
Expand Down
135 changes: 135 additions & 0 deletions util_test.go
Expand Up @@ -6,6 +6,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -365,3 +366,137 @@ func TestCompressDecompressPayload(t *testing.T) {
t.Fatalf("bad payload: %v", decomp)
}
}

func TestMakeCompoundMessages(t *testing.T) {
const (
smallMsgSeqNo = uint32(1)
smallMsgPayloadLength = 1
bigMsgSeqNo = uint32(2)
bigMsgPayloadLength = 70000
)

// Generate some fixtures.
smallMessages := make([][]byte, 300)
for i := 0; i < len(smallMessages); i++ {
msg := &ackResp{SeqNo: smallMsgSeqNo, Payload: []byte{byte(i)}}
encoded, err := encode(ackRespMsg, msg)
require.NoError(t, err)

smallMessages[i] = encoded.Bytes()
}

bigMessages := make([][]byte, 3)
for i := 0; i < len(bigMessages); i++ {
payload := []byte{bigMsgPayloadLength - 1: byte(i)}
require.Len(t, payload, bigMsgPayloadLength)

msg := &ackResp{SeqNo: bigMsgSeqNo, Payload: payload}
encoded, err := encode(ackRespMsg, msg)
require.NoError(t, err)

bigMessages[i] = encoded.Bytes()
}

tests := map[string]struct {
input [][]byte
expected [][]byte
}{
"no input": {
input: [][]byte{},
expected: [][]byte{},
},
"one small message": {
input: smallMessages[0:1],
expected: [][]byte{makeCompoundMessage(smallMessages[0:1]).Bytes()},
},
"few small messages": {
input: smallMessages[0:3],
expected: [][]byte{makeCompoundMessage(smallMessages[0:3]).Bytes()},
},
"many small messages (more than 255)": {
input: smallMessages[0:300],
expected: [][]byte{
makeCompoundMessage(smallMessages[0:255]).Bytes(),
makeCompoundMessage(smallMessages[255:300]).Bytes(),
},
},
"one big message": {
input: bigMessages[0:1],
expected: bigMessages[0:1],
},
"few big messages": {
input: bigMessages[0:3],
expected: bigMessages[0:3],
},
"mix of many small and big messages": {
input: func() [][]byte {
var out [][]byte

out = append(out, bigMessages[0])
out = append(out, smallMessages[0:20]...)
out = append(out, bigMessages[1])
out = append(out, smallMessages[20:260]...)
out = append(out, bigMessages[2])
out = append(out, smallMessages[260:300]...)

return out
}(),
expected: [][]byte{
bigMessages[0],
bigMessages[1],
bigMessages[2],
makeCompoundMessage(smallMessages[0:255]).Bytes(),
makeCompoundMessage(smallMessages[255:300]).Bytes(),
},
},
}

for testName, testData := range tests {
t.Run(testName, func(t *testing.T) {
actual := makeCompoundMessages(testData.input)

// Get the actual []byte of each message.
actualBytes := make([][]byte, 0, len(actual))
for _, data := range actual {
actualBytes = append(actualBytes, data.Bytes())
}

assert.Equal(t, testData.expected, actualBytes)

// Ensure we can successfully decode every message.
for i := 0; i < len(actual); i++ {
msg := actualBytes[i]
typ := messageType(msg[0])

switch typ {
case ackRespMsg:
var got ackResp
require.NoError(t, decode(msg[1:], &got))

if got.SeqNo == smallMsgSeqNo {
assert.Len(t, got.Payload, smallMsgPayloadLength)
} else if got.SeqNo == bigMsgSeqNo {
assert.Len(t, got.Payload, bigMsgPayloadLength)
} else {
require.Fail(t, "unexpected seq no")
}
case compoundMsg:
trunc, parts, err := decodeCompoundMessage(msg[1:])
require.NoError(t, err)
require.Equal(t, 0, trunc)

for _, part := range parts {
require.Equal(t, ackRespMsg, messageType(part[0]))

var got ackResp
require.NoError(t, decode(part[1:], &got))
assert.Equal(t, smallMsgSeqNo, got.SeqNo)
assert.Len(t, got.Payload, smallMsgPayloadLength)
}
default:
require.Fail(t, "unexpected message")
}
}
})
}
}

0 comments on commit c7bc8e9

Please sign in to comment.