Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use a struct containing an array to represent Connection IDs #3529

Merged
merged 7 commits into from Aug 29, 2022
12 changes: 6 additions & 6 deletions client_test.go
Expand Up @@ -50,7 +50,7 @@ var _ = Describe("Client", func() {

BeforeEach(func() {
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37})
originalClientConnConstructor = newClientConnection
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
tr := mocklogging.NewMockTracer(mockCtrl)
Expand Down Expand Up @@ -518,7 +518,7 @@ var _ = Describe("Client", func() {
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)

config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}}
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
c := make(chan struct{})
var cconn sendConn
var version protocol.VersionNumber
Expand Down Expand Up @@ -596,7 +596,7 @@ var _ = Describe("Client", func() {
return conn
}

config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}}
config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}}
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
_, err := DialAddr("localhost:7890", tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -605,14 +605,14 @@ var _ = Describe("Client", func() {
})
})

type mockedConnIDGenerator struct {
type mockConnIDGenerator struct {
ConnID protocol.ConnectionID
}

func (m *mockedConnIDGenerator) GenerateConnectionID() ([]byte, error) {
func (m *mockConnIDGenerator) GenerateConnectionID() (protocol.ConnectionID, error) {
return m.ConnID, nil
}

func (m *mockedConnIDGenerator) ConnectionIDLen() int {
func (m *mockConnIDGenerator) ConnectionIDLen() int {
return m.ConnID.Len()
}
12 changes: 6 additions & 6 deletions conn_id_generator.go
Expand Up @@ -14,7 +14,7 @@ type connIDGenerator struct {
highestSeq uint64

activeSrcConnIDs map[uint64]protocol.ConnectionID
initialClientDestConnID protocol.ConnectionID
initialClientDestConnID *protocol.ConnectionID // nil for the client

addConnectionID func(protocol.ConnectionID)
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
Expand All @@ -28,7 +28,7 @@ type connIDGenerator struct {

func newConnIDGenerator(
initialConnectionID protocol.ConnectionID,
initialClientDestConnID protocol.ConnectionID, // nil for the client
initialClientDestConnID *protocol.ConnectionID, // nil for the client
addConnectionID func(protocol.ConnectionID),
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
removeConnectionID func(protocol.ConnectionID),
Expand Down Expand Up @@ -84,7 +84,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
if !ok {
return nil
}
if connID.Equal(sentWithDestConnID) {
if connID == sentWithDestConnID {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
Expand Down Expand Up @@ -117,14 +117,14 @@ func (m *connIDGenerator) issueNewConnID() error {

func (m *connIDGenerator) SetHandshakeComplete() {
if m.initialClientDestConnID != nil {
m.retireConnectionID(m.initialClientDestConnID)
m.retireConnectionID(*m.initialClientDestConnID)
m.initialClientDestConnID = nil
}
}

func (m *connIDGenerator) RemoveAll() {
if m.initialClientDestConnID != nil {
m.removeConnectionID(m.initialClientDestConnID)
m.removeConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.removeConnectionID(connID)
Expand All @@ -134,7 +134,7 @@ func (m *connIDGenerator) RemoveAll() {
func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
if m.initialClientDestConnID != nil {
connIDs = append(connIDs, m.initialClientDestConnID)
connIDs = append(connIDs, *m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
Expand Down
9 changes: 5 additions & 4 deletions conn_id_generator_test.go
Expand Up @@ -20,11 +20,12 @@ var _ = Describe("Connection ID Generator", func() {
queuedFrames []wire.Frame
g *connIDGenerator
)
initialConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7}
initialClientDestConnID := protocol.ConnectionID{0xa, 0xb, 0xc, 0xd, 0xe}
initialConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
initialClientDestConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc, 0xd, 0xe})

connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken {
return protocol.StatelessResetToken{c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0]}
b := c.Bytes()[0]
return protocol.StatelessResetToken{b, b, b, b, b, b, b, b, b, b, b, b, b, b, b, b}
}

BeforeEach(func() {
Expand All @@ -35,7 +36,7 @@ var _ = Describe("Connection ID Generator", func() {
replacedWithClosed = nil
g = newConnIDGenerator(
initialConnID,
initialClientDestConnID,
&initialClientDestConnID,
func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) },
connIDToToken,
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
Expand Down
2 changes: 1 addition & 1 deletion conn_id_manager.go
Expand Up @@ -121,7 +121,7 @@ func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID
// insert a new element somewhere in the middle
for el := h.queue.Front(); el != nil; el = el.Next() {
if el.Value.SequenceNumber == seq {
if !el.Value.ConnectionID.Equal(connID) {
if el.Value.ConnectionID != connID {
return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq)
}
if el.Value.StatelessResetToken != resetToken {
Expand Down