Skip to content

Commit

Permalink
append to a byte slice instead of a bytes.Buffer when serializing frames
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 28, 2022
1 parent 65dd82a commit 3ca1001
Show file tree
Hide file tree
Showing 50 changed files with 443 additions and 425 deletions.
12 changes: 6 additions & 6 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,11 +561,11 @@ var _ = Describe("Connection", func() {
}
Expect(hdr.Write(buf, conn.version)).To(Succeed())
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) {
buf := &bytes.Buffer{}
Expect((&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Write(buf, conn.version)).To(Succeed())
b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Write(nil, conn.version)
Expect(err).ToNot(HaveOccurred())
return &unpackedPacket{
hdr: hdr,
data: buf.Bytes(),
data: b,
encryptionLevel: protocol.Encryption1RTT,
}, nil
})
Expand Down Expand Up @@ -754,15 +754,15 @@ var _ = Describe("Connection", func() {
PacketNumberLen: protocol.PacketNumberLen1,
}
rcvTime := time.Now().Add(-10 * time.Second)
buf := &bytes.Buffer{}
Expect((&wire.PingFrame{}).Write(buf, conn.version)).To(Succeed())
b, err := (&wire.PingFrame{}).Write(nil, conn.version)
Expect(err).ToNot(HaveOccurred())
packet := getPacket(hdr, nil)
packet.ecn = protocol.ECT1
unpacker.EXPECT().Unpack(gomock.Any(), rcvTime, gomock.Any()).Return(&unpackedPacket{
packetNumber: 0x1337,
encryptionLevel: protocol.Encryption1RTT,
hdr: hdr,
data: buf.Bytes(),
data: b,
}, nil)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
gomock.InOrder(
Expand Down
19 changes: 10 additions & 9 deletions fuzzing/frames/cmd/corpus.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"bytes"
"log"
"math/rand"
"time"
Expand Down Expand Up @@ -253,32 +252,34 @@ func getFrames() []wire.Frame {

func main() {
for _, f := range getFrames() {
b := &bytes.Buffer{}
if err := f.Write(b, version); err != nil {
b, err := f.Write(nil, version)
if err != nil {
log.Fatal(err)
}
if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), 1); err != nil {
if err := helper.WriteCorpusFileWithPrefix("corpus", b, 1); err != nil {
log.Fatal(err)
}
}

for i := 0; i < 30; i++ {
frames := getFrames()

b := &bytes.Buffer{}
var b []byte
for j := 0; j < rand.Intn(30)+2; j++ {
if rand.Intn(10) == 0 { // write a PADDING frame
b.WriteByte(0x0)
b = append(b, 0)
}
f := frames[rand.Intn(len(frames))]
if err := f.Write(b, version); err != nil {
var err error
b, err = f.Write(b, version)
if err != nil {
log.Fatal(err)
}
if rand.Intn(10) == 0 { // write a PADDING frame
b.WriteByte(0x0)
b = append(b, 0)
}
}
if err := helper.WriteCorpusFileWithPrefix("corpus", b.Bytes(), 1); err != nil {
if err := helper.WriteCorpusFileWithPrefix("corpus", b, 1); err != nil {
log.Fatal(err)
}
}
Expand Down
15 changes: 8 additions & 7 deletions fuzzing/frames/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ func Fuzz(data []byte) int {
return 0
}

b := &bytes.Buffer{}
var b []byte
for _, f := range frames {
if f == nil { // PADDING frame
b.WriteByte(0x0)
b = append(b, 0)
continue
}
// We accept empty STREAM frames, but we don't write them.
Expand All @@ -68,20 +68,21 @@ func Fuzz(data []byte) int {
continue
}
}
lenBefore := b.Len()
if err := f.Write(b, version); err != nil {
lenBefore := len(b)
b, err := f.Write(b, version)
if err != nil {
panic(fmt.Sprintf("Error writing frame %#v: %s", f, err))
}
frameLen := b.Len() - lenBefore
frameLen := len(b) - lenBefore
if f.Length(version) != protocol.ByteCount(frameLen) {
panic(fmt.Sprintf("Inconsistent frame length for %#v: expected %d, got %d", f, frameLen, f.Length(version)))
}
if sf, ok := f.(*wire.StreamFrame); ok {
sf.PutBack()
}
}
if b.Len() > parsedLen {
panic(fmt.Sprintf("Serialized length (%d) is longer than parsed length (%d)", b.Len(), parsedLen))
if len(b) > parsedLen {
panic(fmt.Sprintf("Serialized length (%d) is longer than parsed length (%d)", len(b), parsedLen))
}
return 1
}
10 changes: 7 additions & 3 deletions internal/testutils/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@ func writePacket(hdr *wire.ExtendedHeader, data []byte) []byte {

// packRawPayload returns a new raw payload containing given frames
func packRawPayload(version protocol.VersionNumber, frames []wire.Frame) []byte {
buf := new(bytes.Buffer)
var b []byte
for _, cf := range frames {
cf.Write(buf, version)
var err error
b, err = cf.Write(b, version)
if err != nil {
panic(err)
}
}
return buf.Bytes()
return b
}

// ComposeInitialPacket returns an Initial packet encrypted under key
Expand Down
28 changes: 14 additions & 14 deletions internal/wire/ack_frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,40 +107,40 @@ func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ protocol.VersionNu
}

// Write writes an ACK frame.
func (f *AckFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
func (f *AckFrame) Write(b []byte, _ protocol.VersionNumber) ([]byte, error) {
hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
if hasECN {
b.WriteByte(0x3)
b = append(b, 0b11)
} else {
b.WriteByte(0x2)
b = append(b, 0b10)
}
quicvarint.Write(b, uint64(f.LargestAcked()))
quicvarint.Write(b, encodeAckDelay(f.DelayTime))
b = quicvarint.Append(b, uint64(f.LargestAcked()))
b = quicvarint.Append(b, encodeAckDelay(f.DelayTime))

numRanges := f.numEncodableAckRanges()
quicvarint.Write(b, uint64(numRanges-1))
b = quicvarint.Append(b, uint64(numRanges-1))

// write the first range
_, firstRange := f.encodeAckRange(0)
quicvarint.Write(b, firstRange)
b = quicvarint.Append(b, firstRange)

// write all the other range
for i := 1; i < numRanges; i++ {
gap, len := f.encodeAckRange(i)
quicvarint.Write(b, gap)
quicvarint.Write(b, len)
b = quicvarint.Append(b, gap)
b = quicvarint.Append(b, len)
}

if hasECN {
quicvarint.Write(b, f.ECT0)
quicvarint.Write(b, f.ECT1)
quicvarint.Write(b, f.ECNCE)
b = quicvarint.Append(b, f.ECT0)
b = quicvarint.Append(b, f.ECT1)
b = quicvarint.Append(b, f.ECNCE)
}
return nil
return b, nil
}

// Length of a written frame
func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
func (f *AckFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
largestAcked := f.AckRanges[0].Largest
numRanges := f.numEncodableAckRanges()

Expand Down
85 changes: 42 additions & 43 deletions internal/wire/ack_frame_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() {

It("uses the ack delay exponent", func() {
const delayTime = 1 << 10 * time.Millisecond
buf := &bytes.Buffer{}
f := &AckFrame{
AckRanges: []AckRange{{Smallest: 1, Largest: 1}},
DelayTime: delayTime,
}
Expect(f.Write(buf, protocol.Version1)).To(Succeed())
b, err := f.Write(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
for i := uint8(0); i < 8; i++ {
b := bytes.NewReader(buf.Bytes())
frame, err := parseAckFrame(b, protocol.AckDelayExponent+i, protocol.Version1)
r := bytes.NewReader(b)
frame, err := parseAckFrame(r, protocol.AckDelayExponent+i, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(frame.DelayTime).To(Equal(delayTime * (1 << i)))
}
Expand Down Expand Up @@ -202,29 +202,29 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() {

Context("when writing", func() {
It("writes a simple frame", func() {
buf := &bytes.Buffer{}
f := &AckFrame{
AckRanges: []AckRange{{Smallest: 100, Largest: 1337}},
}
Expect(f.Write(buf, protocol.Version1)).To(Succeed())
b, err := f.Write(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
expected := []byte{0x2}
expected = append(expected, encodeVarInt(1337)...) // largest acked
expected = append(expected, 0) // delay
expected = append(expected, encodeVarInt(0)...) // num ranges
expected = append(expected, encodeVarInt(1337-100)...)
Expect(buf.Bytes()).To(Equal(expected))
Expect(b).To(Equal(expected))
})

It("writes an ACK-ECN frame", func() {
buf := &bytes.Buffer{}
f := &AckFrame{
AckRanges: []AckRange{{Smallest: 10, Largest: 2000}},
ECT0: 13,
ECT1: 37,
ECNCE: 12345,
}
Expect(f.Write(buf, protocol.Version1)).To(Succeed())
Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len()))
b, err := f.Write(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(int(f.Length(protocol.Version1))))
expected := []byte{0x3}
expected = append(expected, encodeVarInt(2000)...) // largest acked
expected = append(expected, 0) // delay
Expand All @@ -233,63 +233,61 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() {
expected = append(expected, encodeVarInt(13)...)
expected = append(expected, encodeVarInt(37)...)
expected = append(expected, encodeVarInt(12345)...)
Expect(buf.Bytes()).To(Equal(expected))
Expect(b).To(Equal(expected))
})

It("writes a frame that acks a single packet", func() {
buf := &bytes.Buffer{}
f := &AckFrame{
AckRanges: []AckRange{{Smallest: 0x2eadbeef, Largest: 0x2eadbeef}},
DelayTime: 18 * time.Millisecond,
}
Expect(f.Write(buf, protocol.Version1)).To(Succeed())
Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len()))
b := bytes.NewReader(buf.Bytes())
frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1)
b, err := f.Write(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(int(f.Length(protocol.Version1))))
r := bytes.NewReader(b)
frame, err := parseAckFrame(r, protocol.AckDelayExponent, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
Expect(frame.HasMissingRanges()).To(BeFalse())
Expect(frame.DelayTime).To(Equal(f.DelayTime))
Expect(b.Len()).To(BeZero())
Expect(r.Len()).To(BeZero())
})

It("writes a frame that acks many packets", func() {
buf := &bytes.Buffer{}
f := &AckFrame{
AckRanges: []AckRange{{Smallest: 0x1337, Largest: 0x2eadbeef}},
}
Expect(f.Write(buf, protocol.Version1)).To(Succeed())
Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len()))
b := bytes.NewReader(buf.Bytes())
frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1)
b, err := f.Write(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(int(f.Length(protocol.Version1))))
r := bytes.NewReader(b)
frame, err := parseAckFrame(r, protocol.AckDelayExponent, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
Expect(frame.HasMissingRanges()).To(BeFalse())
Expect(b.Len()).To(BeZero())
Expect(r.Len()).To(BeZero())
})

It("writes a frame with a a single gap", func() {
buf := &bytes.Buffer{}
f := &AckFrame{
AckRanges: []AckRange{
{Smallest: 400, Largest: 1000},
{Smallest: 100, Largest: 200},
},
}
Expect(f.validateAckRanges()).To(BeTrue())
err := f.Write(buf, protocol.Version1)
b, err := f.Write(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len()))
b := bytes.NewReader(buf.Bytes())
frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1)
Expect(b).To(HaveLen(int(f.Length(protocol.Version1))))
r := bytes.NewReader(b)
frame, err := parseAckFrame(r, protocol.AckDelayExponent, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
Expect(frame.HasMissingRanges()).To(BeTrue())
Expect(b.Len()).To(BeZero())
Expect(r.Len()).To(BeZero())
})

It("writes a frame with multiple ranges", func() {
buf := &bytes.Buffer{}
f := &AckFrame{
AckRanges: []AckRange{
{Smallest: 10, Largest: 10},
Expand All @@ -299,35 +297,36 @@ var _ = Describe("ACK Frame (for IETF QUIC)", func() {
},
}
Expect(f.validateAckRanges()).To(BeTrue())
Expect(f.Write(buf, protocol.Version1)).To(Succeed())
Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len()))
b := bytes.NewReader(buf.Bytes())
frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1)
b, err := f.Write(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(int(f.Length(protocol.Version1))))
r := bytes.NewReader(b)
frame, err := parseAckFrame(r, protocol.AckDelayExponent, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(frame).To(Equal(f))
Expect(frame.HasMissingRanges()).To(BeTrue())
Expect(b.Len()).To(BeZero())
Expect(r.Len()).To(BeZero())
})

It("limits the maximum size of the ACK frame", func() {
buf := &bytes.Buffer{}
const numRanges = 1000
ackRanges := make([]AckRange, numRanges)
for i := protocol.PacketNumber(1); i <= numRanges; i++ {
ackRanges[numRanges-i] = AckRange{Smallest: 2 * i, Largest: 2 * i}
}
f := &AckFrame{AckRanges: ackRanges}
Expect(f.validateAckRanges()).To(BeTrue())
Expect(f.Write(buf, protocol.Version1)).To(Succeed())
Expect(f.Length(protocol.Version1)).To(BeEquivalentTo(buf.Len()))
b, err := f.Write(nil, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(b).To(HaveLen(int(f.Length(protocol.Version1))))
// make sure the ACK frame is *a little bit* smaller than the MaxAckFrameSize
Expect(buf.Len()).To(BeNumerically(">", protocol.MaxAckFrameSize-5))
Expect(buf.Len()).To(BeNumerically("<=", protocol.MaxAckFrameSize))
b := bytes.NewReader(buf.Bytes())
frame, err := parseAckFrame(b, protocol.AckDelayExponent, protocol.Version1)
Expect(len(b)).To(BeNumerically(">", protocol.MaxAckFrameSize-5))
Expect(len(b)).To(BeNumerically("<=", protocol.MaxAckFrameSize))
r := bytes.NewReader(b)
frame, err := parseAckFrame(r, protocol.AckDelayExponent, protocol.Version1)
Expect(err).ToNot(HaveOccurred())
Expect(frame.HasMissingRanges()).To(BeTrue())
Expect(b.Len()).To(BeZero())
Expect(r.Len()).To(BeZero())
Expect(len(frame.AckRanges)).To(BeNumerically("<", numRanges)) // make sure we dropped some ranges
})
})
Expand Down

0 comments on commit 3ca1001

Please sign in to comment.