diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 9dbba574635..12ffb918aa7 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -23,6 +23,8 @@ const ( amplificationFactor = 3 // We use Retry packets to derive an RTT estimate. Make sure we don't set the RTT to a super low value yet. minRTTAfterRetry = 5 * time.Millisecond + // The PTO duration uses exponential backoff, but is truncated to a maximum value, as allowed by RFC 8961, section 4.4. + maxPTODuration = 60 * time.Second ) type packetNumberSpace struct { @@ -457,6 +459,14 @@ func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.Encryptio return lossTime, encLevel } +func (h *sentPacketHandler) getScaledPTO(includeMaxAckDelay bool) time.Duration { + pto := h.rttStats.PTO(includeMaxAckDelay) << h.ptoCount + if pto > maxPTODuration || pto <= 0 { + return maxPTODuration + } + return pto +} + // same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) { // We only send application data probe packets once the handshake is confirmed, @@ -465,7 +475,7 @@ func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protoc if h.peerCompletedAddressValidation { return } - t := time.Now().Add(h.rttStats.PTO(false) << h.ptoCount) + t := time.Now().Add(h.getScaledPTO(false)) if h.initialPackets != nil { return t, protocol.EncryptionInitial, true } @@ -475,18 +485,18 @@ func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protoc if h.initialPackets != nil { encLevel = protocol.EncryptionInitial if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() { - pto = t.Add(h.rttStats.PTO(false) << h.ptoCount) + pto = t.Add(h.getScaledPTO(false)) } } if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() { - t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(false) << h.ptoCount) + t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(false)) if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { pto = t encLevel = protocol.EncryptionHandshake } } if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() { - t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount) + t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.getScaledPTO(true)) if pto.IsZero() || (!t.IsZero() && t.Before(pto)) { pto = t encLevel = protocol.Encryption1RTT diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index c39b50f48df..2663a69bbe9 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -687,6 +687,14 @@ var _ = Describe("SentPacketHandler", func() { handler.ptoCount = 2 handler.setLossDetectionTimer() Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(4 * timeout)) + // truncated when the exponential gets too large + handler.ptoCount = 20 + handler.setLossDetectionTimer() + Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(maxPTODuration)) + // protected from rollover + handler.ptoCount = 100 + handler.setLossDetectionTimer() + Expect(handler.GetLossDetectionTimeout().Sub(sendTime)).To(Equal(maxPTODuration)) }) It("reset the PTO count when receiving an ACK", func() { @@ -1036,7 +1044,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("correctly sets the timer after the Initial packet number space has been dropped", func() { - handler.SentPacket(initialPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-42 * time.Second)})) + handler.SentPacket(initialPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-19 * time.Second)})) _, err := handler.ReceivedAck( &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}, protocol.EncryptionInitial, @@ -1048,6 +1056,8 @@ var _ = Describe("SentPacketHandler", func() { pto := handler.rttStats.PTO(false) Expect(pto).ToNot(BeZero()) + // pto is approximately 19 * 3. Using a number > 19 above will + // run into the maxPTODuration limit Expect(handler.GetLossDetectionTimeout()).To(BeTemporally("~", time.Now().Add(pto), 10*time.Millisecond)) })