From 015ae6654b817b2f095e0f3daf93cced9d8d2bb5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 10 Aug 2022 10:27:30 +0200 Subject: [PATCH] add a function to distinguish between long and short header packets --- integrationtests/self/datagram_test.go | 3 ++- integrationtests/self/zero_rtt_test.go | 2 +- internal/wire/header.go | 14 +++++++++----- internal/wire/header_test.go | 5 +++++ internal/wire/version_negotiation_test.go | 2 +- packet_handler_map.go | 4 ++-- packet_handler_map_test.go | 2 +- 7 files changed, 21 insertions(+), 11 deletions(-) diff --git a/integrationtests/self/datagram_test.go b/integrationtests/self/datagram_test.go index 12564f6ec0d..3dd47ac3ee1 100644 --- a/integrationtests/self/datagram_test.go +++ b/integrationtests/self/datagram_test.go @@ -13,6 +13,7 @@ import ( "github.com/lucas-clemente/quic-go" quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -73,7 +74,7 @@ var _ = Describe("Datagram test", func() { return false } // don't drop Long Header packets - if packet[0]&0x80 == 1 { + if wire.IsLongHeaderPacket(packet[0]) { return false } drop := mrand.Int()%10 == 0 diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 274a9183932..39be9eadbbb 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -732,7 +732,7 @@ var _ = Describe("0-RTT", func() { proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ RemoteAddr: ln.Addr().String(), DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { - if dir == quicproxy.DirectionIncoming && data[0]&0x80 > 0 && data[0]&0x30>>4 == 0 { // Initial packet from client + if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client return rtt/2 + rtt } return rtt / 2 diff --git a/internal/wire/header.go b/internal/wire/header.go index f6a31ee0ec4..a01f40ca45d 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -19,8 +19,7 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti if len(data) == 0 { return nil, io.EOF } - isLongHeader := data[0]&0x80 > 0 - if !isLongHeader { + if !IsLongHeaderPacket(data[0]) { if len(data) < shortHeaderConnIDLen+1 { return nil, io.EOF } @@ -36,12 +35,17 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil } +// IsLongHeaderPacket says if this is a Long Header packet +func IsLongHeaderPacket(firstByte byte) bool { + return firstByte&0x80 > 0 +} + // IsVersionNegotiationPacket says if this is a version negotiation packet func IsVersionNegotiationPacket(b []byte) bool { if len(b) < 5 { return false } - return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 + return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0 } // Is0RTTPacket says if this is a 0-RTT packet. @@ -50,7 +54,7 @@ func Is0RTTPacket(b []byte) bool { if len(b) < 5 { return false } - if b[0]&0x80 == 0 { + if !IsLongHeaderPacket(b[0]) { return false } version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5])) @@ -129,7 +133,7 @@ func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) h := &Header{ typeByte: typeByte, - IsLongHeader: typeByte&0x80 > 0, + IsLongHeader: IsLongHeaderPacket(typeByte), } if !h.IsLongHeader { diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index 77b196e3603..de7d26730ae 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -576,6 +576,11 @@ var _ = Describe("Header Parsing", func() { }) }) + It("distinguishes long and short header packets", func() { + Expect(IsLongHeaderPacket(0x40)).To(BeFalse()) + Expect(IsLongHeaderPacket(0x80 ^ 0x40 ^ 0x12)).To(BeTrue()) + }) + It("tells its packet type for logging", func() { Expect((&Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}).PacketType()).To(Equal("Handshake")) Expect((&Header{}).PacketType()).To(Equal("1-RTT")) diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 31ad5d93f86..2783cb1765a 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -56,7 +56,7 @@ var _ = Describe("Version Negotiation Packets", func() { destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{1001, 1003} data := ComposeVersionNegotiation(destConnID, srcConnID, versions) - Expect(data[0] & 0x80).ToNot(BeZero()) + Expect(IsLongHeaderPacket(data[0])).To(BeTrue()) hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) Expect(err).ToNot(HaveOccurred()) Expect(hdr.DestConnectionID).To(Equal(destConnID)) diff --git a/packet_handler_map.go b/packet_handler_map.go index 2d55a95ef86..ff7bd7b74b9 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -390,7 +390,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { return } } - if p.data[0]&0x80 == 0 { + if !wire.IsLongHeaderPacket(p.data[0]) { go h.maybeSendStatelessReset(p, connID) return } @@ -433,7 +433,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { // stateless resets are always short header packets - if data[0]&0x80 != 0 { + if wire.IsLongHeaderPacket(data[0]) { return false } if len(data) < 17 /* type byte + 16 bytes for the reset token */ { diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index d678d6dbf82..63f8e853d95 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -453,7 +453,7 @@ var _ = Describe("Packet Handler Map", func() { done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) { defer close(done) - Expect(b[0] & 0x80).To(BeZero()) // short header packet + Expect(wire.IsLongHeaderPacket(b[0])).To(BeFalse()) // short header packet Expect(b).To(HaveLen(protocol.MinStatelessResetSize)) }) handler.handlePacket(&receivedPacket{