From d5efd340c79e5a19ac69a2523527b56be4371715 Mon Sep 17 00:00:00 2001 From: Toby Date: Sun, 24 Jul 2022 12:50:41 -0700 Subject: [PATCH] optimize FirstOutstanding in the sent packet history (#3467) * optimize FirstOutstanding * fix variable naming * bug fix * minor code improvements * add a test to make sure that `Iterate` iterates in the right order * add comment --- internal/ackhandler/interfaces.go | 4 + internal/ackhandler/sent_packet_handler.go | 4 +- internal/ackhandler/sent_packet_history.go | 99 ++++++++++++++----- .../ackhandler/sent_packet_history_test.go | 18 ++-- 4 files changed, 89 insertions(+), 36 deletions(-) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 5777d97a7b6..226bfcbbcc5 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -23,6 +23,10 @@ type Packet struct { skippedPacket bool } +func (p *Packet) outstanding() bool { + return !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket +} + // SentPacketHandler handles ACKs received for outgoing packets type SentPacketHandler interface { // SentPacket may modify the packet diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index 7df91f23f7d..ff8cc8d2bdf 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -598,7 +598,7 @@ func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.E pnSpace.lossTime = lossTime } if packetLost { - p.declaredLost = true + p = pnSpace.history.DeclareLost(p) // the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted h.removeFromBytesInFlight(p) h.queueFramesForRetransmission(p) @@ -767,7 +767,7 @@ func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) // TODO: don't declare the packet lost here. // Keep track of acknowledged frames instead. h.removeFromBytesInFlight(p) - p.declaredLost = true + pnSpace.history.DeclareLost(p) return true } diff --git a/internal/ackhandler/sent_packet_history.go b/internal/ackhandler/sent_packet_history.go index 36489367dcf..d5704dd2c75 100644 --- a/internal/ackhandler/sent_packet_history.go +++ b/internal/ackhandler/sent_packet_history.go @@ -9,18 +9,20 @@ import ( ) type sentPacketHistory struct { - rttStats *utils.RTTStats - packetList *PacketList - packetMap map[protocol.PacketNumber]*PacketElement - highestSent protocol.PacketNumber + rttStats *utils.RTTStats + outstandingPacketList *PacketList + etcPacketList *PacketList + packetMap map[protocol.PacketNumber]*PacketElement + highestSent protocol.PacketNumber } func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory { return &sentPacketHistory{ - rttStats: rttStats, - packetList: NewPacketList(), - packetMap: make(map[protocol.PacketNumber]*PacketElement), - highestSent: protocol.InvalidPacketNumber, + rttStats: rttStats, + outstandingPacketList: NewPacketList(), + etcPacketList: NewPacketList(), + packetMap: make(map[protocol.PacketNumber]*PacketElement), + highestSent: protocol.InvalidPacketNumber, } } @@ -30,7 +32,7 @@ func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) { } // Skipped packet numbers. for pn := h.highestSent + 1; pn < p.PacketNumber; pn++ { - el := h.packetList.PushBack(Packet{ + el := h.etcPacketList.PushBack(Packet{ PacketNumber: pn, EncryptionLevel: p.EncryptionLevel, SendTime: p.SendTime, @@ -41,7 +43,12 @@ func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) { h.highestSent = p.PacketNumber if isAckEliciting { - el := h.packetList.PushBack(*p) + var el *PacketElement + if p.outstanding() { + el = h.outstandingPacketList.PushBack(*p) + } else { + el = h.etcPacketList.PushBack(*p) + } h.packetMap[p.PacketNumber] = el } } @@ -49,10 +56,25 @@ func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) { // Iterate iterates through all packets. func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error { cont := true - var next *PacketElement - for el := h.packetList.Front(); cont && el != nil; el = next { + outstandingEl := h.outstandingPacketList.Front() + etcEl := h.etcPacketList.Front() + var el *PacketElement + // whichever has the next packet number is returned first + for cont { + if outstandingEl == nil || (etcEl != nil && etcEl.Value.PacketNumber < outstandingEl.Value.PacketNumber) { + el = etcEl + } else { + el = outstandingEl + } + if el == nil { + return nil + } + if el == outstandingEl { + outstandingEl = outstandingEl.Next() + } else { + etcEl = etcEl.Next() + } var err error - next = el.Next() cont, err = cb(&el.Value) if err != nil { return err @@ -61,15 +83,13 @@ func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) err return nil } -// FirstOutStanding returns the first outstanding packet. +// FirstOutstanding returns the first outstanding packet. func (h *sentPacketHistory) FirstOutstanding() *Packet { - for el := h.packetList.Front(); el != nil; el = el.Next() { - p := &el.Value - if !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket { - return p - } + el := h.outstandingPacketList.Front() + if el == nil { + return nil } - return nil + return &el.Value } func (h *sentPacketHistory) Len() int { @@ -81,28 +101,53 @@ func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error { if !ok { return fmt.Errorf("packet %d not found in sent packet history", p) } - h.packetList.Remove(el) + h.outstandingPacketList.Remove(el) + h.etcPacketList.Remove(el) delete(h.packetMap, p) return nil } func (h *sentPacketHistory) HasOutstandingPackets() bool { - return h.FirstOutstanding() != nil + return h.outstandingPacketList.Len() > 0 } func (h *sentPacketHistory) DeleteOldPackets(now time.Time) { maxAge := 3 * h.rttStats.PTO(false) var nextEl *PacketElement - for el := h.packetList.Front(); el != nil; el = nextEl { + // we don't iterate outstandingPacketList, as we should not delete outstanding packets. + // being outstanding for more than 3*PTO should only happen in the case of drastic RTT changes. + for el := h.etcPacketList.Front(); el != nil; el = nextEl { nextEl = el.Next() p := el.Value if p.SendTime.After(now.Add(-maxAge)) { break } - if !p.skippedPacket && !p.declaredLost { // should only happen in the case of drastic RTT changes - continue - } delete(h.packetMap, p.PacketNumber) - h.packetList.Remove(el) + h.etcPacketList.Remove(el) + } +} + +func (h *sentPacketHistory) DeclareLost(p *Packet) *Packet { + el, ok := h.packetMap[p.PacketNumber] + if !ok { + return nil + } + // try to remove it from both lists, as we don't know which one it currently belongs to. + // Remove is a no-op for elements that are not in the list. + h.outstandingPacketList.Remove(el) + h.etcPacketList.Remove(el) + p.declaredLost = true + // move it to the correct position in the etc list (based on the packet number) + for el = h.etcPacketList.Back(); el != nil; el = el.Prev() { + if el.Value.PacketNumber < p.PacketNumber { + break + } + } + if el == nil { + el = h.etcPacketList.PushFront(*p) + } else { + el = h.etcPacketList.InsertAfter(*p, el) } + h.packetMap[p.PacketNumber] = el + return &el.Value } diff --git a/internal/ackhandler/sent_packet_history_test.go b/internal/ackhandler/sent_packet_history_test.go index bf876d3214f..e539e4e4426 100644 --- a/internal/ackhandler/sent_packet_history_test.go +++ b/internal/ackhandler/sent_packet_history_test.go @@ -25,11 +25,12 @@ var _ = Describe("SentPacketHistory", func() { } } var listLen int - for el := hist.packetList.Front(); el != nil; el = el.Next() { - if !el.Value.skippedPacket { + hist.Iterate(func(p *Packet) (bool, error) { + if !p.skippedPacket { listLen++ } - } + return true, nil + }) ExpectWithOffset(1, mapLen).To(Equal(len(packetNumbers))) ExpectWithOffset(1, listLen).To(Equal(len(packetNumbers))) i := 0 @@ -63,9 +64,10 @@ var _ = Describe("SentPacketHistory", func() { hist.SentPacket(&Packet{PacketNumber: 3}, false) hist.SentPacket(&Packet{PacketNumber: 4}, true) expectInHistory([]protocol.PacketNumber{1, 4}) - for el := hist.packetList.Front(); el != nil; el = el.Next() { - Expect(el.Value.PacketNumber).ToNot(Equal(protocol.PacketNumber(3))) - } + hist.Iterate(func(p *Packet) (bool, error) { + Expect(p.PacketNumber).ToNot(Equal(protocol.PacketNumber(3))) + return true, nil + }) }) It("gets the length", func() { @@ -132,17 +134,19 @@ var _ = Describe("SentPacketHistory", func() { }) It("also iterates over skipped packets", func() { - var packets, skippedPackets []protocol.PacketNumber + var packets, skippedPackets, allPackets []protocol.PacketNumber Expect(hist.Iterate(func(p *Packet) (bool, error) { if p.skippedPacket { skippedPackets = append(skippedPackets, p.PacketNumber) } else { packets = append(packets, p.PacketNumber) } + allPackets = append(allPackets, p.PacketNumber) return true, nil })).To(Succeed()) Expect(packets).To(Equal([]protocol.PacketNumber{1, 4, 8})) Expect(skippedPackets).To(Equal([]protocol.PacketNumber{0, 2, 3, 5, 6, 7})) + Expect(allPackets).To(Equal([]protocol.PacketNumber{0, 1, 2, 3, 4, 5, 6, 7, 8})) }) It("stops iterating", func() {