Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a function to distinguish between long and short header packets #3498

Merged
merged 1 commit into from Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion integrationtests/self/datagram_test.go
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/lucas-clemente/quic-go"
quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -73,7 +74,7 @@ var _ = Describe("Datagram test", func() {
return false
}
// don't drop Long Header packets
if packet[0]&0x80 == 1 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw, the old version here was incorrect.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, this is selecting long header packets?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I see it. It would have to be &0x80 == 0x80. Nice one.

if wire.IsLongHeaderPacket(packet[0]) {
return false
}
drop := mrand.Int()%10 == 0
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/zero_rtt_test.go
Expand Up @@ -732,7 +732,7 @@ var _ = Describe("0-RTT", func() {
proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{
RemoteAddr: ln.Addr().String(),
DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration {
if dir == quicproxy.DirectionIncoming && data[0]&0x80 > 0 && data[0]&0x30>>4 == 0 { // Initial packet from client
if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client
return rtt/2 + rtt
}
return rtt / 2
Expand Down
14 changes: 9 additions & 5 deletions internal/wire/header.go
Expand Up @@ -19,8 +19,7 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti
if len(data) == 0 {
return nil, io.EOF
}
isLongHeader := data[0]&0x80 > 0
if !isLongHeader {
if !IsLongHeaderPacket(data[0]) {
if len(data) < shortHeaderConnIDLen+1 {
return nil, io.EOF
}
Expand All @@ -36,12 +35,17 @@ func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.Connecti
return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil
}

// IsLongHeaderPacket says if this is a Long Header packet
func IsLongHeaderPacket(firstByte byte) bool {
return firstByte&0x80 > 0
}

// IsVersionNegotiationPacket says if this is a version negotiation packet
func IsVersionNegotiationPacket(b []byte) bool {
if len(b) < 5 {
return false
}
return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
}

// Is0RTTPacket says if this is a 0-RTT packet.
Expand All @@ -50,7 +54,7 @@ func Is0RTTPacket(b []byte) bool {
if len(b) < 5 {
return false
}
if b[0]&0x80 == 0 {
if !IsLongHeaderPacket(b[0]) {
return false
}
version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))
Expand Down Expand Up @@ -129,7 +133,7 @@ func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error)

h := &Header{
typeByte: typeByte,
IsLongHeader: typeByte&0x80 > 0,
IsLongHeader: IsLongHeaderPacket(typeByte),
}

if !h.IsLongHeader {
Expand Down
5 changes: 5 additions & 0 deletions internal/wire/header_test.go
Expand Up @@ -576,6 +576,11 @@ var _ = Describe("Header Parsing", func() {
})
})

It("distinguishes long and short header packets", func() {
Expect(IsLongHeaderPacket(0x40)).To(BeFalse())
Expect(IsLongHeaderPacket(0x80 ^ 0x40 ^ 0x12)).To(BeTrue())
})

It("tells its packet type for logging", func() {
Expect((&Header{IsLongHeader: true, Type: protocol.PacketTypeHandshake}).PacketType()).To(Equal("Handshake"))
Expect((&Header{}).PacketType()).To(Equal("1-RTT"))
Expand Down
2 changes: 1 addition & 1 deletion internal/wire/version_negotiation_test.go
Expand Up @@ -56,7 +56,7 @@ var _ = Describe("Version Negotiation Packets", func() {
destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
versions := []protocol.VersionNumber{1001, 1003}
data := ComposeVersionNegotiation(destConnID, srcConnID, versions)
Expect(data[0] & 0x80).ToNot(BeZero())
Expect(IsLongHeaderPacket(data[0])).To(BeTrue())
hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data))
Expect(err).ToNot(HaveOccurred())
Expect(hdr.DestConnectionID).To(Equal(destConnID))
Expand Down
4 changes: 2 additions & 2 deletions packet_handler_map.go
Expand Up @@ -390,7 +390,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {
return
}
}
if p.data[0]&0x80 == 0 {
if !wire.IsLongHeaderPacket(p.data[0]) {
go h.maybeSendStatelessReset(p, connID)
return
}
Expand Down Expand Up @@ -433,7 +433,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) {

func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool {
// stateless resets are always short header packets
if data[0]&0x80 != 0 {
if wire.IsLongHeaderPacket(data[0]) {
return false
}
if len(data) < 17 /* type byte + 16 bytes for the reset token */ {
Expand Down
2 changes: 1 addition & 1 deletion packet_handler_map_test.go
Expand Up @@ -453,7 +453,7 @@ var _ = Describe("Packet Handler Map", func() {
done := make(chan struct{})
conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) {
defer close(done)
Expect(b[0] & 0x80).To(BeZero()) // short header packet
Expect(wire.IsLongHeaderPacket(b[0])).To(BeFalse()) // short header packet
Expect(b).To(HaveLen(protocol.MinStatelessResetSize))
})
handler.handlePacket(&receivedPacket{
Expand Down