Skip to content

Commit

Permalink
use a single Go routine to send copies of CONNECTION_CLOSE packets
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 21, 2022
1 parent c3ab9c4 commit 5d947bf
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 143 deletions.
76 changes: 16 additions & 60 deletions closed_conn.go
Expand Up @@ -2,7 +2,7 @@ package quic

import (
"math/bits"
"sync"
"net"

"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
Expand All @@ -12,81 +12,37 @@ import (
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
// with an exponential backoff.
type closedLocalConn struct {
conn sendConn
connClosePacket []byte

closeOnce sync.Once
closeChan chan struct{} // is closed when the connection is closed or destroyed

receivedPackets chan *receivedPacket
counter uint64 // number of packets received

counter uint32
perspective protocol.Perspective
logger utils.Logger

logger utils.Logger
sendPacket func(net.Addr, *packetInfo)
}

var _ packetHandler = &closedLocalConn{}

// newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(
conn sendConn,
connClosePacket []byte,
perspective protocol.Perspective,
logger utils.Logger,
) packetHandler {
s := &closedLocalConn{
conn: conn,
connClosePacket: connClosePacket,
perspective: perspective,
logger: logger,
closeChan: make(chan struct{}),
receivedPackets: make(chan *receivedPacket, 64),
func newClosedLocalConn(sendPacket func(net.Addr, *packetInfo), pers protocol.Perspective, logger utils.Logger) packetHandler {
return &closedLocalConn{
sendPacket: sendPacket,
perspective: pers,
logger: logger,
}
go s.run()
return s
}

func (s *closedLocalConn) run() {
for {
select {
case p := <-s.receivedPackets:
s.handlePacketImpl(p)
case <-s.closeChan:
return
}
}
}

func (s *closedLocalConn) handlePacket(p *receivedPacket) {
select {
case s.receivedPackets <- p:
default:
}
}

func (s *closedLocalConn) handlePacketImpl(_ *receivedPacket) {
s.counter++
func (c *closedLocalConn) handlePacket(p *receivedPacket) {
c.counter++
// exponential backoff
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
if bits.OnesCount64(s.counter) != 1 {
if bits.OnesCount32(c.counter) != 1 {
return
}
s.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", s.counter)
if err := s.conn.Write(s.connClosePacket); err != nil {
s.logger.Debugf("Error retransmitting CONNECTION_CLOSE: %s", err)
}
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", c.counter)
c.sendPacket(p.remoteAddr, p.info)
}

func (s *closedLocalConn) shutdown() {
s.destroy(nil)
}

func (s *closedLocalConn) destroy(error) {
s.closeOnce.Do(func() {
close(s.closeChan)
})
}
func (s *closedLocalConn) shutdown() {}
func (s *closedLocalConn) destroy(error) {}

func (s *closedLocalConn) getPerspective() protocol.Perspective {
return s.perspective
Expand Down
42 changes: 12 additions & 30 deletions closed_conn_test.go
@@ -1,10 +1,8 @@
package quic

import (
"errors"
"time"
"net"

"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"

Expand All @@ -13,44 +11,28 @@ import (
)

var _ = Describe("Closed local connection", func() {
var (
conn packetHandler
mconn *MockSendConn
)

BeforeEach(func() {
mconn = NewMockSendConn(mockCtrl)
conn = newClosedLocalConn(mconn, []byte("close"), protocol.PerspectiveClient, utils.DefaultLogger)
})

AfterEach(func() {
Eventually(areClosedConnsRunning).Should(BeFalse())
})

It("tells its perspective", func() {
conn := newClosedLocalConn(nil, protocol.PerspectiveClient, utils.DefaultLogger)
Expect(conn.getPerspective()).To(Equal(protocol.PerspectiveClient))
// stop the connection
conn.shutdown()
})

It("repeats the packet containing the CONNECTION_CLOSE frame", func() {
written := make(chan []byte)
mconn.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }).AnyTimes()
written := make(chan net.Addr, 1)
conn := newClosedLocalConn(
func(addr net.Addr, _ *packetInfo) { written <- addr },
protocol.PerspectiveClient,
utils.DefaultLogger,
)
addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337}
for i := 1; i <= 20; i++ {
conn.handlePacket(&receivedPacket{})
conn.handlePacket(&receivedPacket{remoteAddr: addr})
if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 {
Eventually(written).Should(Receive(Equal([]byte("close")))) // receive the CONNECTION_CLOSE
Expect(written).To(Receive(Equal(addr))) // receive the CONNECTION_CLOSE
} else {
Consistently(written, 10*time.Millisecond).Should(HaveLen(0))
Expect(written).ToNot(Receive())
}
}
// stop the connection
conn.shutdown()
})

It("destroys connections", func() {
Eventually(areClosedConnsRunning).Should(BeTrue())
conn.destroy(errors.New("destroy"))
Eventually(areClosedConnsRunning).Should(BeFalse())
})
})
8 changes: 4 additions & 4 deletions conn_id_generator.go
Expand Up @@ -20,7 +20,7 @@ type connIDGenerator struct {
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func([]protocol.ConnectionID, packetHandler)
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte)
queueControlFrame func(wire.Frame)

version protocol.VersionNumber
Expand All @@ -33,7 +33,7 @@ func newConnIDGenerator(
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, packetHandler),
replaceWithClosed func([]protocol.ConnectionID, protocol.Perspective, []byte),
queueControlFrame func(wire.Frame),
version protocol.VersionNumber,
) *connIDGenerator {
Expand Down Expand Up @@ -130,13 +130,13 @@ func (m *connIDGenerator) RemoveAll() {
}
}

func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
if m.initialClientDestConnID != nil {
connIDs = append(connIDs, m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
m.replaceWithClosed(connIDs, handler)
m.replaceWithClosed(connIDs, pers, connClose)
}
19 changes: 8 additions & 11 deletions conn_id_generator_test.go
Expand Up @@ -16,7 +16,7 @@ var _ = Describe("Connection ID Generator", func() {
addedConnIDs []protocol.ConnectionID
retiredConnIDs []protocol.ConnectionID
removedConnIDs []protocol.ConnectionID
replacedWithClosed map[string]packetHandler
replacedWithClosed []protocol.ConnectionID
queuedFrames []wire.Frame
g *connIDGenerator
)
Expand All @@ -32,18 +32,16 @@ var _ = Describe("Connection ID Generator", func() {
retiredConnIDs = nil
removedConnIDs = nil
queuedFrames = nil
replacedWithClosed = make(map[string]packetHandler)
replacedWithClosed = nil
g = newConnIDGenerator(
initialConnID,
initialClientDestConnID,
func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) },
connIDToToken,
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
func(cs []protocol.ConnectionID, h packetHandler) {
for _, c := range cs {
replacedWithClosed[string(c)] = h
}
func(cs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
replacedWithClosed = append(replacedWithClosed, cs...)
},
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
protocol.VersionDraft29,
Expand Down Expand Up @@ -178,14 +176,13 @@ var _ = Describe("Connection ID Generator", func() {
It("replaces with a closed connection for all connection IDs", func() {
Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
Expect(queuedFrames).To(HaveLen(4))
sess := NewMockPacketHandler(mockCtrl)
g.ReplaceWithClosed(sess)
g.ReplaceWithClosed(protocol.PerspectiveClient, []byte("foobar"))
Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones
Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialClientDestConnID), sess))
Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess))
Expect(replacedWithClosed).To(ContainElement(initialClientDestConnID))
Expect(replacedWithClosed).To(ContainElement(initialConnID))
for _, f := range queuedFrames {
nf := f.(*wire.NewConnectionIDFrame)
Expect(replacedWithClosed).To(HaveKeyWithValue(string(nf.ConnectionID), sess))
Expect(replacedWithClosed).To(ContainElement(nf.ConnectionID))
}
})
})
7 changes: 3 additions & 4 deletions connection.go
Expand Up @@ -95,7 +95,7 @@ type connRunner interface {
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID)
ReplaceWithClosed([]protocol.ConnectionID, packetHandler)
ReplaceWithClosed([]protocol.ConnectionID, protocol.Perspective, []byte)
AddResetToken(protocol.StatelessResetToken, packetHandler)
RemoveResetToken(protocol.StatelessResetToken)
}
Expand Down Expand Up @@ -1521,7 +1521,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {

// If this is a remote close we're done here
if closeErr.remote {
s.connIDGenerator.ReplaceWithClosed(newClosedRemoteConn(s.perspective))
s.connIDGenerator.ReplaceWithClosed(s.perspective, nil)
return
}
if closeErr.immediate {
Expand All @@ -1538,8 +1538,7 @@ func (s *connection) handleCloseError(closeErr *closeError) {
if err != nil {
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
}
cs := newClosedLocalConn(s.conn, connClosePacket, s.perspective, s.logger)
s.connIDGenerator.ReplaceWithClosed(cs)
s.connIDGenerator.ReplaceWithClosed(s.perspective, connClosePacket)
}

func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) {
Expand Down
29 changes: 6 additions & 23 deletions connection_test.go
Expand Up @@ -37,12 +37,6 @@ func areConnsRunning() bool {
return strings.Contains(b.String(), "quic-go.(*connection).run")
}

func areClosedConnsRunning() bool {
var b bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&b, 1)
return strings.Contains(b.String(), "quic-go.(*closedLocalConn).run")
}

var _ = Describe("Connection", func() {
var (
conn *connection
Expand Down Expand Up @@ -72,14 +66,11 @@ var _ = Describe("Connection", func() {
}

expectReplaceWithClosed := func() {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
Expect(connIDs).To(ContainElement(srcConnID))
if len(connIDs) > 1 {
Expect(connIDs).To(ContainElement(clientDestConnID))
}
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
s.shutdown()
Eventually(areClosedConnsRunning).Should(BeFalse())
})
}

Expand Down Expand Up @@ -333,9 +324,8 @@ var _ = Describe("Connection", func() {
ErrorMessage: "foobar",
}
streamManager.EXPECT().CloseWithError(expectedErr)
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
})
cryptoSetup.EXPECT().Close()
gomock.InOrder(
Expand All @@ -362,9 +352,8 @@ var _ = Describe("Connection", func() {
ErrorMessage: "foobar",
}
streamManager.EXPECT().CloseWithError(testErr)
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
})
cryptoSetup.EXPECT().Close()
gomock.InOrder(
Expand Down Expand Up @@ -564,7 +553,7 @@ var _ = Describe("Connection", func() {
runConn()
cryptoSetup.EXPECT().Close()
streamManager.EXPECT().CloseWithError(gomock.Any())
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
buf := &bytes.Buffer{}
hdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
Expand Down Expand Up @@ -2432,10 +2421,7 @@ var _ = Describe("Client Connection", func() {
}

expectReplaceWithClosed := func() {
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
s.shutdown()
Eventually(areClosedConnsRunning).Should(BeFalse())
})
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any(), gomock.Any())
}

BeforeEach(func() {
Expand Down Expand Up @@ -2766,10 +2752,7 @@ var _ = Describe("Client Connection", func() {

expectClose := func(applicationClose bool) {
if !closed {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
s.shutdown()
})
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any())
if applicationClose {
packer.EXPECT().PackApplicationClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil).MaxTimes(1)
} else {
Expand Down
8 changes: 4 additions & 4 deletions mock_conn_runner_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5d947bf

Please sign in to comment.