diff --git a/connection.go b/connection.go index aa5abf3c0ea..d9ee5fd7baa 100644 --- a/connection.go +++ b/connection.go @@ -1821,7 +1821,24 @@ func (s *connection) sendPackets() error { } func (s *connection) maybeSendAckOnlyPacket() error { - packet, err := s.packer.MaybePackAckPacket(s.handshakeConfirmed) + if !s.handshakeConfirmed { + packet, err := s.packer.PackCoalescedPacket(true) + if err != nil { + return err + } + if packet == nil { + return nil + } + s.logCoalescedPacket(packet) + for _, p := range packet.packets { + s.sentPacketHandler.SentPacket(p.ToAckHandlerPacket(time.Now(), s.retransmissionQueue)) + } + s.connIDManager.SentPacket() + s.sendQueue.Send(packet.buffer) + return nil + } + + packet, err := s.packer.PackPacket(true) if err != nil { return err } @@ -1882,7 +1899,7 @@ func (s *connection) sendPacket() (bool, error) { now := time.Now() if !s.handshakeConfirmed { - packet, err := s.packer.PackCoalescedPacket() + packet, err := s.packer.PackCoalescedPacket(false) if err != nil || packet == nil { return false, err } @@ -1906,7 +1923,7 @@ func (s *connection) sendPacket() (bool, error) { s.sendPackedPacket(packet, now) return true, nil } - packet, err := s.packer.PackPacket() + packet, err := s.packer.PackPacket(false) if err != nil || packet == nil { return false, err } diff --git a/connection_test.go b/connection_test.go index db746d3b293..9c6d0a512bf 100644 --- a/connection_test.go +++ b/connection_test.go @@ -602,8 +602,8 @@ var _ = Describe("Connection", func() { cryptoSetup.EXPECT().Close() conn.sentPacketHandler = sph p := getPacket(1) - packer.EXPECT().PackPacket().Return(p, nil) - packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() + packer.EXPECT().PackPacket(false).Return(p, nil) + packer.EXPECT().PackPacket(false).Return(nil, nil).AnyTimes() runConn() conn.queueControlFrame(&wire.PingFrame{}) conn.scheduleSending() @@ -835,7 +835,7 @@ var _ = Describe("Connection", func() { }).Times(3) tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { }).Times(3) - packer.EXPECT().PackCoalescedPacket() // only expect a single call + packer.EXPECT().PackCoalescedPacket(false) // only expect a single call for i := 0; i < 3; i++ { conn.handlePacket(getPacket(&wire.ExtendedHeader{ @@ -874,7 +874,7 @@ var _ = Describe("Connection", func() { }).Times(3) tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *logging.ShortHeader, _ protocol.ByteCount, _ []logging.Frame) { }).Times(3) - packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call + packer.EXPECT().PackCoalescedPacket(false).Times(3) // only expect a single call for i := 0; i < 3; i++ { conn.handlePacket(getPacket(&wire.ExtendedHeader{ @@ -1229,8 +1229,8 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph runConn() p := getPacket(1) - packer.EXPECT().PackPacket().Return(p, nil) - packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() + packer.EXPECT().PackPacket(false).Return(p, nil) + packer.EXPECT().PackPacket(false).Return(nil, nil).AnyTimes() sent := make(chan struct{}) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()).Do(func(packet *packetBuffer) { close(sent) }) @@ -1242,7 +1242,7 @@ var _ = Describe("Connection", func() { It("doesn't send packets if there's nothing to send", func() { conn.handshakeConfirmed = true runConn() - packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() + packer.EXPECT().PackPacket(false).Return(nil, nil).AnyTimes() conn.receivedPacketHandler.ReceivedPacket(0x035e, protocol.ECNNon, protocol.Encryption1RTT, time.Now(), true) conn.scheduleSending() time.Sleep(50 * time.Millisecond) // make sure there are no calls to mconn.Write() @@ -1254,7 +1254,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAck) done := make(chan struct{}) - packer.EXPECT().MaybePackAckPacket(false).Do(func(bool) { close(done) }) + packer.EXPECT().PackCoalescedPacket(true).Do(func(bool) { close(done) }) conn.sentPacketHandler = sph runConn() conn.scheduleSending() @@ -1274,8 +1274,8 @@ var _ = Describe("Connection", func() { fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) fc.EXPECT().IsNewlyBlocked() p := getPacket(1) - packer.EXPECT().PackPacket().Return(p, nil) - packer.EXPECT().PackPacket().Return(nil, nil).AnyTimes() + packer.EXPECT().PackPacket(false).Return(p, nil) + packer.EXPECT().PackPacket(false).Return(nil, nil).AnyTimes() conn.connFlowController = fc runConn() sent := make(chan struct{}) @@ -1406,8 +1406,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget() sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(3) - packer.EXPECT().PackPacket().Return(getPacket(10), nil) - packer.EXPECT().PackPacket().Return(getPacket(11), nil) + packer.EXPECT().PackPacket(false).Return(getPacket(10), nil) + packer.EXPECT().PackPacket(false).Return(getPacket(11), nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()).Times(2) go func() { @@ -1423,8 +1423,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) - packer.EXPECT().PackPacket().Return(getPacket(10), nil) - packer.EXPECT().PackPacket().Return(nil, nil) + packer.EXPECT().PackPacket(false).Return(getPacket(10), nil) + packer.EXPECT().PackPacket(false).Return(nil, nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) go func() { @@ -1441,7 +1441,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget() sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny) - packer.EXPECT().MaybePackAckPacket(gomock.Any()).Return(getPacket(10), nil) + packer.EXPECT().PackPacket(true).Return(getPacket(10), nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) go func() { @@ -1460,7 +1460,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true) sph.EXPECT().SendMode().Return(ackhandler.SendAny) sph.EXPECT().SendMode().Return(ackhandler.SendAck) - packer.EXPECT().PackPacket().Return(getPacket(100), nil) + packer.EXPECT().PackPacket(false).Return(getPacket(100), nil) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()) go func() { @@ -1477,12 +1477,12 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() gomock.InOrder( sph.EXPECT().HasPacingBudget().Return(true), - packer.EXPECT().PackPacket().Return(getPacket(100), nil), + packer.EXPECT().PackPacket(false).Return(getPacket(100), nil), sph.EXPECT().SentPacket(gomock.Any()), sph.EXPECT().HasPacingBudget(), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)), sph.EXPECT().HasPacingBudget().Return(true), - packer.EXPECT().PackPacket().Return(getPacket(101), nil), + packer.EXPECT().PackPacket(false).Return(getPacket(101), nil), sph.EXPECT().SentPacket(gomock.Any()), sph.EXPECT().HasPacingBudget(), sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)), @@ -1507,9 +1507,9 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget() sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(4) - packer.EXPECT().PackPacket().Return(getPacket(1000), nil) - packer.EXPECT().PackPacket().Return(getPacket(1001), nil) - packer.EXPECT().PackPacket().Return(getPacket(1002), nil) + packer.EXPECT().PackPacket(false).Return(getPacket(1000), nil) + packer.EXPECT().PackPacket(false).Return(getPacket(1001), nil) + packer.EXPECT().PackPacket(false).Return(getPacket(1002), nil) written := make(chan struct{}, 3) sender.EXPECT().WouldBlock().AnyTimes() sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }).Times(3) @@ -1539,8 +1539,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - packer.EXPECT().PackPacket().Return(getPacket(1000), nil) - packer.EXPECT().PackPacket().Return(nil, nil) + packer.EXPECT().PackPacket(false).Return(getPacket(1000), nil) + packer.EXPECT().PackPacket(false).Return(nil, nil) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) available <- struct{}{} Eventually(written).Should(BeClosed()) @@ -1562,8 +1562,8 @@ var _ = Describe("Connection", func() { }) sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - packer.EXPECT().PackPacket().Return(getPacket(1000), nil) - packer.EXPECT().PackPacket().Return(nil, nil) + packer.EXPECT().PackPacket(false).Return(getPacket(1000), nil) + packer.EXPECT().PackPacket(false).Return(nil, nil) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { close(written) }) conn.scheduleSending() @@ -1576,7 +1576,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().SentPacket(gomock.Any()) sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny) - packer.EXPECT().PackPacket().Return(getPacket(1000), nil) + packer.EXPECT().PackPacket(false).Return(getPacket(1000), nil) written := make(chan struct{}, 1) sender.EXPECT().WouldBlock() sender.EXPECT().WouldBlock().Return(true).Times(2) @@ -1597,8 +1597,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() - packer.EXPECT().PackPacket().Return(getPacket(1001), nil) - packer.EXPECT().PackPacket().Return(nil, nil) + packer.EXPECT().PackPacket(false).Return(getPacket(1001), nil) + packer.EXPECT().PackPacket(false).Return(nil, nil) sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) available <- struct{}{} Eventually(written).Should(Receive()) @@ -1612,7 +1612,7 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true) sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() sender.EXPECT().WouldBlock().AnyTimes() - packer.EXPECT().PackPacket() + packer.EXPECT().PackPacket(false) // don't EXPECT any calls to mconn.Write() go func() { defer GinkgoRecover() @@ -1681,8 +1681,8 @@ var _ = Describe("Connection", func() { sph.EXPECT().HasPacingBudget().Return(true).AnyTimes() sph.EXPECT().SentPacket(gomock.Any()) conn.sentPacketHandler = sph - packer.EXPECT().PackPacket().Return(getPacket(1), nil) - packer.EXPECT().PackPacket().Return(nil, nil) + packer.EXPECT().PackPacket(false).Return(getPacket(1), nil) + packer.EXPECT().PackPacket(false).Return(nil, nil) go func() { defer GinkgoRecover() @@ -1700,8 +1700,8 @@ var _ = Describe("Connection", func() { }) It("sets the timer to the ack timer", func() { - packer.EXPECT().PackPacket().Return(getPacket(1234), nil) - packer.EXPECT().PackPacket().Return(nil, nil) + packer.EXPECT().PackPacket(false).Return(getPacket(1234), nil) + packer.EXPECT().PackPacket(false).Return(nil, nil) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -1735,7 +1735,7 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) - packer.EXPECT().PackCoalescedPacket().Return(&coalescedPacket{ + packer.EXPECT().PackCoalescedPacket(false).Return(&coalescedPacket{ buffer: buffer, packets: []*packetContents{ { @@ -1760,7 +1760,7 @@ var _ = Describe("Connection", func() { }, }, }, nil) - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(false).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -1811,7 +1811,7 @@ var _ = Describe("Connection", func() { }) It("cancels the HandshakeComplete context when the handshake completes", func() { - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(false).AnyTimes() finishHandshake := make(chan struct{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) conn.sentPacketHandler = sph @@ -1847,7 +1847,7 @@ var _ = Describe("Connection", func() { It("sends a connection ticket when the handshake completes", func() { const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2 - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(false).AnyTimes() finishHandshake := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) go func() { @@ -1891,7 +1891,7 @@ var _ = Describe("Connection", func() { }) It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(false).AnyTimes() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) @@ -1924,7 +1924,7 @@ var _ = Describe("Connection", func() { conn.sentPacketHandler = sph done := make(chan struct{}) connRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { + packer.EXPECT().PackPacket(false).DoAndReturn(func(bool) (*packedPacket, error) { frames, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) @@ -1936,7 +1936,7 @@ var _ = Describe("Connection", func() { buffer: getPacketBuffer(), }, nil }) - packer.EXPECT().PackPacket().AnyTimes() + packer.EXPECT().PackPacket(false).AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() @@ -2014,7 +2014,7 @@ var _ = Describe("Connection", func() { } streamManager.EXPECT().UpdateLimits(params) packer.EXPECT().HandleTransportParameters(params) - packer.EXPECT().PackCoalescedPacket().MaxTimes(3) + packer.EXPECT().PackCoalescedPacket(false).MaxTimes(3) Expect(conn.earlyConnReady()).ToNot(BeClosed()) connRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2) connRunner.EXPECT().Add(gomock.Any(), conn).Times(2) @@ -2066,7 +2066,7 @@ var _ = Describe("Connection", func() { setRemoteIdleTimeout(5 * time.Second) conn.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2) sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { + packer.EXPECT().PackCoalescedPacket(false).Do(func(bool) (*packedPacket, error) { close(sent) return nil, nil }) @@ -2079,7 +2079,7 @@ var _ = Describe("Connection", func() { setRemoteIdleTimeout(time.Hour) conn.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond) sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { + packer.EXPECT().PackCoalescedPacket(false).Do(func(bool) (*packedPacket, error) { close(sent) return nil, nil }) @@ -2198,7 +2198,7 @@ var _ = Describe("Connection", func() { It("closes the connection due to the idle timeout before handshake", func() { conn.config.HandshakeIdleTimeout = 0 - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(false).AnyTimes() connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() cryptoSetup.EXPECT().Close() gomock.InOrder( @@ -2224,7 +2224,7 @@ var _ = Describe("Connection", func() { }) It("closes the connection due to the idle timeout after handshake", func() { - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(false).AnyTimes() gomock.InOrder( connRunner.EXPECT().Retire(clientDestConnID), connRunner.EXPECT().Remove(gomock.Any()), @@ -2743,7 +2743,7 @@ var _ = Describe("Client Connection", func() { }, } packer.EXPECT().HandleTransportParameters(gomock.Any()) - packer.EXPECT().PackCoalescedPacket().MaxTimes(1) + packer.EXPECT().PackCoalescedPacket(false).MaxTimes(1) tracer.EXPECT().ReceivedTransportParameters(params) conn.handleTransportParameters(params) conn.handleHandshakeComplete() diff --git a/mock_packer_test.go b/mock_packer_test.go index 54b7e482621..1e76e1b4b42 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -49,21 +49,6 @@ func (mr *MockPackerMockRecorder) HandleTransportParameters(arg0 interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleTransportParameters", reflect.TypeOf((*MockPacker)(nil).HandleTransportParameters), arg0) } -// MaybePackAckPacket mocks base method. -func (m *MockPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MaybePackAckPacket", handshakeConfirmed) - ret0, _ := ret[0].(*packedPacket) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// MaybePackAckPacket indicates an expected call of MaybePackAckPacket. -func (mr *MockPackerMockRecorder) MaybePackAckPacket(handshakeConfirmed interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket), handshakeConfirmed) -} - // MaybePackProbePacket mocks base method. func (m *MockPacker) MaybePackProbePacket(arg0 protocol.EncryptionLevel) (*packedPacket, error) { m.ctrl.T.Helper() @@ -95,18 +80,18 @@ func (mr *MockPackerMockRecorder) PackApplicationClose(arg0 interface{}) *gomock } // PackCoalescedPacket mocks base method. -func (m *MockPacker) PackCoalescedPacket() (*coalescedPacket, error) { +func (m *MockPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackCoalescedPacket") + ret := m.ctrl.Call(m, "PackCoalescedPacket", onlyAck) ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // PackCoalescedPacket indicates an expected call of PackCoalescedPacket. -func (mr *MockPackerMockRecorder) PackCoalescedPacket() *gomock.Call { +func (mr *MockPackerMockRecorder) PackCoalescedPacket(onlyAck interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket), onlyAck) } // PackConnectionClose mocks base method. @@ -140,18 +125,18 @@ func (mr *MockPackerMockRecorder) PackMTUProbePacket(ping, size interface{}) *go } // PackPacket mocks base method. -func (m *MockPacker) PackPacket() (*packedPacket, error) { +func (m *MockPacker) PackPacket(onlyAck bool) (*packedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackPacket") + ret := m.ctrl.Call(m, "PackPacket", onlyAck) ret0, _ := ret[0].(*packedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // PackPacket indicates an expected call of PackPacket. -func (mr *MockPackerMockRecorder) PackPacket() *gomock.Call { +func (mr *MockPackerMockRecorder) PackPacket(onlyAck interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket), onlyAck) } // SetMaxPacketSize mocks base method. diff --git a/packet_packer.go b/packet_packer.go index 8e4aa73552d..8b1d4772351 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -16,10 +16,9 @@ import ( ) type packer interface { - PackCoalescedPacket() (*coalescedPacket, error) - PackPacket() (*packedPacket, error) + PackCoalescedPacket(onlyAck bool) (*coalescedPacket, error) + PackPacket(onlyAck bool) (*packedPacket, error) MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) - MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) PackConnectionClose(*qerr.TransportError) (*coalescedPacket, error) PackApplicationClose(*qerr.ApplicationError) (*coalescedPacket, error) @@ -320,39 +319,6 @@ func (p *packetPacker) packetLength(hdr *wire.ExtendedHeader, payload *payload) return hdr.GetLength(p.version) + payload.length + paddingLen } -func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { - var encLevel protocol.EncryptionLevel - var ack *wire.AckFrame - if !handshakeConfirmed { - ack = p.acks.GetAckFrame(protocol.EncryptionInitial, true) - if ack != nil { - encLevel = protocol.EncryptionInitial - } else { - ack = p.acks.GetAckFrame(protocol.EncryptionHandshake, true) - if ack != nil { - encLevel = protocol.EncryptionHandshake - } - } - } - if ack == nil { - ack = p.acks.GetAckFrame(protocol.Encryption1RTT, true) - if ack == nil { - return nil, nil - } - encLevel = protocol.Encryption1RTT - } - payload := &payload{ - ack: ack, - length: ack.Length(p.version), - } - - sealer, hdr, err := p.getSealerAndHeader(encLevel) - if err != nil { - return nil, err - } - return p.writeSinglePacket(hdr, payload, encLevel, sealer) -} - // size is the expected size of the packet, if no padding was applied. func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protocol.ByteCount) protocol.ByteCount { // For the server, only ack-eliciting Initial packets need to be padded. @@ -368,7 +334,7 @@ func (p *packetPacker) initialPaddingLen(frames []ackhandler.Frame, size protoco // PackCoalescedPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. -func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { +func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, error) { maxPacketSize := p.maxPacketSize if p.perspective == protocol.PerspectiveClient { maxPacketSize = protocol.MinInitialPacketSize @@ -383,7 +349,7 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { } var size protocol.ByteCount if initialSealer != nil { - initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), size, protocol.EncryptionInitial) + initialHdr, initialPayload = p.maybeGetCryptoPacket(maxPacketSize-protocol.ByteCount(initialSealer.Overhead()), protocol.EncryptionInitial, onlyAck, true) if initialPayload != nil { size += p.packetLength(initialHdr, initialPayload) + protocol.ByteCount(initialSealer.Overhead()) numPackets++ @@ -392,14 +358,14 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { // Add a Handshake packet. var handshakeSealer sealer - if size < maxPacketSize-protocol.MinCoalescedPacketSize { + if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { var err error handshakeSealer, err = p.cryptoSetup.GetHandshakeSealer() if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { return nil, err } if handshakeSealer != nil { - handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), size, protocol.EncryptionHandshake) + handshakeHdr, handshakePayload = p.maybeGetCryptoPacket(maxPacketSize-size-protocol.ByteCount(handshakeSealer.Overhead()), protocol.EncryptionHandshake, onlyAck, size == 0) if handshakePayload != nil { s := p.packetLength(handshakeHdr, handshakePayload) + protocol.ByteCount(handshakeSealer.Overhead()) size += s @@ -411,7 +377,7 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { // Add a 0-RTT / 1-RTT packet. var appDataSealer sealer appDataEncLevel := protocol.Encryption1RTT - if size < maxPacketSize-protocol.MinCoalescedPacketSize { + if (onlyAck && size == 0) || (!onlyAck && size < maxPacketSize-protocol.MinCoalescedPacketSize) { var sErr error var oneRTTSealer handshake.ShortHeaderSealer oneRTTSealer, sErr = p.cryptoSetup.Get1RTTSealer() @@ -426,7 +392,7 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { case protocol.Encryption0RTT: appDataHdr, appDataPayload = p.maybeGetAppDataPacketFor0RTT(appDataSealer, maxPacketSize-size) case protocol.Encryption1RTT: - appDataHdr, appDataPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, maxPacketSize-size, size) + appDataHdr, appDataPayload = p.maybeGetShortHeaderPacket(oneRTTSealer, maxPacketSize-size, onlyAck, size == 0) } if appDataHdr != nil && appDataPayload != nil { size += p.packetLength(appDataHdr, appDataPayload) + protocol.ByteCount(appDataSealer.Overhead()) @@ -471,12 +437,12 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { // PackPacket packs a packet in the application data packet number space. // It should be called after the handshake is confirmed. -func (p *packetPacker) PackPacket() (*packedPacket, error) { +func (p *packetPacker) PackPacket(onlyAck bool) (*packedPacket, error) { sealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { return nil, err } - hdr, payload := p.maybeGetShortHeaderPacket(sealer, p.maxPacketSize, 0) + hdr, payload := p.maybeGetShortHeaderPacket(sealer, p.maxPacketSize, onlyAck, true) if payload == nil { return nil, nil } @@ -491,7 +457,17 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { }, nil } -func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*wire.ExtendedHeader, *payload) { +func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) { + if onlyAck { + if ack := p.acks.GetAckFrame(encLevel, true); ack != nil { + var payload payload + payload.ack = ack + payload.length = ack.Length(p.version) + return p.getLongHeader(encLevel), &payload + } + return nil, nil + } + var s cryptoStream var hasRetransmission bool //nolint:exhaustive // Initial and Handshake are the only two encryption levels here. @@ -506,7 +482,7 @@ func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize, currentSize protocol. hasData := s.HasData() var ack *wire.AckFrame - if encLevel == protocol.EncryptionInitial || currentSize == 0 { + if ackAllowed { ack = p.acks.GetAckFrame(encLevel, !hasRetransmission && !hasData) } if !hasData && !hasRetransmission && ack == nil { @@ -555,19 +531,19 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize hdr := p.getLongHeader(protocol.Encryption0RTT) maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) - payload := p.maybeGetAppDataPacket(maxPayloadSize, false) + payload := p.maybeGetAppDataPacket(maxPayloadSize, false, false) return hdr, payload } -func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, maxPacketSize, currentSize protocol.ByteCount) (*wire.ExtendedHeader, *payload) { +func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, maxPacketSize protocol.ByteCount, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) { hdr := p.getShortHeader(sealer.KeyPhase()) maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) - payload := p.maybeGetAppDataPacket(maxPayloadSize, currentSize == 0) + payload := p.maybeGetAppDataPacket(maxPayloadSize, onlyAck, ackAllowed) return hdr, payload } -func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, ackAllowed bool) *payload { - payload := p.composeNextPacket(maxPayloadSize, ackAllowed) +func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, onlyAck, ackAllowed bool) *payload { + payload := p.composeNextPacket(maxPayloadSize, onlyAck, ackAllowed) // check if we have anything to send if len(payload.frames) == 0 { @@ -590,7 +566,17 @@ func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, return payload } -func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload { +func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, onlyAck, ackAllowed bool) *payload { + if onlyAck { + if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, true); ack != nil { + payload := &payload{} + payload.ack = ack + payload.length += ack.Length(p.version) + return payload + } + return &payload{} + } + payload := &payload{frames: make([]ackhandler.Frame, 0, 1)} hasData := p.framer.HasData() @@ -663,14 +649,14 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( if err != nil { return nil, err } - hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionInitial) + hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionInitial, false, true) case protocol.EncryptionHandshake: var err error sealer, err = p.cryptoSetup.GetHandshakeSealer() if err != nil { return nil, err } - hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), 0, protocol.EncryptionHandshake) + hdr, payload = p.maybeGetCryptoPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead()), protocol.EncryptionHandshake, false, true) case protocol.Encryption1RTT: oneRTTSealer, err := p.cryptoSetup.Get1RTTSealer() if err != nil { @@ -678,7 +664,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( } sealer = oneRTTSealer hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) - payload = p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), true) + payload = p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), false, true) default: panic("unknown encryption level") } @@ -724,41 +710,6 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B }, nil } -func (p *packetPacker) getSealerAndHeader(encLevel protocol.EncryptionLevel) (sealer, *wire.ExtendedHeader, error) { - switch encLevel { - case protocol.EncryptionInitial: - sealer, err := p.cryptoSetup.GetInitialSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.EncryptionInitial) - return sealer, hdr, nil - case protocol.Encryption0RTT: - sealer, err := p.cryptoSetup.Get0RTTSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.Encryption0RTT) - return sealer, hdr, nil - case protocol.EncryptionHandshake: - sealer, err := p.cryptoSetup.GetHandshakeSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getLongHeader(protocol.EncryptionHandshake) - return sealer, hdr, nil - case protocol.Encryption1RTT: - sealer, err := p.cryptoSetup.Get1RTTSealer() - if err != nil { - return nil, nil, err - } - hdr := p.getShortHeader(sealer.KeyPhase()) - return sealer, hdr, nil - default: - return nil, nil, fmt.Errorf("unexpected encryption level: %s", encLevel) - } -} - func (p *packetPacker) getShortHeader(kp protocol.KeyPhaseBit) *wire.ExtendedHeader { pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) hdr := &wire.ExtendedHeader{} @@ -793,28 +744,6 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex return hdr } -// writeSinglePacket packs a single packet. -func (p *packetPacker) writeSinglePacket( - hdr *wire.ExtendedHeader, - payload *payload, - encLevel protocol.EncryptionLevel, - sealer sealer, -) (*packedPacket, error) { - buffer := getPacketBuffer() - var paddingLen protocol.ByteCount - if encLevel == protocol.EncryptionInitial { - paddingLen = p.initialPaddingLen(payload.frames, hdr.GetLength(p.version)+payload.length+protocol.ByteCount(sealer.Overhead())) - } - contents, err := p.appendPacket(buffer, hdr, payload, paddingLen, encLevel, sealer, false) - if err != nil { - return nil, err - } - return &packedPacket{ - buffer: buffer, - packetContents: contents, - }, nil -} - func (p *packetPacker) appendPacket(buffer *packetBuffer, header *wire.ExtendedHeader, payload *payload, padding protocol.ByteCount, encLevel protocol.EncryptionLevel, sealer sealer, isMTUProbePacket bool) (*packetContents, error) { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) diff --git a/packet_packer_test.go b/packet_packer_test.go index 3d5d456b8f9..e1e9cb95125 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -191,7 +191,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() f := &wire.StreamFrame{Data: []byte{0xde, 0xca, 0xfb, 0xad}} expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.packets).To(HaveLen(1)) @@ -220,10 +220,14 @@ var _ = Describe("Packet packer", func() { Context("packing ACK packets", func() { It("doesn't pack a packet if there's no ACK to send", func() { + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) - p, err := packer.MaybePackAckPacket(false) + p, err := packer.PackCoalescedPacket(true) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -235,11 +239,13 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).Return(ack) - p, err := packer.MaybePackAckPacket(false) + p, err := packer.PackCoalescedPacket(true) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.ack).To(Equal(ack)) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].ack).To(Equal(ack)) + Expect(p.packets[0].frames).To(BeEmpty()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) parsePacket(p.buffer.Data) }) @@ -250,25 +256,47 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true).Return(ack) - p, err := packer.MaybePackAckPacket(false) + p, err := packer.PackCoalescedPacket(true) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) - Expect(p.ack).To(Equal(ack)) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) + Expect(p.packets[0].ack).To(Equal(ack)) + Expect(p.packets[0].frames).To(BeEmpty()) + Expect(p.buffer.Len()).To(BeNumerically("<", 100)) + parsePacket(p.buffer.Data) + }) + + It("packs 1-RTT ACK-only packets, before handshake confirmation", func() { + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, true) + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) + p, err := packer.PackCoalescedPacket(true) + Expect(err).NotTo(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.packets).To(HaveLen(1)) + Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) + Expect(p.packets[0].ack).To(Equal(ack)) + Expect(p.packets[0].frames).To(BeEmpty()) parsePacket(p.buffer.Data) }) - It("packs 1-RTT ACK-only packets", func() { + It("packs 1-RTT ACK-only packets, after handshake confirmation", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) - p, err := packer.MaybePackAckPacket(true) + p, err := packer.PackPacket(true) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) Expect(p.ack).To(Equal(ack)) + Expect(p.frames).To(BeEmpty()) parsePacket(p.buffer.Data) }) }) @@ -300,7 +328,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { return frames, 0 }) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) @@ -479,7 +507,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) framer.EXPECT().HasData() - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) }) @@ -496,7 +524,7 @@ var _ = Describe("Packet packer", func() { Data: []byte{0xde, 0xca, 0xfb, 0xad}, } expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) b, err := f.Append(nil, packer.version) @@ -516,7 +544,7 @@ var _ = Describe("Packet packer", func() { StreamID: 5, Data: []byte("foobar"), }}) - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) }) @@ -528,7 +556,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(ack) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.ack).To(Equal(ack)) @@ -546,7 +574,7 @@ var _ = Describe("Packet packer", func() { } expectAppendControlFrames(frames...) expectAppendStreamFrames() - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(Equal(frames)) @@ -572,7 +600,7 @@ var _ = Describe("Packet packer", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) framer.EXPECT().HasData() - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) @@ -600,7 +628,7 @@ var _ = Describe("Packet packer", func() { time.Sleep(scaleDuration(20 * time.Millisecond)) framer.EXPECT().HasData() - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.ack).ToNot(BeNil()) @@ -627,7 +655,7 @@ var _ = Describe("Packet packer", func() { return fs, 0 }), ) - _, err := packer.PackPacket() + _, err := packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) }) @@ -641,7 +669,7 @@ var _ = Describe("Packet packer", func() { packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) handshakeStream.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - packet, err := packer.PackCoalescedPacket() + packet, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) Expect(packet.packets).To(HaveLen(1)) @@ -684,7 +712,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - packet, err := packer.PackPacket() + packet, err := packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] @@ -735,7 +763,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: f1}, ackhandler.Frame{Frame: f2}, ackhandler.Frame{Frame: f3}) - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(3)) @@ -754,7 +782,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.ack).ToNot(BeNil()) @@ -771,7 +799,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) var hasPing bool @@ -790,7 +818,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) expectAppendControlFrames() expectAppendStreamFrames() - p, err = packer.PackPacket() + p, err = packer.PackPacket(false) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.ack).ToNot(BeNil()) @@ -806,7 +834,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() expectAppendStreamFrames() ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) // now add some frame to send @@ -818,7 +846,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().HasData().Return(true) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false).Return(ack) - p, err = packer.PackPacket() + p, err = packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.ack).To(Equal(ack)) var hasPing bool @@ -840,7 +868,7 @@ var _ = Describe("Packet packer", func() { ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, false) expectAppendStreamFrames() expectAppendControlFrames(ackhandler.Frame{Frame: &wire.MaxDataFrame{}}) - p, err := packer.PackPacket() + p, err := packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) @@ -859,7 +887,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackPacket() + _, err := packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) // now reduce the maxPacketSize packer.HandleTransportParameters(&wire.TransportParameters{ @@ -870,7 +898,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackPacket() + _, err = packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) }) @@ -885,7 +913,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackPacket() + _, err := packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) // now try to increase the maxPacketSize packer.HandleTransportParameters(&wire.TransportParameters{ @@ -896,7 +924,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackPacket() + _, err = packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) }) }) @@ -913,7 +941,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err := packer.PackPacket() + _, err := packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) // now reduce the maxPacketSize const packetSizeIncrease = 50 @@ -923,7 +951,7 @@ var _ = Describe("Packet packer", func() { return nil, 0 }) expectAppendStreamFrames() - _, err = packer.PackPacket() + _, err = packer.PackPacket(false) Expect(err).ToNot(HaveOccurred()) }) }) @@ -943,7 +971,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) parsePacket(p.buffer.Data) @@ -962,7 +990,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { return &wire.CryptoFrame{Offset: 0x42, Data: []byte("initial")} }) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) @@ -991,7 +1019,7 @@ var _ = Describe("Packet packer", func() { Expect(f.Length(packer.version)).To(Equal(size)) return f }) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].frames).To(HaveLen(1)) @@ -1018,7 +1046,7 @@ var _ = Describe("Packet packer", func() { handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")} }) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) Expect(p.packets).To(HaveLen(2)) @@ -1050,7 +1078,7 @@ var _ = Describe("Packet packer", func() { }) handshakeStream.EXPECT().HasData() packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) Expect(p.packets).To(HaveLen(2)) @@ -1079,7 +1107,7 @@ var _ = Describe("Packet packer", func() { handshakeStream.EXPECT().HasData() packer.retransmissionQueue.AddInitial(&wire.PingFrame{}) packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) Expect(p.packets).To(HaveLen(2)) @@ -1112,7 +1140,7 @@ var _ = Describe("Packet packer", func() { expectAppendStreamFrames() framer.EXPECT().HasData().Return(true) packer.retransmissionQueue.AddAppData(&wire.PingFrame{}) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeEquivalentTo(packer.maxPacketSize)) Expect(p.packets).To(HaveLen(2)) @@ -1147,7 +1175,7 @@ var _ = Describe("Packet packer", func() { }) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) @@ -1181,7 +1209,7 @@ var _ = Describe("Packet packer", func() { }) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically("<", 100)) Expect(p.packets).To(HaveLen(2)) @@ -1215,7 +1243,7 @@ var _ = Describe("Packet packer", func() { Expect(f.Length(packer.version)).To(Equal(s)) return f }) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) @@ -1233,7 +1261,7 @@ var _ = Describe("Packet packer", func() { packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) handshakeStream.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake, false) - packet, err := packer.PackCoalescedPacket() + packet, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) Expect(packet.packets).To(HaveLen(1)) @@ -1273,7 +1301,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, false) initialStream.EXPECT().HasData() - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -1290,7 +1318,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].ack).To(Equal(ack)) @@ -1302,7 +1330,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) initialStream.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial, true) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -1318,7 +1346,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].ack).To(Equal(ack)) @@ -1341,7 +1369,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.perspective = protocol.PerspectiveClient - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeNumerically(">=", protocol.MinInitialPacketSize)) Expect(p.buffer.Len()).To(BeEquivalentTo(maxPacketSize)) @@ -1367,7 +1395,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS packer.perspective = protocol.PerspectiveClient - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(false) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].ack).To(Equal(ack))