From f20a07275079db944b0f29c3531ea88ab361c170 Mon Sep 17 00:00:00 2001 From: dchaofei Date: Thu, 12 May 2022 21:07:27 +0800 Subject: [PATCH] compose version negotiation function does not need to return the error type --- connection_test.go | 6 ++---- fuzzing/header/cmd/corpus.go | 6 +----- fuzzing/header/fuzz.go | 4 +--- integrationtests/self/mitm_test.go | 2 +- internal/wire/version_negotiation.go | 4 ++-- internal/wire/version_negotiation_test.go | 13 +++++-------- server.go | 6 +----- server_test.go | 3 +-- 8 files changed, 14 insertions(+), 30 deletions(-) diff --git a/connection_test.go b/connection_test.go index 665feed72fa..cab29632596 100644 --- a/connection_test.go +++ b/connection_test.go @@ -674,8 +674,7 @@ var _ = Describe("Connection", func() { }) It("drops Version Negotiation packets", func() { - b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, conn.config.Versions) - Expect(err).ToNot(HaveOccurred()) + b := wire.ComposeVersionNegotiation(srcConnID, destConnID, conn.config.Versions) tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket) Expect(conn.handlePacketImpl(&receivedPacket{ data: b, @@ -2592,8 +2591,7 @@ var _ = Describe("Client Connection", func() { Context("handling Version Negotiation", func() { getVNP := func(versions ...protocol.VersionNumber) *receivedPacket { - b, err := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions) - Expect(err).ToNot(HaveOccurred()) + b := wire.ComposeVersionNegotiation(srcConnID, destConnID, versions) return &receivedPacket{ data: b, buffer: getPacketBuffer(), diff --git a/fuzzing/header/cmd/corpus.go b/fuzzing/header/cmd/corpus.go index 2422afcc82d..eeb880ce860 100644 --- a/fuzzing/header/cmd/corpus.go +++ b/fuzzing/header/cmd/corpus.go @@ -24,11 +24,7 @@ func getVNP(src, dest protocol.ConnectionID, numVersions int) []byte { for i := 0; i < numVersions; i++ { versions[i] = protocol.VersionNumber(rand.Uint32()) } - data, err := wire.ComposeVersionNegotiation(src, dest, versions) - if err != nil { - log.Fatal(err) - } - return data + return wire.ComposeVersionNegotiation(src, dest, versions) } func main() { diff --git a/fuzzing/header/fuzz.go b/fuzzing/header/fuzz.go index 46bdfa89523..be7564b9e48 100644 --- a/fuzzing/header/fuzz.go +++ b/fuzzing/header/fuzz.go @@ -91,8 +91,6 @@ func fuzzVNP(data []byte) int { if len(versions) == 0 { panic("no versions") } - if _, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions); err != nil { - panic(err) - } + wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) return 1 } diff --git a/integrationtests/self/mitm_test.go b/integrationtests/self/mitm_test.go index 19769f3b8b7..34bb14c6a3f 100644 --- a/integrationtests/self/mitm_test.go +++ b/integrationtests/self/mitm_test.go @@ -297,7 +297,7 @@ var _ = Describe("MITM test", func() { sendForgedVersionNegotationPacket := func(conn net.PacketConn, remoteAddr net.Addr, hdr *wire.Header) { // Create fake version negotiation packet with no supported versions versions := []protocol.VersionNumber{} - packet, _ := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) + packet := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, versions) // Send the packet _, err := conn.WriteTo(packet, remoteAddr) diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index bcae87d17f7..196853e0fc4 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -35,7 +35,7 @@ func ParseVersionNegotiationPacket(b *bytes.Reader) (*Header, []protocol.Version } // ComposeVersionNegotiation composes a Version Negotiation -func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) ([]byte, error) { +func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { greasedVersions := protocol.GetGreasedVersions(versions) expectedLen := 1 /* type byte */ + 4 /* version field */ + 1 /* dest connection ID length field */ + destConnID.Len() + 1 /* src connection ID length field */ + srcConnID.Len() + len(greasedVersions)*4 buf := bytes.NewBuffer(make([]byte, 0, expectedLen)) @@ -50,5 +50,5 @@ func ComposeVersionNegotiation(destConnID, srcConnID protocol.ConnectionID, vers for _, v := range greasedVersions { utils.BigEndian.WriteUint32(buf, uint32(v)) } - return buf.Bytes(), nil + return buf.Bytes() } diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index d3a0b7e6215..31ad5d93f86 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -36,20 +36,18 @@ var _ = Describe("Version Negotiation Packets", func() { It("errors if it contains versions of the wrong length", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x22334455, 0x33445566} - data, err := ComposeVersionNegotiation(connID, connID, versions) - Expect(err).ToNot(HaveOccurred()) - _, _, err = ParseVersionNegotiationPacket(bytes.NewReader(data[:len(data)-2])) + data := ComposeVersionNegotiation(connID, connID, versions) + _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data[:len(data)-2])) Expect(err).To(MatchError("Version Negotiation packet has a version list with an invalid length")) }) It("errors if the version list is empty", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{0x22334455} - data, err := ComposeVersionNegotiation(connID, connID, versions) - Expect(err).ToNot(HaveOccurred()) + data := ComposeVersionNegotiation(connID, connID, versions) // remove 8 bytes (two versions), since ComposeVersionNegotiation also added a reserved version number data = data[:len(data)-8] - _, _, err = ParseVersionNegotiationPacket(bytes.NewReader(data)) + _, _, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) Expect(err).To(MatchError("Version Negotiation packet has empty version list")) }) @@ -57,8 +55,7 @@ var _ = Describe("Version Negotiation Packets", func() { srcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} versions := []protocol.VersionNumber{1001, 1003} - data, err := ComposeVersionNegotiation(destConnID, srcConnID, versions) - Expect(err).ToNot(HaveOccurred()) + data := ComposeVersionNegotiation(destConnID, srcConnID, versions) Expect(data[0] & 0x80).ToNot(BeZero()) hdr, supportedVersions, err := ParseVersionNegotiationPacket(bytes.NewReader(data)) Expect(err).ToNot(HaveOccurred()) diff --git a/server.go b/server.go index 2e987bcb188..0e64297050c 100644 --- a/server.go +++ b/server.go @@ -651,11 +651,7 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han func (s *baseServer) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) - data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) - if err != nil { - s.logger.Debugf("Error composing Version Negotiation: %s", err) - return - } + data := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) if s.config.Tracer != nil { s.config.Tracer.SentPacket( p.remoteAddr, diff --git a/server_test.go b/server_test.go index b1d4f73c92e..509af9e9812 100644 --- a/server_test.go +++ b/server_test.go @@ -426,12 +426,11 @@ var _ = Describe("Server", func() { }) It("ignores Version Negotiation packets", func() { - data, err := wire.ComposeVersionNegotiation( + data := wire.ComposeVersionNegotiation( protocol.ConnectionID{1, 2, 3, 4}, protocol.ConnectionID{4, 3, 2, 1}, []protocol.VersionNumber{1, 2, 3}, ) - Expect(err).ToNot(HaveOccurred()) raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} done := make(chan struct{}) tracer.EXPECT().DroppedPacket(raddr, logging.PacketTypeVersionNegotiation, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {