Skip to content

Commit

Permalink
avoid allocation a wire.Header when parsing short header packets
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Dec 8, 2020
1 parent 0ab2ef3 commit ff69ca6
Show file tree
Hide file tree
Showing 17 changed files with 516 additions and 562 deletions.
5 changes: 4 additions & 1 deletion example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io/ioutil"
"log"
"mime/multipart"
"net"
"net/http"
"os"
"strconv"
Expand Down Expand Up @@ -163,7 +164,9 @@ func main() {
}

handler := setupHandler(*www)
quicConf := &quic.Config{}
quicConf := &quic.Config{
AcceptToken: func(clientAddr net.Addr, token *quic.Token) bool { return true },
}
if *enableQlog {
quicConf.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
filename := fmt.Sprintf("server_%x.qlog", connID)
Expand Down
5 changes: 4 additions & 1 deletion fuzzing/header/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ func Fuzz(data []byte) int {
if err != nil {
return 0
}
hdr, _, _, err := wire.ParsePacket(data, connIDLen)
if !wire.IsLongHeader(data[0]) {
return 0
}
hdr, _, _, err := wire.ParseLongHeaderPacket(data)
if err != nil {
return 0
}
Expand Down
58 changes: 39 additions & 19 deletions integrationtests/self/mitm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,35 @@ var _ = Describe("MITM test", func() {

sendRandomPacketsOfSameType := func(conn net.PacketConn, remoteAddr net.Addr, raw []byte) {
defer GinkgoRecover()
hdr, _, _, err := wire.ParsePacket(raw, connIDLen)
Expect(err).ToNot(HaveOccurred())
replyHdr := &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: hdr.IsLongHeader,
DestConnectionID: hdr.DestConnectionID,
SrcConnectionID: hdr.SrcConnectionID,
Type: hdr.Type,
Version: hdr.Version,
},
PacketNumber: protocol.PacketNumber(mrand.Int31n(math.MaxInt32 / 4)),
PacketNumberLen: protocol.PacketNumberLen(mrand.Int31n(4) + 1),
var replyHdr *wire.ExtendedHeader
if wire.IsLongHeader(raw[0]) {
hdr, _, _, err := wire.ParseLongHeaderPacket(raw)
Expect(err).ToNot(HaveOccurred())
replyHdr = &wire.ExtendedHeader{
Header: wire.Header{
IsLongHeader: hdr.IsLongHeader,
DestConnectionID: hdr.DestConnectionID,
SrcConnectionID: hdr.SrcConnectionID,
Type: hdr.Type,
Version: hdr.Version,
},
PacketNumber: protocol.PacketNumber(mrand.Int31n(math.MaxInt32 / 4)),
PacketNumberLen: protocol.PacketNumberLen(mrand.Int31n(4) + 1),
}
} else {
destConnID, err := wire.ParseConnectionID(raw, connIDLen)
Expect(err).ToNot(HaveOccurred())
replyHdr = &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: destConnID},
}
}

const numPackets = 10
ticker := time.NewTicker(rtt / numPackets)
for i := 0; i < numPackets; i++ {
payloadLen := mrand.Int31n(100)
replyHdr.PacketNumber = protocol.PacketNumber(mrand.Int31())
replyHdr.PacketNumberLen = protocol.PacketNumberLen(mrand.Int31n(4) + 1)
replyHdr.Length = protocol.ByteCount(mrand.Int31n(payloadLen + 1))
buf := &bytes.Buffer{}
Expect(replyHdr.Write(buf, version)).To(Succeed())
Expand Down Expand Up @@ -350,9 +361,11 @@ var _ = Describe("MITM test", func() {
if dir == quicproxy.DirectionIncoming {
defer GinkgoRecover()

hdr, _, _, err := wire.ParsePacket(raw, connIDLen)
if !wire.IsLongHeader(raw[0]) {
return 0
}
hdr, _, _, err := wire.ParseLongHeaderPacket(raw)
Expect(err).ToNot(HaveOccurred())

if hdr.Type != protocol.PacketTypeInitial {
return 0
}
Expand All @@ -375,9 +388,11 @@ var _ = Describe("MITM test", func() {
if dir == quicproxy.DirectionIncoming && !initialPacketIntercepted {
defer GinkgoRecover()

hdr, _, _, err := wire.ParsePacket(raw, connIDLen)
if !wire.IsLongHeader(raw[0]) {
return 0
}
hdr, _, _, err := wire.ParseLongHeaderPacket(raw)
Expect(err).ToNot(HaveOccurred())

if hdr.Type != protocol.PacketTypeInitial {
return 0
}
Expand All @@ -400,9 +415,11 @@ var _ = Describe("MITM test", func() {
if dir == quicproxy.DirectionIncoming {
defer GinkgoRecover()

hdr, _, _, err := wire.ParsePacket(raw, connIDLen)
if !wire.IsLongHeader(raw[0]) {
return 0
}
hdr, _, _, err := wire.ParseLongHeaderPacket(raw)
Expect(err).ToNot(HaveOccurred())

if hdr.Type != protocol.PacketTypeInitial {
return 0
}
Expand All @@ -421,7 +438,10 @@ var _ = Describe("MITM test", func() {
clientAddr := clientConn.LocalAddr()
delayCb := func(dir quicproxy.Direction, raw []byte) time.Duration {
if dir == quicproxy.DirectionIncoming {
hdr, _, _, err := wire.ParsePacket(raw, connIDLen)
if !wire.IsLongHeader(raw[0]) {
return 0
}
hdr, _, _, err := wire.ParseLongHeaderPacket(raw)
Expect(err).ToNot(HaveOccurred())
if hdr.Type != protocol.PacketTypeInitial {
return 0
Expand Down
43 changes: 26 additions & 17 deletions integrationtests/self/zero_rtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ var _ = Describe("0-RTT", func() {
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
hdr, _, _, err := wire.ParsePacket(data, 0)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1)
if wire.IsLongHeader(data[0]) {
hdr, _, _, err := wire.ParseLongHeaderPacket(data)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1)
}
}
return rtt / 2
},
Expand Down Expand Up @@ -222,23 +224,27 @@ var _ = Describe("0-RTT", func() {
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: fmt.Sprintf("localhost:%d", serverPort),
DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration {
hdr, _, _, err := wire.ParsePacket(data, 0)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1)
if wire.IsLongHeader(data[0]) {
hdr, _, _, err := wire.ParseLongHeaderPacket(data)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
atomic.AddUint32(&num0RTTPackets, 1)
}
}
return rtt / 2
},
DropPacket: func(_ quicproxy.Direction, data []byte) bool {
hdr, _, _, err := wire.ParsePacket(data, 0)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
// drop 25% of the 0-RTT packets
drop := mrand.Intn(4) == 0
if drop {
atomic.AddUint32(&num0RTTDropped, 1)
if wire.IsLongHeader(data[0]) {
hdr, _, _, err := wire.ParseLongHeaderPacket(data)
Expect(err).ToNot(HaveOccurred())
if hdr.Type == protocol.PacketType0RTT {
// drop 25% of the 0-RTT packets
drop := mrand.Intn(4) == 0
if drop {
atomic.AddUint32(&num0RTTDropped, 1)
}
return drop
}
return drop
}
return false
},
Expand Down Expand Up @@ -272,7 +278,10 @@ var _ = Describe("0-RTT", func() {

countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) {
for len(data) > 0 {
hdr, _, rest, err := wire.ParsePacket(data, 0)
if !wire.IsLongHeader(data[0]) {
return
}
hdr, _, rest, err := wire.ParseLongHeaderPacket(data)
if err != nil {
return
}
Expand Down
3 changes: 2 additions & 1 deletion integrationtests/tools/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ var _ = Describe("QUIC Proxy", func() {
}

readPacketNumber := func(b []byte) protocol.PacketNumber {
hdr, data, _, err := wire.ParsePacket(b, 0)
Expect(wire.IsLongHeader(b[0])).To(BeTrue())
hdr, data, _, err := wire.ParseLongHeaderPacket(b)
ExpectWithOffset(1, err).ToNot(HaveOccurred())
Expect(hdr.Type).To(Equal(protocol.PacketTypeInitial))
extHdr, err := hdr.ParseExtended(bytes.NewReader(data), protocol.VersionTLS)
Expand Down
42 changes: 11 additions & 31 deletions internal/wire/extended_header.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ type ExtendedHeader struct {

typeByte byte

KeyPhase protocol.KeyPhaseBit
KeyPhase protocol.KeyPhaseBit // TODO: remove. Unused for unpacking. Only used for packing and logging.

PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
Expand All @@ -93,12 +93,7 @@ func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool
if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil {
return false, err
}
var reservedBitsValid bool
if h.IsLongHeader {
reservedBitsValid, err = h.parseLongHeader(b, v)
} else {
reservedBitsValid, err = h.parseShortHeader(b, v)
}
reservedBitsValid, err := h.parseLongHeader(b, v)
if err != nil {
return false, err
}
Expand All @@ -119,21 +114,6 @@ func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumb
return true, nil
}

func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
h.KeyPhase = ReadKeyPhaseBit(h.typeByte)

pn, pnLen, err := ReadPacketNumber(b, h.typeByte)
if err != nil {
return false, err
}
h.PacketNumber = pn
h.PacketNumberLen = pnLen
if !CheckShortHeaderReservedBits(h.typeByte) {
return false, nil
}
return true, nil
}

// Write writes the Header.
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
Expand Down Expand Up @@ -184,7 +164,7 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, _ protocol.VersionNumb
b.Write(h.Token)
}
utils.WriteVarIntWithLen(b, uint64(h.Length), 2)
return h.writePacketNumber(b)
return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}

func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
Expand All @@ -195,21 +175,21 @@ func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNum

b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
return h.writePacketNumber(b)
return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}

func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
switch h.PacketNumberLen {
func writePacketNumber(b *bytes.Buffer, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) error {
switch pnLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
b.WriteByte(uint8(pn))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
utils.BigEndian.WriteUint16(b, uint16(pn))
case protocol.PacketNumberLen3:
utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber))
utils.BigEndian.WriteUint24(b, uint32(pn))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
utils.BigEndian.WriteUint32(b, uint32(pn))
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
return fmt.Errorf("invalid packet number length: %d", pnLen)
}
return nil
}
Expand Down
59 changes: 16 additions & 43 deletions internal/wire/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ var ErrUnsupportedVersion = errors.New("unsupported version")

// The Header is the version independent part of the header
type Header struct {
IsLongHeader bool
IsLongHeader bool // TODO: remove. Currently only needed for logging.
typeByte byte
Type protocol.PacketType

Expand All @@ -66,72 +66,45 @@ type Header struct {
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
}

// ParsePacket parses a packet.
// ParseLongHeaderPacket parses a long header packet.
// If the packet has a long header, the packet is cut according to the length field.
// If we understand the version, the packet is header up unto the packet number.
// Otherwise, only the invariant part of the header is parsed.
func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) {
hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen)
func ParseLongHeaderPacket(data []byte) (*Header, []byte /* packet data */, []byte /* rest */, error) {
hdr, err := parseLongHeader(bytes.NewReader(data))
if err != nil {
if err == ErrUnsupportedVersion {
return hdr, nil, nil, ErrUnsupportedVersion
}
return nil, nil, nil, err
}
var rest []byte
if hdr.IsLongHeader {
if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
}
packetLen := int(hdr.ParsedLen() + hdr.Length)
rest = data[packetLen:]
data = data[:packetLen]
if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
}
packetLen := int(hdr.ParsedLen() + hdr.Length)
rest := data[packetLen:]
data = data[:packetLen]
return hdr, data, rest, nil
}

// ParseHeader parses the header.
// For short header packets: up to the packet number.
// For long header packets:
// parseHeader parses the Long Header.
// * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
func parseLongHeader(b *bytes.Reader) (*Header, error) {
startLen := b.Len()
h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
if err != nil {
return h, err
}
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return h, err
}

func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
}

h := &Header{
typeByte: typeByte,
IsLongHeader: typeByte&0x80 > 0,
IsLongHeader: true,
}

if !h.IsLongHeader {
if h.typeByte&0x40 == 0 {
return nil, errors.New("not a QUIC packet")
}
if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
return nil, err
}
return h, nil
if err := h.parseLongHeader(b); err != nil {
return h, err
}
return h, h.parseLongHeader(b)
}

func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
var err error
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
return err
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return h, err
}

func (h *Header) parseLongHeader(b *bytes.Reader) error {
Expand Down

0 comments on commit ff69ca6

Please sign in to comment.