Skip to content

Commit

Permalink
remove the wire.ShortHeader in favor of more return values (#3535)
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 30, 2022
1 parent 5cd5341 commit 656f3d2
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 131 deletions.
54 changes: 36 additions & 18 deletions connection.go
Expand Up @@ -26,7 +26,7 @@ import (

type unpacker interface {
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
}

type streamGetter interface {
Expand Down Expand Up @@ -856,11 +856,13 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
data := rp.data
p := rp
for len(data) > 0 {
var destConnID protocol.ConnectionID
if counter > 0 {
p = p.Clone()
p.data = data

destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen)
var err error
destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen)
if err != nil {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError)
Expand Down Expand Up @@ -920,7 +922,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
if counter > 0 {
p.buffer.Split()
}
processed = s.handleShortHeaderPacket(p)
processed = s.handleShortHeaderPacket(p, destConnID)
break
}
}
Expand All @@ -929,7 +931,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
return processed
}

func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool {
func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID protocol.ConnectionID) bool {
var wasQueued bool

defer func() {
Expand All @@ -939,26 +941,41 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool {
}
}()

hdr, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
pn, pnLen, keyPhase, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
if err != nil {
wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT)
return false
}

if s.logger.Debug() {
s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", hdr.PacketNumber, p.Size(), hdr.DestConnectionID)
hdr.Log(s.logger)
s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", pn, p.Size(), destConnID)
wire.LogShortHeader(s.logger, destConnID, pn, pnLen, keyPhase)
}

if s.receivedPacketHandler.IsPotentiallyDuplicate(hdr.PacketNumber, protocol.Encryption1RTT) {
if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) {
s.logger.Debugf("Dropping (potentially) duplicate packet.")
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate)
}
return false
}

if err := s.handleUnpackedShortHeaderPacket(hdr, data, p.ecn, p.rcvTime, p.Size()); err != nil {
var log func([]logging.Frame)
if s.tracer != nil {
log = func(frames []logging.Frame) {
s.tracer.ReceivedShortHeaderPacket(
&logging.ShortHeader{
DestConnectionID: destConnID,
PacketNumber: pn,
PacketNumberLen: pnLen,
KeyPhase: keyPhase,
},
p.Size(),
frames,
)
}
}
if err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log); err != nil {
s.closeLocal(err)
return false
}
Expand Down Expand Up @@ -1241,22 +1258,23 @@ func (s *connection) handleUnpackedPacket(
return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
}

func (s *connection) handleUnpackedShortHeaderPacket(hdr *wire.ShortHeader, data []byte, ecn protocol.ECN, rcvTime time.Time, packetSize protocol.ByteCount) error {
func (s *connection) handleUnpackedShortHeaderPacket(
destConnID protocol.ConnectionID,
pn protocol.PacketNumber,
data []byte,
ecn protocol.ECN,
rcvTime time.Time,
log func([]logging.Frame),
) error {
s.lastPacketReceivedTime = rcvTime
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false

var log func([]logging.Frame)
if s.tracer != nil {
log = func(frames []logging.Frame) {
s.tracer.ReceivedShortHeaderPacket(hdr, packetSize, frames)
}
}
isAckEliciting, err := s.handleFrames(data, hdr.DestConnectionID, protocol.Encryption1RTT, log)
isAckEliciting, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log)
if err != nil {
return err
}
return s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting)
return s.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting)
}

func (s *connection) handleFrames(
Expand Down
31 changes: 15 additions & 16 deletions connection_test.go
Expand Up @@ -562,10 +562,10 @@ var _ = Describe("Connection", func() {
}
Expect(hdr.Write(buf, conn.version)).To(Succeed())

unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (*wire.ShortHeader, []byte, error) {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version)
Expect(err).ToNot(HaveOccurred())
return &wire.ShortHeader{PacketNumber: 3}, b, nil
return 3, protocol.PacketNumberLen2, protocol.KeyPhaseOne, b, nil
})
gomock.InOrder(
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()),
Expand Down Expand Up @@ -766,15 +766,15 @@ var _ = Describe("Connection", func() {
Expect(err).ToNot(HaveOccurred())
packet := getPacket(hdr, nil)
packet.ecn = protocol.ECT1
unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, b, nil)
unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, b, nil)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
gomock.InOrder(
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT),
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECT1, protocol.Encryption1RTT, rcvTime, true),
)
conn.receivedPacketHandler = rph
packet.rcvTime = rcvTime
tracer.EXPECT().ReceivedShortHeaderPacket(&wire.ShortHeader{PacketNumber: 0x1337}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}})
tracer.EXPECT().ReceivedShortHeaderPacket(&logging.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 2, KeyPhase: protocol.KeyPhaseZero}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}})
Expect(conn.handlePacketImpl(packet)).To(BeTrue())
})

