diff --git a/connection.go b/connection.go index ce19fbb48c2..c192850db6e 100644 --- a/connection.go +++ b/connection.go @@ -542,7 +542,7 @@ func (s *connection) preSetup() { s.creationTime = now s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.framer.QueueControlFrame) - s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger) + s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger, s.version) } // run the connection main loop diff --git a/datagram_queue.go b/datagram_queue.go index b1cbbf6dcc3..2f62d8c188c 100644 --- a/datagram_queue.go +++ b/datagram_queue.go @@ -1,14 +1,20 @@ package quic import ( + "sync" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" ) type datagramQueue struct { - sendQueue chan *wire.DatagramFrame - rcvQueue chan []byte + mx sync.Mutex + nextToSend *wire.DatagramFrame + + sending chan struct{} // semaphore + + rcvQueue chan []byte closeErr error closed chan struct{} @@ -17,17 +23,19 @@ type datagramQueue struct { dequeued chan struct{} - logger utils.Logger + logger utils.Logger + version protocol.VersionNumber } -func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { +func newDatagramQueue(hasData func(), logger utils.Logger, v protocol.VersionNumber) *datagramQueue { return &datagramQueue{ - hasData: hasData, - sendQueue: make(chan *wire.DatagramFrame, 1), - rcvQueue: make(chan []byte, protocol.DatagramRcvQueueLen), - dequeued: make(chan struct{}), - closed: make(chan struct{}), - logger: logger, + hasData: hasData, + sending: make(chan struct{}, 1), + rcvQueue: make(chan []byte, protocol.DatagramRcvQueueLen), + dequeued: make(chan struct{}), + closed: make(chan struct{}), + logger: logger, + version: v, } } @@ -35,7 +43,10 @@ func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue { // It blocks until the frame has been dequeued. func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { select { - case h.sendQueue <- f: + case h.sending <- struct{}{}: + h.mx.Lock() + h.nextToSend = f + h.mx.Unlock() h.hasData() case <-h.closed: return h.closeErr @@ -52,7 +63,8 @@ func (h *datagramQueue) AddAndWait(f *wire.DatagramFrame) error { // Get dequeues a DATAGRAM frame for sending. func (h *datagramQueue) Get() *wire.DatagramFrame { select { - case f := <-h.sendQueue: + case <-h.sending: + f := h.nextToSend h.dequeued <- struct{}{} return f default: @@ -60,6 +72,16 @@ func (h *datagramQueue) Get() *wire.DatagramFrame { } } +func (h *datagramQueue) NextFrameSize() protocol.ByteCount { + h.mx.Lock() + defer h.mx.Unlock() + + if h.nextToSend == nil { + return 0 + } + return h.nextToSend.Length(h.version) +} + // HandleDatagramFrame handles a received DATAGRAM frame. func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) { data := make([]byte, len(f.Data)) diff --git a/datagram_queue_test.go b/datagram_queue_test.go index 0ff7b96efa5..347b17b98e5 100644 --- a/datagram_queue_test.go +++ b/datagram_queue_test.go @@ -3,6 +3,7 @@ package quic import ( "errors" + "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" @@ -18,7 +19,7 @@ var _ = Describe("Datagram Queue", func() { queued = make(chan struct{}, 100) queue = newDatagramQueue(func() { queued <- struct{}{} - }, utils.DefaultLogger) + }, utils.DefaultLogger, protocol.Version1) }) Context("sending", func() { diff --git a/packet_packer.go b/packet_packer.go index 222d07f4b9b..f167031699e 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -555,18 +555,18 @@ func (p *packetPacker) maybeGetAppDataPacketFor0RTT(sealer sealer, maxPacketSize hdr := p.getLongHeader(protocol.Encryption0RTT) maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) - payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, false) + payload := p.maybeGetAppDataPacket(maxPayloadSize, false) return hdr, payload } func (p *packetPacker) maybeGetShortHeaderPacket(sealer handshake.ShortHeaderSealer, maxPacketSize, currentSize protocol.ByteCount) (*wire.ExtendedHeader, *payload) { hdr := p.getShortHeader(sealer.KeyPhase()) maxPayloadSize := maxPacketSize - hdr.GetLength(p.version) - protocol.ByteCount(sealer.Overhead()) - payload := p.maybeGetAppDataPacketWithEncLevel(maxPayloadSize, currentSize == 0) + payload := p.maybeGetAppDataPacket(maxPayloadSize, currentSize == 0) return hdr, payload } -func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol.ByteCount, ackAllowed bool) *payload { +func (p *packetPacker) maybeGetAppDataPacket(maxPayloadSize protocol.ByteCount, ackAllowed bool) *payload { payload := p.composeNextPacket(maxPayloadSize, ackAllowed) // check if we have anything to send @@ -593,32 +593,35 @@ func (p *packetPacker) maybeGetAppDataPacketWithEncLevel(maxPayloadSize protocol func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount, ackAllowed bool) *payload { payload := &payload{frames: make([]ackhandler.Frame, 0, 1)} - var hasDatagram bool + hasData := p.framer.HasData() + hasRetransmission := p.retransmissionQueue.HasAppData() + + var hasAck bool + if ackAllowed { + if ack := p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData); ack != nil { + payload.ack = ack + payload.length += ack.Length(p.version) + hasAck = true + } + } + if p.datagramQueue != nil { - if datagram := p.datagramQueue.Get(); datagram != nil { + size := p.datagramQueue.NextFrameSize() + if size > 0 && size <= maxFrameSize-payload.length { + datagram := p.datagramQueue.Get() + if datagram.Length(p.version) != size { + panic("packet packer BUG: inconsistent DATAGRAM frame length") + } payload.frames = append(payload.frames, ackhandler.Frame{ Frame: datagram, // set it to a no-op. Then we won't set the default callback, which would retransmit the frame. OnLost: func(wire.Frame) {}, }) payload.length += datagram.Length(p.version) - hasDatagram = true - } - } - - var ack *wire.AckFrame - hasData := p.framer.HasData() - hasRetransmission := p.retransmissionQueue.HasAppData() - // TODO: make sure ACKs are sent when a lot of DATAGRAMs are queued - if !hasDatagram && ackAllowed { - ack = p.acks.GetAckFrame(protocol.Encryption1RTT, !hasRetransmission && !hasData) - if ack != nil { - payload.ack = ack - payload.length += ack.Length(p.version) } } - if ack == nil && !hasData && !hasRetransmission { + if hasAck && !hasData && !hasRetransmission { return payload } @@ -675,7 +678,7 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( } sealer = oneRTTSealer hdr = p.getShortHeader(oneRTTSealer.KeyPhase()) - payload = p.maybeGetAppDataPacketWithEncLevel(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), true) + payload = p.maybeGetAppDataPacket(p.maxPacketSize-protocol.ByteCount(sealer.Overhead())-hdr.GetLength(p.version), true) default: panic("unknown encryption level") } diff --git a/packet_packer_test.go b/packet_packer_test.go index f1e51e1b35d..3d5d456b8f9 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -91,7 +91,7 @@ var _ = Describe("Packet packer", func() { ackFramer = NewMockAckFrameSource(mockCtrl) sealingManager = NewMockSealingManager(mockCtrl) pnManager = mockackhandler.NewMockSentPacketHandler(mockCtrl) - datagramQueue = newDatagramQueue(func() {}, utils.DefaultLogger) + datagramQueue = newDatagramQueue(func() {}, utils.DefaultLogger, version) packer = newPacketPacker( protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), @@ -554,6 +554,7 @@ var _ = Describe("Packet packer", func() { }) It("packs DATAGRAM frames", func() { + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true) pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) @@ -580,6 +581,36 @@ var _ = Describe("Packet packer", func() { Eventually(done).Should(BeClosed()) }) + It("doesn't pack a DATAGRAM frame if the ACK frame is too large", func() { + ackFramer.EXPECT().GetAckFrame(protocol.Encryption1RTT, true).Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100}}}) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)) + sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil) + f := &wire.DatagramFrame{ + DataLenPresent: true, + Data: make([]byte, maxPacketSize-10), + } + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + datagramQueue.AddAndWait(f) + }() + // make sure the DATAGRAM has actually been queued + time.Sleep(scaleDuration(20 * time.Millisecond)) + + framer.EXPECT().HasData() + p, err := packer.PackPacket() + Expect(p).ToNot(BeNil()) + Expect(err).ToNot(HaveOccurred()) + Expect(p.ack).ToNot(BeNil()) + Expect(p.frames).To(BeEmpty()) + Expect(p.buffer.Data).ToNot(BeEmpty()) + Expect(done).ToNot(BeClosed()) + datagramQueue.CloseWithError(nil) + Eventually(done).Should(BeClosed()) + }) + It("accounts for the space consumed by control frames", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) sealingManager.EXPECT().Get1RTTSealer().Return(getSealer(), nil)