diff --git a/client_test.go b/client_test.go index 84b55d0fdcf..9c373c6e8c8 100644 --- a/client_test.go +++ b/client_test.go @@ -50,7 +50,7 @@ var _ = Describe("Client", func() { BeforeEach(func() { tlsConf = &tls.Config{NextProtos: []string{"proto1"}} - connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} + connID = protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0x13, 0x37}) originalClientConnConstructor = newClientConnection tracer = mocklogging.NewMockConnectionTracer(mockCtrl) tr := mocklogging.NewMockTracer(mockCtrl) @@ -518,7 +518,7 @@ var _ = Describe("Client", func() { manager.EXPECT().Add(connID, gomock.Any()) mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) - config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}} + config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}} c := make(chan struct{}) var cconn sendConn var version protocol.VersionNumber @@ -596,7 +596,7 @@ var _ = Describe("Client", func() { return conn } - config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockedConnIDGenerator{ConnID: connID}} + config := &Config{Tracer: config.Tracer, Versions: []protocol.VersionNumber{protocol.VersionTLS}, ConnectionIDGenerator: &mockConnIDGenerator{ConnID: connID}} tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) _, err := DialAddr("localhost:7890", tlsConf, config) Expect(err).ToNot(HaveOccurred()) @@ -605,14 +605,14 @@ var _ = Describe("Client", func() { }) }) -type mockedConnIDGenerator struct { +type mockConnIDGenerator struct { ConnID protocol.ConnectionID } -func (m *mockedConnIDGenerator) GenerateConnectionID() ([]byte, error) { +func (m *mockConnIDGenerator) GenerateConnectionID() (protocol.ConnectionID, error) { return m.ConnID, nil } -func (m *mockedConnIDGenerator) ConnectionIDLen() int { +func (m *mockConnIDGenerator) ConnectionIDLen() int { return m.ConnID.Len() } diff --git a/conn_id_generator.go b/conn_id_generator.go index 0421d678b2f..c56e8a4c16f 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -14,7 +14,7 @@ type connIDGenerator struct { highestSeq uint64 activeSrcConnIDs map[uint64]protocol.ConnectionID - initialClientDestConnID protocol.ConnectionID + initialClientDestConnID *protocol.ConnectionID // nil for the client addConnectionID func(protocol.ConnectionID) getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken @@ -28,7 +28,7 @@ type connIDGenerator struct { func newConnIDGenerator( initialConnectionID protocol.ConnectionID, - initialClientDestConnID protocol.ConnectionID, // nil for the client + initialClientDestConnID *protocol.ConnectionID, // nil for the client addConnectionID func(protocol.ConnectionID), getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken, removeConnectionID func(protocol.ConnectionID), @@ -84,7 +84,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect if !ok { return nil } - if connID.Equal(sentWithDestConnID) { + if connID == sentWithDestConnID { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), @@ -117,14 +117,14 @@ func (m *connIDGenerator) issueNewConnID() error { func (m *connIDGenerator) SetHandshakeComplete() { if m.initialClientDestConnID != nil { - m.retireConnectionID(m.initialClientDestConnID) + m.retireConnectionID(*m.initialClientDestConnID) m.initialClientDestConnID = nil } } func (m *connIDGenerator) RemoveAll() { if m.initialClientDestConnID != nil { - m.removeConnectionID(m.initialClientDestConnID) + m.removeConnectionID(*m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { m.removeConnectionID(connID) @@ -134,7 +134,7 @@ func (m *connIDGenerator) RemoveAll() { func (m *connIDGenerator) ReplaceWithClosed(pers protocol.Perspective, connClose []byte) { connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1) if m.initialClientDestConnID != nil { - connIDs = append(connIDs, m.initialClientDestConnID) + connIDs = append(connIDs, *m.initialClientDestConnID) } for _, connID := range m.activeSrcConnIDs { connIDs = append(connIDs, connID) diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 167a70d6a3a..dc3a1223ea8 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -20,11 +20,12 @@ var _ = Describe("Connection ID Generator", func() { queuedFrames []wire.Frame g *connIDGenerator ) - initialConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} - initialClientDestConnID := protocol.ConnectionID{0xa, 0xb, 0xc, 0xd, 0xe} + initialConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) + initialClientDestConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc, 0xd, 0xe}) connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken { - return protocol.StatelessResetToken{c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0]} + b := c.Bytes()[0] + return protocol.StatelessResetToken{b, b, b, b, b, b, b, b, b, b, b, b, b, b, b, b} } BeforeEach(func() { @@ -35,7 +36,7 @@ var _ = Describe("Connection ID Generator", func() { replacedWithClosed = nil g = newConnIDGenerator( initialConnID, - initialClientDestConnID, + &initialClientDestConnID, func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) }, connIDToToken, func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, diff --git a/conn_id_manager.go b/conn_id_manager.go index c1bb42bef2e..b878b027c03 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -121,7 +121,7 @@ func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID // insert a new element somewhere in the middle for el := h.queue.Front(); el != nil; el = el.Next() { if el.Value.SequenceNumber == seq { - if !el.Value.ConnectionID.Equal(connID) { + if el.Value.ConnectionID != connID { return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq) } if el.Value.StatelessResetToken != resetToken { diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index 6c8490591e9..f0e24e86f6c 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -16,7 +16,7 @@ var _ = Describe("Connection ID Manager", func() { tokenAdded *protocol.StatelessResetToken removedTokens []protocol.StatelessResetToken ) - initialConnID := protocol.ConnectionID{0, 0, 0, 0} + initialConnID := protocol.ParseConnectionID([]byte{0, 0, 0, 0}) BeforeEach(func() { frameQueue = nil @@ -34,7 +34,7 @@ var _ = Describe("Connection ID Manager", func() { get := func() (protocol.ConnectionID, protocol.StatelessResetToken) { if m.queue.Len() == 0 { - return nil, protocol.StatelessResetToken{} + return protocol.ConnectionID{}, protocol.StatelessResetToken{} } val := m.queue.Remove(m.queue.Front()) return val.ConnectionID, val.StatelessResetToken @@ -45,8 +45,8 @@ var _ = Describe("Connection ID Manager", func() { }) It("changes the initial connection ID", func() { - m.ChangeInitialConnID(protocol.ConnectionID{1, 2, 3, 4, 5}) - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) + m.ChangeInitialConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}))) }) It("sets the token for the first connection ID", func() { @@ -59,81 +59,81 @@ var _ = Describe("Connection ID Manager", func() { It("adds and gets connection IDs", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, + ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}), StatelessResetToken: protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, })).To(Succeed()) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 4, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, })).To(Succeed()) c1, rt1 := get() - Expect(c1).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(c1).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) Expect(rt1).To(Equal(protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe})) c2, rt2 := get() - Expect(c2).To(Equal(protocol.ConnectionID{2, 3, 4, 5})) + Expect(c2).To(Equal(protocol.ParseConnectionID([]byte{2, 3, 4, 5}))) Expect(rt2).To(Equal(protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0})) c3, _ := get() - Expect(c3).To(BeNil()) + Expect(c3).To(BeZero()) }) It("accepts duplicates", func() { f1 := &wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, } f2 := &wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, } Expect(m.Add(f1)).To(Succeed()) Expect(m.Add(f2)).To(Succeed()) c1, rt1 := get() - Expect(c1).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(c1).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) Expect(rt1).To(Equal(protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe})) c2, _ := get() - Expect(c2).To(BeNil()) + Expect(c2).To(BeZero()) }) It("ignores duplicates for the currently used connection ID", func() { f := &wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, } m.SetHandshakeComplete() Expect(m.Add(f)).To(Succeed()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) c, _ := get() - Expect(c).To(BeNil()) + Expect(c).To(BeZero()) // Now send the same connection ID again. It should not be queued. Expect(m.Add(f)).To(Succeed()) c, _ = get() - Expect(c).To(BeNil()) + Expect(c).To(BeZero()) }) It("rejects duplicates with different connection IDs", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), })).To(Succeed()) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, + ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}), })).To(MatchError("received conflicting connection IDs for sequence number 42")) }) It("rejects duplicates with different connection IDs", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe}, })).To(Succeed()) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{0xe, 0xd, 0xc, 0xb, 0xa, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, })).To(MatchError("received conflicting stateless reset tokens for sequence number 42")) }) @@ -141,29 +141,29 @@ var _ = Describe("Connection ID Manager", func() { It("retires connection IDs", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), })).To(Succeed()) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 13, - ConnectionID: protocol.ConnectionID{2, 3, 4, 5}, + ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}), })).To(Succeed()) Expect(frameQueue).To(BeEmpty()) Expect(m.Add(&wire.NewConnectionIDFrame{ RetirePriorTo: 14, SequenceNumber: 17, - ConnectionID: protocol.ConnectionID{3, 4, 5, 6}, + ConnectionID: protocol.ParseConnectionID([]byte{3, 4, 5, 6}), })).To(Succeed()) Expect(frameQueue).To(HaveLen(3)) Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(10)) Expect(frameQueue[1].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(13)) Expect(frameQueue[2].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{3, 4, 5, 6})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{3, 4, 5, 6}))) }) It("ignores reordered connection IDs, if their sequence number was already retired", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), RetirePriorTo: 5, })).To(Succeed()) Expect(frameQueue).To(HaveLen(1)) @@ -173,7 +173,7 @@ var _ = Describe("Connection ID Manager", func() { // Make sure it gets retired immediately now. Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 4, - ConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), })).To(Succeed()) Expect(frameQueue).To(HaveLen(1)) Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(4)) @@ -182,17 +182,17 @@ var _ = Describe("Connection ID Manager", func() { It("ignores reordered connection IDs, if their sequence number was already retired or less than active", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), RetirePriorTo: 5, })).To(Succeed()) Expect(frameQueue).To(HaveLen(1)) Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeZero()) frameQueue = nil - Expect(m.Get()).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}))) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 9, - ConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), RetirePriorTo: 5, })).To(Succeed()) Expect(frameQueue).To(HaveLen(1)) @@ -200,7 +200,7 @@ var _ = Describe("Connection ID Manager", func() { }) It("accepts retransmissions for the connection ID that is in use", func() { - connID := protocol.ConnectionID{1, 2, 3, 4} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, @@ -225,13 +225,13 @@ var _ = Describe("Connection ID Manager", func() { for i := uint8(1); i < protocol.MaxActiveConnectionIDs; i++ { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(i), - ConnectionID: protocol.ConnectionID{i, i, i, i}, + ConnectionID: protocol.ParseConnectionID([]byte{i, i, i, i}), StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i}, })).To(Succeed()) } Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(9999), - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, })).To(MatchError(&qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError})) }) @@ -241,22 +241,22 @@ var _ = Describe("Connection ID Manager", func() { m.SetHandshakeComplete() Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, })).To(Succeed()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) }) It("waits until handshake completion before initiating a connection ID update", func() { Expect(m.Get()).To(Equal(initialConnID)) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, })).To(Succeed()) Expect(m.Get()).To(Equal(initialConnID)) m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) }) It("initiates subsequent updates when enough packets are sent", func() { @@ -264,28 +264,28 @@ var _ = Describe("Connection ID Manager", func() { for s = uint8(1); s < protocol.MaxActiveConnectionIDs; s++ { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(s), - ConnectionID: protocol.ConnectionID{s, s, s, s}, + ConnectionID: protocol.ParseConnectionID([]byte{s, s, s, s}), StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, })).To(Succeed()) } m.SetHandshakeComplete() lastConnID := m.Get() - Expect(lastConnID).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) + Expect(lastConnID).To(Equal(protocol.ParseConnectionID([]byte{1, 1, 1, 1}))) var counter int for i := 0; i < 50*protocol.PacketsPerConnectionID; i++ { m.SentPacket() connID := m.Get() - if !connID.Equal(lastConnID) { + if connID != lastConnID { counter++ lastConnID = connID Expect(removedTokens).To(HaveLen(1)) removedTokens = nil Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(s), - ConnectionID: protocol.ConnectionID{s, s, s, s}, + ConnectionID: protocol.ParseConnectionID([]byte{s, s, s, s}), StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, })).To(Succeed()) s++ @@ -298,28 +298,28 @@ var _ = Describe("Connection ID Manager", func() { for s := uint8(10); s <= 10+protocol.MaxActiveConnectionIDs/2; s++ { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(s), - ConnectionID: protocol.ConnectionID{s, s, s, s}, + ConnectionID: protocol.ParseConnectionID([]byte{s, s, s, s}), StatelessResetToken: protocol.StatelessResetToken{s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s}, })).To(Succeed()) } m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{10, 10, 10, 10})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{10, 10, 10, 10}))) for { m.SentPacket() - if m.Get().Equal(protocol.ConnectionID{11, 11, 11, 11}) { + if m.Get() == protocol.ParseConnectionID([]byte{11, 11, 11, 11}) { break } } // The active conn ID is now {11, 11, 11, 11} - Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ConnectionID{12, 12, 12, 12})) + Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ParseConnectionID([]byte{12, 12, 12, 12}))) // Add a delayed connection ID. It should just be ignored now. frameQueue = nil Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(5), - ConnectionID: protocol.ConnectionID{5, 5, 5, 5}, + ConnectionID: protocol.ParseConnectionID([]byte{5, 5, 5, 5}), StatelessResetToken: protocol.StatelessResetToken{5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}, })).To(Succeed()) - Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ConnectionID{12, 12, 12, 12})) + Expect(m.queue.Front().Value.ConnectionID).To(Equal(protocol.ParseConnectionID([]byte{12, 12, 12, 12}))) Expect(frameQueue).To(HaveLen(1)) Expect(frameQueue[0].(*wire.RetireConnectionIDFrame).SequenceNumber).To(BeEquivalentTo(5)) }) @@ -328,21 +328,21 @@ var _ = Describe("Connection ID Manager", func() { for i := uint8(1); i <= protocol.MaxActiveConnectionIDs/2; i++ { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: uint64(i), - ConnectionID: protocol.ConnectionID{i, i, i, i}, + ConnectionID: protocol.ParseConnectionID([]byte{i, i, i, i}), StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i}, })).To(Succeed()) } m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 1, 1, 1}))) for i := 0; i < 2*protocol.PacketsPerConnectionID; i++ { m.SentPacket() } - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 1, 1, 1})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 1, 1, 1}))) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1337, - ConnectionID: protocol.ConnectionID{1, 3, 3, 7}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 3, 3, 7}), })).To(Succeed()) - Expect(m.Get()).To(Equal(protocol.ConnectionID{2, 2, 2, 2})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{2, 2, 2, 2}))) Expect(removedTokens).To(HaveLen(1)) Expect(removedTokens[0]).To(Equal(protocol.StatelessResetToken{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1})) }) @@ -352,11 +352,11 @@ var _ = Describe("Connection ID Manager", func() { Expect(removedTokens).To(BeEmpty()) Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, })).To(Succeed()) m.SetHandshakeComplete() - Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) m.Close() Expect(removedTokens).To(HaveLen(1)) Expect(removedTokens[0]).To(Equal(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})) diff --git a/connection.go b/connection.go index 015d98adcbe..39dd0621f86 100644 --- a/connection.go +++ b/connection.go @@ -261,7 +261,7 @@ var newConnection = func( logger: logger, version: v, } - if origDestConnID != nil { + if origDestConnID.Len() > 0 { s.logID = origDestConnID.String() } else { s.logID = destConnID.String() @@ -274,7 +274,7 @@ var newConnection = func( ) s.connIDGenerator = newConnIDGenerator( srcConnID, - clientDestConnID, + &clientDestConnID, func(connID protocol.ConnectionID) { runner.Add(connID, s) }, runner.GetStatelessResetToken, runner.Remove, @@ -881,7 +881,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool { break } - if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { + if counter > 0 && hdr.DestConnectionID != lastConnID { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) } @@ -925,7 +925,7 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. - if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && !hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) } @@ -1017,7 +1017,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa return false } destConnID := s.connIDManager.Get() - if hdr.SrcConnectionID.Equal(destConnID) { + if hdr.SrcConnectionID == destConnID { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } @@ -1072,7 +1072,7 @@ func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { return } - hdr, supportedVersions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(p.data)) + src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data) if err != nil { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, p.Size(), logging.PacketDropHeaderParseError) @@ -1094,7 +1094,7 @@ func (s *connection) handleVersionNegotiationPacket(p *receivedPacket) { s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions) if s.tracer != nil { - s.tracer.ReceivedVersionNegotiationPacket(hdr, supportedVersions) + s.tracer.ReceivedVersionNegotiationPacket(dest, src, supportedVersions) } newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions) if !ok { @@ -1143,7 +1143,7 @@ func (s *connection) handleUnpackedPacket( s.tracer.NegotiatedVersion(s.version, clientVersions, serverVersions) } // The server can change the source connection ID with the first Handshake packet. - if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && packet.hdr.SrcConnectionID != s.handshakeDestConnID { cid := packet.hdr.SrcConnectionID s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid) s.handshakeDestConnID = cid @@ -1155,7 +1155,7 @@ func (s *connection) handleUnpackedPacket( // we might have create a connection with an incorrect source connection ID. // Once we authenticate the first packet, we need to update it. if s.perspective == protocol.PerspectiveServer { - if !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + if packet.hdr.SrcConnectionID != s.handshakeDestConnID { s.handshakeDestConnID = packet.hdr.SrcConnectionID s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) } @@ -1601,7 +1601,7 @@ func (s *connection) checkTransportParameters(params *wire.TransportParameters) } // check the initial_source_connection_id - if !params.InitialSourceConnectionID.Equal(s.handshakeDestConnID) { + if params.InitialSourceConnectionID != s.handshakeDestConnID { return fmt.Errorf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID) } @@ -1609,14 +1609,14 @@ func (s *connection) checkTransportParameters(params *wire.TransportParameters) return nil } // check the original_destination_connection_id - if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { + if params.OriginalDestinationConnectionID != s.origDestConnID { return fmt.Errorf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID) } if s.retrySrcConnID != nil { // a Retry was performed if params.RetrySourceConnectionID == nil { return errors.New("missing retry_source_connection_id") } - if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) { + if *params.RetrySourceConnectionID != *s.retrySrcConnID { return fmt.Errorf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID) } } else if params.RetrySourceConnectionID != nil { diff --git a/connection_test.go b/connection_test.go index 0d640dc90ed..3956616db3c 100644 --- a/connection_test.go +++ b/connection_test.go @@ -49,9 +49,9 @@ var _ = Describe("Connection", func() { ) remoteAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7331} - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) + clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) getPacket := func(pn protocol.PacketNumber) *packedPacket { buffer := getPacketBuffer() @@ -91,7 +91,7 @@ var _ = Describe("Connection", func() { conn = newConnection( mconn, connRunner, - nil, + protocol.ConnectionID{}, nil, clientDestConnID, destConnID, @@ -270,11 +270,12 @@ var _ = Describe("Connection", func() { }) It("handles NEW_CONNECTION_ID frames", func() { + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) Expect(conn.handleFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 10, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: connID, }, protocol.Encryption1RTT, protocol.ConnectionID{})).To(Succeed()) - Expect(conn.connIDManager.queue.Back().Value.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(conn.connIDManager.queue.Back().Value.ConnectionID).To(Equal(connID)) }) It("handles PING frames", func() { @@ -663,7 +664,11 @@ var _ = Describe("Connection", func() { }) It("drops Version Negotiation packets", func() { - b := wire.ComposeVersionNegotiation(srcConnID, destConnID, conn.config.Versions) + b := wire.ComposeVersionNegotiation( + protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), + protocol.ArbitraryLenConnectionID(destConnID.Bytes()), + conn.config.Versions, + ) tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) Expect(conn.handlePacketImpl(&receivedPacket{ data: b, @@ -1047,7 +1052,7 @@ var _ = Describe("Connection", func() { IsLongHeader: true, Type: protocol.PacketTypeInitial, DestConnectionID: destConnID, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), Length: 1, Version: conn.version, }, @@ -1204,7 +1209,7 @@ var _ = Describe("Connection", func() { }) It("ignores coalesced packet parts if the destination connection IDs don't match", func() { - wrongConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + wrongConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) Expect(srcConnID).ToNot(Equal(wrongConnID)) hdrLen1, packet1 := getPacketWithLength(srcConnID, 456) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, _ time.Time, data []byte) (*unpackedPacket, error) { @@ -2408,8 +2413,8 @@ var _ = Describe("Client Connection", func() { tlsConf *tls.Config quicConf *Config ) - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { buf := &bytes.Buffer{} @@ -2448,7 +2453,7 @@ var _ = Describe("Client Connection", func() { mconn, connRunner, destConnID, - protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), quicConf, tlsConf, 42, // initial packet number @@ -2481,7 +2486,7 @@ var _ = Describe("Client Connection", func() { cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) conn.run() }() - newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} + newConnID := protocol.ParseConnectionID([]byte{1, 3, 3, 7, 1, 3, 3, 7}) p := getPacket(&wire.ExtendedHeader{ Header: wire.Header{ IsLongHeader: true, @@ -2513,9 +2518,9 @@ var _ = Describe("Client Connection", func() { conn.connIDManager.SetHandshakeComplete() conn.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{ SequenceNumber: 1, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}), }) - Expect(conn.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) + Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}))) // now receive a packet with the original source connection ID unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) { return &unpackedPacket{ @@ -2593,7 +2598,11 @@ var _ = Describe("Client Connection", func() { Context("handling Version Negotiation", func() { getVNP := func(versions ...protocol.VersionNumber) *receivedPacket { - b := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions) + b := wire.ComposeVersionNegotiation( + protocol.ArbitraryLenConnectionID(srcConnID.Bytes()), + protocol.ArbitraryLenConnectionID(destConnID.Bytes()), + versions, + ) return &receivedPacket{ data: b, buffer: getPacketBuffer(), @@ -2613,8 +2622,7 @@ var _ = Describe("Client Connection", func() { errChan <- conn.run() }() connRunner.EXPECT().Remove(srcConnID) - tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any()).Do(func(hdr *wire.Header, versions []logging.VersionNumber) { - Expect(hdr.Version).To(BeZero()) + tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.VersionNumber) { Expect(versions).To(And( ContainElement(protocol.VersionNumber(4321)), ContainElement(protocol.VersionNumber(1337)), @@ -2640,7 +2648,7 @@ var _ = Describe("Client Connection", func() { }() connRunner.EXPECT().Remove(srcConnID).MaxTimes(1) gomock.InOrder( - tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any()), + tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()), tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { var vnErr *VersionNegotiationError Expect(errors.As(e, &vnErr)).To(BeTrue()) @@ -2672,7 +2680,7 @@ var _ = Describe("Client Connection", func() { }) Context("handling Retry", func() { - origDestConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + origDestConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) var retryHdr *wire.ExtendedHeader @@ -2681,8 +2689,8 @@ var _ = Describe("Client Connection", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Token: []byte("foobar"), Version: conn.version, }, @@ -2700,7 +2708,7 @@ var _ = Describe("Client Connection", func() { conn.sentPacketHandler = sph sph.EXPECT().ResetForRetry() sph.EXPECT().ReceivedBytes(gomock.Any()) - cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) + cryptoSetup.EXPECT().ChangeConnectionID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})) packer.EXPECT().SetToken([]byte("foobar")) tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { Expect(hdr.DestConnectionID).To(Equal(retryHdr.DestConnectionID)) @@ -2781,7 +2789,7 @@ var _ = Describe("Client Connection", func() { PreferredAddress: &wire.PreferredAddress{ IPv4: net.IPv4(127, 0, 0, 1), IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, }, } @@ -2794,7 +2802,7 @@ var _ = Describe("Client Connection", func() { cf, _ := conn.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(cf).To(BeEmpty()) connRunner.EXPECT().AddResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, conn) - Expect(conn.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4})) + Expect(conn.connIDManager.Get()).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4}))) // shut down connRunner.EXPECT().RemoveResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}) expectClose(true) @@ -2816,10 +2824,10 @@ var _ = Describe("Client Connection", func() { }) It("errors if the transport parameters contain a wrong initial_source_connection_id", func() { - conn.handshakeDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + conn.handshakeDestConnID = protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } expectClose(false) @@ -2832,7 +2840,8 @@ var _ = Describe("Client Connection", func() { }) It("errors if the transport parameters don't contain the retry_source_connection_id, if a Retry was performed", func() { - conn.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) + conn.retrySrcConnID = &rcid params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: destConnID, @@ -2848,11 +2857,13 @@ var _ = Describe("Client Connection", func() { }) It("errors if the transport parameters contain the wrong retry_source_connection_id, if a Retry was performed", func() { - conn.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) + rcid2 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) + conn.retrySrcConnID = &rcid params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: destConnID, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + RetrySourceConnectionID: &rcid2, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } expectClose(false) @@ -2865,10 +2876,11 @@ var _ = Describe("Client Connection", func() { }) It("errors if the transport parameters contain the retry_source_connection_id, if no Retry was performed", func() { + rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) params := &wire.TransportParameters{ OriginalDestinationConnectionID: destConnID, InitialSourceConnectionID: destConnID, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + RetrySourceConnectionID: &rcid, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } expectClose(false) @@ -2881,9 +2893,9 @@ var _ = Describe("Client Connection", func() { }) It("errors if the transport parameters contain a wrong original_destination_connection_id", func() { - conn.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + conn.origDestConnID = protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) params := &wire.TransportParameters{ - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), InitialSourceConnectionID: conn.handshakeDestConnID, StatelessResetToken: &protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, } @@ -2941,7 +2953,7 @@ var _ = Describe("Client Connection", func() { IsLongHeader: true, Type: protocol.PacketTypeInitial, DestConnectionID: destConnID, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), Length: 1, Version: conn.version, }, @@ -3007,7 +3019,7 @@ var _ = Describe("Client Connection", func() { conn.sentPacketHandler = sph sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2) sph.EXPECT().ResetForRetry() - newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + newSrcConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID) packer.EXPECT().SetToken([]byte("foobar")) diff --git a/framer_test.go b/framer_test.go index 71f7b2aa73b..13fe65fc891 100644 --- a/framer_test.go +++ b/framer_test.go @@ -89,7 +89,10 @@ var _ = Describe("Framer", func() { It("drops *_BLOCKED frames when 0-RTT is rejected", func() { ping := &wire.PingFrame{} - ncid := &wire.NewConnectionIDFrame{SequenceNumber: 10, ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}} + ncid := &wire.NewConnectionIDFrame{ + SequenceNumber: 10, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + } frames := []wire.Frame{ &wire.DataBlockedFrame{MaximumData: 1337}, &wire.StreamDataBlockedFrame{StreamID: 42, MaximumStreamData: 1337}, diff --git a/fuzzing/frames/cmd/corpus.go b/fuzzing/frames/cmd/corpus.go index eea064260cf..14a8360d211 100644 --- a/fuzzing/frames/cmd/corpus.go +++ b/fuzzing/frames/cmd/corpus.go @@ -224,13 +224,13 @@ func getFrames() []wire.Frame { &wire.NewConnectionIDFrame{ SequenceNumber: seq1, RetirePriorTo: seq1 / 2, - ConnectionID: getRandomData(4), + ConnectionID: protocol.ParseConnectionID(getRandomData(4)), StatelessResetToken: token1, }, &wire.NewConnectionIDFrame{ SequenceNumber: seq2, RetirePriorTo: seq2, - ConnectionID: getRandomData(17), + ConnectionID: protocol.ParseConnectionID(getRandomData(17)), StatelessResetToken: token2, }, }...) diff --git a/fuzzing/header/cmd/corpus.go b/fuzzing/header/cmd/corpus.go index eeb880ce860..0c02b699199 100644 --- a/fuzzing/header/cmd/corpus.go +++ b/fuzzing/header/cmd/corpus.go @@ -19,7 +19,7 @@ func getRandomData(l int) []byte { return b } -func getVNP(src, dest protocol.ConnectionID, numVersions int) []byte { +func getVNP(src, dest protocol.ArbitraryLenConnectionID, numVersions int) []byte { versions := make([]protocol.VersionNumber, numVersions) for i := 0; i < numVersions; i++ { versions[i] = protocol.VersionNumber(rand.Uint32()) @@ -31,23 +31,23 @@ func main() { headers := []wire.Header{ { // Initial without token IsLongHeader: true, - SrcConnectionID: protocol.ConnectionID(getRandomData(3)), - DestConnectionID: protocol.ConnectionID(getRandomData(8)), + SrcConnectionID: protocol.ParseConnectionID(getRandomData(3)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(8)), Type: protocol.PacketTypeInitial, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, }, { // Initial without token, with zero-length src conn id IsLongHeader: true, - DestConnectionID: protocol.ConnectionID(getRandomData(8)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(8)), Type: protocol.PacketTypeInitial, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, }, { // Initial with Token IsLongHeader: true, - SrcConnectionID: protocol.ConnectionID(getRandomData(10)), - DestConnectionID: protocol.ConnectionID(getRandomData(19)), + SrcConnectionID: protocol.ParseConnectionID(getRandomData(10)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(19)), Type: protocol.PacketTypeInitial, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, @@ -55,37 +55,37 @@ func main() { }, { // Handshake packet IsLongHeader: true, - SrcConnectionID: protocol.ConnectionID(getRandomData(5)), - DestConnectionID: protocol.ConnectionID(getRandomData(10)), + SrcConnectionID: protocol.ParseConnectionID(getRandomData(5)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(10)), Type: protocol.PacketTypeHandshake, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, }, { // Handshake packet, with zero-length src conn id IsLongHeader: true, - DestConnectionID: protocol.ConnectionID(getRandomData(12)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(12)), Type: protocol.PacketTypeHandshake, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, }, { // 0-RTT packet IsLongHeader: true, - SrcConnectionID: protocol.ConnectionID(getRandomData(8)), - DestConnectionID: protocol.ConnectionID(getRandomData(9)), + SrcConnectionID: protocol.ParseConnectionID(getRandomData(8)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(9)), Type: protocol.PacketType0RTT, Length: protocol.ByteCount(rand.Intn(1000)), Version: version, }, { // Retry Packet, with empty orig dest conn id IsLongHeader: true, - SrcConnectionID: protocol.ConnectionID(getRandomData(8)), - DestConnectionID: protocol.ConnectionID(getRandomData(9)), + SrcConnectionID: protocol.ParseConnectionID(getRandomData(8)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(9)), Type: protocol.PacketTypeRetry, Token: getRandomData(1000), Version: version, }, { // Short-Header - DestConnectionID: protocol.ConnectionID(getRandomData(8)), + DestConnectionID: protocol.ParseConnectionID(getRandomData(8)), }, } @@ -113,28 +113,28 @@ func main() { vnps := [][]byte{ getVNP( - protocol.ConnectionID(getRandomData(8)), - protocol.ConnectionID(getRandomData(10)), + protocol.ArbitraryLenConnectionID(getRandomData(8)), + protocol.ArbitraryLenConnectionID(getRandomData(10)), 4, ), getVNP( - protocol.ConnectionID(getRandomData(10)), - protocol.ConnectionID(getRandomData(5)), + protocol.ArbitraryLenConnectionID(getRandomData(10)), + protocol.ArbitraryLenConnectionID(getRandomData(5)), 0, ), getVNP( - protocol.ConnectionID(getRandomData(3)), - protocol.ConnectionID(getRandomData(19)), + protocol.ArbitraryLenConnectionID(getRandomData(3)), + protocol.ArbitraryLenConnectionID(getRandomData(19)), 100, ), getVNP( - protocol.ConnectionID(getRandomData(3)), + protocol.ArbitraryLenConnectionID(getRandomData(3)), nil, 20, ), getVNP( nil, - protocol.ConnectionID(getRandomData(10)), + protocol.ArbitraryLenConnectionID(getRandomData(10)), 5, ), } diff --git a/fuzzing/header/fuzz.go b/fuzzing/header/fuzz.go index ba37172e209..ea0054d83bb 100644 --- a/fuzzing/header/fuzz.go +++ b/fuzzing/header/fuzz.go @@ -35,7 +35,7 @@ func Fuzz(data []byte) int { if err != nil { return 0 } - if !hdr.DestConnectionID.Equal(connID) { + if hdr.DestConnectionID != connID { panic(fmt.Sprintf("Expected connection IDs to match: %s vs %s", hdr.DestConnectionID, connID)) } if (hdr.Type == protocol.PacketType0RTT) != is0RTTPacket { @@ -82,16 +82,16 @@ func fuzzVNP(data []byte) int { if err != nil { return 0 } - hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(data)) + dest, src, versions, err := wire.ParseVersionNegotiationPacket(data) if err != nil { return 0 } - if !hdr.DestConnectionID.Equal(connID) { + if !bytes.Equal(dest, connID.Bytes()) { panic("connection IDs don't match") } if len(versions) == 0 { panic("no versions") } - wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) + wire.ComposeVersionNegotiation(src, dest, versions) return 1 } diff --git a/fuzzing/tokens/fuzz.go b/fuzzing/tokens/fuzz.go index 1e1904ba3cd..a753716ad3a 100644 --- a/fuzzing/tokens/fuzz.go +++ b/fuzzing/tokens/fuzz.go @@ -73,7 +73,7 @@ func newToken(tg *handshake.TokenGenerator, data []byte) int { if token.SentTime.Before(start) || token.SentTime.After(time.Now()) { panic("incorrect send time") } - if token.OriginalDestConnectionID != nil || token.RetrySrcConnectionID != nil { + if token.OriginalDestConnectionID.Len() > 0 || token.RetrySrcConnectionID.Len() > 0 { panic("didn't expect connection IDs") } return 1 @@ -89,12 +89,12 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int { if len(data) < origDestConnIDLen { return -1 } - origDestConnID := protocol.ConnectionID(data[:origDestConnIDLen]) + origDestConnID := protocol.ParseConnectionID(data[:origDestConnIDLen]) data = data[origDestConnIDLen:] if len(data) < retrySrcConnIDLen { return -1 } - retrySrcConnID := protocol.ConnectionID(data[:retrySrcConnIDLen]) + retrySrcConnID := protocol.ParseConnectionID(data[:retrySrcConnIDLen]) data = data[retrySrcConnIDLen:] if len(data) < 1 { @@ -132,10 +132,10 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int { if token.SentTime.Before(start) || token.SentTime.After(time.Now()) { panic("incorrect send time") } - if !token.OriginalDestConnectionID.Equal(origDestConnID) { + if token.OriginalDestConnectionID != origDestConnID { panic("orig dest conn ID doesn't match") } - if !token.RetrySrcConnectionID.Equal(retrySrcConnID) { + if token.RetrySrcConnectionID != retrySrcConnID { panic("retry src conn ID doesn't match") } return 1 diff --git a/fuzzing/transportparameters/cmd/corpus.go b/fuzzing/transportparameters/cmd/corpus.go index 8c88a6a87f6..0f4720610fc 100644 --- a/fuzzing/transportparameters/cmd/corpus.go +++ b/fuzzing/transportparameters/cmd/corpus.go @@ -43,13 +43,13 @@ func main() { ActiveConnectionIDLimit: getRandomValue(), } if rand.Int()%2 == 0 { - tp.OriginalDestinationConnectionID = protocol.ConnectionID(getRandomData(rand.Intn(50))) + tp.OriginalDestinationConnectionID = protocol.ParseConnectionID(getRandomData(rand.Intn(21))) } if rand.Int()%2 == 0 { - tp.InitialSourceConnectionID = protocol.ConnectionID(getRandomData(rand.Intn(50))) + tp.InitialSourceConnectionID = protocol.ParseConnectionID(getRandomData(rand.Intn(21))) } if rand.Int()%2 == 0 { - connID := protocol.ConnectionID(getRandomData(rand.Intn(50))) + connID := protocol.ParseConnectionID(getRandomData(rand.Intn(21))) tp.RetrySourceConnectionID = &connID } if rand.Int()%2 == 0 { @@ -65,7 +65,7 @@ func main() { IPv4Port: uint16(rand.Int()), IPv6: net.IP(getRandomData(16)), IPv6Port: uint16(rand.Int()), - ConnectionID: protocol.ConnectionID(getRandomData(rand.Intn(25))), + ConnectionID: protocol.ParseConnectionID(getRandomData(rand.Intn(21))), StatelessResetToken: token, } } diff --git a/integrationtests/self/conn_id_test.go b/integrationtests/self/conn_id_test.go index dc47aa867a2..760d7131625 100644 --- a/integrationtests/self/conn_id_test.go +++ b/integrationtests/self/conn_id_test.go @@ -19,13 +19,12 @@ type connIDGenerator struct { length int } -func (c *connIDGenerator) GenerateConnectionID() ([]byte, error) { +func (c *connIDGenerator) GenerateConnectionID() (quic.ConnectionID, error) { b := make([]byte, c.length) - _, err := rand.Read(b) - if err != nil { + if _, err := rand.Read(b); err != nil { fmt.Fprintf(GinkgoWriter, "generating conn ID failed: %s", err) } - return b, nil + return protocol.ParseConnectionID(b), nil } func (c *connIDGenerator) ConnectionIDLen() int { diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 202097961aa..a594a1d8b22 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -58,6 +58,8 @@ type versionNegotiationTracer struct { clientVersions, serverVersions []logging.VersionNumber } +var _ logging.ConnectionTracer = &versionNegotiationTracer{} + func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumber, clientVersions, serverVersions []logging.VersionNumber) { if t.loggedVersions { Fail("only expected one call to NegotiatedVersions") @@ -68,7 +70,7 @@ func (t *versionNegotiationTracer) NegotiatedVersion(chosen logging.VersionNumbe t.serverVersions = serverVersions } -func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(*logging.Header, []logging.VersionNumber) { +func (t *versionNegotiationTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, _ []logging.VersionNumber) { t.receivedVersionNegotiation = true } diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index f7c6e6c7c02..0a08bb2ecfd 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -326,7 +326,11 @@ var _ = Describe("MITM test", func() { // Create fake version negotiation packet with no supported versions versions := []protocol.VersionNumber{} - packet := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) + packet := wire.ComposeVersionNegotiation( + protocol.ArbitraryLenConnectionID(hdr.SrcConnectionID.Bytes()), + protocol.ArbitraryLenConnectionID(hdr.DestConnectionID.Bytes()), + versions, + ) // Send the packet _, err = serverUDPConn.WriteTo(packet, clientUDPConn.LocalAddr()) @@ -363,7 +367,7 @@ var _ = Describe("MITM test", func() { } initialPacketIntercepted = true - fakeSrcConnID := protocol.ConnectionID{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12} + fakeSrcConnID := protocol.ParseConnectionID([]byte{0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12}) retryPacket := testutils.ComposeRetryPacket(fakeSrcConnID, hdr.SrcConnectionID, hdr.DestConnectionID, []byte("token"), hdr.Version) _, err = serverUDPConn.WriteTo(retryPacket, clientUDPConn.LocalAddr()) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index c39a02dda75..2837512e1c9 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -372,7 +372,7 @@ var _ = Describe("0-RTT", func() { It("retransmits all 0-RTT data when the server performs a Retry", func() { var mutex sync.Mutex - var firstConnID, secondConnID protocol.ConnectionID + var firstConnID, secondConnID *protocol.ConnectionID var firstCounter, secondCounter protocol.ByteCount tlsConf, clientConf := dialAndReceiveSessionTicket(nil) @@ -415,15 +415,15 @@ var _ = Describe("0-RTT", func() { if zeroRTTBytes := countZeroRTTBytes(data); zeroRTTBytes > 0 { if firstConnID == nil { - firstConnID = connID + firstConnID = &connID firstCounter += zeroRTTBytes - } else if firstConnID != nil && firstConnID.Equal(connID) { + } else if firstConnID != nil && *firstConnID == connID { Expect(secondConnID).To(BeNil()) firstCounter += zeroRTTBytes } else if secondConnID == nil { - secondConnID = connID + secondConnID = &connID secondCounter += zeroRTTBytes - } else if secondConnID != nil && secondConnID.Equal(connID) { + } else if secondConnID != nil && *secondConnID == connID { secondCounter += zeroRTTBytes } else { Fail("received 3 connection IDs on 0-RTT packets") diff --git a/integrationtests/tools/proxy/proxy_test.go b/integrationtests/tools/proxy/proxy_test.go index 35a56baa325..3254d3dfffa 100644 --- a/integrationtests/tools/proxy/proxy_test.go +++ b/integrationtests/tools/proxy/proxy_test.go @@ -35,8 +35,8 @@ var _ = Describe("QUIC Proxy", func() { Type: protocol.PacketTypeInitial, Version: protocol.VersionTLS, Length: 4 + protocol.ByteCount(len(payload)), - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0, 0, 0x13, 0x37}), }, PacketNumber: p, PacketNumberLen: protocol.PacketNumberLen4, diff --git a/interface.go b/interface.go index ea94aa4567f..8fbeab1f10d 100644 --- a/interface.go +++ b/interface.go @@ -201,6 +201,11 @@ type EarlyConnection interface { NextConnection() Connection } +// A ConnectionID is a QUIC Connection ID, as defined in RFC 9000. +// It is not able to handle QUIC Connection IDs longer than 20 bytes, +// as they are allowed by RFC 8999. +type ConnectionID = protocol.ConnectionID + // A ConnectionIDGenerator is an interface that allows clients to implement their own format // for the Connection IDs that servers/clients use as SrcConnectionID in QUIC packets. // @@ -208,7 +213,7 @@ type EarlyConnection interface { type ConnectionIDGenerator interface { // GenerateConnectionID generates a new ConnectionID. // Generated ConnectionIDs should be unique and observers should not be able to correlate two ConnectionIDs. - GenerateConnectionID() ([]byte, error) + GenerateConnectionID() (ConnectionID, error) // ConnectionIDLen tells what is the length of the ConnectionIDs generated by the implementation of // this interface. diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index 00ed243c75f..6128147c55f 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -62,7 +62,7 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p } func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) { - initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(v)) + initialSecret := hkdf.Extract(crypto.SHA256.New, connID.Bytes(), getSalt(v)) clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size()) serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size()) return diff --git a/internal/handshake/initial_aead_test.go b/internal/handshake/initial_aead_test.go index bb8c4a156ac..a3f38ac649e 100644 --- a/internal/handshake/initial_aead_test.go +++ b/internal/handshake/initial_aead_test.go @@ -18,7 +18,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { Expect(splitHexString("dead beef")).To(Equal([]byte{0xde, 0xad, 0xbe, 0xef})) }) - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + connID := protocol.ParseConnectionID(splitHexString("0x8394c8f03e515708")) DescribeTable("computes the client key and IV", func(v protocol.VersionNumber, expectedClientSecret, expectedKey, expectedIV []byte) { @@ -160,7 +160,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { Context(fmt.Sprintf("using version %s", v), func() { It("seals and opens", func() { - connectionID := protocol.ConnectionID{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef} + connectionID := protocol.ParseConnectionID([]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab, 0xcd, 0xef}) clientSealer, clientOpener := NewInitialAEAD(connectionID, protocol.PerspectiveClient, v) serverSealer, serverOpener := NewInitialAEAD(connectionID, protocol.PerspectiveServer, v) @@ -175,8 +175,8 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { }) It("doesn't work if initialized with different connection IDs", func() { - c1 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 1} - c2 := protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0, 2} + c1 := protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 1}) + c2 := protocol.ParseConnectionID([]byte{0, 0, 0, 0, 0, 0, 0, 2}) clientSealer, _ := NewInitialAEAD(c1, protocol.PerspectiveClient, v) _, serverOpener := NewInitialAEAD(c2, protocol.PerspectiveServer, v) @@ -186,7 +186,7 @@ var _ = Describe("Initial AEAD using AES-GCM", func() { }) It("encrypts und decrypts the header", func() { - connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} + connID := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) clientSealer, clientOpener := NewInitialAEAD(connID, protocol.PerspectiveClient, v) serverSealer, serverOpener := NewInitialAEAD(connID, protocol.PerspectiveServer, v) diff --git a/internal/handshake/retry_test.go b/internal/handshake/retry_test.go index e1d3a215d79..017fa4281e3 100644 --- a/internal/handshake/retry_test.go +++ b/internal/handshake/retry_test.go @@ -9,27 +9,30 @@ import ( var _ = Describe("Retry Integrity Check", func() { It("calculates retry integrity tags", func() { - fooTag := GetRetryIntegrityTag([]byte("foo"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) - barTag := GetRetryIntegrityTag([]byte("bar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.VersionDraft29) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + fooTag := GetRetryIntegrityTag([]byte("foo"), connID, protocol.VersionDraft29) + barTag := GetRetryIntegrityTag([]byte("bar"), connID, protocol.VersionDraft29) Expect(fooTag).ToNot(BeNil()) Expect(barTag).ToNot(BeNil()) Expect(*fooTag).ToNot(Equal(*barTag)) }) It("includes the original connection ID in the tag calculation", func() { - t1 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{1, 2, 3, 4}, protocol.Version1) - t2 := GetRetryIntegrityTag([]byte("foobar"), protocol.ConnectionID{4, 3, 2, 1}, protocol.Version1) + connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + connID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + t1 := GetRetryIntegrityTag([]byte("foobar"), connID1, protocol.Version1) + t2 := GetRetryIntegrityTag([]byte("foobar"), connID2, protocol.Version1) Expect(*t1).ToNot(Equal(*t2)) }) It("uses the test vector from the draft, for old draft versions", func() { - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + connID := protocol.ParseConnectionID(splitHexString("0x8394c8f03e515708")) data := splitHexString("ffff00001d0008f067a5502a4262b574 6f6b656ed16926d81f6f9ca2953a8aa4 575e1e49") Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.VersionDraft29)[:]).To(Equal(data[len(data)-16:])) }) It("uses the test vector from the draft, for version 1", func() { - connID := protocol.ConnectionID(splitHexString("0x8394c8f03e515708")) + connID := protocol.ParseConnectionID(splitHexString("0x8394c8f03e515708")) data := splitHexString("ff000000010008f067a5502a4262b574 6f6b656e04a265ba2eff4d829058fb3f 0f2496ba") Expect(GetRetryIntegrityTag(data[:len(data)-16], connID, protocol.Version1)[:]).To(Equal(data[len(data)-16:])) }) diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go index a8dda91e6d4..cda49466d75 100644 --- a/internal/handshake/token_generator.go +++ b/internal/handshake/token_generator.go @@ -65,8 +65,8 @@ func (g *TokenGenerator) NewRetryToken( data, err := asn1.Marshal(token{ IsRetryToken: true, RemoteAddr: encodeRemoteAddr(raddr), - OriginalDestConnectionID: origDestConnID, - RetrySrcConnectionID: retrySrcConnID, + OriginalDestConnectionID: origDestConnID.Bytes(), + RetrySrcConnectionID: retrySrcConnID.Bytes(), Timestamp: time.Now().UnixNano(), }) if err != nil { @@ -112,8 +112,8 @@ func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) { encodedRemoteAddr: t.RemoteAddr, } if t.IsRetryToken { - token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID) - token.RetrySrcConnectionID = protocol.ConnectionID(t.RetrySrcConnectionID) + token.OriginalDestConnectionID = protocol.ParseConnectionID(t.OriginalDestConnectionID) + token.RetrySrcConnectionID = protocol.ParseConnectionID(t.RetrySrcConnectionID) } return token, nil } diff --git a/internal/handshake/token_generator_test.go b/internal/handshake/token_generator_test.go index 4d4be0f239e..d674e72e53b 100644 --- a/internal/handshake/token_generator_test.go +++ b/internal/handshake/token_generator_test.go @@ -23,7 +23,7 @@ var _ = Describe("Token Generator", func() { It("generates a token", func() { ip := net.IPv4(127, 0, 0, 1) - token, err := tokenGen.NewRetryToken(&net.UDPAddr{IP: ip, Port: 1337}, nil, nil) + token, err := tokenGen.NewRetryToken(&net.UDPAddr{IP: ip, Port: 1337}, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) Expect(token).ToNot(BeEmpty()) }) @@ -36,7 +36,7 @@ var _ = Describe("Token Generator", func() { It("accepts a valid token", func() { addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(addr, nil, nil) + tokenEnc, err := tokenGen.NewRetryToken(addr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) @@ -48,16 +48,14 @@ var _ = Describe("Token Generator", func() { }) It("saves the connection ID", func() { - tokenEnc, err := tokenGen.NewRetryToken( - &net.UDPAddr{}, - protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - ) + connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) + connID2 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) + tokenEnc, err := tokenGen.NewRetryToken(&net.UDPAddr{}, connID1, connID2) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) - Expect(token.OriginalDestConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(token.RetrySrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) + Expect(token.OriginalDestConnectionID).To(Equal(connID1)) + Expect(token.RetrySrcConnectionID).To(Equal(connID2)) }) It("rejects invalid tokens", func() { @@ -103,7 +101,7 @@ var _ = Describe("Token Generator", func() { ip := net.ParseIP(addr) Expect(ip).ToNot(BeNil()) raddr := &net.UDPAddr{IP: ip, Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) + tokenEnc, err := tokenGen.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) @@ -114,7 +112,7 @@ var _ = Describe("Token Generator", func() { It("uses the string representation an address that is not a UDP address", func() { raddr := &net.TCPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} - tokenEnc, err := tokenGen.NewRetryToken(raddr, nil, nil) + tokenEnc, err := tokenGen.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) token, err := tokenGen.DecodeToken(tokenEnc) Expect(err).ToNot(HaveOccurred()) diff --git a/internal/mocks/logging/connection_tracer.go b/internal/mocks/logging/connection_tracer.go index ed5d268870c..4c037916f6d 100644 --- a/internal/mocks/logging/connection_tracer.go +++ b/internal/mocks/logging/connection_tracer.go @@ -220,15 +220,15 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 int } // ReceivedVersionNegotiationPacket mocks base method. -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header, arg1 []protocol.VersionNumber) { +func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 protocol.ArbitraryLenConnectionID, arg2 []protocol.VersionNumber) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1) + m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1, arg2) } // ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) } // RestoredTransportParameters mocks base method. diff --git a/internal/mocks/logging/tracer.go b/internal/mocks/logging/tracer.go index 04c72623364..762d590004d 100644 --- a/internal/mocks/logging/tracer.go +++ b/internal/mocks/logging/tracer.go @@ -62,6 +62,18 @@ func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) } +// SentVersionNegotiationPacket mocks base method. +func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 protocol.ArbitraryLenConnectionID, arg3 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentVersionNegotiationPacket", arg0, arg1, arg2, arg3) +} + +// SentVersionNegotiationPacket indicates an expected call of SentVersionNegotiationPacket. +func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) +} + // TracerForConnection mocks base method. func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) logging.ConnectionTracer { m.ctrl.T.Helper() diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go index 7ae7d9dcfa1..77259b5fa58 100644 --- a/internal/protocol/connection_id.go +++ b/internal/protocol/connection_id.go @@ -1,24 +1,60 @@ package protocol import ( - "bytes" "crypto/rand" + "errors" "fmt" "io" ) -// A ConnectionID in QUIC -type ConnectionID []byte +var ErrInvalidConnectionIDLen = errors.New("invalid Connection ID length") + +// An ArbitraryLenConnectionID is a QUIC Connection ID able to represent Connection IDs according to RFC 8999. +// Future QUIC versions might allow connection ID lengths up to 255 bytes, while QUIC v1 +// restricts the length to 20 bytes. +type ArbitraryLenConnectionID []byte + +func (c ArbitraryLenConnectionID) Len() int { + return len(c) +} + +func (c ArbitraryLenConnectionID) Bytes() []byte { + return c +} + +func (c ArbitraryLenConnectionID) String() string { + if c.Len() == 0 { + return "(empty)" + } + return fmt.Sprintf("%x", c.Bytes()) +} const maxConnectionIDLen = 20 +// A ConnectionID in QUIC +type ConnectionID struct { + b [20]byte + l uint8 +} + // GenerateConnectionID generates a connection ID using cryptographic random -func GenerateConnectionID(len int) (ConnectionID, error) { - b := make([]byte, len) - if _, err := rand.Read(b); err != nil { - return nil, err +func GenerateConnectionID(l int) (ConnectionID, error) { + var c ConnectionID + c.l = uint8(l) + _, err := rand.Read(c.b[:l]) + return c, err +} + +// ParseConnectionID interprets b as a Connection ID. +// It panics if b is longer than 20 bytes. +func ParseConnectionID(b []byte) ConnectionID { + if len(b) > maxConnectionIDLen { + panic("invalid conn id length") } - return ConnectionID(b), nil + var c ConnectionID + c.l = uint8(len(b)) + copy(c.b[:c.l], b) + return c } // GenerateConnectionIDForInitial generates a connection ID for the Initial packet. @@ -26,39 +62,38 @@ func GenerateConnectionID(len int) (ConnectionID, error) { func GenerateConnectionIDForInitial() (ConnectionID, error) { r := make([]byte, 1) if _, err := rand.Read(r); err != nil { - return nil, err + return ConnectionID{}, err } - len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) - return GenerateConnectionID(len) + l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1) + return GenerateConnectionID(l) } // ReadConnectionID reads a connection ID of length len from the given io.Reader. // It returns io.EOF if there are not enough bytes to read. -func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) { - if len == 0 { - return nil, nil +func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) { + var c ConnectionID + if l == 0 { + return c, nil + } + if l > maxConnectionIDLen { + return c, ErrInvalidConnectionIDLen } - c := make(ConnectionID, len) - _, err := io.ReadFull(r, c) + c.l = uint8(l) + _, err := io.ReadFull(r, c.b[:l]) if err == io.ErrUnexpectedEOF { - return nil, io.EOF + return c, io.EOF } return c, err } -// Equal says if two connection IDs are equal -func (c ConnectionID) Equal(other ConnectionID) bool { - return bytes.Equal(c, other) -} - // Len returns the length of the connection ID in bytes func (c ConnectionID) Len() int { - return len(c) + return int(c.l) } // Bytes returns the byte representation func (c ConnectionID) Bytes() []byte { - return []byte(c) + return c.b[:c.l] } func (c ConnectionID) String() string { @@ -72,7 +107,7 @@ type DefaultConnectionIDGenerator struct { ConnLen int } -func (d *DefaultConnectionIDGenerator) GenerateConnectionID() ([]byte, error) { +func (d *DefaultConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) { return GenerateConnectionID(d.ConnLen) } diff --git a/internal/protocol/connection_id_test.go b/internal/protocol/connection_id_test.go index 345e656c0f4..98abb1d26c8 100644 --- a/internal/protocol/connection_id_test.go +++ b/internal/protocol/connection_id_test.go @@ -2,6 +2,7 @@ package protocol import ( "bytes" + "crypto/rand" "io" . "github.com/onsi/ginkgo" @@ -42,15 +43,6 @@ var _ = Describe("Connection ID generation", func() { Expect(has20ByteConnID).To(BeTrue()) }) - It("says if connection IDs are equal", func() { - c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - Expect(c1.Equal(c1)).To(BeTrue()) - Expect(c2.Equal(c2)).To(BeTrue()) - Expect(c1.Equal(c2)).To(BeFalse()) - Expect(c2.Equal(c1)).To(BeFalse()) - }) - It("reads the connection ID", func() { buf := bytes.NewBuffer([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) c, err := ReadConnectionID(buf, 9) @@ -64,15 +56,21 @@ var _ = Describe("Connection ID generation", func() { Expect(err).To(MatchError(io.EOF)) }) - It("returns nil for a 0 length connection ID", func() { + It("returns a 0 length connection ID", func() { buf := bytes.NewBuffer([]byte{1, 2, 3, 4}) c, err := ReadConnectionID(buf, 0) Expect(err).ToNot(HaveOccurred()) - Expect(c).To(BeNil()) + Expect(c.Len()).To(BeZero()) + }) + + It("errors when trying to read a too long connection ID", func() { + buf := bytes.NewBuffer(make([]byte, 21)) + _, err := ReadConnectionID(buf, 21) + Expect(err).To(MatchError(ErrInvalidConnectionIDLen)) }) It("returns the length", func() { - c := ConnectionID{1, 2, 3, 4, 5, 6, 7} + c := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) Expect(c.Len()).To(Equal(7)) }) @@ -82,22 +80,22 @@ var _ = Describe("Connection ID generation", func() { }) It("returns the bytes", func() { - c := ConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) + c := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7}) Expect(c.Bytes()).To(Equal([]byte{1, 2, 3, 4, 5, 6, 7})) }) It("returns a nil byte slice for the default value", func() { var c ConnectionID - Expect(c.Bytes()).To(BeNil()) + Expect(c.Bytes()).To(HaveLen(0)) }) It("has a string representation", func() { - c := ConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) + c := ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) Expect(c.String()).To(Equal("deadbeef42")) }) It("has a long string representation", func() { - c := ConnectionID{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad} + c := ParseConnectionID([]byte{0x13, 0x37, 0, 0, 0xde, 0xca, 0xfb, 0xad}) Expect(c.String()).To(Equal("13370000decafbad")) }) @@ -105,4 +103,28 @@ var _ = Describe("Connection ID generation", func() { var c ConnectionID Expect(c.String()).To(Equal("(empty)")) }) + + Context("arbitrary length connection IDs", func() { + It("returns the bytes", func() { + b := make([]byte, 30) + rand.Read(b) + c := ArbitraryLenConnectionID(b) + Expect(c.Bytes()).To(Equal(b)) + }) + + It("returns the length", func() { + c := ArbitraryLenConnectionID(make([]byte, 156)) + Expect(c.Len()).To(Equal(156)) + }) + + It("has a string representation", func() { + c := ArbitraryLenConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) + Expect(c.String()).To(Equal("deadbeef42")) + }) + + It("has a string representation for the default value", func() { + var c ArbitraryLenConnectionID + Expect(c.String()).To(Equal("(empty)")) + }) + }) }) diff --git a/internal/wire/extended_header_test.go b/internal/wire/extended_header_test.go index 51719e83cd7..a3f25fb3f0a 100644 --- a/internal/wire/extended_header_test.go +++ b/internal/wire/extended_header_test.go @@ -24,15 +24,15 @@ var _ = Describe("Header", func() { }) Context("Long Header", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) It("writes", func() { Expect((&ExtendedHeader{ Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x0, 0x0, 0x13, 0x37}), Version: 0x1020304, Length: protocol.InitialPacketSizeIPv4, }, @@ -52,27 +52,12 @@ var _ = Describe("Header", func() { Expect(buf.Bytes()).To(Equal(expected)) }) - It("refuses to write a header with a too long connection ID", func() { - err := (&ExtendedHeader{ - Header: Header{ - IsLongHeader: true, - SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}, // connection IDs must be at most 20 bytes long - Version: 0x1020304, - Type: 0x5, - }, - PacketNumber: 0xdecafbad, - PacketNumberLen: protocol.PacketNumberLen4, - }).Write(buf, versionIETFHeader) - Expect(err).To(MatchError("invalid connection ID length: 21 bytes")) - }) - It("writes a header with a 20 byte connection ID", func() { err := (&ExtendedHeader{ Header: Header{ IsLongHeader: true, SrcConnectionID: srcConnID, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, // connection IDs must be at most 20 bytes long + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}), // connection IDs must be at most 20 bytes long Version: 0x1020304, Type: 0x5, }, @@ -194,7 +179,7 @@ var _ = Describe("Header", func() { It("writes a header with connection ID", func() { Expect((&ExtendedHeader{ Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}), }, PacketNumberLen: protocol.PacketNumberLen1, PacketNumber: 0x42, @@ -271,8 +256,8 @@ var _ = Describe("Header", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Length: 1, }, PacketNumberLen: protocol.PacketNumberLen1, @@ -288,8 +273,8 @@ var _ = Describe("Header", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Length: 1500, }, PacketNumberLen: protocol.PacketNumberLen2, @@ -305,8 +290,8 @@ var _ = Describe("Header", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 15, }, PacketNumberLen: protocol.PacketNumberLen2, @@ -322,8 +307,8 @@ var _ = Describe("Header", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 1500, }, PacketNumberLen: protocol.PacketNumberLen2, @@ -338,8 +323,8 @@ var _ = Describe("Header", func() { h := &ExtendedHeader{ Header: Header{ IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Type: protocol.PacketTypeInitial, Length: 1500, Token: []byte("foo"), @@ -355,7 +340,7 @@ var _ = Describe("Header", func() { It("has the right length for a Short Header containing a connection ID", func() { h := &ExtendedHeader{ Header: Header{ - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), }, PacketNumberLen: protocol.PacketNumberLen1, } @@ -407,8 +392,8 @@ var _ = Describe("Header", func() { (&ExtendedHeader{ Header: Header{ IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x013, 0x37, 0x13, 0x37}), Type: protocol.PacketTypeHandshake, Length: 54321, Version: 0xfeed, @@ -423,8 +408,8 @@ var _ = Describe("Header", func() { (&ExtendedHeader{ Header: Header{ IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), Type: protocol.PacketTypeInitial, Token: []byte{0xde, 0xad, 0xbe, 0xef}, Length: 100, @@ -440,8 +425,8 @@ var _ = Describe("Header", func() { (&ExtendedHeader{ Header: Header{ IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), Type: protocol.PacketTypeInitial, Length: 100, Version: 0xfeed, @@ -456,8 +441,8 @@ var _ = Describe("Header", func() { (&ExtendedHeader{ Header: Header{ IsLongHeader: true, - DestConnectionID: protocol.ConnectionID{0xca, 0xfe, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xca, 0xfe, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), Type: protocol.PacketTypeRetry, Token: []byte{0x12, 0x34, 0x56}, Version: 0xfeed, @@ -469,7 +454,7 @@ var _ = Describe("Header", func() { It("logs Short Headers containing a connection ID", func() { (&ExtendedHeader{ Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}), }, KeyPhase: protocol.KeyPhaseOne, PacketNumber: 1337, diff --git a/internal/wire/frame_parser_test.go b/internal/wire/frame_parser_test.go index dcf570e8583..8b88a378676 100644 --- a/internal/wire/frame_parser_test.go +++ b/internal/wire/frame_parser_test.go @@ -210,7 +210,7 @@ var _ = Describe("Frame parsing", func() { It("unpacks NEW_CONNECTION_ID frames", func() { f := &NewConnectionIDFrame{ SequenceNumber: 0x1337, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, } b, err := f.Append(nil, protocol.Version1) @@ -330,7 +330,7 @@ var _ = Describe("Frame parsing", func() { &DataBlockedFrame{}, &StreamDataBlockedFrame{}, &StreamsBlockedFrame{}, - &NewConnectionIDFrame{ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}}, + &NewConnectionIDFrame{ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})}, &RetireConnectionIDFrame{}, &PathChallengeFrame{}, &PathResponseFrame{}, diff --git a/internal/wire/header.go b/internal/wire/header.go index a01f40ca45d..e8a08242491 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -17,22 +17,60 @@ import ( // That means that the connection ID must not be used after the packet buffer is released. func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) { if len(data) == 0 { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF } if !IsLongHeaderPacket(data[0]) { if len(data) < shortHeaderConnIDLen+1 { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF } - return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil + return protocol.ParseConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil } if len(data) < 6 { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF } destConnIDLen := int(data[5]) if len(data) < 6+destConnIDLen { - return nil, io.EOF + return protocol.ConnectionID{}, io.EOF } - return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil + return protocol.ParseConnectionID(data[6 : 6+destConnIDLen]), nil +} + +// ParseArbitraryLenConnectionIDs parses the most general form of a Long Header packet, +// using only the version-independent packet format as described in Section 5.1 of RFC 8999: +// https://datatracker.ietf.org/doc/html/rfc8999#section-5.1. +// This function should only be called on Long Header packets for which we don't support the version. +func ParseArbitraryLenConnectionIDs(data []byte) (bytesParsed int, dest, src protocol.ArbitraryLenConnectionID, _ error) { + r := bytes.NewReader(data) + remaining := r.Len() + src, dest, err := parseArbitraryLenConnectionIDs(r) + return remaining - r.Len(), src, dest, err +} + +func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.ArbitraryLenConnectionID, _ error) { + r.Seek(5, io.SeekStart) // skip first byte and version field + destConnIDLen, err := r.ReadByte() + if err != nil { + return nil, nil, err + } + destConnID := make(protocol.ArbitraryLenConnectionID, destConnIDLen) + if _, err := io.ReadFull(r, destConnID); err != nil { + if err == io.ErrUnexpectedEOF { + err = io.EOF + } + return nil, nil, err + } + srcConnIDLen, err := r.ReadByte() + if err != nil { + return nil, nil, err + } + srcConnID := make(protocol.ArbitraryLenConnectionID, srcConnIDLen) + if _, err := io.ReadFull(r, srcConnID); err != nil { + if err == io.ErrUnexpectedEOF { + err = io.EOF + } + return nil, nil, err + } + return destConnID, srcConnID, nil } // IsLongHeaderPacket says if this is a Long Header packet @@ -40,6 +78,15 @@ func IsLongHeaderPacket(firstByte byte) bool { return firstByte&0x80 > 0 } +// ParseVersion parses the QUIC version. +// It should only be called for Long Header packets (Short Header packets don't contain a version number). +func ParseVersion(data []byte) (protocol.VersionNumber, error) { + if len(data) < 5 { + return 0, io.EOF + } + return protocol.VersionNumber(binary.BigEndian.Uint32(data[1:5])), nil +} + // IsVersionNegotiationPacket says if this is a version negotiation packet func IsVersionNegotiationPacket(b []byte) bool { if len(b) < 5 { diff --git a/internal/wire/header_test.go b/internal/wire/header_test.go index de7d26730ae..88b045de52f 100644 --- a/internal/wire/header_test.go +++ b/internal/wire/header_test.go @@ -2,8 +2,10 @@ package wire import ( "bytes" + "crypto/rand" "encoding/binary" "io" + mrand "math/rand" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" @@ -18,36 +20,36 @@ var _ = Describe("Header Parsing", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}), Version: protocol.Version1, }, PacketNumberLen: 2, }).Write(buf, protocol.Version1)).To(Succeed()) connID, err := ParseConnectionID(buf.Bytes(), 8) Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) }) It("parses the connection ID of a short header packet", func() { buf := &bytes.Buffer{} Expect((&ExtendedHeader{ Header: Header{ - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), }, PacketNumberLen: 2, }).Write(buf, protocol.Version1)).To(Succeed()) buf.Write([]byte("foobar")) connID, err := ParseConnectionID(buf.Bytes(), 4) Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + Expect(connID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) }) It("errors on EOF, for short header packets", func() { buf := &bytes.Buffer{} Expect((&ExtendedHeader{ Header: Header{ - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), }, PacketNumberLen: 2, }).Write(buf, protocol.Version1)).To(Succeed()) @@ -68,8 +70,8 @@ var _ = Describe("Header Parsing", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 8, 9}, + DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad, 0x13, 0x37}), + SrcConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 8, 9}), Version: protocol.Version1, }, PacketNumberLen: 2, @@ -111,6 +113,66 @@ var _ = Describe("Header Parsing", func() { Expect(Is0RTTPacket(append(zeroRTTHeader, []byte("foobar")...))).To(BeTrue()) }) }) + Context("parsing the version", func() { + It("parses the version", func() { + b := []byte{0x80, 0xde, 0xad, 0xbe, 0xef} + v, err := ParseVersion(b) + Expect(err).ToNot(HaveOccurred()) + Expect(v).To(Equal(protocol.VersionNumber(0xdeadbeef))) + }) + + It("errors with EOF", func() { + b := []byte{0x80, 0xde, 0xad, 0xbe, 0xef} + _, err := ParseVersion(b) + Expect(err).ToNot(HaveOccurred()) + for i := range b { + _, err := ParseVersion(b[:i]) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) + + Context("parsing arbitrary length connection IDs", func() { + generateConnID := func(l int) protocol.ArbitraryLenConnectionID { + c := make(protocol.ArbitraryLenConnectionID, l) + rand.Read(c) + return c + } + + generatePacket := func(src, dest protocol.ArbitraryLenConnectionID) []byte { + b := []byte{0x80, 1, 2, 3, 4} + b = append(b, uint8(dest.Len())) + b = append(b, dest.Bytes()...) + b = append(b, uint8(src.Len())) + b = append(b, src.Bytes()...) + return b + } + + It("parses arbitrary length connection IDs", func() { + src := generateConnID(mrand.Intn(255) + 1) + dest := generateConnID(mrand.Intn(255) + 1) + b := generatePacket(src, dest) + l := len(b) + b = append(b, []byte("foobar")...) // add some payload + + parsed, d, s, err := ParseArbitraryLenConnectionIDs(b) + Expect(parsed).To(Equal(l)) + Expect(err).ToNot(HaveOccurred()) + Expect(s).To(Equal(src)) + Expect(d).To(Equal(dest)) + }) + + It("errors on EOF", func() { + b := generatePacket(generateConnID(mrand.Intn(255)+1), generateConnID(mrand.Intn(255)+1)) + _, _, _, err := ParseArbitraryLenConnectionIDs(b) + Expect(err).ToNot(HaveOccurred()) + + for i := range b { + _, _, _, err := ParseArbitraryLenConnectionIDs(b[:i]) + Expect(err).To(MatchError(io.EOF)) + } + }) + }) Context("Identifying Version Negotiation Packets", func() { It("identifies version negotiation packets", func() { @@ -132,14 +194,14 @@ var _ = Describe("Header Parsing", func() { Context("Long Headers", func() { It("parses a Long Header", func() { - destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} - srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + destConnID := protocol.ParseConnectionID([]byte{9, 8, 7, 6, 5, 4, 3, 2, 1}) + srcConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) data := []byte{0xc0 ^ 0x3} data = appendVersion(data, protocol.Version1) data = append(data, 0x9) // dest conn id length - data = append(data, destConnID...) + data = append(data, destConnID.Bytes()...) data = append(data, 0x4) // src conn id length - data = append(data, srcConnID...) + data = append(data, srcConnID.Bytes()...) data = append(data, encodeVarInt(6)...) // token length data = append(data, []byte("foobar")...) // token data = append(data, encodeVarInt(10)...) // length @@ -194,38 +256,50 @@ var _ = Describe("Header Parsing", func() { Expect(err).To(MatchError(ErrUnsupportedVersion)) Expect(hdr.IsLongHeader).To(BeTrue()) Expect(hdr.Version).To(Equal(protocol.VersionNumber(0xdeadbeef))) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1})) + Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8}))) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0x8, 0x7, 0x6, 0x5, 0x4, 0x3, 0x2, 0x1}))) Expect(rest).To(BeEmpty()) }) It("parses a Long Header without a destination connection ID", func() { data := []byte{0xc0 ^ 0x1<<4} data = appendVersion(data, protocol.Version1) - data = append(data, 0x0) // dest conn ID len - data = append(data, 0x4) // src conn ID len + data = append(data, 0) // dest conn ID len + data = append(data, 4) // src conn ID len data = append(data, []byte{0xde, 0xad, 0xbe, 0xef}...) // source connection ID data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketType0RTT)) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(hdr.DestConnectionID).To(BeEmpty()) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}))) + Expect(hdr.DestConnectionID).To(BeZero()) }) It("parses a Long Header without a source connection ID", func() { data := []byte{0xc0 ^ 0x2<<4} data = appendVersion(data, protocol.Version1) - data = append(data, 0xa) // dest conn ID len + data = append(data, 10) // dest conn ID len data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}...) // dest connection ID - data = append(data, 0x0) // src conn ID len + data = append(data, 0) // src conn ID len data = append(data, encodeVarInt(0)...) // length data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) hdr, _, _, err := ParsePacket(data, 0) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.SrcConnectionID).To(BeEmpty()) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(hdr.SrcConnectionID).To(BeZero()) + Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))) + }) + + It("parses a Long Header without a too long destination connection ID", func() { + data := []byte{0xc0 ^ 0x2<<4} + data = appendVersion(data, protocol.Version1) + data = append(data, 21) // dest conn ID len + data = append(data, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21}...) // dest connection ID + data = append(data, 0x0) // src conn ID len + data = append(data, encodeVarInt(0)...) // length + data = append(data, []byte{0xde, 0xca, 0xfb, 0xad}...) + _, _, _, err := ParsePacket(data, 0) + Expect(err).To(MatchError(protocol.ErrInvalidConnectionIDLen)) }) It("parses a Long Header with a 2 byte packet number", func() { @@ -259,8 +333,8 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(hdr.Version).To(Equal(protocol.Version1)) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{6, 5, 4, 3, 2, 1}))) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))) Expect(hdr.Token).To(Equal([]byte("foobar"))) Expect(pdata).To(Equal(data)) Expect(rest).To(BeEmpty()) @@ -279,8 +353,8 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) Expect(hdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(hdr.Version).To(Equal(protocol.Version2)) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{6, 5, 4, 3, 2, 1})) - Expect(hdr.SrcConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(hdr.DestConnectionID).To(Equal(protocol.ParseConnectionID([]byte{6, 5, 4, 3, 2, 1}))) + Expect(hdr.SrcConnectionID).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))) Expect(hdr.Token).To(Equal([]byte("foobar"))) Expect(pdata).To(Equal(data)) Expect(rest).To(BeEmpty()) @@ -377,7 +451,7 @@ var _ = Describe("Header Parsing", func() { hdr := Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 2 + 6, Version: protocol.Version1, } @@ -403,7 +477,7 @@ var _ = Describe("Header Parsing", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 3, Version: protocol.Version1, }, @@ -421,7 +495,7 @@ var _ = Describe("Header Parsing", func() { Header: Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Length: 1000, Version: protocol.Version1, }, @@ -437,8 +511,8 @@ var _ = Describe("Header Parsing", func() { Context("Short Headers", func() { It("reads a Short Header with a 8 byte connection ID", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - data := append([]byte{0x40}, connID...) + connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}) + data := append([]byte{0x40}, connID.Bytes()...) data = append(data, 0x42) // packet number Expect(IsVersionNegotiationPacket(data)).To(BeFalse()) @@ -451,7 +525,7 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) Expect(extHdr.DestConnectionID).To(Equal(connID)) - Expect(extHdr.SrcConnectionID).To(BeEmpty()) + Expect(extHdr.SrcConnectionID).To(BeZero()) Expect(extHdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) Expect(hdr.ParsedLen()).To(BeEquivalentTo(len(data) - 1)) Expect(extHdr.ParsedLen()).To(Equal(hdr.ParsedLen() + 1)) @@ -460,15 +534,15 @@ var _ = Describe("Header Parsing", func() { }) It("errors if 0x40 is not set", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - data := append([]byte{0x0}, connID...) + connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}) + data := append([]byte{0x0}, connID.Bytes()...) _, _, _, err := ParsePacket(data, 8) Expect(err).To(MatchError("not a QUIC packet")) }) It("errors if the 4th or 5th bit are set", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5} - data := append([]byte{0x40 | 0x10 /* set the 4th bit */}, connID...) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + data := append([]byte{0x40 | 0x10 /* set the 4th bit */}, connID.Bytes()...) data = append(data, 0x42) // packet number hdr, _, _, err := ParsePacket(data, 5) Expect(err).ToNot(HaveOccurred()) @@ -480,8 +554,8 @@ var _ = Describe("Header Parsing", func() { }) It("reads a Short Header with a 5 byte connection ID", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5} - data := append([]byte{0x40}, connID...) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + data := append([]byte{0x40}, connID.Bytes()...) data = append(data, 0x42) // packet number hdr, pdata, rest, err := ParsePacket(data, 5) Expect(err).ToNot(HaveOccurred()) @@ -493,7 +567,7 @@ var _ = Describe("Header Parsing", func() { Expect(err).ToNot(HaveOccurred()) Expect(extHdr.KeyPhase).To(Equal(protocol.KeyPhaseZero)) Expect(extHdr.DestConnectionID).To(Equal(connID)) - Expect(extHdr.SrcConnectionID).To(BeEmpty()) + Expect(extHdr.SrcConnectionID).To(BeZero()) Expect(rest).To(BeEmpty()) }) diff --git a/internal/wire/log_test.go b/internal/wire/log_test.go index 6e970c1c8ec..c79138046b5 100644 --- a/internal/wire/log_test.go +++ b/internal/wire/log_test.go @@ -153,7 +153,7 @@ var _ = Describe("Frame logging", func() { It("logs NEW_CONNECTION_ID frames", func() { LogFrame(logger, &NewConnectionIDFrame{ SequenceNumber: 42, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10}, }, false) Expect(buf.String()).To(ContainSubstring("\t<- &wire.NewConnectionIDFrame{SequenceNumber: 42, ConnectionID: deadbeef, StatelessResetToken: 0x0102030405060708090a0b0c0d0e0f10}")) diff --git a/internal/wire/new_connection_id_frame.go b/internal/wire/new_connection_id_frame.go index befc4037032..828cda3bc9d 100644 --- a/internal/wire/new_connection_id_frame.go +++ b/internal/wire/new_connection_id_frame.go @@ -38,9 +38,6 @@ func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewC if err != nil { return nil, err } - if connIDLen > protocol.MaxConnIDLen { - return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen) - } connID, err := protocol.ReadConnectionID(r, int(connIDLen)) if err != nil { return nil, err diff --git a/internal/wire/new_connection_id_frame_test.go b/internal/wire/new_connection_id_frame_test.go index fa9f53aa9cf..f289cb65411 100644 --- a/internal/wire/new_connection_id_frame_test.go +++ b/internal/wire/new_connection_id_frame_test.go @@ -24,7 +24,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { Expect(err).ToNot(HaveOccurred()) Expect(frame.SequenceNumber).To(Equal(uint64(0xdeadbeef))) Expect(frame.RetirePriorTo).To(Equal(uint64(0xcafe))) - Expect(frame.ConnectionID).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) + Expect(frame.ConnectionID).To(Equal(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}))) Expect(string(frame.StatelessResetToken[:])).To(Equal("deadbeefdecafbad")) }) @@ -49,7 +49,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { data = append(data, []byte("deadbeefdecafbad")...) // stateless reset token b := bytes.NewReader(data) _, err := parseNewConnectionIDFrame(b, protocol.Version1) - Expect(err).To(MatchError("invalid connection ID length: 21")) + Expect(err).To(MatchError(protocol.ErrInvalidConnectionIDLen)) }) It("errors on EOFs", func() { @@ -74,7 +74,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { frame := &NewConnectionIDFrame{ SequenceNumber: 0x1337, RetirePriorTo: 0x42, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}), StatelessResetToken: token, } b, err := frame.Append(nil, protocol.Version1) @@ -93,7 +93,7 @@ var _ = Describe("NEW_CONNECTION_ID frame", func() { frame := &NewConnectionIDFrame{ SequenceNumber: 0xdecafbad, RetirePriorTo: 0xdeadbeefcafe, - ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), StatelessResetToken: token, } b, err := frame.Append(nil, protocol.Version1) diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 81d33c602f7..b5f478fb79d 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -41,6 +41,7 @@ var _ = Describe("Transport Parameters", func() { } It("has a string representation", func() { + rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) p := &TransportParameters{ InitialMaxStreamDataBidiLocal: 1234, InitialMaxStreamDataBidiRemote: 2345, @@ -49,9 +50,9 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42 * time.Second, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), + RetrySourceConnectionID: &rcid, AckDelayExponent: 14, MaxAckDelay: 37 * time.Millisecond, StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, @@ -70,8 +71,8 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreamNum: 1337, MaxUniStreamNum: 7331, MaxIdleTimeout: 42 * time.Second, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{}), AckDelayExponent: 14, MaxAckDelay: 37 * time.Second, ActiveConnectionIDLimit: 89, @@ -83,6 +84,7 @@ var _ = Describe("Transport Parameters", func() { It("marshals and unmarshals", func() { var token protocol.StatelessResetToken rand.Read(token[:]) + rcid := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}) params := &TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), @@ -93,9 +95,9 @@ var _ = Describe("Transport Parameters", func() { MaxUniStreamNum: protocol.StreamNum(getRandomValueUpTo(int64(protocol.MaxStreamCount))), DisableActiveMigration: true, StatelessResetToken: &token, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), + RetrySourceConnectionID: &rcid, AckDelayExponent: 13, MaxAckDelay: 42 * time.Millisecond, ActiveConnectionIDLimit: getRandomValue(), @@ -114,9 +116,9 @@ var _ = Describe("Transport Parameters", func() { Expect(p.MaxIdleTimeout).To(Equal(params.MaxIdleTimeout)) Expect(p.DisableActiveMigration).To(Equal(params.DisableActiveMigration)) Expect(p.StatelessResetToken).To(Equal(params.StatelessResetToken)) - Expect(p.OriginalDestinationConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef})) - Expect(p.InitialSourceConnectionID).To(Equal(protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) - Expect(p.RetrySourceConnectionID).To(Equal(&protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) + Expect(p.OriginalDestinationConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}))) + Expect(p.InitialSourceConnectionID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) + Expect(p.RetrySourceConnectionID).To(Equal(&rcid)) Expect(p.AckDelayExponent).To(Equal(uint8(13))) Expect(p.MaxAckDelay).To(Equal(42 * time.Millisecond)) Expect(p.ActiveConnectionIDLimit).To(Equal(params.ActiveConnectionIDLimit)) @@ -133,8 +135,9 @@ var _ = Describe("Transport Parameters", func() { }) It("marshals a zero-length retry_source_connection_id", func() { + rcid := protocol.ParseConnectionID([]byte{}) data := (&TransportParameters{ - RetrySourceConnectionID: &protocol.ConnectionID{}, + RetrySourceConnectionID: &rcid, StatelessResetToken: &protocol.StatelessResetToken{}, }).Marshal(protocol.PerspectiveServer) p := &TransportParameters{} @@ -406,7 +409,7 @@ var _ = Describe("Transport Parameters", func() { IPv4Port: 42, IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, IPv6Port: 13, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}, } }) @@ -439,7 +442,7 @@ var _ = Describe("Transport Parameters", func() { }) It("errors on zero-length connection IDs", func() { - pa.ConnectionID = protocol.ConnectionID{} + pa.ConnectionID = protocol.ParseConnectionID([]byte{}) data := (&TransportParameters{ PreferredAddress: pa, StatelessResetToken: &protocol.StatelessResetToken{}, @@ -451,20 +454,6 @@ var _ = Describe("Transport Parameters", func() { })) }) - It("errors on too long connection IDs", func() { - pa.ConnectionID = protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21} - Expect(pa.ConnectionID.Len()).To(BeNumerically(">", protocol.MaxConnIDLen)) - data := (&TransportParameters{ - PreferredAddress: pa, - StatelessResetToken: &protocol.StatelessResetToken{}, - }).Marshal(protocol.PerspectiveServer) - p := &TransportParameters{} - Expect(p.Unmarshal(data, protocol.PerspectiveServer)).To(MatchError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: "invalid connection ID length: 21", - })) - }) - It("errors on EOF", func() { raw := []byte{ 127, 0, 0, 1, // IPv4 diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index 196853e0fc4..2cfa2ca3dcb 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -3,6 +3,7 @@ package wire import ( "bytes" "crypto/rand" + "encoding/binary" "errors" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -10,32 +11,30 @@ import ( ) // ParseVersionNegotiationPacket parses a Version Negotiation packet. -func ParseVersionNegotiationPacket(b *bytes.Reader) (*Header, []protocol.VersionNumber, error) { - hdr, err := parseHeader(b, 0) +func ParseVersionNegotiationPacket(b []byte) (dest, src protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber, _ error) { + n, dest, src, err := ParseArbitraryLenConnectionIDs(b) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - if b.Len() == 0 { + b = b[n:] + if len(b) == 0 { //nolint:stylecheck - return nil, nil, errors.New("Version Negotiation packet has empty version list") + return nil, nil, nil, errors.New("Version Negotiation packet has empty version list") } - if b.Len()%4 != 0 { + if len(b)%4 != 0 { //nolint:stylecheck - return nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") + return nil, nil, nil, errors.New("Version Negotiation packet has a version list with an invalid length") } - versions := make([]protocol.VersionNumber, b.Len()/4) - for i := 0; b.Len() > 0; i++ { - v, err := utils.BigEndian.ReadUint32(b) - if err != nil { - return nil, nil, err - } - versions[i] = protocol.VersionNumber(v) + versions := make([]protocol.VersionNumber, len(b)/4) + for i := 0; len(b) > 0; i++ { + versions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(b[:4])) + b = b[4:] } - return hdr, versions, nil + return dest, src, versions, nil } // ComposeVersionNegotiation composes a Version Negotiation -func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ArbitraryLenConnectionID, versions []protocol.VersionNumber) []byte { greasedVersions := protocol.GetGreasedVersions(versions) expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) @@ -44,9 +43,9 @@ func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, vers buf.WriteByte(r[0] | 0x80) utils.BigEndian.WriteUint32(buf, 0) // version 0 buf.WriteByte(uint8(destConnID.Len())) - buf.Write(destConnID) + buf.Write(destConnID.Bytes()) buf.WriteByte(uint8(srcConnID.Len())) - buf.Write(srcConnID) + buf.Write(srcConnID.Bytes()) for _, v := range greasedVersions { utils.BigEndian.WriteUint32(buf, uint32(v)) } diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 2783cb1765a..cd9bb1c2c32 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -1,8 +1,10 @@ package wire import ( - "bytes" "encoding/binary" + mrand "math/rand" + + "golang.org/x/exp/rand" "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" @@ -10,9 +12,16 @@ import ( ) var _ = Describe("Version Negotiation Packets", func() { + randConnID := func(l int) protocol.ArbitraryLenConnectionID { + b := make(protocol.ArbitraryLenConnectionID, l) + _, err := mrand.Read(b) + Expect(err).ToNot(HaveOccurred()) + return b + } + It("parses a Version Negotiation packet", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} - destConnID := protocol.ConnectionID{9, 8, 7, 6, 5, 4, 3, 2, 1} + srcConnID := randConnID(rand.Intn(255) + 1) + destConnID := randConnID(rand.Intn(255) + 1) versions := []protocol.VersionNumber{0x22334455, 0x33445566} data := []byte{0x80, 0, 0, 0, 0} data = append(data, uint8(len(destConnID))) @@ -24,44 +33,44 @@ var _ = Describe("Version Negotiation Packets", func() { binary.BigEndian.PutUint32(data[len(data)-4:], uint32(v)) } Expect(IsVersionNegotiationPacket(data)).To(BeTrue()) - hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + dest, src, supportedVersions, err := ParseVersionNegotiationPacket(data) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(destConnID)) - Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) - Expect(hdr.IsLongHeader).To(BeTrue()) - Expect(hdr.Version).To(BeZero()) + Expect(dest).To(Equal(destConnID)) + Expect(src).To(Equal(srcConnID)) Expect(supportedVersions).To(Equal(versions)) }) It("errors if it contains versions of the wrong length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x22334455, 0x33445566} data := ComposeVersionNegotiation(connID, connID, versions) - _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data[:len(data)-2])) + _, _, _, err := ParseVersionNegotiationPacket(data[:len(data)-2]) Expect(err).To(MatchError("Version Negotiation packet has a version list with an invalid length")) }) It("errors if the version list is empty", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x22334455} data := ComposeVersionNegotiation(connID, connID, versions) // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number data = data[:len(data)-8] - _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + _, _, _, err := ParseVersionNegotiationPacket(data) Expect(err).To(MatchError("Version Negotiation packet has empty version list")) }) It("adds a reserved version", func() { - srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + srcConnID := protocol.ArbitraryLenConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} + destConnID := protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{1001, 1003} data := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(IsLongHeaderPacket(data[0])).To(BeTrue()) - hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) + v, err := ParseVersion(data) + Expect(err).ToNot(HaveOccurred()) + Expect(v).To(BeZero()) + dest, src, supportedVersions, err := ParseVersionNegotiationPacket(data) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(destConnID)) - Expect(hdr.SrcConnectionID).To(Equal(srcConnID)) - Expect(hdr.Version).To(BeZero()) + Expect(dest).To(Equal(destConnID)) + Expect(src).To(Equal(srcConnID)) // the supported versions should include one reserved version number Expect(supportedVersions).To(HaveLen(len(versions) + 1)) for _, v := range versions { diff --git a/logging/interface.go b/logging/interface.go index f71d68f7d60..1f98f29370e 100644 --- a/logging/interface.go +++ b/logging/interface.go @@ -19,6 +19,8 @@ type ( ByteCount = protocol.ByteCount // A ConnectionID is a QUIC Connection ID. ConnectionID = protocol.ConnectionID + // An ArbitraryLenConnectionID is a QUIC Connection ID that can be up to 255 bytes long. + ArbitraryLenConnectionID = protocol.ArbitraryLenConnectionID // The EncryptionLevel is the encryption level of a packet. EncryptionLevel = protocol.EncryptionLevel // The KeyPhase is the key phase of the 1-RTT keys. @@ -99,6 +101,7 @@ type Tracer interface { TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer SentPacket(net.Addr, *Header, ByteCount, []Frame) + SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) } @@ -111,7 +114,7 @@ type ConnectionTracer interface { ReceivedTransportParameters(*TransportParameters) RestoredTransportParameters(parameters *TransportParameters) // for 0-RTT SentPacket(hdr *ExtendedHeader, size ByteCount, ack *AckFrame, frames []Frame) - ReceivedVersionNegotiationPacket(*Header, []VersionNumber) + ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) ReceivedRetry(*Header) ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) BufferedPacket(PacketType) diff --git a/logging/mock_connection_tracer_test.go b/logging/mock_connection_tracer_test.go index 6eacea878e7..56aae88ac18 100644 --- a/logging/mock_connection_tracer_test.go +++ b/logging/mock_connection_tracer_test.go @@ -219,15 +219,15 @@ func (mr *MockConnectionTracerMockRecorder) ReceivedTransportParameters(arg0 int } // ReceivedVersionNegotiationPacket mocks base method. -func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0 *wire.Header, arg1 []protocol.VersionNumber) { +func (m *MockConnectionTracer) ReceivedVersionNegotiationPacket(arg0, arg1 protocol.ArbitraryLenConnectionID, arg2 []protocol.VersionNumber) { m.ctrl.T.Helper() - m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1) + m.ctrl.Call(m, "ReceivedVersionNegotiationPacket", arg0, arg1, arg2) } // ReceivedVersionNegotiationPacket indicates an expected call of ReceivedVersionNegotiationPacket. -func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockConnectionTracerMockRecorder) ReceivedVersionNegotiationPacket(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedVersionNegotiationPacket", reflect.TypeOf((*MockConnectionTracer)(nil).ReceivedVersionNegotiationPacket), arg0, arg1, arg2) } // RestoredTransportParameters mocks base method. diff --git a/logging/mock_tracer_test.go b/logging/mock_tracer_test.go index e970c09b204..ff6e852543b 100644 --- a/logging/mock_tracer_test.go +++ b/logging/mock_tracer_test.go @@ -61,6 +61,18 @@ func (mr *MockTracerMockRecorder) SentPacket(arg0, arg1, arg2, arg3 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentPacket", reflect.TypeOf((*MockTracer)(nil).SentPacket), arg0, arg1, arg2, arg3) } +// SentVersionNegotiationPacket mocks base method. +func (m *MockTracer) SentVersionNegotiationPacket(arg0 net.Addr, arg1, arg2 protocol.ArbitraryLenConnectionID, arg3 []protocol.VersionNumber) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SentVersionNegotiationPacket", arg0, arg1, arg2, arg3) +} + +// SentVersionNegotiationPacket indicates an expected call of SentVersionNegotiationPacket. +func (mr *MockTracerMockRecorder) SentVersionNegotiationPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SentVersionNegotiationPacket", reflect.TypeOf((*MockTracer)(nil).SentVersionNegotiationPacket), arg0, arg1, arg2, arg3) +} + // TracerForConnection mocks base method. func (m *MockTracer) TracerForConnection(arg0 context.Context, arg1 protocol.Perspective, arg2 protocol.ConnectionID) ConnectionTracer { m.ctrl.T.Helper() diff --git a/logging/multiplex.go b/logging/multiplex.go index 8280e8cdf49..0f69a9dadea 100644 --- a/logging/multiplex.go +++ b/logging/multiplex.go @@ -39,6 +39,12 @@ func (m *tracerMultiplexer) SentPacket(remote net.Addr, hdr *Header, size ByteCo } } +func (m *tracerMultiplexer) SentVersionNegotiationPacket(remote net.Addr, dest, src ArbitraryLenConnectionID, versions []VersionNumber) { + for _, t := range m.tracers { + t.SentVersionNegotiationPacket(remote, dest, src, versions) + } +} + func (m *tracerMultiplexer) DroppedPacket(remote net.Addr, typ PacketType, size ByteCount, reason PacketDropReason) { for _, t := range m.tracers { t.DroppedPacket(remote, typ, size, reason) @@ -104,9 +110,9 @@ func (m *connTracerMultiplexer) SentPacket(hdr *ExtendedHeader, size ByteCount, } } -func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(hdr *Header, versions []VersionNumber) { +func (m *connTracerMultiplexer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, versions []VersionNumber) { for _, t := range m.tracers { - t.ReceivedVersionNegotiationPacket(hdr, versions) + t.ReceivedVersionNegotiationPacket(dest, src, versions) } } diff --git a/logging/multiplex_test.go b/logging/multiplex_test.go index 84b44d928f6..f5a888de630 100644 --- a/logging/multiplex_test.go +++ b/logging/multiplex_test.go @@ -6,6 +6,8 @@ import ( "net" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" + "github.com/lucas-clemente/quic-go/internal/wire" . "github.com/onsi/ginkgo" @@ -38,18 +40,20 @@ var _ = Describe("Tracing", func() { It("multiplexes the TracerForConnection call", func() { ctx := context.Background() - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) + connID := protocol.ParseConnectionID([]byte{1, 2, 3}) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) + tracer.TracerForConnection(ctx, PerspectiveClient, connID) }) It("uses multiple connection tracers", func() { ctx := context.Background() ctr1 := NewMockConnectionTracer(mockCtrl) ctr2 := NewMockConnectionTracer(mockCtrl) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) + connID := protocol.ParseConnectionID([]byte{1, 2, 3}) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr2) + tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID) ctr1.EXPECT().LossTimerCanceled() ctr2.EXPECT().LossTimerCanceled() tr.LossTimerCanceled() @@ -58,29 +62,41 @@ var _ = Describe("Tracing", func() { It("handles tracers that return a nil ConnectionTracer", func() { ctx := context.Background() ctr1 := NewMockConnectionTracer(mockCtrl) - tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) - tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID).Return(ctr1) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, connID) + tr := tracer.TracerForConnection(ctx, PerspectiveServer, connID) ctr1.EXPECT().LossTimerCanceled() tr.LossTimerCanceled() }) It("returns nil when all tracers return a nil ConnectionTracer", func() { ctx := context.Background() - tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3}) - Expect(tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil()) + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) + tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, connID) + Expect(tracer.TracerForConnection(ctx, PerspectiveClient, connID)).To(BeNil()) }) It("traces the PacketSent event", func() { remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} + hdr := &Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} f := &MaxDataFrame{MaximumData: 1337} tr1.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) tr2.EXPECT().SentPacket(remote, hdr, ByteCount(1024), []Frame{f}) tracer.SentPacket(remote, hdr, 1024, []Frame{f}) }) + It("traces the PacketSent event", func() { + remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} + src := ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13} + dest := ArbitraryLenConnectionID{1, 2, 3, 4} + versions := []VersionNumber{1, 2, 3} + tr1.EXPECT().SentVersionNegotiationPacket(remote, dest, src, versions) + tr2.EXPECT().SentVersionNegotiationPacket(remote, dest, src, versions) + tracer.SentVersionNegotiationPacket(remote, dest, src, versions) + }) + It("traces the PacketDropped event", func() { remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} tr1.EXPECT().DroppedPacket(remote, PacketTypeRetry, ByteCount(1024), PacketDropDuplicate) @@ -106,9 +122,11 @@ var _ = Describe("Tracing", func() { It("trace the ConnectionStarted event", func() { local := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4)} remote := &net.UDPAddr{IP: net.IPv4(4, 3, 2, 1)} - tr1.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) - tr2.EXPECT().StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) - tracer.StartedConnection(local, remote, ConnectionID{1, 2, 3, 4}, ConnectionID{4, 3, 2, 1}) + dest := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + src := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) + tr1.EXPECT().StartedConnection(local, remote, src, dest) + tr2.EXPECT().StartedConnection(local, remote, src, dest) + tracer.StartedConnection(local, remote, src, dest) }) It("traces the ClosedConnection event", func() { @@ -140,7 +158,7 @@ var _ = Describe("Tracing", func() { }) It("traces the SentPacket event", func() { - hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} + hdr := &ExtendedHeader{Header: Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}} ack := &AckFrame{AckRanges: []AckRange{{Smallest: 1, Largest: 10}}} ping := &PingFrame{} tr1.EXPECT().SentPacket(hdr, ByteCount(1337), ack, []Frame{ping}) @@ -149,21 +167,22 @@ var _ = Describe("Tracing", func() { }) It("traces the ReceivedVersionNegotiationPacket event", func() { - hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} - tr1.EXPECT().ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) - tr2.EXPECT().ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) - tracer.ReceivedVersionNegotiationPacket(hdr, []VersionNumber{1337}) + src := ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13} + dest := ArbitraryLenConnectionID{1, 2, 3, 4} + tr1.EXPECT().ReceivedVersionNegotiationPacket(dest, src, []VersionNumber{1337}) + tr2.EXPECT().ReceivedVersionNegotiationPacket(dest, src, []VersionNumber{1337}) + tracer.ReceivedVersionNegotiationPacket(dest, src, []VersionNumber{1337}) }) It("traces the ReceivedRetry event", func() { - hdr := &Header{DestConnectionID: ConnectionID{1, 2, 3}} + hdr := &Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})} tr1.EXPECT().ReceivedRetry(hdr) tr2.EXPECT().ReceivedRetry(hdr) tracer.ReceivedRetry(hdr) }) It("traces the ReceivedPacket event", func() { - hdr := &ExtendedHeader{Header: Header{DestConnectionID: ConnectionID{1, 2, 3}}} + hdr := &ExtendedHeader{Header: Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3})}} ping := &PingFrame{} tr1.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) tr2.EXPECT().ReceivedPacket(hdr, ByteCount(1337), []Frame{ping}) diff --git a/logging/null_tracer.go b/logging/null_tracer.go index 4e0bb60bafb..e9d0d4e497e 100644 --- a/logging/null_tracer.go +++ b/logging/null_tracer.go @@ -10,27 +10,34 @@ import ( // It is useful for embedding. type NullTracer struct{} +var _ Tracer = &NullTracer{} + func (n NullTracer) TracerForConnection(context.Context, Perspective, ConnectionID) ConnectionTracer { return NullConnectionTracer{} } -func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {} +func (n NullTracer) SentPacket(net.Addr, *Header, ByteCount, []Frame) {} +func (n NullTracer) SentVersionNegotiationPacket(_ net.Addr, dest, src ArbitraryLenConnectionID, _ []VersionNumber) { +} func (n NullTracer) DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason) {} // The NullConnectionTracer is a ConnectionTracer that does nothing. // It is useful for embedding. type NullConnectionTracer struct{} +var _ ConnectionTracer = &NullConnectionTracer{} + func (n NullConnectionTracer) StartedConnection(local, remote net.Addr, srcConnID, destConnID ConnectionID) { } func (n NullConnectionTracer) NegotiatedVersion(chosen VersionNumber, clientVersions, serverVersions []VersionNumber) { } -func (n NullConnectionTracer) ClosedConnection(err error) {} -func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {} -func (n NullConnectionTracer) SentPacket(*ExtendedHeader, ByteCount, *AckFrame, []Frame) {} -func (n NullConnectionTracer) ReceivedVersionNegotiationPacket(*Header, []VersionNumber) {} +func (n NullConnectionTracer) ClosedConnection(err error) {} +func (n NullConnectionTracer) SentTransportParameters(*TransportParameters) {} +func (n NullConnectionTracer) ReceivedTransportParameters(*TransportParameters) {} +func (n NullConnectionTracer) RestoredTransportParameters(*TransportParameters) {} +func (n NullConnectionTracer) SentPacket(*ExtendedHeader, ByteCount, *AckFrame, []Frame) {} +func (n NullConnectionTracer) ReceivedVersionNegotiationPacket(dest, src ArbitraryLenConnectionID, _ []VersionNumber) { +} func (n NullConnectionTracer) ReceivedRetry(*Header) {} func (n NullConnectionTracer) ReceivedPacket(hdr *ExtendedHeader, size ByteCount, frames []Frame) {} func (n NullConnectionTracer) BufferedPacket(PacketType) {} diff --git a/packet_handler_map.go b/packet_handler_map.go index 0caa4907557..6018765a4a6 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -48,7 +48,7 @@ type packetHandlerMap struct { closeQueue chan closePacket - handlers map[string] /* string(ConnectionID)*/ packetHandler + handlers map[protocol.ConnectionID]packetHandler resetTokens map[protocol.StatelessResetToken] /* stateless reset token */ packetHandler server unknownPacketHandler numZeroRTTEntries int @@ -127,7 +127,7 @@ func newPacketHandlerMap( conn: conn, connIDLen: connIDLen, listening: make(chan struct{}), - handlers: make(map[string]packetHandler), + handlers: make(map[protocol.ConnectionID]packetHandler), resetTokens: make(map[protocol.StatelessResetToken]packetHandler), deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, @@ -176,11 +176,11 @@ func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) h.mutex.Lock() defer h.mutex.Unlock() - if _, ok := h.handlers[string(id)]; ok { + if _, ok := h.handlers[id]; ok { h.logger.Debugf("Not adding connection ID %s, as it already exists.", id) return false } - h.handlers[string(id)] = handler + h.handlers[id] = handler h.logger.Debugf("Adding connection ID %s.", id) return true } @@ -190,7 +190,7 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co defer h.mutex.Unlock() var q *zeroRTTQueue - if handler, ok := h.handlers[string(clientDestConnID)]; ok { + if handler, ok := h.handlers[clientDestConnID]; ok { q, ok = handler.(*zeroRTTQueue) if !ok { h.logger.Debugf("Not adding connection ID %s for a new connection, as it already exists.", clientDestConnID) @@ -206,15 +206,15 @@ func (h *packetHandlerMap) AddWithConnID(clientDestConnID, newConnID protocol.Co if q != nil { q.EnqueueAll(conn) } - h.handlers[string(clientDestConnID)] = conn - h.handlers[string(newConnID)] = conn + h.handlers[clientDestConnID] = conn + h.handlers[newConnID] = conn h.logger.Debugf("Adding connection IDs %s and %s for a new connection.", clientDestConnID, newConnID) return true } func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { h.mutex.Lock() - delete(h.handlers, string(id)) + delete(h.handlers, id) h.mutex.Unlock() h.logger.Debugf("Removing connection ID %s.", id) } @@ -223,7 +223,7 @@ func (h *packetHandlerMap) Retire(id protocol.ConnectionID) { h.logger.Debugf("Retiring connection ID %s in %s.", id, h.deleteRetiredConnsAfter) time.AfterFunc(h.deleteRetiredConnsAfter, func() { h.mutex.Lock() - delete(h.handlers, string(id)) + delete(h.handlers, id) h.mutex.Unlock() h.logger.Debugf("Removing connection ID %s after it has been retired.", id) }) @@ -254,7 +254,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p h.mutex.Lock() for _, id := range ids { - h.handlers[string(id)] = handler + h.handlers[id] = handler } h.mutex.Unlock() h.logger.Debugf("Replacing connection for connection IDs %s with a closed connection.", ids) @@ -263,7 +263,7 @@ func (h *packetHandlerMap) ReplaceWithClosed(ids []protocol.ConnectionID, pers p h.mutex.Lock() handler.shutdown() for _, id := range ids { - delete(h.handlers, string(id)) + delete(h.handlers, id) } h.mutex.Unlock() h.logger.Debugf("Removing connection IDs %s for a closed connection after it has been retired.", ids) @@ -394,7 +394,7 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { return } - if handler, ok := h.handlers[string(connID)]; ok { + if handler, ok := h.handlers[connID]; ok { if ha, ok := handler.(*zeroRTTQueue); ok { // only enqueue 0-RTT packets in the 0-RTT queue if wire.Is0RTTPacket(p.data) { ha.handlePacket(p) @@ -419,15 +419,15 @@ func (h *packetHandlerMap) handlePacket(p *receivedPacket) { } h.numZeroRTTEntries++ queue := &zeroRTTQueue{queue: make([]*receivedPacket, 0, 8)} - h.handlers[string(connID)] = queue + h.handlers[connID] = queue queue.retireTimer = time.AfterFunc(h.zeroRTTQueueDuration, func() { h.mutex.Lock() defer h.mutex.Unlock() // The entry might have been replaced by an actual connection. // Only delete it if it's still a 0-RTT queue. - if handler, ok := h.handlers[string(connID)]; ok { + if handler, ok := h.handlers[connID]; ok { if q, ok := handler.(*zeroRTTQueue); ok { - delete(h.handlers, string(connID)) + delete(h.handlers, connID) h.numZeroRTTEntries-- if h.numZeroRTTEntries < 0 { panic("number of 0-RTT queues < 0") diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 63f8e853d95..cc4b94213fe 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -93,8 +93,8 @@ var _ = Describe("Packet Handler Map", func() { conn1.EXPECT().destroy(testErr) conn2 := NewMockPacketHandler(mockCtrl) conn2.EXPECT().destroy(testErr) - handler.Add(protocol.ConnectionID{1, 1, 1, 1}, conn1) - handler.Add(protocol.ConnectionID{2, 2, 2, 2}, conn2) + handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), conn1) + handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), conn2) mockMultiplexer.EXPECT().RemoveConn(gomock.Any()) handler.close(testErr) close(packetChan) @@ -123,8 +123,8 @@ var _ = Describe("Packet Handler Map", func() { }) It("handles packets for different packet handlers on the same packet conn", func() { - connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + connID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + connID2 := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) packetHandler1 := NewMockPacketHandler(mockCtrl) packetHandler2 := NewMockPacketHandler(mockCtrl) handledPacket1 := make(chan struct{}) @@ -162,7 +162,7 @@ var _ = Describe("Packet Handler Map", func() { It("deletes removed connections immediately", func() { handler.deleteRetiredConnsAfter = time.Hour - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) handler.Add(connID, NewMockPacketHandler(mockCtrl)) handler.Remove(connID) handler.handlePacket(&receivedPacket{data: getPacket(connID)}) @@ -171,7 +171,7 @@ var _ = Describe("Packet Handler Map", func() { It("deletes retired connection entries after a wait time", func() { handler.deleteRetiredConnsAfter = scaleDuration(10 * time.Millisecond) - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) conn := NewMockPacketHandler(mockCtrl) handler.Add(connID, conn) handler.Retire(connID) @@ -182,7 +182,7 @@ var _ = Describe("Packet Handler Map", func() { It("passes packets arriving late for closed connections to that connection", func() { handler.deleteRetiredConnsAfter = time.Hour - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) packetHandler := NewMockPacketHandler(mockCtrl) handled := make(chan struct{}) packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { @@ -195,7 +195,7 @@ var _ = Describe("Packet Handler Map", func() { }) It("drops packets for unknown receivers", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) handler.handlePacket(&receivedPacket{data: getPacket(connID)}) }) @@ -206,14 +206,14 @@ var _ = Describe("Packet Handler Map", func() { Expect(e).To(HaveOccurred()) close(done) }) - handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) + handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler) packetChan <- packetToRead{err: errors.New("read failed")} Eventually(done).Should(BeClosed()) }) It("continues listening for temporary errors", func() { packetHandler := NewMockPacketHandler(mockCtrl) - handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) + handler.Add(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), packetHandler) err := deadlineError{} Expect(err.Temporary()).To(BeTrue()) packetChan <- packetToRead{err: err} @@ -222,15 +222,15 @@ var _ = Describe("Packet Handler Map", func() { }) It("says if a connection ID is already taken", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) }) It("says if a connection ID is already taken, for AddWithConnID", func() { - clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - newConnID1 := protocol.ConnectionID{1, 2, 3, 4} - newConnID2 := protocol.ConnectionID{4, 3, 2, 1} + clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + newConnID1 := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + newConnID2 := protocol.ParseConnectionID([]byte{4, 3, 2, 1}) Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue()) Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse()) }) @@ -238,7 +238,7 @@ var _ = Describe("Packet Handler Map", func() { Context("running a server", func() { It("adds a server", func() { - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) p := getPacket(connID) server := NewMockUnknownPacketHandler(mockCtrl) server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { @@ -258,13 +258,13 @@ var _ = Describe("Packet Handler Map", func() { serverConn.EXPECT().getPerspective().Return(protocol.PerspectiveServer) serverConn.EXPECT().shutdown() - handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientConn) - handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverConn) + handler.Add(protocol.ParseConnectionID([]byte{1, 1, 1, 1}), clientConn) + handler.Add(protocol.ParseConnectionID([]byte{2, 2, 2, 2}), serverConn) handler.CloseServer() }) It("stops handling packets with unknown connection IDs after the server is closed", func() { - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) p := getPacket(connID) server := NewMockUnknownPacketHandler(mockCtrl) // don't EXPECT any calls to server.handlePacket @@ -286,7 +286,7 @@ var _ = Describe("Packet Handler Map", func() { server := NewMockUnknownPacketHandler(mockCtrl) // don't EXPECT any calls to server.handlePacket handler.SetServer(server) - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} p2 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 2)} p3 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 3)} @@ -300,14 +300,14 @@ var _ = Describe("Packet Handler Map", func() { conn.EXPECT().handlePacket(p2), conn.EXPECT().handlePacket(p3).Do(func(packet *receivedPacket) { close(done) }), ) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) Eventually(done).Should(BeClosed()) }) It("directs 0-RTT packets to existing connections", func() { - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) p1 := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} conn.EXPECT().handlePacket(p1) handler.handlePacket(p1) @@ -315,17 +315,21 @@ var _ = Describe("Packet Handler Map", func() { It("limits the number of 0-RTT queues", func() { for i := 0; i < protocol.Max0RTTQueues; i++ { - connID := make(protocol.ConnectionID, 8) - rand.Read(connID) - p := &receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)} + b := make([]byte, 8) + rand.Read(b) + p := &receivedPacket{data: getPacketWithPacketType( + protocol.ParseConnectionID(b), + protocol.PacketType0RTT, + 1, + )} handler.handlePacket(p) } // We're already storing the maximum number of queues. This packet will be dropped. - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) handler.handlePacket(&receivedPacket{data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1)}) // Don't EXPECT any handlePacket() calls. conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) time.Sleep(20 * time.Millisecond) }) @@ -336,7 +340,7 @@ var _ = Describe("Packet Handler Map", func() { server := NewMockUnknownPacketHandler(mockCtrl) // don't EXPECT any calls to server.handlePacket handler.SetServer(server) - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + connID := protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) p1 := &receivedPacket{ data: getPacketWithPacketType(connID, protocol.PacketType0RTT, 1), buffer: getPacketBuffer(), @@ -351,7 +355,7 @@ var _ = Describe("Packet Handler Map", func() { time.Sleep(queueDuration * 3) // Don't EXPECT any handlePacket() calls. conn := NewMockPacketHandler(mockCtrl) - handler.AddWithConnID(connID, protocol.ConnectionID{1, 2, 3, 4}, func() packetHandler { return conn }) + handler.AddWithConnID(connID, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), func() packetHandler { return conn }) time.Sleep(20 * time.Millisecond) }) }) @@ -404,7 +408,7 @@ var _ = Describe("Packet Handler Map", func() { }) It("removes reset tokens", func() { - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} + connID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef, 0x42}) packetHandler := NewMockPacketHandler(mockCtrl) handler.Add(connID, packetHandler) token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} @@ -442,8 +446,8 @@ var _ = Describe("Packet Handler Map", func() { }) It("generates stateless reset tokens", func() { - connID1 := []byte{0xde, 0xad, 0xbe, 0xef} - connID2 := []byte{0xde, 0xca, 0xfb, 0xad} + connID1 := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) + connID2 := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2))) }) diff --git a/packet_packer_test.go b/packet_packer_test.go index ac095939fdb..6b50a82cecb 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -38,7 +38,7 @@ var _ = Describe("Packet packer", func() { sealingManager *MockSealingManager pnManager *mockackhandler.MockSentPacketHandler ) - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) parsePacket := func(data []byte) []*wire.ExtendedHeader { var hdrs []*wire.ExtendedHeader @@ -94,7 +94,7 @@ var _ = Describe("Packet packer", func() { datagramQueue = newDatagramQueue(func() {}, utils.DefaultLogger) packer = newPacketPacker( - protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), func() protocol.ConnectionID { return connID }, initialStream, handshakeStream, @@ -141,8 +141,8 @@ var _ = Describe("Packet packer", func() { It("sets source and destination connection ID", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) + destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) packer.srcConnID = srcConnID packer.getDestConnID = func() protocol.ConnectionID { return destConnID } h := packer.getLongHeader(protocol.EncryptionHandshake) @@ -616,7 +616,7 @@ var _ = Describe("Packet packer", func() { Expect(packet.packets).To(HaveLen(1)) // cut off the tag that the mock sealer added // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] - hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) + hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, packer.getDestConnID().Len()) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(packet.buffer.Data) extHdr, err := hdr.ParseExtended(r, packer.version) @@ -656,7 +656,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) // cut off the tag that the mock sealer added packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] - hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) + hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, packer.getDestConnID().Len()) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(packet.buffer.Data) extHdr, err := hdr.ParseExtended(r, packer.version) @@ -1206,7 +1206,7 @@ var _ = Describe("Packet packer", func() { Expect(packet.packets).To(HaveLen(1)) // cut off the tag that the mock sealer added // packet.buffer.Data = packet.buffer.Data[:packet.buffer.Len()-protocol.ByteCount(sealer.Overhead())] - hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, len(packer.getDestConnID())) + hdr, _, _, err := wire.ParsePacket(packet.buffer.Data, packer.getDestConnID().Len()) Expect(err).ToNot(HaveOccurred()) r := bytes.NewReader(packet.buffer.Data) extHdr, err := hdr.ParseExtended(r, packer.version) diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 16c708e6d44..813cf82a069 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -23,7 +23,7 @@ var _ = Describe("Packet Unpacker", func() { var ( unpacker *packetUnpacker cs *mocks.MockCryptoSetup - connID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + connID = protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}) payload = []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") ) diff --git a/qlog/event.go b/qlog/event.go index 8427b6e22c9..83423aa5e72 100644 --- a/qlog/event.go +++ b/qlog/event.go @@ -80,8 +80,8 @@ func (e eventConnectionStarted) MarshalJSONObject(enc *gojay.Encoder) { enc.IntKey("src_port", e.SrcAddr.Port) enc.StringKey("dst_ip", e.DestAddr.IP.String()) enc.IntKey("dst_port", e.DestAddr.Port) - enc.StringKey("src_cid", connectionID(e.SrcConnectionID).String()) - enc.StringKey("dst_cid", connectionID(e.DestConnectionID).String()) + enc.StringKey("src_cid", e.SrcConnectionID.String()) + enc.StringKey("dst_cid", e.DestConnectionID.String()) } type eventVersionNegotiated struct { @@ -212,7 +212,7 @@ func (e eventRetryReceived) MarshalJSONObject(enc *gojay.Encoder) { } type eventVersionNegotiationReceived struct { - Header packetHeader + Header packetHeaderVersionNegotiation SupportedVersions []versionNumber } @@ -410,15 +410,15 @@ func (e eventTransportParameters) MarshalJSONObject(enc *gojay.Encoder) { if !e.Restore { enc.StringKey("owner", e.Owner.String()) if e.SentBy == protocol.PerspectiveServer { - enc.StringKey("original_destination_connection_id", connectionID(e.OriginalDestinationConnectionID).String()) + enc.StringKey("original_destination_connection_id", e.OriginalDestinationConnectionID.String()) if e.StatelessResetToken != nil { enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", e.StatelessResetToken[:])) } if e.RetrySourceConnectionID != nil { - enc.StringKey("retry_source_connection_id", connectionID(*e.RetrySourceConnectionID).String()) + enc.StringKey("retry_source_connection_id", (*e.RetrySourceConnectionID).String()) } } - enc.StringKey("initial_source_connection_id", connectionID(e.InitialSourceConnectionID).String()) + enc.StringKey("initial_source_connection_id", e.InitialSourceConnectionID.String()) } enc.BoolKey("disable_active_migration", e.DisableActiveMigration) enc.FloatKeyOmitEmpty("max_idle_timeout", milliseconds(e.MaxIdleTimeout)) @@ -457,7 +457,7 @@ func (a preferredAddress) MarshalJSONObject(enc *gojay.Encoder) { enc.Uint16Key("port_v4", a.PortV4) enc.StringKey("ip_v6", a.IPv6.String()) enc.Uint16Key("port_v6", a.PortV6) - enc.StringKey("connection_id", connectionID(a.ConnectionID).String()) + enc.StringKey("connection_id", a.ConnectionID.String()) enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", a.StatelessResetToken)) } diff --git a/qlog/frame.go b/qlog/frame.go index 4530f0fbac1..35761dae8ec 100644 --- a/qlog/frame.go +++ b/qlog/frame.go @@ -182,7 +182,7 @@ func marshalNewConnectionIDFrame(enc *gojay.Encoder, f *logging.NewConnectionIDF enc.Int64Key("sequence_number", int64(f.SequenceNumber)) enc.Int64Key("retire_prior_to", int64(f.RetirePriorTo)) enc.IntKey("length", f.ConnectionID.Len()) - enc.StringKey("connection_id", connectionID(f.ConnectionID).String()) + enc.StringKey("connection_id", f.ConnectionID.String()) enc.StringKey("stateless_reset_token", fmt.Sprintf("%x", f.StatelessResetToken)) } diff --git a/qlog/frame_test.go b/qlog/frame_test.go index b5e553e844f..bb98f9f8c3b 100644 --- a/qlog/frame_test.go +++ b/qlog/frame_test.go @@ -273,7 +273,7 @@ var _ = Describe("Frames", func() { &logging.NewConnectionIDFrame{ SequenceNumber: 42, RetirePriorTo: 24, - ConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, }, map[string]interface{}{ diff --git a/qlog/packet_header.go b/qlog/packet_header.go index cc270f2f564..0b77936d897 100644 --- a/qlog/packet_header.go +++ b/qlog/packet_header.go @@ -72,7 +72,7 @@ func transformExtendedHeader(hdr *wire.ExtendedHeader) *packetHeader { func (h packetHeader) MarshalJSONObject(enc *gojay.Encoder) { enc.StringKey("packet_type", packetType(h.PacketType).String()) - if h.PacketType != logging.PacketTypeRetry && h.PacketType != logging.PacketTypeVersionNegotiation { + if h.PacketType != logging.PacketTypeRetry { enc.Int64Key("packet_number", int64(h.PacketNumber)) } if h.Version != 0 { @@ -81,12 +81,12 @@ func (h packetHeader) MarshalJSONObject(enc *gojay.Encoder) { if h.PacketType != logging.PacketType1RTT { enc.IntKey("scil", h.SrcConnectionID.Len()) if h.SrcConnectionID.Len() > 0 { - enc.StringKey("scid", connectionID(h.SrcConnectionID).String()) + enc.StringKey("scid", h.SrcConnectionID.String()) } } enc.IntKey("dcil", h.DestConnectionID.Len()) if h.DestConnectionID.Len() > 0 { - enc.StringKey("dcid", connectionID(h.DestConnectionID).String()) + enc.StringKey("dcid", h.DestConnectionID.String()) } if h.KeyPhaseBit == logging.KeyPhaseZero || h.KeyPhaseBit == logging.KeyPhaseOne { enc.StringKey("key_phase_bit", h.KeyPhaseBit.String()) @@ -96,6 +96,20 @@ func (h packetHeader) MarshalJSONObject(enc *gojay.Encoder) { } } +type packetHeaderVersionNegotiation struct { + SrcConnectionID logging.ArbitraryLenConnectionID + DestConnectionID logging.ArbitraryLenConnectionID +} + +func (h packetHeaderVersionNegotiation) IsNil() bool { return false } +func (h packetHeaderVersionNegotiation) MarshalJSONObject(enc *gojay.Encoder) { + enc.StringKey("packet_type", "version_negotiation") + enc.IntKey("scil", h.SrcConnectionID.Len()) + enc.StringKey("scid", h.SrcConnectionID.String()) + enc.IntKey("dcil", h.DestConnectionID.Len()) + enc.StringKey("dcid", h.DestConnectionID.String()) +} + // a minimal header that only outputs the packet type type packetHeaderWithType struct { PacketType logging.PacketType diff --git a/qlog/packet_header_test.go b/qlog/packet_header_test.go index 54fe782a3d4..f5b2c0337ac 100644 --- a/qlog/packet_header_test.go +++ b/qlog/packet_header_test.go @@ -97,7 +97,7 @@ var _ = Describe("Packet Header", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, - SrcConnectionID: protocol.ConnectionID{0x11, 0x22, 0x33, 0x44}, + SrcConnectionID: protocol.ParseConnectionID([]byte{0x11, 0x22, 0x33, 0x44}), Version: protocol.VersionNumber(0xdecafbad), Token: []byte{0xde, 0xad, 0xbe, 0xef}, }, @@ -140,7 +140,7 @@ var _ = Describe("Packet Header", func() { Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - SrcConnectionID: protocol.ConnectionID{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}, + SrcConnectionID: protocol.ParseConnectionID([]byte{0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}), Version: protocol.VersionNumber(0xdecafbad), }, }, @@ -159,7 +159,7 @@ var _ = Describe("Packet Header", func() { check( &wire.ExtendedHeader{ PacketNumber: 42, - Header: wire.Header{DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}}, + Header: wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})}, KeyPhase: protocol.KeyPhaseOne, }, map[string]interface{}{ diff --git a/qlog/qlog.go b/qlog/qlog.go index 5b41174288f..7adfc790ea3 100644 --- a/qlog/qlog.go +++ b/qlog/qlog.go @@ -50,6 +50,8 @@ func init() { const eventChanSize = 50 type tracer struct { + logging.NullTracer + getLogWriter func(p logging.Perspective, connectionID []byte) io.WriteCloser } @@ -67,10 +69,6 @@ func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, o return nil } -func (t *tracer) SentPacket(net.Addr, *logging.Header, protocol.ByteCount, []logging.Frame) {} -func (t *tracer) DroppedPacket(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { -} - type connectionTracer struct { mutex sync.Mutex @@ -110,8 +108,8 @@ func (t *connectionTracer) run() { trace: trace{ VantagePoint: vantagePoint{Type: t.perspective}, CommonFields: commonFields{ - ODCID: connectionID(t.odcid), - GroupID: connectionID(t.odcid), + ODCID: t.odcid, + GroupID: t.odcid, ReferenceTime: t.referenceTime, }, }, @@ -323,14 +321,17 @@ func (t *connectionTracer) ReceivedRetry(hdr *wire.Header) { t.mutex.Unlock() } -func (t *connectionTracer) ReceivedVersionNegotiationPacket(hdr *wire.Header, versions []logging.VersionNumber) { +func (t *connectionTracer) ReceivedVersionNegotiationPacket(dest, src logging.ArbitraryLenConnectionID, versions []logging.VersionNumber) { ver := make([]versionNumber, len(versions)) for i, v := range versions { ver[i] = versionNumber(v) } t.mutex.Lock() t.recordEvent(time.Now(), &eventVersionNegotiationReceived{ - Header: *transformHeader(hdr), + Header: packetHeaderVersionNegotiation{ + SrcConnectionID: src, + DestConnectionID: dest, + }, SupportedVersions: ver, }) t.mutex.Unlock() diff --git a/qlog/qlog_test.go b/qlog/qlog_test.go index f43390c4cbe..c37b6da3d17 100644 --- a/qlog/qlog_test.go +++ b/qlog/qlog_test.go @@ -54,7 +54,11 @@ var _ = Describe("Tracing", func() { Context("tracer", func() { It("returns nil when there's no io.WriteCloser", func() { t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nil }) - Expect(t.TracerForConnection(context.Background(), logging.PerspectiveClient, logging.ConnectionID{1, 2, 3, 4})).To(BeNil()) + Expect(t.TracerForConnection( + context.Background(), + logging.PerspectiveClient, + protocol.ParseConnectionID([]byte{1, 2, 3, 4}), + )).To(BeNil()) }) }) @@ -63,7 +67,7 @@ var _ = Describe("Tracing", func() { t := NewConnectionTracer( &limitedWriter{WriteCloser: nopWriteCloser(buf), N: 250}, protocol.PerspectiveServer, - protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), ) for i := uint32(0); i < 1000; i++ { t.UpdatedPTOCount(i) @@ -85,7 +89,11 @@ var _ = Describe("Tracing", func() { BeforeEach(func() { buf = &bytes.Buffer{} t := NewTracer(func(logging.Perspective, []byte) io.WriteCloser { return nopWriteCloser(buf) }) - tracer = t.TracerForConnection(context.Background(), logging.PerspectiveServer, logging.ConnectionID{0xde, 0xad, 0xbe, 0xef}) + tracer = t.TracerForConnection( + context.Background(), + logging.PerspectiveServer, + protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + ) }) It("exports a trace that has the right metadata", func() { @@ -155,8 +163,8 @@ var _ = Describe("Tracing", func() { tracer.StartedConnection( &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 42}, &net.UDPAddr{IP: net.IPv4(192, 168, 12, 34), Port: 24}, - protocol.ConnectionID{1, 2, 3, 4}, - protocol.ConnectionID{5, 6, 7, 8}, + protocol.ParseConnectionID([]byte{1, 2, 3, 4}), + protocol.ParseConnectionID([]byte{5, 6, 7, 8}), ) entry := exportAndParseSingle() Expect(entry.Time).To(BeTemporally("~", time.Now(), scaleDuration(10*time.Millisecond))) @@ -274,6 +282,7 @@ var _ = Describe("Tracing", func() { }) It("records sent transport parameters", func() { + rcid := protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}) tracer.SentTransportParameters(&logging.TransportParameters{ InitialMaxStreamDataBidiLocal: 1000, InitialMaxStreamDataBidiRemote: 2000, @@ -287,9 +296,9 @@ var _ = Describe("Tracing", func() { MaxUDPPayloadSize: 1234, MaxIdleTimeout: 321 * time.Millisecond, StatelessResetToken: &protocol.StatelessResetToken{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - InitialSourceConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), + InitialSourceConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), + RetrySourceConnectionID: &rcid, ActiveConnectionIDLimit: 7, MaxDatagramFrameSize: protocol.InvalidByteCount, }) @@ -318,7 +327,7 @@ var _ = Describe("Tracing", func() { It("records the server's transport parameters, without a stateless reset token", func() { tracer.SentTransportParameters(&logging.TransportParameters{ - OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), ActiveConnectionIDLimit: 7, }) entry := exportAndParseSingle() @@ -347,7 +356,7 @@ var _ = Describe("Tracing", func() { IPv4Port: 123, IPv6: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, IPv6Port: 456, - ConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, + ConnectionID: protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}), StatelessResetToken: protocol.StatelessResetToken{15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, }, }) @@ -417,8 +426,8 @@ var _ = Describe("Tracing", func() { Header: logging.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), Length: 1337, Version: protocol.VersionTLS, }, @@ -454,7 +463,7 @@ var _ = Describe("Tracing", func() { It("records a sent packet, without an ACK", func() { tracer.SentPacket( &logging.ExtendedHeader{ - Header: logging.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}}, + Header: logging.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4})}, PacketNumber: 1337, }, 123, @@ -483,8 +492,8 @@ var _ = Describe("Tracing", func() { Header: logging.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), Token: []byte{0xde, 0xad, 0xbe, 0xef}, Length: 1234, Version: protocol.VersionTLS, @@ -522,8 +531,8 @@ var _ = Describe("Tracing", func() { &logging.Header{ IsLongHeader: true, Type: protocol.PacketTypeRetry, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), + SrcConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}), Token: []byte{0xde, 0xad, 0xbe, 0xef}, Version: protocol.VersionTLS, }, @@ -548,12 +557,8 @@ var _ = Describe("Tracing", func() { It("records a received Version Negotiation packet", func() { tracer.ReceivedVersionNegotiationPacket( - &logging.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - SrcConnectionID: protocol.ConnectionID{4, 3, 2, 1}, - }, + protocol.ArbitraryLenConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, []protocol.VersionNumber{0xdeadbeef, 0xdecafbad}, ) entry := exportAndParseSingle() @@ -568,8 +573,8 @@ var _ = Describe("Tracing", func() { Expect(header).To(HaveKeyWithValue("packet_type", "version_negotiation")) Expect(header).ToNot(HaveKey("packet_number")) Expect(header).ToNot(HaveKey("version")) - Expect(header).To(HaveKey("dcid")) - Expect(header).To(HaveKey("scid")) + Expect(header).To(HaveKeyWithValue("dcid", "0102030405060708")) + Expect(header).To(HaveKeyWithValue("scid", "04030201")) }) It("records buffered packets", func() { diff --git a/qlog/trace.go b/qlog/trace.go index 4f0b5e64eb8..cf61558af55 100644 --- a/qlog/trace.go +++ b/qlog/trace.go @@ -3,6 +3,8 @@ package qlog import ( "time" + "github.com/lucas-clemente/quic-go/logging" + "github.com/francoispqt/gojay" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -38,8 +40,8 @@ func (p vantagePoint) MarshalJSONObject(enc *gojay.Encoder) { } type commonFields struct { - ODCID connectionID - GroupID connectionID + ODCID logging.ConnectionID + GroupID logging.ConnectionID ProtocolType string ReferenceTime time.Time } diff --git a/qlog/types.go b/qlog/types.go index b485e17dafe..42e562f9b8f 100644 --- a/qlog/types.go +++ b/qlog/types.go @@ -39,12 +39,6 @@ func (s streamType) String() string { } } -type connectionID protocol.ConnectionID - -func (c connectionID) String() string { - return fmt.Sprintf("%x", []byte(c)) -} - // category is the qlog event category. type category uint8 diff --git a/server.go b/server.go index 218d58252eb..ae29ff9b651 100644 --- a/server.go +++ b/server.go @@ -320,20 +320,43 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s } return false } + // Short header packets should never end up here in the first place + if !wire.IsLongHeaderPacket(p.data[0]) { + panic(fmt.Sprintf("misrouted packet: %#v", p.data)) + } + v, err := wire.ParseVersion(p.data) + // send a Version Negotiation Packet if the client is speaking a different protocol version + if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) { + if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize { + s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) + if err != nil { // should never happen + s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") + if s.config.Tracer != nil { + s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + if !s.config.DisableVersionNegotiationPackets { + go s.sendVersionNegotiationPacket(p.remoteAddr, src, dest, p.info.OOB()) + } + return false + } // If we're creating a new connection, the packet will be passed to the connection. // The header will then be parsed again. hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDGenerator.ConnectionIDLen()) - if err != nil && err != wire.ErrUnsupportedVersion { + if err != nil { if s.config.Tracer != nil { s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) } s.logger.Debugf("Error parsing packet: %s", err) return false } - // Short header packets should never end up here in the first place - if !hdr.IsLongHeader { - panic(fmt.Sprintf("misrouted packet: %#v", hdr)) - } if hdr.Type == protocol.PacketTypeInitial && p.Size() < protocol.MinInitialPacketSize { s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", p.Size()) if s.config.Tracer != nil { @@ -341,20 +364,7 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s } return false } - // send a Version Negotiation Packet if the client is speaking a different protocol version - if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - if p.Size() < protocol.MinUnknownVersionPacketSize { - s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) - if s.config.Tracer != nil { - s.config.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - if !s.config.DisableVersionNegotiationPackets { - go s.sendVersionNegotiationPacket(p, hdr) - } - return false - } + if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial { // Drop long header packets. // There's little point in sending a Stateless Reset, since the client @@ -467,7 +477,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro if err != nil { return err } - s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(connID)) + s.logger.Debugf("Changing connection ID to %s.", connID) var conn quicConn tracingID := nextConnTracingID() if added := s.connHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { @@ -565,7 +575,7 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack replyHdr.DestConnectionID = hdr.SrcConnectionID replyHdr.Token = token if s.logger.Debug() { - s.logger.Debugf("Changing connection ID to %s.", protocol.ConnectionID(srcConnID)) + s.logger.Debugf("Changing connection ID to %s.", srcConnID) s.logger.Debugf("-> Sending Retry") replyHdr.Log(s.logger) } @@ -664,22 +674,14 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han return err } -func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { - s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) - data := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) +func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest protocol.ArbitraryLenConnectionID, oob []byte) { + s.logger.Debugf("Client offered version %s, sending Version Negotiation") + + data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) if s.config.Tracer != nil { - s.config.Tracer.SentPacket( - p.remoteAddr, - &wire.Header{ - IsLongHeader: true, - DestConnectionID: hdr.SrcConnectionID, - SrcConnectionID: hdr.DestConnectionID, - }, - protocol.ByteCount(len(data)), - nil, - ) + s.config.Tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { + if _, err := s.conn.WritePacket(data, remote, oob); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } } diff --git a/server_test.go b/server_test.go index eafba0cc741..1cc68a97b5b 100644 --- a/server_test.go +++ b/server_test.go @@ -71,7 +71,7 @@ var _ = Describe("Server", func() { hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), DestConnectionID: destConnID, Version: protocol.VersionTLS, } @@ -82,11 +82,11 @@ var _ = Describe("Server", func() { } getInitialWithRandomDestConnID := func() *receivedPacket { - destConnID := make([]byte, 10) - _, err := rand.Read(destConnID) + b := make([]byte, 10) + _, err := rand.Read(b) Expect(err).ToNot(HaveOccurred()) - return getInitial(destConnID) + return getInitial(protocol.ParseConnectionID(b)) } parseHeader := func(data []byte) *wire.Header { @@ -204,7 +204,7 @@ var _ = Describe("Server", func() { p := getPacket(&wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), Version: serv.config.Versions[0], }, nil) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) @@ -217,7 +217,7 @@ var _ = Describe("Server", func() { p := getPacket(&wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}), Version: serv.config.Versions[0], }, make([]byte, protocol.MinInitialPacketSize-100), ) @@ -244,15 +244,15 @@ var _ = Describe("Server", func() { raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} retryToken, err := serv.tokenGenerator.NewRetryToken( raddr, - protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, - protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}), + protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}), ) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Version: protocol.VersionTLS, Token: retryToken, } @@ -263,7 +263,7 @@ var _ = Describe("Server", func() { rand.Read(token[:]) var newConnID protocol.ConnectionID - phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { newConnID = c phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { newConnID = c @@ -272,7 +272,7 @@ var _ = Describe("Server", func() { fn() return true }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde})) conn := NewMockQuicConn(mockCtrl) serv.newConn = func( _ sendConn, @@ -294,8 +294,8 @@ var _ = Describe("Server", func() { _ protocol.VersionNumber, ) quicConn { Expect(enable0RTT).To(BeFalse()) - Expect(origDestConnID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) - Expect(retrySrcConnID).To(Equal(&protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) + Expect(origDestConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xc0, 0xde}))) + Expect(*retrySrcConnID).To(Equal(protocol.ParseConnectionID([]byte{0xde, 0xca, 0xfb, 0xad}))) Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) Expect(destConnID).To(Equal(hdr.SrcConnectionID)) // make sure we're using a server-generated connection ID @@ -325,8 +325,8 @@ var _ = Describe("Server", func() { }) It("sends a Version Negotiation Packet for unsupported versions", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) packet := getPacket(&wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -336,20 +336,18 @@ var _ = Describe("Server", func() { }, make([]byte, protocol.MinUnknownVersionPacketSize)) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} packet.remoteAddr = raddr - tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { - Expect(replyHdr.IsLongHeader).To(BeTrue()) - Expect(replyHdr.Version).To(BeZero()) - Expect(replyHdr.SrcConnectionID).To(Equal(destConnID)) - Expect(replyHdr.DestConnectionID).To(Equal(srcConnID)) + tracer.EXPECT().SentVersionNegotiationPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, src, dest protocol.ArbitraryLenConnectionID, _ []protocol.VersionNumber) { + Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) + Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) }) done := make(chan struct{}) conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { defer close(done) Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) - hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(b)) + dest, src, versions, err := wire.ParseVersionNegotiationPacket(b) Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(srcConnID)) - Expect(hdr.SrcConnectionID).To(Equal(destConnID)) + Expect(dest).To(Equal(protocol.ArbitraryLenConnectionID(srcConnID.Bytes()))) + Expect(src).To(Equal(protocol.ArbitraryLenConnectionID(destConnID.Bytes()))) Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42))) return len(b), nil }) @@ -359,8 +357,8 @@ var _ = Describe("Server", func() { It("doesn't send a Version Negotiation packets if sending them is disabled", func() { serv.config.DisableVersionNegotiationPackets = true - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) packet := getPacket(&wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -378,8 +376,8 @@ var _ = Describe("Server", func() { It("ignores Version Negotiation packets", func() { data := wire.ComposeVersionNegotiation( - protocol.ConnectionID{1, 2, 3, 4}, - protocol.ConnectionID{4, 3, 2, 1}, + protocol.ArbitraryLenConnectionID{1, 2, 3, 4}, + protocol.ArbitraryLenConnectionID{4, 3, 2, 1}, []protocol.VersionNumber{1, 2, 3}, ) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} @@ -398,8 +396,8 @@ var _ = Describe("Server", func() { }) It("doesn't send a Version Negotiation Packet for unsupported versions, if the packet is too small", func() { - srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} + srcConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5}) + destConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6}) p := getPacket(&wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, @@ -425,8 +423,8 @@ var _ = Describe("Server", func() { hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Version: protocol.VersionTLS, } packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) @@ -457,8 +455,8 @@ var _ = Describe("Server", func() { hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Version: protocol.VersionTLS, } p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) @@ -467,7 +465,7 @@ var _ = Describe("Server", func() { rand.Read(token[:]) var newConnID protocol.ConnectionID - phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { newConnID = c phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) protocol.StatelessResetToken { newConnID = c @@ -476,7 +474,7 @@ var _ = Describe("Server", func() { fn() return true }) - tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) conn := NewMockQuicConn(mockCtrl) serv.newConn = func( @@ -568,7 +566,7 @@ var _ = Describe("Server", func() { return conn } - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}) + p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})) serv.handlePacket(p) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention).MinTimes(1) var wg sync.WaitGroup @@ -577,7 +575,7 @@ var _ = Describe("Server", func() { go func() { defer GinkgoRecover() defer wg.Done() - serv.handlePacket(getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8})) + serv.handlePacket(getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}))) }() } wg.Wait() @@ -616,8 +614,8 @@ var _ = Describe("Server", func() { return conn } - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}) - phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9}, gomock.Any(), gomock.Any()).Return(false) + p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})) + phm.EXPECT().AddWithConnID(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}), gomock.Any(), gomock.Any()).Return(false) Expect(serv.handlePacketImpl(p)).To(BeTrue()) Expect(createdConn).To(BeFalse()) }) @@ -690,7 +688,7 @@ var _ = Describe("Server", func() { }) It("doesn't accept new connections if they were closed in the mean time", func() { - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) ctx, cancel := context.WithCancel(context.Background()) connCreated := make(chan struct{}) conn := NewMockQuicConn(mockCtrl) @@ -774,7 +772,7 @@ var _ = Describe("Server", func() { It("decodes the token from the token field", func() { raddr := &net.UDPAddr{IP: net.IPv4(192, 168, 13, 37), Port: 1337} - token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) + token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) packet := getPacket(&wire.Header{ IsLongHeader: true, @@ -794,13 +792,13 @@ var _ = Describe("Server", func() { It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.VersionTLS, } @@ -832,14 +830,14 @@ var _ = Describe("Server", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } serv.config.MaxRetryTokenAge = time.Millisecond raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - token, err := serv.tokenGenerator.NewRetryToken(raddr, nil, nil) + token, err := serv.tokenGenerator.NewRetryToken(raddr, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) time.Sleep(2 * time.Millisecond) // make sure the token is expired hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.VersionTLS, } @@ -872,8 +870,8 @@ var _ = Describe("Server", func() { hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.VersionTLS, } @@ -904,8 +902,8 @@ var _ = Describe("Server", func() { hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.VersionTLS, } @@ -925,13 +923,13 @@ var _ = Describe("Server", func() { It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { serv.config.RequireAddressValidation = func(net.Addr) bool { return true } - token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, nil, nil) + token, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{}, protocol.ConnectionID{}) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + SrcConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2, 1}), + DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), Token: token, Version: protocol.VersionTLS, } @@ -1033,7 +1031,7 @@ var _ = Describe("Server", func() { tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()) serv.handleInitialImpl( &receivedPacket{buffer: getPacketBuffer()}, - &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, + &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Consistently(done).ShouldNot(BeClosed()) cancel() // complete the handshake @@ -1107,7 +1105,7 @@ var _ = Describe("Server", func() { }) serv.handleInitialImpl( &receivedPacket{buffer: getPacketBuffer()}, - &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, + &wire.Header{DestConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})}, ) Consistently(done).ShouldNot(BeClosed()) close(ready) @@ -1176,7 +1174,7 @@ var _ = Describe("Server", func() { }) It("doesn't accept new connections if they were closed in the mean time", func() { - p := getInitial(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + p := getInitial(protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) ctx, cancel := context.WithCancel(context.Background()) connCreated := make(chan struct{}) conn := NewMockQuicConn(mockCtrl)