diff --git a/conn_id_generator.go b/conn_id_generator.go index 0a6aa855bad..c56e8a4c16f 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -84,7 +84,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect if !ok { return nil } - if connID.Equal(sentWithDestConnID) { + if connID == sentWithDestConnID { return &qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID), diff --git a/conn_id_manager.go b/conn_id_manager.go index c1bb42bef2e..b878b027c03 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -121,7 +121,7 @@ func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID // insert a new element somewhere in the middle for el := h.queue.Front(); el != nil; el = el.Next() { if el.Value.SequenceNumber == seq { - if !el.Value.ConnectionID.Equal(connID) { + if el.Value.ConnectionID != connID { return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq) } if el.Value.StatelessResetToken != resetToken { diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index 72d7d5af0e9..f0e24e86f6c 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -278,7 +278,7 @@ var _ = Describe("Connection ID Manager", func() { m.SentPacket() connID := m.Get() - if !connID.Equal(lastConnID) { + if connID != lastConnID { counter++ lastConnID = connID Expect(removedTokens).To(HaveLen(1)) @@ -306,7 +306,7 @@ var _ = Describe("Connection ID Manager", func() { Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{10, 10, 10, 10}))) for { m.SentPacket() - if m.Get().Equal(protocol.ParseConnectionID([]byte{11, 11, 11, 11})) { + if m.Get() == protocol.ParseConnectionID([]byte{11, 11, 11, 11}) { break } } diff --git a/connection.go b/connection.go index 72c5e65273d..39dd0621f86 100644 --- a/connection.go +++ b/connection.go @@ -881,7 +881,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool { break } - if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { + if counter > 0 && hdr.DestConnectionID != lastConnID { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID) } @@ -925,7 +925,7 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. - if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && !hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID) } @@ -1017,7 +1017,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa return false } destConnID := s.connIDManager.Get() - if hdr.SrcConnectionID.Equal(destConnID) { + if hdr.SrcConnectionID == destConnID { if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket) } @@ -1143,7 +1143,7 @@ func (s *connection) handleUnpackedPacket( s.tracer.NegotiatedVersion(s.version, clientVersions, serverVersions) } // The server can change the source connection ID with the first Handshake packet. - if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && packet.hdr.SrcConnectionID != s.handshakeDestConnID { cid := packet.hdr.SrcConnectionID s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid) s.handshakeDestConnID = cid @@ -1155,7 +1155,7 @@ func (s *connection) handleUnpackedPacket( // we might have create a connection with an incorrect source connection ID. // Once we authenticate the first packet, we need to update it. if s.perspective == protocol.PerspectiveServer { - if !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + if packet.hdr.SrcConnectionID != s.handshakeDestConnID { s.handshakeDestConnID = packet.hdr.SrcConnectionID s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) } @@ -1601,7 +1601,7 @@ func (s *connection) checkTransportParameters(params *wire.TransportParameters) } // check the initial_source_connection_id - if !params.InitialSourceConnectionID.Equal(s.handshakeDestConnID) { + if params.InitialSourceConnectionID != s.handshakeDestConnID { return fmt.Errorf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID) } @@ -1609,14 +1609,14 @@ func (s *connection) checkTransportParameters(params *wire.TransportParameters) return nil } // check the original_destination_connection_id - if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { + if params.OriginalDestinationConnectionID != s.origDestConnID { return fmt.Errorf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID) } if s.retrySrcConnID != nil { // a Retry was performed if params.RetrySourceConnectionID == nil { return errors.New("missing retry_source_connection_id") } - if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) { + if *params.RetrySourceConnectionID != *s.retrySrcConnID { return fmt.Errorf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID) } } else if params.RetrySourceConnectionID != nil { diff --git a/fuzzing/header/fuzz.go b/fuzzing/header/fuzz.go index 87c72e6cd85..ea0054d83bb 100644 --- a/fuzzing/header/fuzz.go +++ b/fuzzing/header/fuzz.go @@ -35,7 +35,7 @@ func Fuzz(data []byte) int { if err != nil { return 0 } - if !hdr.DestConnectionID.Equal(connID) { + if hdr.DestConnectionID != connID { panic(fmt.Sprintf("Expected connection IDs to match: %s vs %s", hdr.DestConnectionID, connID)) } if (hdr.Type == protocol.PacketType0RTT) != is0RTTPacket { diff --git a/fuzzing/tokens/fuzz.go b/fuzzing/tokens/fuzz.go index 64d4576fea9..a753716ad3a 100644 --- a/fuzzing/tokens/fuzz.go +++ b/fuzzing/tokens/fuzz.go @@ -132,10 +132,10 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int { if token.SentTime.Before(start) || token.SentTime.After(time.Now()) { panic("incorrect send time") } - if !token.OriginalDestConnectionID.Equal(origDestConnID) { + if token.OriginalDestConnectionID != origDestConnID { panic("orig dest conn ID doesn't match") } - if !token.RetrySrcConnectionID.Equal(retrySrcConnID) { + if token.RetrySrcConnectionID != retrySrcConnID { panic("retry src conn ID doesn't match") } return 1 diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index e63c85d0371..2837512e1c9 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -417,13 +417,13 @@ var _ = Describe("0-RTT", func() { if firstConnID == nil { firstConnID = &connID firstCounter += zeroRTTBytes - } else if firstConnID != nil && firstConnID.Equal(connID) { + } else if firstConnID != nil && *firstConnID == connID { Expect(secondConnID).To(BeNil()) firstCounter += zeroRTTBytes } else if secondConnID == nil { secondConnID = &connID secondCounter += zeroRTTBytes - } else if secondConnID != nil && secondConnID.Equal(connID) { + } else if secondConnID != nil && *secondConnID == connID { secondCounter += zeroRTTBytes } else { Fail("received 3 connection IDs on 0-RTT packets") diff --git a/internal/protocol/connection_id.go b/internal/protocol/connection_id.go index c2e2900ec10..70ac06fb477 100644 --- a/internal/protocol/connection_id.go +++ b/internal/protocol/connection_id.go @@ -86,11 +86,6 @@ func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) { return c, err } -// Equal says if two connection IDs are equal -func (c ConnectionID) Equal(other ConnectionID) bool { - return c == other -} - // Len returns the length of the connection ID in bytes func (c ConnectionID) Len() int { return int(c.l) diff --git a/internal/protocol/connection_id_test.go b/internal/protocol/connection_id_test.go index 6b4cda61643..98abb1d26c8 100644 --- a/internal/protocol/connection_id_test.go +++ b/internal/protocol/connection_id_test.go @@ -43,18 +43,6 @@ var _ = Describe("Connection ID generation", func() { Expect(has20ByteConnID).To(BeTrue()) }) - It("says if connection IDs are equal", func() { - c1 := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - c2 := ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) - c3 := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8}) - Expect(c1.Equal(c1)).To(BeTrue()) - Expect(c1.Equal(c3)).To(BeTrue()) - Expect(c2.Equal(c2)).To(BeTrue()) - Expect(c2.Equal(c3)).To(BeFalse()) - Expect(c1.Equal(c2)).To(BeFalse()) - Expect(c2.Equal(c1)).To(BeFalse()) - }) - It("reads the connection ID", func() { buf := bytes.NewBuffer([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) c, err := ReadConnectionID(buf, 9)