Expand All @@ -785,7 +785,7 @@ var _ = Describe("Connection", func() {
PacketNumberLen: protocol.PacketNumberLen1,
}
packet := getPacket(hdr, nil)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, []byte("foobar"), nil)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseOne, []byte("foobar"), nil)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true)
conn.receivedPacketHandler = rph
Expand Down Expand Up @@ -829,11 +829,11 @@ var _ = Describe("Connection", func() {
It("processes multiple received packets before sending one", func() {
conn.creationTime = time.Now()
var pn protocol.PacketNumber
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
pn++
return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil
return pn, protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil
}).Times(3)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
}).Times(3)
packer.EXPECT().PackCoalescedPacket() // only expect a single call

Expand Down Expand Up @@ -868,11 +868,11 @@ var _ = Describe("Connection", func() {
conn.handshakeComplete = false
conn.creationTime = time.Now()
var pn protocol.PacketNumber
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
pn++
return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil
return pn, protocol.PacketNumberLen4, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil
}).Times(3)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) {
}).Times(3)
packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call

Expand Down Expand Up @@ -904,7 +904,7 @@ var _ = Describe("Connection", func() {
})

It("closes the connection when unpacking fails because the reserved bits were incorrect", func() {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, wire.ErrInvalidReservedBits)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, wire.ErrInvalidReservedBits)
streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
Expand Down Expand Up @@ -932,7 +932,7 @@ var _ = Describe("Connection", func() {

It("ignores packets when unpacking the header fails", func() {
testErr := &headerParseError{errors.New("test error")}
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, testErr)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, testErr)
streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close()
runErr := make(chan error)
Expand All @@ -958,7 +958,7 @@ var _ = Describe("Connection", func() {
})

It("closes the connection when unpacking fails because of an error other than a decryption error", func() {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError})
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError})
streamManager.EXPECT().CloseWithError(gomock.Any())
cryptoSetup.EXPECT().Close()
packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
Expand Down Expand Up @@ -1050,8 +1050,7 @@ var _ = Describe("Connection", func() {

Context("updating the remote address", func() {
It("doesn't support connection migration", func() {
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{},
[]byte{0} /* one PADDING frame */, nil)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* one PADDING frame */, nil)
packet := getPacket(&wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumberLen: protocol.PacketNumberLen1,
Expand Down
2 changes: 1 addition & 1 deletion internal/mocks/logging/connection_tracer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 9 additions & 27 deletions internal/wire/short_header.go
Expand Up @@ -9,28 +9,20 @@ import (
"github.com/lucas-clemente/quic-go/internal/utils"
)

type ShortHeader struct {
DestConnectionID protocol.ConnectionID
PacketNumber protocol.PacketNumber
PacketNumberLen protocol.PacketNumberLen
KeyPhase protocol.KeyPhaseBit
}

func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) {
func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.PacketNumber, _ protocol.PacketNumberLen, _ protocol.KeyPhaseBit, _ error) {
if len(data) == 0 {
return nil, io.EOF
return 0, 0, 0, 0, io.EOF
}
if data[0]&0x80 > 0 {
return nil, errors.New("not a short header packet")
return 0, 0, 0, 0, errors.New("not a short header packet")
}
if data[0]&0x40 == 0 {
return nil, errors.New("not a QUIC packet")
return 0, 0, 0, 0, errors.New("not a QUIC packet")
}
pnLen := protocol.PacketNumberLen(data[0]&0b11) + 1
if len(data) < 1+int(pnLen)+connIDLen {
return nil, io.EOF
return 0, 0, 0, 0, io.EOF
}
destConnID := protocol.ParseConnectionID(data[1 : 1+connIDLen])

pos := 1 + connIDLen
var pn protocol.PacketNumber
Expand All @@ -44,7 +36,7 @@ func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) {
case protocol.PacketNumberLen4:
pn = protocol.PacketNumber(utils.BigEndian.Uint32(data[pos : pos+4]))
default:
return nil, fmt.Errorf("invalid packet number length: %d", pnLen)
return 0, 0, 0, 0, fmt.Errorf("invalid packet number length: %d", pnLen)
}
kp := protocol.KeyPhaseZero
if data[0]&0b100 > 0 {
Expand All @@ -55,19 +47,9 @@ func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) {
if data[0]&0x18 != 0 {
err = ErrInvalidReservedBits
}
return &ShortHeader{
DestConnectionID: destConnID,
PacketNumber: pn,
PacketNumberLen: pnLen,
KeyPhase: kp,
}, err
}

func (h *ShortHeader) Len() protocol.ByteCount {
return 1 + protocol.ByteCount(h.DestConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen)
return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err
}

// Log logs the Header
func (h *ShortHeader) Log(logger utils.Logger) {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
func LogShortHeader(logger utils.Logger, dest protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", dest, pn, pnLen, kp)
}
40 changes: 13 additions & 27 deletions internal/wire/short_header_test.go
Expand Up @@ -21,12 +21,12 @@ var _ = Describe("Short Header", func() {
0xde, 0xad, 0xbe, 0xef,
0x13, 0x37, 0x99,
}
hdr, err := ParseShortHeader(data, 4)
l, pn, pnLen, kp, err := ParseShortHeader(data, 4)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})))
Expect(hdr.KeyPhase).To(Equal(protocol.KeyPhaseOne))
Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x133799)))
Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen3))
Expect(l).To(Equal(len(data)))
Expect(kp).To(Equal(protocol.KeyPhaseOne))
Expect(pn).To(Equal(protocol.PacketNumber(0x133799)))
Expect(pnLen).To(Equal(protocol.PacketNumberLen3))
})

