diff --git a/closed_conn.go b/closed_conn.go index b9585e77d48..a888d054c0d 100644 --- a/closed_conn.go +++ b/closed_conn.go @@ -2,7 +2,7 @@ package quic import ( "math/bits" - "sync" + "net" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -12,81 +12,37 @@ import ( // When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame, // with an exponential backoff. type closedLocalConn struct { - conn sendConn - connClosePacket []byte - - closeOnce sync.Once - closeChan chan struct{} // is closed when the connection is closed or destroyed - - receivedPackets chan *receivedPacket - counter uint64 // number of packets received - + counter uint32 perspective protocol.Perspective + logger utils.Logger - logger utils.Logger + sendPacket func(net.Addr, *packetInfo) } var _ packetHandler = &closedLocalConn{} // newClosedLocalConn creates a new closedLocalConn and runs it. -func newClosedLocalConn( - conn sendConn, - connClosePacket []byte, - perspective protocol.Perspective, - logger utils.Logger, -) packetHandler { - s := &closedLocalConn{ - conn: conn, - connClosePacket: connClosePacket, - perspective: perspective, - logger: logger, - closeChan: make(chan struct{}), - receivedPackets: make(chan *receivedPacket, 64), +func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler { + return &closedLocalConn{ + sendPacket: sendPacket, + perspective: pers, + logger: logger, } - go s.run() - return s } -func (s *closedLocalConn) run() { - for { - select { - case p := <-s.receivedPackets: - s.handlePacketImpl(p) - case <-s.closeChan: - return - } - } -} - -func (s *closedLocalConn) handlePacket(p *receivedPacket) { - select { - case s.receivedPackets <- p: - default: - } -} - -func (s *closedLocalConn) handlePacketImpl(_ *receivedPacket) { - s.counter++ +func (c *closedLocalConn) handlePacket(p *receivedPacket) { + c.counter++ // exponential backoff // only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving - if bits.OnesCount64(s.counter) != 1 { + if bits.OnesCount32(c.counter) != 1 { return } - s.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", s.counter) - if err := s.conn.Write(s.connClosePacket); err != nil { - s.logger.Debugf("Error retransmitting CONNECTION_CLOSE: %s", err) - } + c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", c.counter) + c.sendPacket(p.remoteAddr, p.info) } -func (s *closedLocalConn) shutdown() { - s.destroy(nil) -} - -func (s *closedLocalConn) destroy(error) { - s.closeOnce.Do(func() { - close(s.closeChan) - }) -} +func (s *closedLocalConn) shutdown() {} +func (s *closedLocalConn) destroy(error) {} func (s *closedLocalConn) getPerspective() protocol.Perspective { return s.perspective diff --git a/closed_conn_test.go b/closed_conn_test.go index e81b0050ed7..ea29e10eb92 100644 --- a/closed_conn_test.go +++ b/closed_conn_test.go @@ -1,10 +1,8 @@ package quic import ( - "errors" - "time" + "net" - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -13,44 +11,28 @@ import ( ) var _ = Describe("Closed local connection", func() { - var ( - conn packetHandler - mconn *MockSendConn - ) - - BeforeEach(func() { - mconn = NewMockSendConn(mockCtrl) - conn = newClosedLocalConn(mconn, []byte("close"), protocol.PerspectiveClient, utils.DefaultLogger) - }) - - AfterEach(func() { - Eventually(areClosedConnsRunning).Should(BeFalse()) - }) - It("tells its perspective", func() { + conn := newClosedLocalConn(nil, protocol.PerspectiveClient, utils.DefaultLogger) Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient)) // stop the connection conn.shutdown() }) It("repeats the packet containing the CONNECTION_CLOSE frame", func() { - written := make(chan []byte) - mconn.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }).AnyTimes() + written := make(chan net.Addr, 1) + conn := newClosedLocalConn( + func(addr net.Addr, _ *packetInfo) { written <- addr }, + protocol.PerspectiveClient, + utils.DefaultLogger, + ) + addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337} for i := 1; i <= 20; i++ { - conn.handlePacket(&receivedPacket{}) + conn.handlePacket(&receivedPacket{remoteAddr: addr}) if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 { - Eventually(written).Should(Receive(Equal([]byte("close")))) // receive the CONNECTION_CLOSE + Expect(written).To(Receive(Equal(addr))) // receive the CONNECTION_CLOSE } else { - Consistently(written, 10*time.Millisecond).Should(HaveLen(0)) + Expect(written).ToNot(Receive()) } } - // stop the connection - conn.shutdown() - }) - - It("destroys connections", func() { - Eventually(areClosedConnsRunning).Should(BeTrue()) - conn.destroy(errors.New("destroy")) - Eventually(areClosedConnsRunning).Should(BeFalse()) }) }) diff --git a/conn_id_generator.go b/conn_id_generator.go index b07c7e48a77..570045e64b3 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -20,7 +20,7 @@ type connIDGenerator struct { getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken removeConnectionID func(protocol.ConnectionID) retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func([]protocol.ConnectionID, packetHandler) + replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte) queueControlFrame func(wire.Frame) version protocol.VersionNumber @@ -33,7 +33,7 @@ func newConnIDGenerator( getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), - replaceWithClosed func([]protocol.ConnectionID, packetHandler), + replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte), queueControlFrame func(wire.Frame), version protocol.VersionNumber, ) *connIDGenerator { @@ -130,7 +130,7 @@ func (m *connIDGenerator) RemoveAll() { } } -func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) { +func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) { connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) if m.initialClientDestConnID != nil { connIDs = append(connIDs, m.initialClientDestConnID) @@ -138,5 +138,5 @@ func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) { for _, connID := range m.activeSrcConnIDs { connIDs = append(connIDs, connID) } - m.replaceWithClosed(connIDs, handler) + m.replaceWithClosed(connIDs, pers, connClose) } diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 9c832fd492e..98f1eb7dcdf 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -16,7 +16,7 @@ var _ = Describe("Connection ID Generator", func() { addedConnIDs []protocol.ConnectionID retiredConnIDs []protocol.ConnectionID removedConnIDs []protocol.ConnectionID - replacedWithClosed map[string]packetHandler + replacedWithClosed []protocol.ConnectionID queuedFrames []wire.Frame g *connIDGenerator ) @@ -32,7 +32,7 @@ var _ = Describe("Connection ID Generator", func() { retiredConnIDs = nil removedConnIDs = nil queuedFrames = nil - replacedWithClosed = make(map[string]packetHandler) + replacedWithClosed = nil g = newConnIDGenerator( initialConnID, initialClientDestConnID, @@ -40,10 +40,8 @@ var _ = Describe("Connection ID Generator", func() { connIDToToken, func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, - func(cs []protocol.ConnectionID, h packetHandler) { - for _, c := range cs { - replacedWithClosed[string(c)] = h - } + func(cs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { + replacedWithClosed = append(replacedWithClosed, cs...) }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, protocol.VersionDraft29, @@ -178,14 +176,13 @@ var _ = Describe("Connection ID Generator", func() { It("replaces with a closed connection for all connection IDs", func() { Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) Expect(queuedFrames).To(HaveLen(4)) - sess := NewMockPacketHandler(mockCtrl) - g.ReplaceWithClosed(sess) + g.ReplaceWithClosed(protocol.PerspectiveClient, []byte("foobar")) Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones - Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialClientDestConnID), sess)) - Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess)) + Expect(replacedWithClosed).To(ContainElement(initialClientDestConnID)) + Expect(replacedWithClosed).To(ContainElement(initialConnID)) for _, f := range queuedFrames { nf := f.(*wire.NewConnectionIDFrame) - Expect(replacedWithClosed).To(HaveKeyWithValue(string(nf.ConnectionID), sess)) + Expect(replacedWithClosed).To(ContainElement(nf.ConnectionID)) } }) }) diff --git a/connection.go b/connection.go index 8c6f50e7610..df3dc20b594 100644 --- a/connection.go +++ b/connection.go @@ -95,7 +95,7 @@ type connRunner interface { GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) - ReplaceWithClosed([]protocol.ConnectionID, packetHandler) + ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte) AddResetToken(protocol.StatelessResetToken, packetHandler) RemoveResetToken(protocol.StatelessResetToken) } @@ -1521,7 +1521,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { // If this is a remote close we're done here if closeErr.remote { - s.connIDGenerator.ReplaceWithClosed(newClosedRemoteConn(s.perspective)) + s.connIDGenerator.ReplaceWithClosed(s.perspective, nil) return } if closeErr.immediate { @@ -1538,8 +1538,7 @@ func (s *connection) handleCloseError(closeErr *closeError) { if err != nil { s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err) } - cs := newClosedLocalConn(s.conn, connClosePacket, s.perspective, s.logger) - s.connIDGenerator.ReplaceWithClosed(cs) + s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket) } func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { diff --git a/connection_test.go b/connection_test.go index d7535d5fc34..f7608def30b 100644 --- a/connection_test.go +++ b/connection_test.go @@ -37,12 +37,6 @@ func areConnsRunning() bool { return strings.Contains(b.String(), "quic-go.(*connection).run") } -func areClosedConnsRunning() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "quic-go.(*closedLocalConn).run") -} - var _ = Describe("Connection", func() { var ( conn *connection @@ -72,14 +66,11 @@ var _ = Describe("Connection", func() { } expectReplaceWithClosed := func() { - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { Expect(connIDs).To(ContainElement(srcConnID)) if len(connIDs) > 1 { Expect(connIDs).To(ContainElement(clientDestConnID)) } - Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{})) - s.shutdown() - Eventually(areClosedConnsRunning).Should(BeFalse()) }) } @@ -333,9 +324,8 @@ var _ = Describe("Connection", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(expectedErr) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID)) - Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) }) cryptoSetup.EXPECT().Close() gomock.InOrder( @@ -362,9 +352,8 @@ var _ = Describe("Connection", func() { ErrorMessage: "foobar", } streamManager.EXPECT().CloseWithError(testErr) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) { + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) { Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID)) - Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{})) }) cryptoSetup.EXPECT().Close() gomock.InOrder( @@ -564,7 +553,7 @@ var _ = Describe("Connection", func() { runConn() cryptoSetup.EXPECT().Close() streamManager.EXPECT().CloseWithError(gomock.Any()) - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes() + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() buf := &bytes.Buffer{} hdr := &wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, @@ -2432,10 +2421,7 @@ var _ = Describe("Client Connection", func() { } expectReplaceWithClosed := func() { - connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) { - s.shutdown() - Eventually(areClosedConnsRunning).Should(BeFalse()) - }) + connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any()) } BeforeEach(func() { @@ -2766,10 +2752,7 @@ var _ = Describe("Client Connection", func() { expectClose := func(applicationClose bool) { if !closed { - connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) { - Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{})) - s.shutdown() - }) + connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) if applicationClose { packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1) } else { diff --git a/mock_conn_runner_test.go b/mock_conn_runner_test.go index 02080834faf..81f2699a598 100644 --- a/mock_conn_runner_test.go +++ b/mock_conn_runner_test.go @@ -99,15 +99,15 @@ func (mr *MockConnRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock } // ReplaceWithClosed mocks base method. -func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) { +func (m *MockConnRunner) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 protocol.Perspective, arg2 []byte) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) + m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2) } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnRunnerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockConnRunner)(nil).ReplaceWithClosed), arg0, arg1, arg2) } // Retire mocks base method. diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 331695c4686..9bfd55a23cd 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -139,15 +139,15 @@ func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{ } // ReplaceWithClosed mocks base method. -func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 packetHandler) { +func (m *MockPacketHandlerManager) ReplaceWithClosed(arg0 []protocol.ConnectionID, arg1 protocol.Perspective, arg2 []byte) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1) + m.ctrl.Call(m, "ReplaceWithClosed", arg0, arg1, arg2) } // ReplaceWithClosed indicates an expected call of ReplaceWithClosed. -func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockPacketHandlerManagerMockRecorder) ReplaceWithClosed(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceWithClosed", reflect.TypeOf((*MockPacketHandlerManager)(nil).ReplaceWithClosed), arg0, arg1, arg2) } // Retire mocks base method. diff --git a/mock_quic_conn_test.go b/mock_quic_conn_test.go index 880f1dd1446..ffee54b60de 100644 --- a/mock_quic_conn_test.go +++ b/mock_quic_conn_test.go @@ -322,7 +322,7 @@ func (mr *MockQuicConnMockRecorder) handlePacket(arg0 interface{}) *gomock.Call // run mocks base method. func (m *MockQuicConn) run() error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "run") + ret := m.ctrl.Call(m, "runCloseQueue") ret0, _ := ret[0].(error) return ret0 } @@ -330,7 +330,7 @@ func (m *MockQuicConn) run() error { // run indicates an expected call of run. func (mr *MockQuicConnMockRecorder) run() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "run", reflect.TypeOf((*MockQuicConn)(nil).run)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "runCloseQueue", reflect.TypeOf((*MockQuicConn)(nil).run)) } // shutdown mocks base method. diff --git a/packet_handler_map.go b/packet_handler_map.go index 3b37bf27af0..0caa4907557 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -30,6 +30,12 @@ type rawConn interface { io.Closer } +type closePacket struct { + payload []byte + addr net.Addr + info *packetInfo +} + // The packetHandlerMap stores packetHandlers, identified by connection ID. // It is used: // * by the server to store connections @@ -40,6 +46,8 @@ type packetHandlerMap struct { conn rawConn connIDLen int + closeQueue chan closePacket + handlers map[string] /* string(ConnectionID)*/ packetHandler resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler server unknownPacketHandler @@ -123,12 +131,14 @@ func newPacketHandlerMap( resetTokens: make(map[protocol.StatelessResetToken]packetHandler), deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, + closeQueue: make(chan closePacket, 4), statelessResetEnabled: len(statelessResetKey) > 0, statelessResetHasher: hmac.New(sha256.New, statelessResetKey), tracer: tracer, logger: logger, } go m.listen() + go m.runCloseQueue() if logger.Debug() { go m.logUsage() @@ -219,7 +229,29 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { }) } -func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, handler packetHandler) { +// ReplaceWithClosed is called when a connection is closed. +// Depending on which side closed the connection, we need to: +// * remote close: absorb delayed packets +// * local close: retransmit the CONNECTION_CLOSE packet, in case it was lost +func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers protocol.Perspective, connClosePacket []byte) { + var handler packetHandler + if connClosePacket != nil { + handler = newClosedLocalConn( + func(addr net.Addr, info *packetInfo) { + select { + case h.closeQueue <- closePacket{payload: connClosePacket, addr: addr, info: info}: + default: + // Oops, we're backlogged. + // Just drop the packet, sending CONNECTION_CLOSE copies is best effort anyway. + } + }, + pers, + h.logger, + ) + } else { + handler = newClosedRemoteConn(pers) + } + h.mutex.Lock() for _, id := range ids { h.handlers[string(id)] = handler @@ -238,6 +270,17 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, handle }) } +func (h *packetHandlerMap) runCloseQueue() { + for { + select { + case <-h.listening: + return + case p := <-h.closeQueue: + h.conn.WritePacket(p.payload, p.addr, p.info.OOB()) + } + } +} + func (h *packetHandlerMap) AddResetToken(token protocol.StatelessResetToken, handler packetHandler) { h.mutex.Lock() h.resetTokens[token] = handler