From 44407ffd370dd4b994b1719ca19304c5163c4ee8 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 25 May 2023 13:31:57 +0300 Subject: [PATCH] re-add support for Go 1.20 --- .github/workflows/gotip.yml | 2 + connection_test.go | 86 +- crypto_stream_manager_test.go | 51 +- go.mod | 2 +- go.sum | 4 +- http3/server_test.go | 10 +- integrationtests/gomodvendor/go.sum | 4 +- integrationtests/self/handshake_test.go | 9 +- integrationtests/self/zero_rtt_oldgo_test.go | 804 ++++++++++++++++++ integrationtests/self/zero_rtt_test.go | 2 + internal/handshake/crypto_setup.go | 156 ++-- internal/handshake/crypto_setup_test.go | 303 ++----- internal/handshake/handshake_suite_test.go | 10 +- internal/handshake/header_protector.go | 2 +- internal/protocol/encryption_level.go | 35 - internal/qerr/error_codes.go | 5 +- .../cipher_suite_go121.go} | 4 +- .../client_session_cache.go | 4 +- internal/qtls/go119.go | 145 ++++ internal/qtls/go120.go | 147 ++++ internal/qtls/go121.go | 142 ++++ internal/qtls/go_oldversion.go | 5 + internal/qtls/qtls_suite_test.go | 25 + internal/testdata/cert.go | 1 + 24 files changed, 1473 insertions(+), 485 deletions(-) create mode 100644 integrationtests/self/zero_rtt_oldgo_test.go rename internal/{handshake/cipher_suite_unsafe.go => qtls/cipher_suite_go121.go} (98%) rename internal/{handshake => qtls}/client_session_cache.go (97%) create mode 100644 internal/qtls/go119.go create mode 100644 internal/qtls/go120.go create mode 100644 internal/qtls/go121.go create mode 100644 internal/qtls/go_oldversion.go create mode 100644 internal/qtls/qtls_suite_test.go diff --git a/.github/workflows/gotip.yml b/.github/workflows/gotip.yml index e7d78b907fe..2eb48f7f65b 100644 --- a/.github/workflows/gotip.yml +++ b/.github/workflows/gotip.yml @@ -26,3 +26,5 @@ jobs: run: echo "QLOGFLAG= -qlog" >> $GITHUB_ENV - name: Run self tests, using QUIC v1 run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=1 ${{ env.QLOGFLAG }} + - name: Run self tests, using QUIC v1, with race detector + run: go run github.com/onsi/ginkgo/v2/ginkgo -race -r -v -randomize-all -randomize-suites -trace integrationtests/self -- -version=1 ${{ env.QLOGFLAG }} diff --git a/connection_test.go b/connection_test.go index 10f0c202b6b..cba1d6a5b41 100644 --- a/connection_test.go +++ b/connection_test.go @@ -116,7 +116,7 @@ var _ = Describe("Connection", func() { &protocol.DefaultConnectionIDGenerator{}, protocol.StatelessResetToken{}, populateServerConfig(&Config{DisablePathMTUDiscovery: true}), - nil, // tls.Config + &tls.Config{}, tokenGenerator, false, tracer, @@ -353,7 +353,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) Expect(conn.run()).To(MatchError(expectedErr)) }() Expect(conn.handleFrame(&wire.ConnectionCloseFrame{ @@ -381,7 +381,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) Expect(conn.run()).To(MatchError(testErr)) }() ccf := &wire.ConnectionCloseFrame{ @@ -428,7 +428,7 @@ var _ = Describe("Connection", func() { runConn := func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) runErr <- conn.run() }() Eventually(areConnsRunning).Should(BeTrue()) @@ -808,7 +808,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackConnectionClose(gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() expectReplaceWithClosed() @@ -850,7 +850,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -885,7 +885,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -910,7 +910,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) err := conn.run() Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) @@ -934,7 +934,7 @@ var _ = Describe("Connection", func() { runErr := make(chan error) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) runErr <- conn.run() }() expectReplaceWithClosed() @@ -958,7 +958,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) err := conn.run() Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) @@ -1194,7 +1194,7 @@ var _ = Describe("Connection", func() { runConn := func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() close(connDone) }() @@ -1409,7 +1409,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any()).Times(2) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1427,7 +1427,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any()) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1446,7 +1446,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any()) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1466,7 +1466,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any()) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1495,7 +1495,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }).Times(2) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1519,7 +1519,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }).Times(3) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1532,7 +1532,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Available().Return(available) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1555,7 +1555,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().WouldBlock().AnyTimes() go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() @@ -1590,7 +1590,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any()).DoAndReturn(func(p *packetBuffer) { written <- struct{}{} }) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() available := make(chan struct{}, 1) @@ -1624,7 +1624,7 @@ var _ = Describe("Connection", func() { // don't EXPECT any calls to mconn.Write() go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() // no packet will get sent @@ -1649,7 +1649,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234), gomock.Any(), conn.version).Return(p, buffer, nil) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1696,7 +1696,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() // don't EXPECT any calls to mconn.Write() @@ -1732,7 +1732,7 @@ var _ = Describe("Connection", func() { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() Eventually(written).Should(BeClosed()) @@ -1796,7 +1796,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() @@ -1828,7 +1828,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() <-finishHandshake - cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().StartHandshake() cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() close(conn.handshakeCompleteChan) @@ -1858,7 +1858,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() <-finishHandshake - cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().StartHandshake() cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) close(conn.handshakeCompleteChan) @@ -1905,7 +1905,7 @@ var _ = Describe("Connection", func() { tracer.EXPECT().Close() go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().StartHandshake() conn.run() }() handshakeCtx := conn.HandshakeComplete() @@ -1939,7 +1939,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackPacket(false, gomock.Any(), conn.version).Return(shortHeaderPacket{}, nil, errNothingToPack).AnyTimes() go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().StartHandshake() cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() mconn.EXPECT().Write(gomock.Any()) @@ -1962,7 +1962,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) Expect(conn.run()).To(Succeed()) close(done) }() @@ -1982,7 +1982,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) err := conn.run() Expect(err).To(MatchError(&qerr.ApplicationError{ ErrorCode: 0x1337, @@ -2036,7 +2036,7 @@ var _ = Describe("Connection", func() { runConn := func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() } @@ -2124,7 +2124,7 @@ var _ = Describe("Connection", func() { ) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) @@ -2149,7 +2149,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) @@ -2182,7 +2182,7 @@ var _ = Describe("Connection", func() { // and not on the last network activity go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -2209,7 +2209,7 @@ var _ = Describe("Connection", func() { conn.handshakeComplete = false go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) err := conn.run() nerr, ok := err.(net.Error) @@ -2238,7 +2238,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) close(conn.handshakeCompleteChan) @@ -2258,7 +2258,7 @@ var _ = Describe("Connection", func() { conn.idleTimeout = 30 * time.Second go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -2431,7 +2431,7 @@ var _ = Describe("Client Connection", func() { conn.unpacker = unpacker go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() newConnID := protocol.ParseConnectionID([]byte{1, 3, 3, 7, 1, 3, 3, 7}) @@ -2511,7 +2511,7 @@ var _ = Describe("Client Connection", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() running := make(chan struct{}) - cryptoSetup.EXPECT().RunHandshake().Do(func() { + cryptoSetup.EXPECT().StartHandshake().Do(func() { close(running) conn.closeLocal(errors.New("early error")) }) @@ -2564,7 +2564,7 @@ var _ = Describe("Client Connection", func() { errChan := make(chan error, 1) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) errChan <- conn.run() }() connRunner.EXPECT().Remove(srcConnID) @@ -2589,7 +2589,7 @@ var _ = Describe("Client Connection", func() { errChan := make(chan error, 1) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) errChan <- conn.run() }() connRunner.EXPECT().Remove(srcConnID).MaxTimes(1) @@ -2697,7 +2697,7 @@ var _ = Describe("Client Connection", func() { closed = false go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) errChan <- conn.run() close(errChan) }() diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index e2f46c8cc97..c5b59b92780 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -1,12 +1,9 @@ package quic import ( - "errors" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -35,9 +32,7 @@ var _ = Describe("Crypto Stream Manager", func() { initialStream.EXPECT().GetCryptoData().Return([]byte("foobar")) initialStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)).To(Succeed()) }) It("passes messages to the handshake stream", func() { @@ -46,9 +41,7 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")) handshakeStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) }) It("passes messages to the 1-RTT stream", func() { @@ -57,9 +50,7 @@ var _ = Describe("Crypto Stream Manager", func() { oneRTTStream.EXPECT().GetCryptoData().Return([]byte("foobar")) oneRTTStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foobar"), protocol.Encryption1RTT) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) + Expect(csm.HandleCryptoFrame(cf, protocol.Encryption1RTT)).To(Succeed()) }) It("doesn't call the message handler, if there's no message", func() { @@ -67,9 +58,7 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().HandleCryptoFrame(cf) handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle // don't EXPECT any calls to HandleMessage() - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) }) It("processes all messages", func() { @@ -80,39 +69,11 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake) cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) - }) - - It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() { - cf := &wire.CryptoFrame{Data: []byte("foobar")} - gomock.InOrder( - handshakeStream.EXPECT().HandleCryptoFrame(cf), - handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), - handshakeStream.EXPECT().Finish(), - ) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeTrue()) - }) - - It("returns errors that occur when finishing a stream", func() { - testErr := errors.New("test error") - cf := &wire.CryptoFrame{Data: []byte("foobar")} - gomock.InOrder( - handshakeStream.EXPECT().HandleCryptoFrame(cf), - handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), - handshakeStream.EXPECT().Finish().Return(testErr), - ) - _, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).To(MatchError(err)) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) }) It("errors for unknown encryption levels", func() { - _, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, 42) + err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, 42) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("received CRYPTO frame with unexpected encryption level")) }) diff --git a/go.mod b/go.mod index bdb8183f932..8c8fca83bad 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/onsi/gomega v1.20.1 github.com/quic-go/qpack v0.4.0 github.com/quic-go/qtls-go1-19 v0.3.2 - github.com/quic-go/qtls-go1-20 v0.2.2 + github.com/quic-go/qtls-go1-20 v0.0.0-20230529182851-8b69d972a82a golang.org/x/crypto v0.4.0 golang.org/x/exp v0.0.0-20221205204356-47842c84f3db golang.org/x/net v0.7.0 diff --git a/go.sum b/go.sum index a2375b032b6..cc094086cdf 100644 --- a/go.sum +++ b/go.sum @@ -90,8 +90,8 @@ github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= -github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= -github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= +github.com/quic-go/qtls-go1-20 v0.0.0-20230529182851-8b69d972a82a h1:29YotFwKdXvoi6tI6C2XJi4gXG+GmZOzG4xHibYyBhw= +github.com/quic-go/qtls-go1-20 v0.0.0-20230529182851-8b69d972a82a/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= diff --git a/http3/server_test.go b/http3/server_test.go index 572446eaa77..fb8484225c9 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -926,7 +926,7 @@ var _ = Describe("Server", func() { c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") - Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) + Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) }) It("advertises h3-29 for draft-29", func() { @@ -937,7 +937,7 @@ var _ = Describe("Server", func() { c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3Draft29}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") - Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3Draft29)) + Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3Draft29)) }) It("sets the GetConfigForClient callback if no tls.Config is given", func() { @@ -965,7 +965,7 @@ var _ = Describe("Server", func() { c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") - Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) + Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) }) It("works if GetConfigForClient returns a nil tls.Config", func() { @@ -978,7 +978,7 @@ var _ = Describe("Server", func() { c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") - Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) + Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) }) It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { @@ -996,7 +996,7 @@ var _ = Describe("Server", func() { c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") - Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) + Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) // check that the original config was not modified Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"})) }) diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index 9ba8661bc3c..98fab665ba1 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -113,8 +113,8 @@ github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= -github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= -github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= +github.com/quic-go/qtls-go1-20 v0.0.0-20230529182851-8b69d972a82a h1:29YotFwKdXvoi6tI6C2XJi4gXG+GmZOzG4xHibYyBhw= +github.com/quic-go/qtls-go1-20 v0.0.0-20230529182851-8b69d972a82a/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 9fe44809407..2ef4dd20768 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -10,9 +10,9 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/qtls" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -91,7 +91,7 @@ var _ = Describe("Handshake tests", func() { suiteID := id It(fmt.Sprintf("using %s", name), func() { - reset := handshake.SetCipherSuite(suiteID) + reset := qtls.SetCipherSuite(suiteID) defer reset() tlsConf := getTLSConfig() @@ -198,7 +198,10 @@ var _ = Describe("Handshake tests", func() { var transportErr *quic.TransportError Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) - Expect(transportErr.Error()).To(ContainSubstring("tls: certificate required")) + Expect(transportErr.Error()).To(Or( + ContainSubstring("tls: certificate required"), + ContainSubstring("tls: bad certificate"), + )) }) It("uses the ServerName in the tls.Config", func() { diff --git a/integrationtests/self/zero_rtt_oldgo_test.go b/integrationtests/self/zero_rtt_oldgo_test.go new file mode 100644 index 00000000000..beaf351e249 --- /dev/null +++ b/integrationtests/self/zero_rtt_oldgo_test.go @@ -0,0 +1,804 @@ +//go:build !go1.21 + +package self_test + +import ( + "context" + "crypto/tls" + "fmt" + "io" + mrand "math/rand" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/quic-go/quic-go" + quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("0-RTT", func() { + rtt := scaleDuration(5 * time.Millisecond) + + runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) { + var num0RTTPackets uint32 // to be used as an atomic + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { + for len(data) > 0 { + if !wire.IsLongHeaderPacket(data[0]) { + break + } + hdr, _, rest, err := wire.ParsePacket(data) + Expect(err).ToNot(HaveOccurred()) + if hdr.Type == protocol.PacketType0RTT { + atomic.AddUint32(&num0RTTPackets, 1) + break + } + data = rest + } + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + + return proxy, &num0RTTPackets + } + + dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) { + tlsConf := getTLSConfig() + if serverConf == nil { + serverConf = getQuicConfig(nil) + } + serverConf.Allow0RTT = true + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + serverConf, + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + // dial the first connection in order to receive a session ticket + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + <-conn.Context().Done() + }() + + clientConf := getTLSClientConfig() + gets := make(chan string, 100) + puts := make(chan string, 100) + clientConf.ClientSessionCache = newClientSessionCache(tls.NewLRUClientSessionCache(100), gets, puts) + conn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + Eventually(puts).Should(Receive()) + // received the session ticket. We're done here. + Expect(conn.CloseWithError(0, "")).To(Succeed()) + Eventually(done).Should(BeClosed()) + return tlsConf, clientConf + } + + transfer0RTTData := func( + ln *quic.EarlyListener, + proxyPort int, + connIDLen int, + clientTLSConf *tls.Config, + clientConf *quic.Config, + testdata []byte, // data to transfer + ) { + // accept the second connection, and receive the data sent in 0-RTT + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(testdata)) + Expect(str.Close()).To(Succeed()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + <-conn.Context().Done() + close(done) + }() + + if clientConf == nil { + clientConf = getQuicConfig(nil) + } + var conn quic.EarlyConnection + if connIDLen == 0 { + var err error + conn, err = quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxyPort), + clientTLSConf, + clientConf, + ) + Expect(err).ToNot(HaveOccurred()) + } else { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + tr := &quic.Transport{ + Conn: udpConn, + ConnectionIDLength: connIDLen, + } + defer tr.Close() + conn, err = tr.DialEarly( + context.Background(), + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort}, + clientTLSConf, + clientConf, + ) + Expect(err).ToNot(HaveOccurred()) + } + defer conn.CloseWithError(0, "") + str, err := conn.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(testdata) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + <-conn.HandshakeComplete() + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn + conn.CloseWithError(0, "") + Eventually(done).Should(BeClosed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + } + + check0RTTRejected := func( + ln *quic.EarlyListener, + proxyPort int, + clientConf *tls.Config, + ) { + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxyPort), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(make([]byte, 3000)) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + + // make sure the server doesn't process the data + ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) + defer cancel() + serverConn, err := ln.Accept(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse()) + _, err = serverConn.AcceptUniStream(ctx) + Expect(err).To(Equal(context.DeadlineExceeded)) + Expect(serverConn.CloseWithError(0, "")).To(Succeed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + } + + // can be used to extract 0-RTT from a packetTracer + get0RTTPackets := func(packets []packet) []protocol.PacketNumber { + var zeroRTTPackets []protocol.PacketNumber + for _, p := range packets { + if p.hdr.Type == protocol.PacketType0RTT { + zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber) + } + } + return zeroRTTPackets + } + + for _, l := range []int{0, 15} { + connIDLen := l + + It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() { + tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + transfer0RTTData( + ln, + proxy.LocalPort(), + connIDLen, + clientTLSConf, + getQuicConfig(nil), + PRData, + ) + + var numNewConnIDs int + for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, f := range p.frames { + if _, ok := f.(*logging.NewConnectionIDFrame); ok { + numNewConnIDs++ + } + } + } + if connIDLen == 0 { + Expect(numNewConnIDs).To(BeZero()) + } else { + Expect(numNewConnIDs).ToNot(BeZero()) + } + + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) + Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0))) + }) + } + + // Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets. + It("waits for a connection until the handshake is done", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + zeroRTTData := GeneratePRData(5 << 10) + oneRTTData := PRData + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + // now accept the second connection, and receive the 0-RTT data + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(zeroRTTData)) + str, err = conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err = io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(oneRTTData)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }() + + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + firstStr, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = firstStr.Write(zeroRTTData) + Expect(err).ToNot(HaveOccurred()) + Expect(firstStr.Close()).To(Succeed()) + + // wait for the handshake to complete + Eventually(conn.HandshakeComplete()).Should(BeClosed()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(PRData) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + <-conn.Context().Done() + + // check that 0-RTT packets only contain STREAM frames for the first stream + var num0RTT int + for _, p := range tracer.getRcvdLongHeaderPackets() { + if p.hdr.Header.Type != protocol.PacketType0RTT { + continue + } + for _, f := range p.frames { + sf, ok := f.(*logging.StreamFrame) + if !ok { + continue + } + num0RTT++ + Expect(sf.StreamID).To(Equal(firstStr.StreamID())) + } + } + fmt.Fprintf(GinkgoWriter, "received %d STREAM frames in 0-RTT packets\n", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + }) + + It("transfers 0-RTT data, when 0-RTT packets are lost", func() { + var ( + num0RTTPackets uint32 // to be used as an atomic + num0RTTDropped uint32 + ) + + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { + if wire.IsLongHeaderPacket(data[0]) { + hdr, _, _, err := wire.ParsePacket(data) + Expect(err).ToNot(HaveOccurred()) + if hdr.Type == protocol.PacketType0RTT { + atomic.AddUint32(&num0RTTPackets, 1) + } + } + return rtt / 2 + }, + DropPacket: func(_ quicproxy.Direction, data []byte) bool { + if !wire.IsLongHeaderPacket(data[0]) { + return false + } + hdr, _, _, err := wire.ParsePacket(data) + Expect(err).ToNot(HaveOccurred()) + if hdr.Type == protocol.PacketType0RTT { + // drop 25% of the 0-RTT packets + drop := mrand.Intn(4) == 0 + if drop { + atomic.AddUint32(&num0RTTDropped, 1) + } + return drop + } + return false + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) + + num0RTT := atomic.LoadUint32(&num0RTTPackets) + numDropped := atomic.LoadUint32(&num0RTTDropped) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) + Expect(numDropped).ToNot(BeZero()) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) + }) + + It("retransmits all 0-RTT data when the server performs a Retry", func() { + var mutex sync.Mutex + var firstConnID, secondConnID *protocol.ConnectionID + var firstCounter, secondCounter protocol.ByteCount + + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) { + for len(data) > 0 { + hdr, _, rest, err := wire.ParsePacket(data) + if err != nil { + return + } + data = rest + if hdr.Type == protocol.PacketType0RTT { + n += hdr.Length - 16 /* AEAD tag */ + } + } + return + } + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + RequireAddressValidation: func(net.Addr) bool { return true }, + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { + connID, err := wire.ParseConnectionID(data, 0) + Expect(err).ToNot(HaveOccurred()) + + mutex.Lock() + defer mutex.Unlock() + + if zeroRTTBytes := countZeroRTTBytes(data); zeroRTTBytes > 0 { + if firstConnID == nil { + firstConnID = &connID + firstCounter += zeroRTTBytes + } else if firstConnID != nil && *firstConnID == connID { + Expect(secondConnID).To(BeNil()) + firstCounter += zeroRTTBytes + } else if secondConnID == nil { + secondConnID = &connID + secondCounter += zeroRTTBytes + } else if secondConnID != nil && *secondConnID == connID { + secondCounter += zeroRTTBytes + } else { + Fail("received 3 connection IDs on 0-RTT packets") + } + } + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets + + mutex.Lock() + defer mutex.Unlock() + Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra + Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) + Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) + }) + + It("doesn't reject 0-RTT when the server's transport stream limit increased", func() { + const maxStreams = 1 + tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + MaxIncomingUniStreams: maxStreams, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + MaxIncomingUniStreams: maxStreams + 1, + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + // The client remembers the old limit and refuses to open a new stream. + _, err = conn.OpenUniStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too many open streams")) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err = conn.OpenUniStreamSync(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) + + It("rejects 0-RTT when the server's stream limit decreased", func() { + const maxStreams = 42 + tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + MaxIncomingStreams: maxStreams, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + MaxIncomingStreams: maxStreams - 1, + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + check0RTTRejected(ln, proxy.LocalPort(), clientConf) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + + It("rejects 0-RTT when the ALPN changed", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + // now close the listener and dial new connection with a different ALPN + clientConf.NextProtos = []string{"new-alpn"} + tlsConf.NextProtos = []string{"new-alpn"} + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + check0RTTRejected(ln, proxy.LocalPort(), clientConf) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + + It("rejects 0-RTT when the application doesn't allow it", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + // now close the listener and dial new connection with a different ALPN + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: false, // application rejects 0-RTT + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + check0RTTRejected(ln, proxy.LocalPort(), clientConf) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + + DescribeTable("flow control limits", + func(addFlowControlLimit func(*quic.Config, uint64)) { + tracer := newPacketTracer() + firstConf := getQuicConfig(&quic.Config{Allow0RTT: true}) + addFlowControlLimit(firstConf, 3) + tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) + + secondConf := getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }) + addFlowControlLimit(secondConf, 100) + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + secondConf, + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(written) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + + Eventually(written).Should(BeClosed()) + + serverConn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + rstr, err := serverConn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(rstr) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + Expect(serverConn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(serverConn.CloseWithError(0, "")).To(Succeed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + + var processedFirst bool + for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, f := range p.frames { + if sf, ok := f.(*logging.StreamFrame); ok { + if !processedFirst { + // The first STREAM should have been sent in a 0-RTT packet. + // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. + Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(sf.Length).To(BeEquivalentTo(3)) + processedFirst = true + } else { + Fail("STREAM was shouldn't have been sent in 0-RTT") + } + } + } + } + }, + Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }), + Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }), + ) + + for _, l := range []int{0, 15} { + connIDLen := l + + It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + // now dial new connection with different transport parameters + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + MaxIncomingUniStreams: 1, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + // The client remembers that it was allowed to open 2 uni-directional streams. + firstStr, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}, 2) + go func() { + defer GinkgoRecover() + defer func() { written <- struct{}{} }() + _, err := firstStr.Write([]byte("first flight")) + Expect(err).ToNot(HaveOccurred()) + }() + secondStr, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + go func() { + defer GinkgoRecover() + defer func() { written <- struct{}{} }() + _, err := secondStr.Write([]byte("first flight")) + Expect(err).ToNot(HaveOccurred()) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err = conn.AcceptStream(ctx) + Expect(err).To(MatchError(quic.Err0RTTRejected)) + Eventually(written).Should(Receive()) + Eventually(written).Should(Receive()) + _, err = firstStr.Write([]byte("foobar")) + Expect(err).To(MatchError(quic.Err0RTTRejected)) + _, err = conn.OpenUniStream() + Expect(err).To(MatchError(quic.Err0RTTRejected)) + + _, err = conn.AcceptStream(ctx) + Expect(err).To(Equal(quic.Err0RTTRejected)) + + newConn := conn.NextConnection() + str, err := newConn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = newConn.OpenUniStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too many open streams")) + _, err = str.Write([]byte("second flight")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + } + + It("queues 0-RTT packets, if the Initial is delayed", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: ln.Addr().String(), + DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { + if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client + return rtt/2 + rtt + } + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) + + Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) + Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) + }) +}) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 460a9c75970..b3bb32fe84b 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -1,3 +1,5 @@ +//go:build go1.21 + package self_test import ( diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index d6ed68981f7..114c9e03ff7 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -8,10 +8,12 @@ import ( "fmt" "io" "sync" + "sync/atomic" "time" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" + "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/logging" @@ -26,7 +28,7 @@ const clientSessionStateRevision = 3 type cryptoSetup struct { tlsConf *tls.Config - conn *tls.QUICConn + conn *qtls.QUICConn version protocol.VersionNumber @@ -61,7 +63,7 @@ type cryptoSetup struct { handshakeOpener LongHeaderOpener handshakeSealer LongHeaderSealer - used0RTT bool + used0RTT atomic.Bool oneRTTStream io.Writer aead *updatableAEAD @@ -100,17 +102,11 @@ func NewCryptoSetupClient( tlsConf = tlsConf.Clone() tlsConf.MinVersion = tls.VersionTLS13 - if tlsConf.ClientSessionCache != nil { - origCache := tlsConf.ClientSessionCache - tlsConf.ClientSessionCache = &clientSessionCache{ - wrapped: origCache, - getData: cs.marshalDataForSessionState, - setData: cs.handleDataFromSessionState, - } - } + quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} + qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState) cs.tlsConf = tlsConf - cs.conn = tls.QUICClient(&tls.QUICConfig{TLSConfig: cs.tlsConf}) + cs.conn = qtls.QUICClient(quicConf) cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) return cs, clientHelloWritten @@ -144,58 +140,11 @@ func NewCryptoSetupServer( ) cs.allow0RTT = allow0RTT - // TODO: this is a hack to initialize the session ticket keys - tlsConf.DecryptTicket([]byte("foobar"), tls.ConnectionState{}) - tlsConf = tlsConf.Clone() - tlsConf.MinVersion = tls.VersionTLS13 - // add callbacks to save transport parameters into the session ticket - origWrapSession := tlsConf.WrapSession - tlsConf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) { - // Add QUIC transport parameters if this is a 0-RTT packet. - // TODO(#3853): also save the RTT for non-0-RTT tickets - if state.EarlyData { - // At this point, crypto/tls has just called the WrapSession callback. - // state.Extra is guaranteed to be empty. - state.Extra = (&sessionTicket{ - Parameters: tp, - RTT: rttStats.SmoothedRTT(), - }).Marshal() - } - if origWrapSession != nil { - return origWrapSession(cs, state) - } - b, err := tlsConf.EncryptTicket(cs, state) - return b, err - } - origUnwrapSession := tlsConf.UnwrapSession - // UnwrapSession might be called multiple times, as the client can use multiple session tickets. - // However, using 0-RTT is only possible with the first session ticket. - // crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello. - var unwrapCount int - tlsConf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) { - unwrapCount++ - var state *tls.SessionState - var err error - if origUnwrapSession != nil { - state, err = origUnwrapSession(identity, connState) - } else { - state, err = tlsConf.DecryptTicket(identity, connState) - } - if err != nil || state == nil { - return nil, err - } - if state.EarlyData { - if unwrapCount == 1 { // first session ticket - state.EarlyData = cs.accept0RTT(state.Extra) - } else { // subsequent session ticket, can't be used for 0-RTT - state.EarlyData = false - } - } - return state, nil - } + quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} + qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT) - cs.tlsConf = tlsConf - cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: cs.tlsConf}) + cs.tlsConf = quicConf.TLSConfig + cs.conn = qtls.QUICServer(quicConf) return cs } @@ -256,7 +205,7 @@ func (h *cryptoSetup) StartHandshake() error { } for { ev := h.conn.NextEvent() - if ev.Kind == tls.QUICNoEvent { + if ev.Kind == qtls.QUICNoEvent { break } if err := h.handleEvent(ev); err != nil { @@ -277,7 +226,9 @@ func (h *cryptoSetup) StartHandshake() error { // Close closes the crypto setup. // It aborts the handshake, if it is still running. -func (h *cryptoSetup) Close() error { return h.conn.Close() } +func (h *cryptoSetup) Close() error { + return h.conn.Close() +} // HandleMessage handles a TLS handshake message. // It is called by the crypto streams when a new message is available. @@ -289,12 +240,12 @@ func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLev } func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error { - if err := h.conn.HandleData(encLevel.ToTLSEncryptionLevel(), data); err != nil { + if err := h.conn.HandleData(qtls.ToTLSEncryptionLevel(encLevel), data); err != nil { return err } for { ev := h.conn.NextEvent() - if ev.Kind == tls.QUICNoEvent { + if ev.Kind == qtls.QUICNoEvent { return nil } if err := h.handleEvent(ev); err != nil { @@ -303,25 +254,25 @@ func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLev } } -func (h *cryptoSetup) handleEvent(ev tls.QUICEvent) error { +func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) error { switch ev.Kind { - case tls.QUICSetReadSecret: + case qtls.QUICSetReadSecret: h.SetReadKey(ev.Level, ev.Suite, ev.Data) return nil - case tls.QUICSetWriteSecret: + case qtls.QUICSetWriteSecret: h.SetWriteKey(ev.Level, ev.Suite, ev.Data) return nil - case tls.QUICTransportParameters: + case qtls.QUICTransportParameters: return h.handleTransportParameters(ev.Data) - case tls.QUICTransportParametersRequired: + case qtls.QUICTransportParametersRequired: h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective)) return nil - case tls.QUICRejectedEarlyData: + case qtls.QUICRejectedEarlyData: h.rejected0RTT() return nil - case tls.QUICWriteData: + case qtls.QUICWriteData: return h.WriteRecord(ev.Level, ev.Data) - case tls.QUICHandshakeDone: + case qtls.QUICHandshakeDone: h.handshakeComplete() return nil default: @@ -377,6 +328,13 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo return &tp, nil } +func (h *cryptoSetup) getDataForSessionTicket() []byte { + return (&sessionTicket{ + Parameters: h.ourParams, + RTT: h.rttStats.SmoothedRTT(), + }).Marshal() +} + // GetSessionTicket generates a new session ticket. // Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. // It is only valid for the server. @@ -388,11 +346,11 @@ func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { return nil, err } ev := h.conn.NextEvent() - if ev.Kind != tls.QUICWriteData || ev.Level != tls.QUICEncryptionLevelApplication { + if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication { panic("crypto/tls bug: where's my session ticket?") } ticket := ev.Data - if ev := h.conn.NextEvent(); ev.Kind != tls.QUICNoEvent { + if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent { panic("crypto/tls bug: why more than one ticket?") } return ticket, nil @@ -434,12 +392,11 @@ func (h *cryptoSetup) rejected0RTT() { } } -func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { - encLevel := protocol.FromTLSEncryptionLevel(el) +func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) h.mutex.Lock() - switch encLevel { - case protocol.Encryption0RTT: + switch el { + case qtls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveClient { panic("Received 0-RTT read key for the client") } @@ -447,11 +404,11 @@ func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), ) - h.used0RTT = true + h.used0RTT.Store(true) if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case protocol.EncryptionHandshake: + case qtls.QUICEncryptionLevelHandshake: h.handshakeOpener = newHandshakeOpener( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -461,7 +418,7 @@ func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra if h.logger.Debug() { h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case protocol.Encryption1RTT: + case qtls.QUICEncryptionLevelApplication: h.aead.SetReadKey(suite, trafficSecret) h.has1RTTOpener = true if h.logger.Debug() { @@ -473,16 +430,15 @@ func (h *cryptoSetup) SetReadKey(el tls.QUICEncryptionLevel, suiteID uint16, tra h.mutex.Unlock() h.runner.OnReceivedReadKeys() if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(encLevel, h.perspective.Opposite()) + h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) } } -func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { - encLevel := protocol.FromTLSEncryptionLevel(el) +func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { suite := getCipherSuite(suiteID) h.mutex.Lock() - switch encLevel { - case protocol.Encryption0RTT: + switch el { + case qtls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveServer { panic("Received 0-RTT write key for the server") } @@ -499,7 +455,7 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr } // don't set used0RTT here. 0-RTT might still get rejected. return - case protocol.EncryptionHandshake: + case qtls.QUICEncryptionLevelHandshake: h.handshakeSealer = newHandshakeSealer( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -509,17 +465,15 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr if h.logger.Debug() { h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - if h.zeroRTTSealer != nil { - // Once we receive handshake keys, we know that 0-RTT was not rejected. - h.used0RTT = true - } - case protocol.Encryption1RTT: + case qtls.QUICEncryptionLevelApplication: h.aead.SetWriteKey(suite, trafficSecret) h.has1RTTSealer = true if h.logger.Debug() { h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } if h.zeroRTTSealer != nil { + // Once we receive handshake keys, we know that 0-RTT was not rejected. + h.used0RTT.Store(true) h.zeroRTTSealer = nil h.logger.Debugf("Dropping 0-RTT keys.") if h.tracer != nil { @@ -531,24 +485,24 @@ func (h *cryptoSetup) SetWriteKey(el tls.QUICEncryptionLevel, suiteID uint16, tr } h.mutex.Unlock() if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(encLevel, h.perspective) + h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) } } // WriteRecord is called when TLS writes data -func (h *cryptoSetup) WriteRecord(encLevel tls.QUICEncryptionLevel, p []byte) error { +func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) error { h.mutex.Lock() defer h.mutex.Unlock() var str io.Writer //nolint:exhaustive // handshake records can only be written for Initial and Handshake. switch encLevel { - case tls.QUICEncryptionLevelInitial: + case qtls.QUICEncryptionLevelInitial: // assume that the first WriteRecord call contains the ClientHello str = h.initialStream - case tls.QUICEncryptionLevelHandshake: + case qtls.QUICEncryptionLevelHandshake: str = h.handshakeStream - case tls.QUICEncryptionLevelApplication: + case qtls.QUICEncryptionLevelApplication: str = h.oneRTTStream default: panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel)) @@ -691,12 +645,12 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { func (h *cryptoSetup) ConnectionState() ConnectionState { return ConnectionState{ ConnectionState: h.conn.ConnectionState(), - Used0RTT: h.used0RTT, + Used0RTT: h.used0RTT.Load(), } } func wrapError(err error) error { - if alertErr := tls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 { + if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 { return qerr.NewLocalCryptoError(uint8(alertErr), err.Error()) } return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()} diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 51d1980a01e..5a685e929f4 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -1,7 +1,6 @@ package handshake import ( - "bytes" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -30,6 +29,13 @@ var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, } +const ( + typeClientHello = 1 + typeNewSessionTicket = 4 +) + +const alertUnexpectedMessage = 10 + type chunk struct { data []byte encLevel protocol.EncryptionLevel @@ -80,54 +86,7 @@ var _ = Describe("Crypto Setup TLS", func() { } }) - It("returns Handshake() when an error occurs in qtls", func() { - sErrChan := make(chan error, 1) - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - _, sInitialStream, sHandshakeStream := initStreams() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - runner, - testdata.GetTLSConfig(), - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.Version1, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "local error: tls: unexpected message", - }))) - close(done) - }() - - fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - handledMessage := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.HandleMessage(fakeCH, protocol.EncryptionInitial) - close(handledMessage) - }() - Eventually(handledMessage).Should(BeClosed()) - Eventually(done).Should(BeClosed()) - }) - It("handles qtls errors occurring before during ClientHello generation", func() { - sErrChan := make(chan error, 1) - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) _, sInitialStream, sHandshakeStream := initStreams() tlsConf := testdata.GetTLSConfig() tlsConf.InsecureSkipVerify = true @@ -135,11 +94,10 @@ var _ = Describe("Crypto Setup TLS", func() { cl, _ := NewCryptoSetupClient( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{}, - runner, + NewMockHandshakeRunner(mockCtrl), tlsConf, false, &utils.RTTStats{}, @@ -148,32 +106,21 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.Version1, ) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cl.RunHandshake() - close(done) - }() - - Eventually(done).Should(BeClosed()) - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ + Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{ ErrorCode: qerr.InternalError, ErrorMessage: "tls: invalid NextProtos value", - }))) + })) }) It("errors when a message is received at the wrong encryption level", func() { - sErrChan := make(chan error, 1) _, sInitialStream, sHandshakeStream := initStreams() runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) var token protocol.StatelessResetToken server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{StatelessResetToken: &token}, runner, testdata.GetTLSConfig(), @@ -184,90 +131,13 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.Version1, ) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - close(done) - }() - - fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "expected handshake message ClientHello to have encryption level Initial, has Handshake", - }))) - - // make the go routine return - Expect(server.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("returns Handshake() when handling a message fails", func() { - sErrChan := make(chan error, 1) - _, sInitialStream, sHandshakeStream := initStreams() - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - runner, - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.Version1, - ) + Expect(server.StartHandshake()).To(Succeed()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - var err error - Expect(sErrChan).To(Receive(&err)) - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) - close(done) - }() - - fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...) - server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level - Eventually(done).Should(BeClosed()) - }) - - It("returns Handshake() when it is closed", func() { - _, sInitialStream, sHandshakeStream := initStreams() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - NewMockHandshakeRunner(mockCtrl), - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.Version1, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - close(done) - }() - Expect(server.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) + fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...) + // wrong encryption level + err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) }) Context("doing the handshake", func() { @@ -297,55 +167,32 @@ var _ = Describe("Crypto Setup TLS", func() { return rttStats } - handshake := func(client CryptoSetup, cChunkChan <-chan chunk, - server CryptoSetup, sChunkChan <-chan chunk, - ) { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - for { - select { - case c := <-cChunkChan: - msgType := messageType(c.data[0]) - finished := server.HandleMessage(c.data, c.encLevel) - if msgType == typeFinished { - Expect(finished).To(BeTrue()) - } else if msgType == typeClientHello { - // If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys. - _, err := server.GetHandshakeOpener() - Expect(finished).To(Equal(err == nil)) - } else { - Expect(finished).To(BeFalse()) - } - case c := <-sChunkChan: - msgType := messageType(c.data[0]) - finished := client.HandleMessage(c.data, c.encLevel) - if msgType == typeFinished { - Expect(finished).To(BeTrue()) - } else if msgType == typeServerHello { - Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom))) - } else { - Expect(finished).To(BeFalse()) - } - case <-done: // handshake complete - return - } - } - }() + handshake := func(client CryptoSetup, cChunkChan <-chan chunk, server CryptoSetup, sChunkChan <-chan chunk) { + Expect(client.StartHandshake()).To(Succeed()) + Expect(server.StartHandshake()).To(Succeed()) - go func() { - defer GinkgoRecover() - defer close(done) - server.RunHandshake() - ticket, err := server.GetSessionTicket() - Expect(err).ToNot(HaveOccurred()) - if ticket != nil { - client.HandleMessage(ticket, protocol.Encryption1RTT) + for { + select { + case c := <-cChunkChan: + Expect(server.HandleMessage(c.data, c.encLevel)).To(Succeed()) + continue + default: } - }() + select { + case c := <-sChunkChan: + Expect(client.HandleMessage(c.data, c.encLevel)).To(Succeed()) + continue + default: + } + // no more messages to send from client and server. Handshake complete? + break + } - client.RunHandshake() - Eventually(done).Should(BeClosed()) + ticket, err := server.GetSessionTicket() + Expect(err).ToNot(HaveOccurred()) + if ticket != nil { + Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed()) + } } handshakeWithTLSConf := func( @@ -359,15 +206,14 @@ var _ = Describe("Crypto Setup TLS", func() { cErrChan := make(chan error, 1) cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1) + cRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1) cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1) client, clientHelloWrittenChan := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, clientTransportParameters, cRunner, clientConf, @@ -383,7 +229,7 @@ var _ = Describe("Crypto Setup TLS", func() { sErrChan := make(chan error, 1) sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1) + sRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1) if serverTransportParameters.StatelessResetToken == nil { var token protocol.StatelessResetToken @@ -392,9 +238,8 @@ var _ = Describe("Crypto Setup TLS", func() { server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, serverTransportParameters, sRunner, serverConf, @@ -462,9 +307,8 @@ var _ = Describe("Crypto Setup TLS", func() { client, chChan := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{}, runner, &tls.Config{InsecureSkipVerify: true}, @@ -475,24 +319,15 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.Version1, ) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - client.RunHandshake() - close(done) - }() + Expect(client.StartHandshake()).To(Succeed()) var ch chunk Eventually(cChunkChan).Should(Receive(&ch)) Eventually(chChan).Should(Receive(BeNil())) // make sure the whole ClientHello was written Expect(len(ch.data)).To(BeNumerically(">=", 4)) - Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) + Expect(ch.data[0]).To(BeEquivalentTo(typeClientHello)) length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) Expect(len(ch.data) - 4).To(Equal(length)) - - // make the go routine return - Expect(client.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) }) It("receives transport parameters", func() { @@ -500,14 +335,14 @@ var _ = Describe("Crypto Setup TLS", func() { cChunkChan, cInitialStream, cHandshakeStream := initStreams() cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 0x42 * time.Second} cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedReadKeys().Times(2) cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp }) cRunner.EXPECT().OnHandshakeComplete() client, _ := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, cTransportParameters, cRunner, clientConf, @@ -521,6 +356,7 @@ var _ = Describe("Crypto Setup TLS", func() { sChunkChan, sInitialStream, sHandshakeStream := initStreams() var token protocol.StatelessResetToken sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedReadKeys().Times(2) sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp }) sRunner.EXPECT().OnHandshakeComplete() sTransportParameters := &wire.TransportParameters{ @@ -531,9 +367,8 @@ var _ = Describe("Crypto Setup TLS", func() { server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, sTransportParameters, sRunner, serverConf, @@ -561,13 +396,13 @@ var _ = Describe("Crypto Setup TLS", func() { cChunkChan, cInitialStream, cHandshakeStream := initStreams() cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnReceivedReadKeys().Times(2) cRunner.EXPECT().OnHandshakeComplete() client, _ := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, cRunner, clientConf, @@ -581,14 +416,14 @@ var _ = Describe("Crypto Setup TLS", func() { sChunkChan, sInitialStream, sHandshakeStream := initStreams() sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnReceivedReadKeys().Times(2) sRunner.EXPECT().OnHandshakeComplete() var token protocol.StatelessResetToken server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token}, sRunner, serverConf, @@ -608,25 +443,23 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(done).Should(BeClosed()) // inject an invalid session ticket - cRunner.EXPECT().OnError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake", - }) b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) - client.HandleMessage(b, protocol.EncryptionHandshake) + err := client.HandleMessage(b, protocol.EncryptionHandshake) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) }) It("errors when handling the NewSessionTicket fails", func() { cChunkChan, cInitialStream, cHandshakeStream := initStreams() cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnReceivedReadKeys().Times(2) cRunner.EXPECT().OnHandshakeComplete() client, _ := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, cRunner, clientConf, @@ -640,14 +473,14 @@ var _ = Describe("Crypto Setup TLS", func() { sChunkChan, sInitialStream, sHandshakeStream := initStreams() sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnReceivedReadKeys().Times(2) sRunner.EXPECT().OnHandshakeComplete() var token protocol.StatelessResetToken server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token}, sRunner, serverConf, @@ -667,12 +500,10 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(done).Should(BeClosed()) // inject an invalid session ticket - cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) - }) b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) - client.HandleMessage(b, protocol.Encryption1RTT) + err := client.HandleMessage(b, protocol.Encryption1RTT) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) }) It("uses session resumption", func() { @@ -785,7 +616,6 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(clientHelloWrittenChan).To(Receive(BeNil())) csc.EXPECT().Get(gomock.Any()).Return(state, true) - csc.EXPECT().Put(gomock.Any(), nil) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) clientRTTStats := &utils.RTTStats{} @@ -840,7 +670,6 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(clientHelloWrittenChan).To(Receive(BeNil())) csc.EXPECT().Get(gomock.Any()).Return(state, true) - csc.EXPECT().Put(gomock.Any(), nil) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) clientRTTStats := &utils.RTTStats{} diff --git a/internal/handshake/handshake_suite_test.go b/internal/handshake/handshake_suite_test.go index a68f2b9589a..3289928ea07 100644 --- a/internal/handshake/handshake_suite_test.go +++ b/internal/handshake/handshake_suite_test.go @@ -6,8 +6,6 @@ import ( "strings" "testing" - "github.com/quic-go/quic-go/internal/qtls" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" @@ -41,8 +39,8 @@ func splitHexString(s string) (slice []byte) { return } -var cipherSuites = []*qtls.CipherSuiteTLS13{ - qtls.CipherSuiteTLS13ByID(tls.TLS_AES_128_GCM_SHA256), - qtls.CipherSuiteTLS13ByID(tls.TLS_AES_256_GCM_SHA384), - qtls.CipherSuiteTLS13ByID(tls.TLS_CHACHA20_POLY1305_SHA256), +var cipherSuites = []*cipherSuite{ + getCipherSuite(tls.TLS_AES_128_GCM_SHA256), + getCipherSuite(tls.TLS_AES_256_GCM_SHA384), + getCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256), } diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go index f2c9c634f52..fb6092e040a 100644 --- a/internal/handshake/header_protector.go +++ b/internal/handshake/header_protector.go @@ -32,7 +32,7 @@ func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader b case tls.TLS_CHACHA20_POLY1305_SHA256: return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader, hkdfLabel) default: - panic(fmt.Sprintf("Invalid cipher suite id: %d", suite)) + panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID)) } } diff --git a/internal/protocol/encryption_level.go b/internal/protocol/encryption_level.go index 40aa331aaa6..32d38ab1e88 100644 --- a/internal/protocol/encryption_level.go +++ b/internal/protocol/encryption_level.go @@ -1,10 +1,5 @@ package protocol -import ( - "crypto/tls" - "fmt" -) - // EncryptionLevel is the encryption level // Default value is Unencrypted type EncryptionLevel uint8 @@ -33,33 +28,3 @@ func (e EncryptionLevel) String() string { } return "unknown" } - -func (e EncryptionLevel) ToTLSEncryptionLevel() tls.QUICEncryptionLevel { - switch e { - case EncryptionInitial: - return tls.QUICEncryptionLevelInitial - case EncryptionHandshake: - return tls.QUICEncryptionLevelHandshake - case Encryption1RTT: - return tls.QUICEncryptionLevelApplication - case Encryption0RTT: - return tls.QUICEncryptionLevelEarly - default: - panic(fmt.Sprintf("unexpected encryption level: %s", e)) - } -} - -func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) EncryptionLevel { - switch e { - case tls.QUICEncryptionLevelInitial: - return EncryptionInitial - case tls.QUICEncryptionLevelHandshake: - return EncryptionHandshake - case tls.QUICEncryptionLevelApplication: - return Encryption1RTT - case tls.QUICEncryptionLevelEarly: - return Encryption0RTT - default: - panic(fmt.Sprintf("unexpect encryption level: %s", e)) - } -} diff --git a/internal/qerr/error_codes.go b/internal/qerr/error_codes.go index 00361308e70..a037acd22e6 100644 --- a/internal/qerr/error_codes.go +++ b/internal/qerr/error_codes.go @@ -1,8 +1,9 @@ package qerr import ( - "crypto/tls" "fmt" + + "github.com/quic-go/quic-go/internal/qtls" ) // TransportErrorCode is a QUIC transport error. @@ -39,7 +40,7 @@ func (e TransportErrorCode) Message() string { if !e.IsCryptoError() { return "" } - return tls.AlertError(e - 0x100).Error() + return qtls.AlertError(e - 0x100).Error() } func (e TransportErrorCode) String() string { diff --git a/internal/handshake/cipher_suite_unsafe.go b/internal/qtls/cipher_suite_go121.go similarity index 98% rename from internal/handshake/cipher_suite_unsafe.go rename to internal/qtls/cipher_suite_go121.go index bd08359f8e8..aa8c768fd25 100644 --- a/internal/handshake/cipher_suite_unsafe.go +++ b/internal/qtls/cipher_suite_go121.go @@ -1,4 +1,6 @@ -package handshake +//go:build go1.21 + +package qtls import ( "crypto" diff --git a/internal/handshake/client_session_cache.go b/internal/qtls/client_session_cache.go similarity index 97% rename from internal/handshake/client_session_cache.go rename to internal/qtls/client_session_cache.go index 6636a3a11aa..b98dd4f038d 100644 --- a/internal/handshake/client_session_cache.go +++ b/internal/qtls/client_session_cache.go @@ -1,4 +1,6 @@ -package handshake +//go:build go1.21 + +package qtls import ( "crypto/tls" diff --git a/internal/qtls/go119.go b/internal/qtls/go119.go new file mode 100644 index 00000000000..f040b859c6e --- /dev/null +++ b/internal/qtls/go119.go @@ -0,0 +1,145 @@ +//go:build go1.19 && !go1.20 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "fmt" + "net" + "unsafe" + + "github.com/quic-go/qtls-go1-19" +) + +type ( + // Alert is a TLS alert + Alert = qtls.Alert + // A Certificate is qtls.Certificate. + Certificate = qtls.Certificate + // CertificateRequestInfo contains information about a certificate request. + CertificateRequestInfo = qtls.CertificateRequestInfo + // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 + CipherSuiteTLS13 = qtls.CipherSuiteTLS13 + // ClientHelloInfo contains information about a ClientHello. + ClientHelloInfo = qtls.ClientHelloInfo + // ClientSessionCache is a cache used for session resumption. + ClientSessionCache = qtls.ClientSessionCache + // ClientSessionState is a state needed for session resumption. + ClientSessionState = qtls.ClientSessionState + // A Config is a qtls.Config. + Config = qtls.Config + // A Conn is a qtls.Conn. + Conn = qtls.Conn + // ConnectionState contains information about the state of the connection. + ConnectionState = qtls.ConnectionStateWith0RTT + // EncryptionLevel is the encryption level of a message. + EncryptionLevel = qtls.EncryptionLevel + // Extension is a TLS extension + Extension = qtls.Extension + // ExtraConfig is the qtls.ExtraConfig + ExtraConfig = qtls.ExtraConfig + // RecordLayer is a qtls RecordLayer. + RecordLayer = qtls.RecordLayer +) + +const ( + // EncryptionHandshake is the Handshake encryption level + EncryptionHandshake = qtls.EncryptionHandshake + // Encryption0RTT is the 0-RTT encryption level + Encryption0RTT = qtls.Encryption0RTT + // EncryptionApplication is the application data encryption level + EncryptionApplication = qtls.EncryptionApplication +) + +// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 +func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { + return qtls.AEADAESGCMTLS13(key, fixedNonce) +} + +// Client returns a new TLS client side connection. +func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Client(conn, config, extraConfig) +} + +// Server returns a new TLS server side connection. +func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { + return qtls.Server(conn, config, extraConfig) +} + +func GetConnectionState(conn *Conn) ConnectionState { + return conn.ConnectionStateWith0RTT() +} + +// ToTLSConnectionState extracts the tls.ConnectionState +func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { + return cs.ConnectionState +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-19.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. +func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { + val := cipherSuiteTLS13ByID(id) + cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) + return &qtls.CipherSuiteTLS13{ + ID: cs.ID, + KeyLen: cs.KeyLen, + AEAD: cs.AEAD, + Hash: cs.Hash, + } +} + +//go:linkname cipherSuitesTLS13 github.com/quic-go/qtls-go1-19.cipherSuitesTLS13 +var cipherSuitesTLS13 []unsafe.Pointer + +//go:linkname defaultCipherSuitesTLS13 github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13 +var defaultCipherSuitesTLS13 []uint16 + +//go:linkname defaultCipherSuitesTLS13NoAES github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13NoAES +var defaultCipherSuitesTLS13NoAES []uint16 + +var cipherSuitesModified bool + +// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls +// such that it only contains the cipher suite with the chosen id. +// The reset function returned resets them back to the original value. +func SetCipherSuite(id uint16) (reset func()) { + if cipherSuitesModified { + panic("cipher suites modified multiple times without resetting") + } + cipherSuitesModified = true + + origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...) + origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...) + origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...) + // The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls. + switch id { + case tls.TLS_AES_128_GCM_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[:1] + case tls.TLS_CHACHA20_POLY1305_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[1:2] + case tls.TLS_AES_256_GCM_SHA384: + cipherSuitesTLS13 = cipherSuitesTLS13[2:] + default: + panic(fmt.Sprintf("unexpected cipher suite: %d", id)) + } + defaultCipherSuitesTLS13 = []uint16{id} + defaultCipherSuitesTLS13NoAES = []uint16{id} + + return func() { + cipherSuitesTLS13 = origCipherSuitesTLS13 + defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13 + defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES + cipherSuitesModified = false + } +} diff --git a/internal/qtls/go120.go b/internal/qtls/go120.go new file mode 100644 index 00000000000..494d8d29b8e --- /dev/null +++ b/internal/qtls/go120.go @@ -0,0 +1,147 @@ +//go:build go1.20 && !go1.21 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "fmt" + "unsafe" + + "github.com/quic-go/quic-go/internal/protocol" + + "github.com/quic-go/qtls-go1-20" +) + +type ( + QUICConn = qtls.QUICConn + QUICConfig = qtls.QUICConfig + QUICEvent = qtls.QUICEvent + QUICEventKind = qtls.QUICEventKind + QUICEncryptionLevel = qtls.QUICEncryptionLevel + AlertError = qtls.AlertError +) + +const ( + QUICEncryptionLevelInitial = qtls.QUICEncryptionLevelInitial + QUICEncryptionLevelEarly = qtls.QUICEncryptionLevelEarly + QUICEncryptionLevelHandshake = qtls.QUICEncryptionLevelHandshake + QUICEncryptionLevelApplication = qtls.QUICEncryptionLevelApplication +) + +const ( + QUICNoEvent = qtls.QUICNoEvent + QUICSetReadSecret = qtls.QUICSetReadSecret + QUICSetWriteSecret = qtls.QUICSetWriteSecret + QUICWriteData = qtls.QUICWriteData + QUICTransportParameters = qtls.QUICTransportParameters + QUICTransportParametersRequired = qtls.QUICTransportParametersRequired + QUICRejectedEarlyData = qtls.QUICRejectedEarlyData + QUICHandshakeDone = qtls.QUICHandshakeDone +) + +func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, accept0RTT func([]byte) bool) { + conf.ExtraConfig = &qtls.ExtraConfig{ + Enable0RTT: enable0RTT, + Accept0RTT: accept0RTT, + GetAppDataForSessionTicket: getDataForSessionTicket, + } +} + +func SetupConfigForClient(conf *QUICConfig, getDataForSessionState func() []byte, setDataFromSessionState func([]byte)) { + conf.ExtraConfig = &qtls.ExtraConfig{ + GetAppDataForSessionState: getDataForSessionState, + SetAppDataFromSessionState: setDataFromSessionState, + } +} + +func QUICServer(config *QUICConfig) *QUICConn { + return qtls.QUICServer(config) +} + +func QUICClient(config *QUICConfig) *QUICConn { + return qtls.QUICClient(config) +} + +func ToTLSEncryptionLevel(e protocol.EncryptionLevel) qtls.QUICEncryptionLevel { + switch e { + case protocol.EncryptionInitial: + return qtls.QUICEncryptionLevelInitial + case protocol.EncryptionHandshake: + return qtls.QUICEncryptionLevelHandshake + case protocol.Encryption1RTT: + return qtls.QUICEncryptionLevelApplication + case protocol.Encryption0RTT: + return qtls.QUICEncryptionLevelEarly + default: + panic(fmt.Sprintf("unexpected encryption level: %s", e)) + } +} + +func FromTLSEncryptionLevel(e qtls.QUICEncryptionLevel) protocol.EncryptionLevel { + switch e { + case qtls.QUICEncryptionLevelInitial: + return protocol.EncryptionInitial + case qtls.QUICEncryptionLevelHandshake: + return protocol.EncryptionHandshake + case qtls.QUICEncryptionLevelApplication: + return protocol.Encryption1RTT + case qtls.QUICEncryptionLevelEarly: + return protocol.Encryption0RTT + default: + panic(fmt.Sprintf("unexpect encryption level: %s", e)) + } +} + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuitesTLS13 github.com/quic-go/qtls-go1-20.cipherSuitesTLS13 +var cipherSuitesTLS13 []unsafe.Pointer + +//go:linkname defaultCipherSuitesTLS13 github.com/quic-go/qtls-go1-20.defaultCipherSuitesTLS13 +var defaultCipherSuitesTLS13 []uint16 + +//go:linkname defaultCipherSuitesTLS13NoAES github.com/quic-go/qtls-go1-20.defaultCipherSuitesTLS13NoAES +var defaultCipherSuitesTLS13NoAES []uint16 + +var cipherSuitesModified bool + +// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls +// such that it only contains the cipher suite with the chosen id. +// The reset function returned resets them back to the original value. +func SetCipherSuite(id uint16) (reset func()) { + if cipherSuitesModified { + panic("cipher suites modified multiple times without resetting") + } + cipherSuitesModified = true + + origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...) + origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...) + origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...) + // The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls. + switch id { + case tls.TLS_AES_128_GCM_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[:1] + case tls.TLS_CHACHA20_POLY1305_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[1:2] + case tls.TLS_AES_256_GCM_SHA384: + cipherSuitesTLS13 = cipherSuitesTLS13[2:] + default: + panic(fmt.Sprintf("unexpected cipher suite: %d", id)) + } + defaultCipherSuitesTLS13 = []uint16{id} + defaultCipherSuitesTLS13NoAES = []uint16{id} + + return func() { + cipherSuitesTLS13 = origCipherSuitesTLS13 + defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13 + defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES + cipherSuitesModified = false + } +} diff --git a/internal/qtls/go121.go b/internal/qtls/go121.go new file mode 100644 index 00000000000..b63ef88596d --- /dev/null +++ b/internal/qtls/go121.go @@ -0,0 +1,142 @@ +//go:build go1.21 + +package qtls + +import ( + "crypto/tls" + "fmt" + + "github.com/quic-go/quic-go/internal/protocol" +) + +type ( + QUICConn = tls.QUICConn + QUICConfig = tls.QUICConfig + QUICEvent = tls.QUICEvent + QUICEventKind = tls.QUICEventKind + QUICEncryptionLevel = tls.QUICEncryptionLevel + AlertError = tls.AlertError +) + +const ( + QUICEncryptionLevelInitial = tls.QUICEncryptionLevelInitial + QUICEncryptionLevelEarly = tls.QUICEncryptionLevelEarly + QUICEncryptionLevelHandshake = tls.QUICEncryptionLevelHandshake + QUICEncryptionLevelApplication = tls.QUICEncryptionLevelApplication +) + +const ( + QUICNoEvent = tls.QUICNoEvent + QUICSetReadSecret = tls.QUICSetReadSecret + QUICSetWriteSecret = tls.QUICSetWriteSecret + QUICWriteData = tls.QUICWriteData + QUICTransportParameters = tls.QUICTransportParameters + QUICTransportParametersRequired = tls.QUICTransportParametersRequired + QUICRejectedEarlyData = tls.QUICRejectedEarlyData + QUICHandshakeDone = tls.QUICHandshakeDone +) + +// ExtraConfig is not used in this Go version +type ExtraConfig struct{} + +func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) } +func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) } + +// InitSessionTickets is a hack to initialize the session ticket keys +// TODO: find a better way. +func InitSessionTickets(conf *tls.Config) { + conf.DecryptTicket([]byte("foobar"), tls.ConnectionState{}) +} + +func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) { + conf := qconf.TLSConfig + InitSessionTickets(conf) + conf = conf.Clone() + conf.MinVersion = tls.VersionTLS13 + qconf.TLSConfig = conf + + // add callbacks to save transport parameters into the session ticket + origWrapSession := conf.WrapSession + conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) { + // Add QUIC transport parameters if this is a 0-RTT packet. + // TODO(#3853): also save the RTT for non-0-RTT tickets + if state.EarlyData { + // At this point, crypto/tls has just called the WrapSession callback. + // state.Extra is guaranteed to be empty. + state.Extra = getData() + } + if origWrapSession != nil { + return origWrapSession(cs, state) + } + b, err := conf.EncryptTicket(cs, state) + return b, err + } + origUnwrapSession := conf.UnwrapSession + // UnwrapSession might be called multiple times, as the client can use multiple session tickets. + // However, using 0-RTT is only possible with the first session ticket. + // crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello. + var unwrapCount int + conf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) { + unwrapCount++ + var state *tls.SessionState + var err error + if origUnwrapSession != nil { + state, err = origUnwrapSession(identity, connState) + } else { + state, err = conf.DecryptTicket(identity, connState) + } + if err != nil || state == nil { + return nil, err + } + if state.EarlyData { + if unwrapCount == 1 { // first session ticket + state.EarlyData = accept0RTT(state.Extra) + } else { // subsequent session ticket, can't be used for 0-RTT + state.EarlyData = false + } + } + return state, nil + } +} + +func SetupConfigForClient(qconf *QUICConfig, getData func() []byte, setData func([]byte)) { + conf := qconf.TLSConfig + if conf.ClientSessionCache != nil { + origCache := conf.ClientSessionCache + conf.ClientSessionCache = &clientSessionCache{ + wrapped: origCache, + getData: getData, + setData: setData, + } + } +} + +func ToTLSEncryptionLevel(e protocol.EncryptionLevel) tls.QUICEncryptionLevel { + switch e { + case protocol.EncryptionInitial: + return tls.QUICEncryptionLevelInitial + case protocol.EncryptionHandshake: + return tls.QUICEncryptionLevelHandshake + case protocol.Encryption1RTT: + return tls.QUICEncryptionLevelApplication + case protocol.Encryption0RTT: + return tls.QUICEncryptionLevelEarly + default: + panic(fmt.Sprintf("unexpected encryption level: %s", e)) + } +} + +func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) protocol.EncryptionLevel { + switch e { + case tls.QUICEncryptionLevelInitial: + return protocol.EncryptionInitial + case tls.QUICEncryptionLevelHandshake: + return protocol.EncryptionHandshake + case tls.QUICEncryptionLevelApplication: + return protocol.Encryption1RTT + case tls.QUICEncryptionLevelEarly: + return protocol.Encryption0RTT + default: + panic(fmt.Sprintf("unexpect encryption level: %s", e)) + } +} diff --git a/internal/qtls/go_oldversion.go b/internal/qtls/go_oldversion.go new file mode 100644 index 00000000000..e15f03629a6 --- /dev/null +++ b/internal/qtls/go_oldversion.go @@ -0,0 +1,5 @@ +//go:build !go1.19 + +package qtls + +var _ int = "The version of quic-go you're using can't be built using outdated Go versions. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions." diff --git a/internal/qtls/qtls_suite_test.go b/internal/qtls/qtls_suite_test.go new file mode 100644 index 00000000000..e8ce652a1b9 --- /dev/null +++ b/internal/qtls/qtls_suite_test.go @@ -0,0 +1,25 @@ +package qtls + +import ( + "testing" + + gomock "github.com/golang/mock/gomock" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestQTLS(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "qtls Suite") +} + +var mockCtrl *gomock.Controller + +var _ = BeforeEach(func() { + mockCtrl = gomock.NewController(GinkgoT()) +}) + +var _ = AfterEach(func() { + mockCtrl.Finish() +}) diff --git a/internal/testdata/cert.go b/internal/testdata/cert.go index 6cc7a091e51..f77a7b2ddbe 100644 --- a/internal/testdata/cert.go +++ b/internal/testdata/cert.go @@ -31,6 +31,7 @@ func GetTLSConfig() *tls.Config { panic(err) } return &tls.Config{ + MinVersion: tls.VersionTLS13, Certificates: []tls.Certificate{cert}, } }