It("errors when the QUIC bit is not set", func() {
Expand All @@ -35,7 +35,7 @@ var _ = Describe("Short Header", func() {
0xde, 0xad, 0xbe, 0xef,
0x13, 0x37,
}
_, err := ParseShortHeader(data, 4)
_, _, _, _, err := ParseShortHeader(data, 4)
Expect(err).To(MatchError("not a QUIC packet"))
})

Expand All @@ -45,14 +45,13 @@ var _ = Describe("Short Header", func() {
0xde, 0xad, 0xbe, 0xef,
0x13, 0x37,
}
hdr, err := ParseShortHeader(data, 4)
_, pn, _, _, err := ParseShortHeader(data, 4)
Expect(err).To(MatchError(ErrInvalidReservedBits))
Expect(hdr).ToNot(BeNil())
Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337)))
Expect(pn).To(Equal(protocol.PacketNumber(0x1337)))
})

It("errors when passed a long header packet", func() {
_, err := ParseShortHeader([]byte{0x80}, 4)
_, _, _, _, err := ParseShortHeader([]byte{0x80}, 4)
Expect(err).To(MatchError("not a short header packet"))
})

Expand All @@ -62,10 +61,10 @@ var _ = Describe("Short Header", func() {
0xde, 0xad, 0xbe, 0xef,
0x13, 0x37, 0x99,
}
_, err := ParseShortHeader(data, 4)
_, _, _, _, err := ParseShortHeader(data, 4)
Expect(err).ToNot(HaveOccurred())
for i := range data {
_, err := ParseShortHeader(data[:i], 4)
_, _, _, _, err := ParseShortHeader(data[:i], 4)
Expect(err).To(MatchError(io.EOF))
}
})
Expand All @@ -89,22 +88,9 @@ var _ = Describe("Short Header", func() {
})

It("logs Short Headers containing a connection ID", func() {
(&ShortHeader{
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}),
KeyPhase: protocol.KeyPhaseOne,
PacketNumber: 1337,
PacketNumberLen: 4,
}).Log(logger)
connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37})
LogShortHeader(logger, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne)
Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}"))
})
})

It("determines the length", func() {
Expect((&ShortHeader{
DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xaf}),
PacketNumber: 0x1337,
PacketNumberLen: protocol.PacketNumberLen3,
KeyPhase: protocol.KeyPhaseOne,
}).Len()).To(Equal(protocol.ByteCount(1 + 2 + 3)))
})
})

0 comments on commit 656f3d2

Please sign in to comment.