Skip to content

Commit

Permalink
add a function to distinguish between long and short header packets (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 11, 2022
1 parent bebff46 commit 80fd1b5
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 11 deletions.
3 changes: 2 additions & 1 deletion integrationtests/self/datagram_test.go
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/lucas-clemente/quic-go"
quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -73,7 +74,7 @@ var _ = Describe("Datagram test", func() {
return false
}
// don't drop Long Header packets
if packet[0]&0x80 == 1 {
if wire.IsLongHeaderPacket(packet[0]) {
return false
}
drop := mrand.Int()%10 == 0
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/zero_rtt_test.go
Expand Up @@ -732,7 +732,7 @@ var _ = Describe("0-RTT", func() {
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: ln.Addr().String(),
DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
if dir == quicproxy.DirectionIncoming && data[0]&0x80 > 0 && data[0]&0x30>>4 == 0 { // Initial packet from client
if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client
return rtt/2 + rtt
}
return rtt / 2
Expand Down
14 changes: 9 additions & 5 deletions internal/wire/header.go
Expand Up @@ -19,8 +19,7 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti
if len(data) == 0 {
return nil, io.EOF
}
isLongHeader := data[0]&0x80 > 0
if !isLongHeader {
if !IsLongHeaderPacket(data[0]) {
if len(data) < shortHeaderConnIDLen+1 {
return nil, io.EOF
}
Expand All @@ -36,12 +35,17 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti
return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil
}

// IsLongHeaderPacket says if this is a Long Header packet
func IsLongHeaderPacket(firstByte byte) bool {
return firstByte&0x80 > 0
}

// IsVersionNegotiationPacket says if this is a version negotiation packet
func IsVersionNegotiationPacket(b []byte) bool {
if len(b) < 5 {
return false
}
return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
}

// Is0RTTPacket says if this is a 0-RTT packet.
Expand All @@ -50,7 +54,7 @@ func Is0RTTPacket(b []byte) bool {
if len(b) < 5 {
return false
}
if b[0]&0x80 == 0 {
if !IsLongHeaderPacket(b[0]) {
return false
}
version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))
Expand Down Expand Up @@ -129,7 +133,7 @@ func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error)

h := &Header{
typeByte: typeByte,
IsLongHeader: typeByte&0x80 > 0,
IsLongHeader: IsLongHeaderPacket(typeByte),
}

if !h.IsLongHeader {
Expand Down
5 changes: 5 additions & 0 deletions internal/wire/header_test.go
Expand Up @@ -576,6 +576,11 @@ var _ = Describe("Header Parsing", func() {
})
})

It("distinguishes long and short header packets", func() {
Expect(IsLongHeaderPacket(0x40)).To(BeFalse())
Expect(IsLongHeaderPacket(0x80 ^ 0x40 ^ 0x12)).To(BeTrue())
})

It("tells its packet type for logging", func() {
Expect((&Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}).PacketType()).To(Equal("Handshake"))
Expect((&Header{}).PacketType()).To(Equal("1-RTT"))
Expand Down
2 changes: 1 addition & 1 deletion internal/wire/version_negotiation_test.go
Expand Up @@ -56,7 +56,7 @@ var _ = Describe("Version Negotiation Packets", func() {
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
versions := []protocol.VersionNumber{1001, 1003}
data := ComposeVersionNegotiation(destConnID, srcConnID, versions)
Expect(data[0] & 0x80).ToNot(BeZero())
Expect(IsLongHeaderPacket(data[0])).To(BeTrue())
hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data))
Expect(err).ToNot(HaveOccurred())
Expect(hdr.DestConnectionID).To(Equal(destConnID))
Expand Down
4 changes: 2 additions & 2 deletions packet_handler_map.go
Expand Up @@ -390,7 +390,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
return
}
}
if p.data[0]&0x80 == 0 {
if !wire.IsLongHeaderPacket(p.data[0]) {
go h.maybeSendStatelessReset(p, connID)
return
}
Expand Down Expand Up @@ -433,7 +433,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {

func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
// stateless resets are always short header packets
if data[0]&0x80 != 0 {
if wire.IsLongHeaderPacket(data[0]) {
return false
}
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
Expand Down
2 changes: 1 addition & 1 deletion packet_handler_map_test.go
Expand Up @@ -453,7 +453,7 @@ var _ = Describe("Packet Handler Map", func() {
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) {
defer close(done)
Expect(b[0] & 0x80).To(BeZero()) // short header packet
Expect(wire.IsLongHeaderPacket(b[0])).To(BeFalse()) // short header packet
Expect(b).To(HaveLen(protocol.MinStatelessResetSize))
})
handler.handlePacket(&receivedPacket{
Expand Down

0 comments on commit 80fd1b5

Please sign in to comment.