Skip to content

Commit

Permalink
remove ConnectionID.Equal function
Browse files Browse the repository at this point in the history
Connection IDs can now be compared with ==.
  • Loading branch information
marten-seemann committed Aug 29, 2022
1 parent 1aced95 commit 4cbb4f8
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 34 deletions.
2 changes: 1 addition & 1 deletion conn_id_generator.go
Expand Up @@ -84,7 +84,7 @@ func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.Connect
if !ok {
return nil
}
if connID.Equal(sentWithDestConnID) {
if connID == sentWithDestConnID {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
Expand Down
2 changes: 1 addition & 1 deletion conn_id_manager.go
Expand Up @@ -121,7 +121,7 @@ func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID
// insert a new element somewhere in the middle
for el := h.queue.Front(); el != nil; el = el.Next() {
if el.Value.SequenceNumber == seq {
if !el.Value.ConnectionID.Equal(connID) {
if el.Value.ConnectionID != connID {
return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq)
}
if el.Value.StatelessResetToken != resetToken {
Expand Down
4 changes: 2 additions & 2 deletions conn_id_manager_test.go
Expand Up @@ -278,7 +278,7 @@ var _ = Describe("Connection ID Manager", func() {
m.SentPacket()

connID := m.Get()
if !connID.Equal(lastConnID) {
if connID != lastConnID {
counter++
lastConnID = connID
Expect(removedTokens).To(HaveLen(1))
Expand Down Expand Up @@ -306,7 +306,7 @@ var _ = Describe("Connection ID Manager", func() {
Expect(m.Get()).To(Equal(protocol.ParseConnectionID([]byte{10, 10, 10, 10})))
for {
m.SentPacket()
if m.Get().Equal(protocol.ParseConnectionID([]byte{11, 11, 11, 11})) {
if m.Get() == protocol.ParseConnectionID([]byte{11, 11, 11, 11}) {
break
}
}
Expand Down
16 changes: 8 additions & 8 deletions connection.go
Expand Up @@ -881,7 +881,7 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
break
}

if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) {
if counter > 0 && hdr.DestConnectionID != lastConnID {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID)
}
Expand Down Expand Up @@ -925,7 +925,7 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo

// The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && !hdr.SrcConnectionID.Equal(s.handshakeDestConnID) {
if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID)
}
Expand Down Expand Up @@ -1017,7 +1017,7 @@ func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* wa
return false
}
destConnID := s.connIDManager.Get()
if hdr.SrcConnectionID.Equal(destConnID) {
if hdr.SrcConnectionID == destConnID {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket)
}
Expand Down Expand Up @@ -1143,7 +1143,7 @@ func (s *connection) handleUnpackedPacket(
s.tracer.NegotiatedVersion(s.version, clientVersions, serverVersions)
}
// The server can change the source connection ID with the first Handshake packet.
if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) {
if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && packet.hdr.SrcConnectionID != s.handshakeDestConnID {
cid := packet.hdr.SrcConnectionID
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid)
s.handshakeDestConnID = cid
Expand All @@ -1155,7 +1155,7 @@ func (s *connection) handleUnpackedPacket(
// we might have create a connection with an incorrect source connection ID.
// Once we authenticate the first packet, we need to update it.
if s.perspective == protocol.PerspectiveServer {
if !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) {
if packet.hdr.SrcConnectionID != s.handshakeDestConnID {
s.handshakeDestConnID = packet.hdr.SrcConnectionID
s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID)
}
Expand Down Expand Up @@ -1601,22 +1601,22 @@ func (s *connection) checkTransportParameters(params *wire.TransportParameters)
}

// check the initial_source_connection_id
if !params.InitialSourceConnectionID.Equal(s.handshakeDestConnID) {
if params.InitialSourceConnectionID != s.handshakeDestConnID {
return fmt.Errorf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID)
}

if s.perspective == protocol.PerspectiveServer {
return nil
}
// check the original_destination_connection_id
if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) {
if params.OriginalDestinationConnectionID != s.origDestConnID {
return fmt.Errorf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID)
}
if s.retrySrcConnID != nil { // a Retry was performed
if params.RetrySourceConnectionID == nil {
return errors.New("missing retry_source_connection_id")
}
if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) {
if *params.RetrySourceConnectionID != *s.retrySrcConnID {
return fmt.Errorf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID)
}
} else if params.RetrySourceConnectionID != nil {
Expand Down
2 changes: 1 addition & 1 deletion fuzzing/header/fuzz.go
Expand Up @@ -35,7 +35,7 @@ func Fuzz(data []byte) int {
if err != nil {
return 0
}
if !hdr.DestConnectionID.Equal(connID) {
if hdr.DestConnectionID != connID {
panic(fmt.Sprintf("Expected connection IDs to match: %s vs %s", hdr.DestConnectionID, connID))
}
if (hdr.Type == protocol.PacketType0RTT) != is0RTTPacket {
Expand Down
4 changes: 2 additions & 2 deletions fuzzing/tokens/fuzz.go
Expand Up @@ -132,10 +132,10 @@ func newRetryToken(tg *handshake.TokenGenerator, data []byte) int {
if token.SentTime.Before(start) || token.SentTime.After(time.Now()) {
panic("incorrect send time")
}
if !token.OriginalDestConnectionID.Equal(origDestConnID) {
if token.OriginalDestConnectionID != origDestConnID {
panic("orig dest conn ID doesn't match")
}
if !token.RetrySrcConnectionID.Equal(retrySrcConnID) {
if token.RetrySrcConnectionID != retrySrcConnID {
panic("retry src conn ID doesn't match")
}
return 1
Expand Down
4 changes: 2 additions & 2 deletions integrationtests/self/zero_rtt_test.go
Expand Up @@ -417,13 +417,13 @@ var _ = Describe("0-RTT", func() {
if firstConnID == nil {
firstConnID = &connID
firstCounter += zeroRTTBytes
} else if firstConnID != nil && firstConnID.Equal(connID) {
} else if firstConnID != nil && *firstConnID == connID {
Expect(secondConnID).To(BeNil())
firstCounter += zeroRTTBytes
} else if secondConnID == nil {
secondConnID = &connID
secondCounter += zeroRTTBytes
} else if secondConnID != nil && secondConnID.Equal(connID) {
} else if secondConnID != nil && *secondConnID == connID {
secondCounter += zeroRTTBytes
} else {
Fail("received 3 connection IDs on 0-RTT packets")
Expand Down
5 changes: 0 additions & 5 deletions internal/protocol/connection_id.go
Expand Up @@ -86,11 +86,6 @@ func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {
return c, err
}

// Equal says if two connection IDs are equal
func (c ConnectionID) Equal(other ConnectionID) bool {
return c == other
}

// Len returns the length of the connection ID in bytes
func (c ConnectionID) Len() int {
return int(c.l)
Expand Down
12 changes: 0 additions & 12 deletions internal/protocol/connection_id_test.go
Expand Up @@ -43,18 +43,6 @@ var _ = Describe("Connection ID generation", func() {
Expect(has20ByteConnID).To(BeTrue())
})

It("says if connection IDs are equal", func() {
c1 := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
c2 := ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1})
c3 := ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8})
Expect(c1.Equal(c1)).To(BeTrue())
Expect(c1.Equal(c3)).To(BeTrue())
Expect(c2.Equal(c2)).To(BeTrue())
Expect(c2.Equal(c3)).To(BeFalse())
Expect(c1.Equal(c2)).To(BeFalse())
Expect(c2.Equal(c1)).To(BeFalse())
})

It("reads the connection ID", func() {
buf := bytes.NewBuffer([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9})
c, err := ReadConnectionID(buf, 9)
Expand Down

0 comments on commit 4cbb4f8

Please sign in to comment.