From 656f3d2d7de3382b66f2bdb501db548ff5e22ea2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 30 Aug 2022 14:37:36 +0300 Subject: [PATCH] remove the wire.ShortHeader in favor of more return values (#3535) --- connection.go | 54 ++++++++++++++------- connection_test.go | 31 ++++++------ internal/mocks/logging/connection_tracer.go | 2 +- internal/wire/short_header.go | 36 ++++---------- internal/wire/short_header_test.go | 40 +++++---------- logging/interface.go | 10 +++- logging/mock_connection_tracer_test.go | 2 +- mock_unpacker_test.go | 13 +++-- packet_unpacker.go | 41 ++++++++-------- packet_unpacker_test.go | 17 ++++--- qlog/qlog.go | 3 +- qlog/qlog_test.go | 7 ++- 12 files changed, 125 insertions(+), 131 deletions(-) diff --git a/connection.go b/connection.go index 41cab3ddae2..178506095ba 100644 --- a/connection.go +++ b/connection.go @@ -26,7 +26,7 @@ import ( type unpacker interface { UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) - UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) + UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) } type streamGetter interface { @@ -856,11 +856,13 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool { data := rp.data p := rp for len(data) > 0 { + var destConnID protocol.ConnectionID if counter > 0 { p = p.Clone() p.data = data - destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen) + var err error + destConnID, err = wire.ParseConnectionID(p.data, s.srcConnIDLen) if err != nil { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError) @@ -920,7 +922,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool { if counter > 0 { p.buffer.Split() } - processed = s.handleShortHeaderPacket(p) + processed = s.handleShortHeaderPacket(p, destConnID) break } } @@ -929,7 +931,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool { return processed } -func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool { +func (s *connection) handleShortHeaderPacket(p *receivedPacket, destConnID protocol.ConnectionID) bool { var wasQueued bool defer func() { @@ -939,18 +941,18 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool { } }() - hdr, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data) + pn, pnLen, keyPhase, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data) if err != nil { wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT) return false } if s.logger.Debug() { - s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", hdr.PacketNumber, p.Size(), hdr.DestConnectionID) - hdr.Log(s.logger) + s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", pn, p.Size(), destConnID) + wire.LogShortHeader(s.logger, destConnID, pn, pnLen, keyPhase) } - if s.receivedPacketHandler.IsPotentiallyDuplicate(hdr.PacketNumber, protocol.Encryption1RTT) { + if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) { s.logger.Debugf("Dropping (potentially) duplicate packet.") if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate) @@ -958,7 +960,22 @@ func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool { return false } - if err := s.handleUnpackedShortHeaderPacket(hdr, data, p.ecn, p.rcvTime, p.Size()); err != nil { + var log func([]logging.Frame) + if s.tracer != nil { + log = func(frames []logging.Frame) { + s.tracer.ReceivedShortHeaderPacket( + &logging.ShortHeader{ + DestConnectionID: destConnID, + PacketNumber: pn, + PacketNumberLen: pnLen, + KeyPhase: keyPhase, + }, + p.Size(), + frames, + ) + } + } + if err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log); err != nil { s.closeLocal(err) return false } @@ -1241,22 +1258,23 @@ func (s *connection) handleUnpackedPacket( return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting) } -func (s *connection) handleUnpackedShortHeaderPacket(hdr *wire.ShortHeader, data []byte, ecn protocol.ECN, rcvTime time.Time, packetSize protocol.ByteCount) error { +func (s *connection) handleUnpackedShortHeaderPacket( + destConnID protocol.ConnectionID, + pn protocol.PacketNumber, + data []byte, + ecn protocol.ECN, + rcvTime time.Time, + log func([]logging.Frame), +) error { s.lastPacketReceivedTime = rcvTime s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.keepAlivePingSent = false - var log func([]logging.Frame) - if s.tracer != nil { - log = func(frames []logging.Frame) { - s.tracer.ReceivedShortHeaderPacket(hdr, packetSize, frames) - } - } - isAckEliciting, err := s.handleFrames(data, hdr.DestConnectionID, protocol.Encryption1RTT, log) + isAckEliciting, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log) if err != nil { return err } - return s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting) + return s.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting) } func (s *connection) handleFrames( diff --git a/connection_test.go b/connection_test.go index af559b9851b..db746d3b293 100644 --- a/connection_test.go +++ b/connection_test.go @@ -562,10 +562,10 @@ var _ = Describe("Connection", func() { } Expect(hdr.Write(buf, conn.version)).To(Succeed()) - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (*wire.ShortHeader, []byte, error) { + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(time.Time, []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { b, err := (&wire.ConnectionCloseFrame{ErrorCode: uint64(qerr.StreamLimitError)}).Append(nil, conn.version) Expect(err).ToNot(HaveOccurred()) - return &wire.ShortHeader{PacketNumber: 3}, b, nil + return 3, protocol.PacketNumberLen2, protocol.KeyPhaseOne, b, nil }) gomock.InOrder( tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()), @@ -766,7 +766,7 @@ var _ = Describe("Connection", func() { Expect(err).ToNot(HaveOccurred()) packet := getPacket(hdr, nil) packet.ecn = protocol.ECT1 - unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, b, nil) + unpacker.EXPECT().UnpackShortHeader(rcvTime, gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, b, nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) gomock.InOrder( rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT), @@ -774,7 +774,7 @@ var _ = Describe("Connection", func() { ) conn.receivedPacketHandler = rph packet.rcvTime = rcvTime - tracer.EXPECT().ReceivedShortHeaderPacket(&wire.ShortHeader{PacketNumber: 0x1337}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}}) + tracer.EXPECT().ReceivedShortHeaderPacket(&logging.ShortHeader{PacketNumber: 0x1337, PacketNumberLen: 2, KeyPhase: protocol.KeyPhaseZero}, protocol.ByteCount(len(packet.data)), []logging.Frame{&logging.PingFrame{}}) Expect(conn.handlePacketImpl(packet)).To(BeTrue()) }) @@ -785,7 +785,7 @@ var _ = Describe("Connection", func() { PacketNumberLen: protocol.PacketNumberLen1, } packet := getPacket(hdr, nil) - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{PacketNumber: 0x1337}, []byte("foobar"), nil) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseOne, []byte("foobar"), nil) rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT).Return(true) conn.receivedPacketHandler = rph @@ -829,11 +829,11 @@ var _ = Describe("Connection", func() { It("processes multiple received packets before sending one", func() { conn.creationTime = time.Now() var pn protocol.PacketNumber - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { pn++ - return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil + return pn, protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil }).Times(3) - tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { + tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { }).Times(3) packer.EXPECT().PackCoalescedPacket() // only expect a single call @@ -868,11 +868,11 @@ var _ = Describe("Connection", func() { conn.handshakeComplete = false conn.creationTime = time.Now() var pn protocol.PacketNumber - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { pn++ - return &wire.ShortHeader{PacketNumber: pn}, []byte{0} /* PADDING frame */, nil + return pn, protocol.PacketNumberLen4, protocol.KeyPhaseZero, []byte{0} /* PADDING frame */, nil }).Times(3) - tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { + tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { }).Times(3) packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call @@ -904,7 +904,7 @@ var _ = Describe("Connection", func() { }) It("closes the connection when unpacking fails because the reserved bits were incorrect", func() { - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, wire.ErrInvalidReservedBits) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, wire.ErrInvalidReservedBits) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) @@ -932,7 +932,7 @@ var _ = Describe("Connection", func() { It("ignores packets when unpacking the header fails", func() { testErr := &headerParseError{errors.New("test error")} - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, testErr) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, testErr) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() runErr := make(chan error) @@ -958,7 +958,7 @@ var _ = Describe("Connection", func() { }) It("closes the connection when unpacking fails because of an error other than a decryption error", func() { - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(nil, nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) @@ -1050,8 +1050,7 @@ var _ = Describe("Connection", func() { Context("updating the remote address", func() { It("doesn't support connection migration", func() { - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(&wire.ShortHeader{}, - []byte{0} /* one PADDING frame */, nil) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* one PADDING frame */, nil) packet := getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index 9e24129a2b1..748d0d7a8e4 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -208,7 +208,7 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gom } // ReceivedShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *wire.ShortHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { +func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *logging.ShortHeader, arg1 protocol.ByteCount, arg2 []logging.Frame) { m.ctrl.T.Helper() m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2) } diff --git a/internal/wire/short_header.go b/internal/wire/short_header.go index 0121f9ada68..9639b5d4aec 100644 --- a/internal/wire/short_header.go +++ b/internal/wire/short_header.go @@ -9,28 +9,20 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -type ShortHeader struct { - DestConnectionID protocol.ConnectionID - PacketNumber protocol.PacketNumber - PacketNumberLen protocol.PacketNumberLen - KeyPhase protocol.KeyPhaseBit -} - -func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) { +func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.PacketNumber, _ protocol.PacketNumberLen, _ protocol.KeyPhaseBit, _ error) { if len(data) == 0 { - return nil, io.EOF + return 0, 0, 0, 0, io.EOF } if data[0]&0x80 > 0 { - return nil, errors.New("not a short header packet") + return 0, 0, 0, 0, errors.New("not a short header packet") } if data[0]&0x40 == 0 { - return nil, errors.New("not a QUIC packet") + return 0, 0, 0, 0, errors.New("not a QUIC packet") } pnLen := protocol.PacketNumberLen(data[0]&0b11) + 1 if len(data) < 1+int(pnLen)+connIDLen { - return nil, io.EOF + return 0, 0, 0, 0, io.EOF } - destConnID := protocol.ParseConnectionID(data[1 : 1+connIDLen]) pos := 1 + connIDLen var pn protocol.PacketNumber @@ -44,7 +36,7 @@ func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) { case protocol.PacketNumberLen4: pn = protocol.PacketNumber(utils.BigEndian.Uint32(data[pos : pos+4])) default: - return nil, fmt.Errorf("invalid packet number length: %d", pnLen) + return 0, 0, 0, 0, fmt.Errorf("invalid packet number length: %d", pnLen) } kp := protocol.KeyPhaseZero if data[0]&0b100 > 0 { @@ -55,19 +47,9 @@ func ParseShortHeader(data []byte, connIDLen int) (*ShortHeader, error) { if data[0]&0x18 != 0 { err = ErrInvalidReservedBits } - return &ShortHeader{ - DestConnectionID: destConnID, - PacketNumber: pn, - PacketNumberLen: pnLen, - KeyPhase: kp, - }, err -} - -func (h *ShortHeader) Len() protocol.ByteCount { - return 1 + protocol.ByteCount(h.DestConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err } -// Log logs the Header -func (h *ShortHeader) Log(logger utils.Logger) { - logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) +func LogShortHeader(logger utils.Logger, dest protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) { + logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", dest, pn, pnLen, kp) } diff --git a/internal/wire/short_header_test.go b/internal/wire/short_header_test.go index 8f1efaea7ca..b7884def815 100644 --- a/internal/wire/short_header_test.go +++ b/internal/wire/short_header_test.go @@ -21,12 +21,12 @@ var _ = Describe("Short Header", func() { 0xde, 0xad, 0xbe, 0xef, 0x13, 0x37, 0x99, } - hdr, err := ParseShortHeader(data, 4) + l, pn, pnLen, kp, err := ParseShortHeader(data, 4) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}))) - Expect(hdr.KeyPhase).To(Equal(protocol.KeyPhaseOne)) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x133799))) - Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen3)) + Expect(l).To(Equal(len(data))) + Expect(kp).To(Equal(protocol.KeyPhaseOne)) + Expect(pn).To(Equal(protocol.PacketNumber(0x133799))) + Expect(pnLen).To(Equal(protocol.PacketNumberLen3)) }) It("errors when the QUIC bit is not set", func() { @@ -35,7 +35,7 @@ var _ = Describe("Short Header", func() { 0xde, 0xad, 0xbe, 0xef, 0x13, 0x37, } - _, err := ParseShortHeader(data, 4) + _, _, _, _, err := ParseShortHeader(data, 4) Expect(err).To(MatchError("not a QUIC packet")) }) @@ -45,14 +45,13 @@ var _ = Describe("Short Header", func() { 0xde, 0xad, 0xbe, 0xef, 0x13, 0x37, } - hdr, err := ParseShortHeader(data, 4) + _, pn, _, _, err := ParseShortHeader(data, 4) Expect(err).To(MatchError(ErrInvalidReservedBits)) - Expect(hdr).ToNot(BeNil()) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) + Expect(pn).To(Equal(protocol.PacketNumber(0x1337))) }) It("errors when passed a long header packet", func() { - _, err := ParseShortHeader([]byte{0x80}, 4) + _, _, _, _, err := ParseShortHeader([]byte{0x80}, 4) Expect(err).To(MatchError("not a short header packet")) }) @@ -62,10 +61,10 @@ var _ = Describe("Short Header", func() { 0xde, 0xad, 0xbe, 0xef, 0x13, 0x37, 0x99, } - _, err := ParseShortHeader(data, 4) + _, _, _, _, err := ParseShortHeader(data, 4) Expect(err).ToNot(HaveOccurred()) for i := range data { - _, err := ParseShortHeader(data[:i], 4) + _, _, _, _, err := ParseShortHeader(data[:i], 4) Expect(err).To(MatchError(io.EOF)) } }) @@ -89,22 +88,9 @@ var _ = Describe("Short Header", func() { }) It("logs Short Headers containing a connection ID", func() { - (&ShortHeader{ - DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}), - KeyPhase: protocol.KeyPhaseOne, - PacketNumber: 1337, - PacketNumberLen: 4, - }).Log(logger) + connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}) + LogShortHeader(logger, connID, 1337, protocol.PacketNumberLen4, protocol.KeyPhaseOne) Expect(buf.String()).To(ContainSubstring("Short Header{DestConnectionID: deadbeefcafe1337, PacketNumber: 1337, PacketNumberLen: 4, KeyPhase: 1}")) }) }) - - It("determines the length", func() { - Expect((&ShortHeader{ - DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xaf}), - PacketNumber: 0x1337, - PacketNumberLen: protocol.PacketNumberLen3, - KeyPhase: protocol.KeyPhaseOne, - }).Len()).To(Equal(protocol.ByteCount(1 + 2 + 3))) - }) }) diff --git a/logging/interface.go b/logging/interface.go index 12a57206536..7f772dd6cfb 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -46,8 +46,6 @@ type ( Header = wire.Header // The ExtendedHeader is the QUIC Long Header packet header, after removing header protection. ExtendedHeader = wire.ExtendedHeader - // The ShortHeader is the QUIC Short Header packet header, after removing header protection. - ShortHeader = wire.ShortHeader // The TransportParameters are QUIC transport parameters. TransportParameters = wire.TransportParameters // The PreferredAddress is the preferred address sent in the transport parameters. @@ -94,6 +92,14 @@ const ( StreamTypeBidi = protocol.StreamTypeBidi ) +// The ShortHeader is the QUIC Short Header packet header, after removing header protection. +type ShortHeader struct { + DestConnectionID ConnectionID + PacketNumber PacketNumber + PacketNumberLen protocol.PacketNumberLen + KeyPhase KeyPhaseBit +} + // A Tracer traces events. type Tracer interface { // TracerForConnection requests a new tracer for a connection. diff --git a/logging/mock_connection_tracer_test.go b/logging/mock_connection_tracer_test.go index 971e474e752..95b9b02ad66 100644 --- a/logging/mock_connection_tracer_test.go +++ b/logging/mock_connection_tracer_test.go @@ -207,7 +207,7 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedRetry(arg0 interface{}) *gom } // ReceivedShortHeaderPacket mocks base method. -func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *wire.ShortHeader, arg1 protocol.ByteCount, arg2 []Frame) { +func (m *MockConnectionTracer) ReceivedShortHeaderPacket(arg0 *ShortHeader, arg1 protocol.ByteCount, arg2 []Frame) { m.ctrl.T.Helper() m.ctrl.Call(m, "ReceivedShortHeaderPacket", arg0, arg1, arg2) } diff --git a/mock_unpacker_test.go b/mock_unpacker_test.go index 1410f8ecbbe..cc8ee3eea1b 100644 --- a/mock_unpacker_test.go +++ b/mock_unpacker_test.go @@ -9,6 +9,7 @@ import ( time "time" gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" wire "github.com/lucas-clemente/quic-go/internal/wire" ) @@ -51,13 +52,15 @@ func (mr *MockUnpackerMockRecorder) UnpackLongHeader(hdr, rcvTime, data interfac } // UnpackShortHeader mocks base method. -func (m *MockUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { +func (m *MockUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UnpackShortHeader", rcvTime, data) - ret0, _ := ret[0].(*wire.ShortHeader) - ret1, _ := ret[1].([]byte) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret0, _ := ret[0].(protocol.PacketNumber) + ret1, _ := ret[1].(protocol.PacketNumberLen) + ret2, _ := ret[2].(protocol.KeyPhaseBit) + ret3, _ := ret[3].([]byte) + ret4, _ := ret[4].(error) + return ret0, ret1, ret2, ret3, ret4 } // UnpackShortHeader indicates an expected call of UnpackShortHeader. diff --git a/packet_unpacker.go b/packet_unpacker.go index 688724f81fb..e7754145b03 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -109,22 +109,22 @@ func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, d }, nil } -func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { +func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { opener, err := u.cs.Get1RTTOpener() if err != nil { - return nil, nil, err + return 0, 0, 0, nil, err } - hdr, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data) + pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data) if err != nil { - return nil, nil, err + return 0, 0, 0, nil, err } if len(decrypted) == 0 { - return nil, nil, &qerr.TransportError{ + return 0, 0, 0, nil, &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "empty packet", } } - return hdr, decrypted, nil + return pn, pnLen, kp, decrypted, nil } func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) { @@ -147,27 +147,26 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene return extHdr, decrypted, nil } -func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error) { - hdr, parseErr := u.unpackShortHeader(opener, data) +func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { + l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data) // If the reserved bits are set incorrectly, we still need to continue unpacking. // This avoids a timing side-channel, which otherwise might allow an attacker // to gain information about the header encryption. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { - return nil, nil, &headerParseError{parseErr} + return 0, 0, 0, nil, &headerParseError{parseErr} } - hdr.PacketNumber = opener.DecodePacketNumber(hdr.PacketNumber, hdr.PacketNumberLen) - l := hdr.Len() - decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, hdr.PacketNumber, hdr.KeyPhase, data[:l]) + pn = opener.DecodePacketNumber(pn, pnLen) + decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l]) if err != nil { - return nil, nil, err + return 0, 0, 0, nil, err } - return hdr, decrypted, parseErr + return pn, pnLen, kp, decrypted, parseErr } -func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (*wire.ShortHeader, error) { +func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) { hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen if len(data) < hdrLen+4+16 { - return nil, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen) + return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen) } origPNBytes := make([]byte, 4) copy(origPNBytes, data[hdrLen:hdrLen+4]) @@ -178,15 +177,15 @@ func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (*wi data[hdrLen:hdrLen+4], ) // 3. parse the header (and learn the actual length of the packet number) - hdr, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen) + l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen) if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { - return nil, parseErr + return l, pn, pnLen, kp, parseErr } // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier - if hdr.PacketNumberLen != protocol.PacketNumberLen4 { - copy(data[hdrLen+int(hdr.PacketNumberLen):hdrLen+4], origPNBytes[int(hdr.PacketNumberLen):]) + if pnLen != protocol.PacketNumberLen4 { + copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):]) } - return hdr, parseErr + return l, pn, pnLen, kp, parseErr } // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 3857f70a344..e418bfb80d4 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -75,7 +75,7 @@ var _ = Describe("Packet Unpacker", func() { data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) opener := mocks.NewMockShortHeaderOpener(mockCtrl) cs.EXPECT().Get1RTTOpener().Return(opener, nil) - _, _, err := unpacker.UnpackShortHeader(time.Now(), data) + _, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), data) Expect(err).To(BeAssignableToTypeOf(&headerParseError{})) Expect(err).To(MatchError("packet too small, expected at least 20 bytes after the header, got 19")) }) @@ -148,10 +148,11 @@ var _ = Describe("Packet Unpacker", func() { opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)), opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil), ) - hdr, data, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...)) + pn, pnLen, kp, data, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...)) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(321))) - Expect(hdr.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(pn).To(Equal(protocol.PacketNumber(321))) + Expect(pnLen).To(Equal(protocol.PacketNumberLen4)) + Expect(kp).To(Equal(protocol.KeyPhaseOne)) Expect(data).To(Equal([]byte("decrypted"))) }) @@ -163,7 +164,7 @@ var _ = Describe("Packet Unpacker", func() { } _, hdrRaw := getHeader(extHdr) cs.EXPECT().Get1RTTOpener().Return(nil, handshake.ErrKeysNotYetAvailable) - _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) + _, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(handshake.ErrKeysNotYetAvailable)) }) @@ -208,7 +209,7 @@ var _ = Describe("Packet Unpacker", func() { opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(321)), opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte(""), nil), ) - _, _, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...)) + _, _, _, _, err := unpacker.UnpackShortHeader(now, append(hdrRaw, payload...)) Expect(err).To(MatchError(&qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "empty packet", @@ -273,7 +274,7 @@ var _ = Describe("Packet Unpacker", func() { cs.EXPECT().Get1RTTOpener().Return(opener, nil) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil) - _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) + _, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(wire.ErrInvalidReservedBits)) }) @@ -312,7 +313,7 @@ var _ = Describe("Packet Unpacker", func() { cs.EXPECT().Get1RTTOpener().Return(opener, nil) opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any()) opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) - _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) + _, _, _, _, err := unpacker.UnpackShortHeader(time.Now(), append(hdrRaw, payload...)) Expect(err).To(MatchError(handshake.ErrDecryptionFailed)) }) diff --git a/qlog/qlog.go b/qlog/qlog.go index 38ec342dbc6..3fd58a0ecf3 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -319,11 +319,12 @@ func (t *connectionTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, p fs[i] = frame{Frame: f} } header := *transformShortHeader(hdr) + hdrLen := 1 + hdr.DestConnectionID.Len() + int(hdr.PacketNumberLen) t.mutex.Lock() t.recordEvent(time.Now(), &eventPacketReceived{ Header: header, Length: packetSize, - PayloadLength: packetSize - hdr.Len(), + PayloadLength: packetSize - protocol.ByteCount(hdrLen), Frames: fs, }) t.mutex.Unlock() diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index 4f109cf6576..0f5337b857a 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -11,8 +11,6 @@ import ( "os" "time" - "github.com/lucas-clemente/quic-go/internal/wire" - "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" @@ -529,9 +527,10 @@ var _ = Describe("Tracing", func() { }) It("records a received Short Header packet", func() { - shdr := &wire.ShortHeader{ + shdr := &logging.ShortHeader{ DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), PacketNumber: 1337, + PacketNumberLen: protocol.PacketNumberLen3, KeyPhase: protocol.KeyPhaseZero, } tracer.ReceivedShortHeaderPacket( @@ -549,7 +548,7 @@ var _ = Describe("Tracing", func() { Expect(ev).To(HaveKey("raw")) raw := ev["raw"].(map[string]interface{}) Expect(raw).To(HaveKeyWithValue("length", float64(789))) - Expect(raw).To(HaveKeyWithValue("payload_length", float64(789-shdr.Len()))) + Expect(raw).To(HaveKeyWithValue("payload_length", float64(789-(1+8+3)))) Expect(ev).To(HaveKey("header")) hdr := ev["header"].(map[string]interface{}) Expect(hdr).To(HaveKeyWithValue("packet_type", "1RTT"))