Skip to content

Commit

Permalink
replace all connection IDs at the same time when connection is closed
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 21, 2022
1 parent 635dc90 commit c3ab9c4
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 25 deletions.
10 changes: 6 additions & 4 deletions conn_id_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type connIDGenerator struct {
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func(protocol.ConnectionID, packetHandler)
replaceWithClosed func([]protocol.ConnectionID, packetHandler)
queueControlFrame func(wire.Frame)

version protocol.VersionNumber
Expand All @@ -33,7 +33,7 @@ func newConnIDGenerator(
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func(protocol.ConnectionID, packetHandler),
replaceWithClosed func([]protocol.ConnectionID, packetHandler),
queueControlFrame func(wire.Frame),
version protocol.VersionNumber,
) *connIDGenerator {
Expand Down Expand Up @@ -131,10 +131,12 @@ func (m *connIDGenerator) RemoveAll() {
}

func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
if m.initialClientDestConnID != nil {
m.replaceWithClosed(m.initialClientDestConnID, handler)
connIDs = append(connIDs, m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.replaceWithClosed(connID, handler)
connIDs = append(connIDs, connID)
}
m.replaceWithClosed(connIDs, handler)
}
6 changes: 5 additions & 1 deletion conn_id_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ var _ = Describe("Connection ID Generator", func() {
connIDToToken,
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h },
func(cs []protocol.ConnectionID, h packetHandler) {
for _, c := range cs {
replacedWithClosed[string(c)] = h
}
},
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
protocol.VersionDraft29,
)
Expand Down
2 changes: 1 addition & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ type connRunner interface {
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID)
ReplaceWithClosed(protocol.ConnectionID, packetHandler)
ReplaceWithClosed([]protocol.ConnectionID, packetHandler)
AddResetToken(protocol.StatelessResetToken, packetHandler)
RemoveResetToken(protocol.StatelessResetToken)
}
Expand Down
23 changes: 11 additions & 12 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@ var _ = Describe("Connection", func() {
}

expectReplaceWithClosed := func() {
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).MaxTimes(1)
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
Expect(connIDs).To(ContainElement(srcConnID))
if len(connIDs) > 1 {
Expect(connIDs).To(ContainElement(clientDestConnID))
}
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
s.shutdown()
Eventually(areClosedConnsRunning).Should(BeFalse())
Expand Down Expand Up @@ -330,10 +333,8 @@ var _ = Describe("Connection", func() {
ErrorMessage: "foobar",
}
streamManager.EXPECT().CloseWithError(expectedErr)
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
})
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
})
cryptoSetup.EXPECT().Close()
Expand Down Expand Up @@ -361,10 +362,8 @@ var _ = Describe("Connection", func() {
ErrorMessage: "foobar",
}
streamManager.EXPECT().CloseWithError(testErr)
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
})
connRunner.EXPECT().ReplaceWithClosed(clientDestConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(connIDs []protocol.ConnectionID, s packetHandler) {
Expect(connIDs).To(ConsistOf(clientDestConnID, srcConnID))
Expect(s).To(BeAssignableToTypeOf(&closedRemoteConn{}))
})
cryptoSetup.EXPECT().Close()
Expand Down Expand Up @@ -2433,7 +2432,7 @@ var _ = Describe("Client Connection", func() {
}

expectReplaceWithClosed := func() {
connRunner.EXPECT().ReplaceWithClosed(srcConnID, gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
connRunner.EXPECT().ReplaceWithClosed([]protocol.ConnectionID{srcConnID}, gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
s.shutdown()
Eventually(areClosedConnsRunning).Should(BeFalse())
})
Expand Down Expand Up @@ -2767,7 +2766,7 @@ var _ = Describe("Client Connection", func() {

expectClose := func(applicationClose bool) {
if !closed {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ protocol.ConnectionID, s packetHandler) {
connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).Do(func(_ []protocol.ConnectionID, s packetHandler) {
Expect(s).To(BeAssignableToTypeOf(&closedLocalConn{}))
s.shutdown()
})
Expand Down
2 changes: 1 addition & 1 deletion mock_conn_runner_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion mock_packet_handler_manager_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 9 additions & 5 deletions packet_handler_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,18 +219,22 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) {
})
}

func (h *packetHandlerMap) ReplaceWithClosed(id protocol.ConnectionID, handler packetHandler) {
func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, handler packetHandler) {
h.mutex.Lock()
h.handlers[string(id)] = handler
for _, id := range ids {
h.handlers[string(id)] = handler
}
h.mutex.Unlock()
h.logger.Debugf("Replacing connection for connection ID %s with a closed connection.", id)
h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids)

time.AfterFunc(h.deleteRetiredConnsAfter, func() {
h.mutex.Lock()
handler.shutdown()
delete(h.handlers, string(id))
for _, id := range ids {
delete(h.handlers, string(id))
}
h.mutex.Unlock()
h.logger.Debugf("Removing connection ID %s for a closed connection after it has been retired.", id)
h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids)
})
}

Expand Down

0 comments on commit c3ab9c4

Please sign in to comment.