diff --git a/memberlist_test.go b/memberlist_test.go index 70b5d780c..6a1eb70a0 100644 --- a/memberlist_test.go +++ b/memberlist_test.go @@ -1200,9 +1200,9 @@ func TestMemberlist_UserData(t *testing.T) { bindPort := m1.config.BindPort - bcasts := [][]byte{ - []byte("test"), - []byte("foobar"), + bcasts := make([][]byte, 256) + for i := range bcasts { + bcasts[i] = []byte(fmt.Sprintf("%d", i)) } // Create a second node diff --git a/net.go b/net.go index bac73bd89..fe4acbc2e 100644 --- a/net.go +++ b/net.go @@ -739,11 +739,17 @@ func (m *Memberlist) sendMsg(a Address, msg []byte) error { msgs = append(msgs, msg) msgs = append(msgs, extra...) - // Create a compound message - compound := makeCompoundMessage(msgs) + // Create one or more compound messages. + compounds := makeCompoundMessages(msgs) - // Send the message - return m.rawSendMsgPacket(a, nil, compound.Bytes()) + // Send the messages. + for _, compound := range compounds { + if err := m.rawSendMsgPacket(a, nil, compound.Bytes()); err != nil { + return err + } + } + + return nil } // rawSendMsgPacket is used to send message via packet to another host without diff --git a/state.go b/state.go index 5e4f7fdd7..95e6ddc48 100644 --- a/state.go +++ b/state.go @@ -606,10 +606,12 @@ func (m *Memberlist) gossip() { m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err) } } else { - // Otherwise create and send a compound message - compound := makeCompoundMessage(msgs) - if err := m.rawSendMsgPacket(node.FullAddress(), &node, compound.Bytes()); err != nil { - m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err) + // Otherwise create and send one or more compound messages + compounds := makeCompoundMessages(msgs) + for _, compound := range compounds { + if err := m.rawSendMsgPacket(node.FullAddress(), &node, compound.Bytes()); err != nil { + m.logger.Printf("[ERR] memberlist: Failed to send gossip to %s: %s", addr, err) + } } } } diff --git a/util.go b/util.go index 16a7d36d0..e7be4ad88 100644 --- a/util.go +++ b/util.go @@ -152,6 +152,49 @@ OUTER: return kNodes } +// makeCompoundMessages takes a list of messages and packs +// them into one or multiple messages based on the limitations +// of compound messages (255 messages each, 64KB max message size). +// +// The input msgs can be modified in-place. +func makeCompoundMessages(msgs [][]byte) []*bytes.Buffer { + const ( + maxMsgs = math.MaxUint8 + maxMsgLength = math.MaxUint16 + ) + + // Optimistically assume there will be no big message. + bufs := make([]*bytes.Buffer, 0, (len(msgs)+(maxMsgs-1))/maxMsgs) + + // Do not add to a compound message any message bigger than the max message length + // we can store. + r, w := 0, 0 + for r < len(msgs) { + if len(msgs[r]) <= maxMsgLength { + // Keep it. + msgs[w] = msgs[r] + r++ + w++ + continue + } + + // This message is a large one, so we send it alone. + bufs = append(bufs, bytes.NewBuffer(msgs[r])) + r++ + } + msgs = msgs[:w] + + // Group remaining messages in compound message(s). + for ; len(msgs) > maxMsgs; msgs = msgs[maxMsgs:] { + bufs = append(bufs, makeCompoundMessage(msgs[:maxMsgs])) + } + if len(msgs) > 0 { + bufs = append(bufs, makeCompoundMessage(msgs)) + } + + return bufs +} + // makeCompoundMessage takes a list of messages and generates // a single compound message containing all of them func makeCompoundMessage(msgs [][]byte) *bytes.Buffer { diff --git a/util_test.go b/util_test.go index f97eb1703..0b43f2aa6 100644 --- a/util_test.go +++ b/util_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -365,3 +366,137 @@ func TestCompressDecompressPayload(t *testing.T) { t.Fatalf("bad payload: %v", decomp) } } + +func TestMakeCompoundMessages(t *testing.T) { + const ( + smallMsgSeqNo = uint32(1) + smallMsgPayloadLength = 1 + bigMsgSeqNo = uint32(2) + bigMsgPayloadLength = 70000 + ) + + // Generate some fixtures. + smallMessages := make([][]byte, 300) + for i := 0; i < len(smallMessages); i++ { + msg := &ackResp{SeqNo: smallMsgSeqNo, Payload: []byte{byte(i)}} + encoded, err := encode(ackRespMsg, msg) + require.NoError(t, err) + + smallMessages[i] = encoded.Bytes() + } + + bigMessages := make([][]byte, 3) + for i := 0; i < len(bigMessages); i++ { + payload := []byte{bigMsgPayloadLength - 1: byte(i)} + require.Len(t, payload, bigMsgPayloadLength) + + msg := &ackResp{SeqNo: bigMsgSeqNo, Payload: payload} + encoded, err := encode(ackRespMsg, msg) + require.NoError(t, err) + + bigMessages[i] = encoded.Bytes() + } + + tests := map[string]struct { + input [][]byte + expected [][]byte + }{ + "no input": { + input: [][]byte{}, + expected: [][]byte{}, + }, + "one small message": { + input: smallMessages[0:1], + expected: [][]byte{makeCompoundMessage(smallMessages[0:1]).Bytes()}, + }, + "few small messages": { + input: smallMessages[0:3], + expected: [][]byte{makeCompoundMessage(smallMessages[0:3]).Bytes()}, + }, + "many small messages (more than 255)": { + input: smallMessages[0:300], + expected: [][]byte{ + makeCompoundMessage(smallMessages[0:255]).Bytes(), + makeCompoundMessage(smallMessages[255:300]).Bytes(), + }, + }, + "one big message": { + input: bigMessages[0:1], + expected: bigMessages[0:1], + }, + "few big messages": { + input: bigMessages[0:3], + expected: bigMessages[0:3], + }, + "mix of many small and big messages": { + input: func() [][]byte { + var out [][]byte + + out = append(out, bigMessages[0]) + out = append(out, smallMessages[0:20]...) + out = append(out, bigMessages[1]) + out = append(out, smallMessages[20:260]...) + out = append(out, bigMessages[2]) + out = append(out, smallMessages[260:300]...) + + return out + }(), + expected: [][]byte{ + bigMessages[0], + bigMessages[1], + bigMessages[2], + makeCompoundMessage(smallMessages[0:255]).Bytes(), + makeCompoundMessage(smallMessages[255:300]).Bytes(), + }, + }, + } + + for testName, testData := range tests { + t.Run(testName, func(t *testing.T) { + actual := makeCompoundMessages(testData.input) + + // Get the actual []byte of each message. + actualBytes := make([][]byte, 0, len(actual)) + for _, data := range actual { + actualBytes = append(actualBytes, data.Bytes()) + } + + assert.Equal(t, testData.expected, actualBytes) + + // Ensure we can successfully decode every message. + for i := 0; i < len(actual); i++ { + msg := actualBytes[i] + typ := messageType(msg[0]) + + switch typ { + case ackRespMsg: + var got ackResp + require.NoError(t, decode(msg[1:], &got)) + + if got.SeqNo == smallMsgSeqNo { + assert.Len(t, got.Payload, smallMsgPayloadLength) + } else if got.SeqNo == bigMsgSeqNo { + assert.Len(t, got.Payload, bigMsgPayloadLength) + } else { + require.Fail(t, "unexpected seq no") + } + case compoundMsg: + trunc, parts, err := decodeCompoundMessage(msg[1:]) + require.NoError(t, err) + require.Equal(t, 0, trunc) + + for _, part := range parts { + require.Equal(t, ackRespMsg, messageType(part[0])) + + var got ackResp + require.NoError(t, decode(part[1:], &got)) + assert.Equal(t, smallMsgSeqNo, got.SeqNo) + assert.Len(t, got.Payload, smallMsgPayloadLength) + } + default: + require.Fail(t, "unexpected message") + } + } + }) + } +}