diff --git a/http3/client_test.go b/http3/client_test.go index 3cdf3a881d2..eb1383fbb7f 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -185,6 +185,107 @@ var _ = Describe("Client", func() { }) }) + Context("hijacking bidirectional streams", func() { + var ( + request *http.Request + conn *mockquic.MockEarlyConnection + settingsFrameWritten chan struct{} + ) + testDone := make(chan struct{}) + + BeforeEach(func() { + testDone = make(chan struct{}) + settingsFrameWritten = make(chan struct{}) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + defer GinkgoRecover() + close(settingsFrameWritten) + }) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().AcceptUniStream(gomock.Any()).Return(nil, errors.New("done")).AnyTimes() + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return conn, nil + } + var err error + request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + testDone <- struct{}{} + Eventually(settingsFrameWritten).Should(BeClosed()) + }) + + It("hijacks a bidirectional stream of unknown frame type", func() { + frameTypeChan := make(chan FrameType, 1) + client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) { + frameTypeChan <- ft + return true, nil + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTripOpt(request, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("closes the connection when hijacker didn't hijack a bidirectional stream", func() { + frameTypeChan := make(chan FrameType, 1) + client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) { + frameTypeChan <- ft + return false, nil + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { + <-testDone + return nil, errors.New("test done") + }) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() + _, err := client.RoundTripOpt(request, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + }) + + It("closes the connection when hijacker returned error", func() { + frameTypeChan := make(chan FrameType, 1) + client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) { + frameTypeChan <- ft + return false, errors.New("error in hijacker") + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) { + <-testDone + return nil, errors.New("test done") + }) + conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes() + _, err := client.RoundTripOpt(request, RoundTripOpt{}) + Expect(err).To(MatchError("done")) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + }) + }) + Context("hijacking unidirectional streams", func() { var ( req *http.Request diff --git a/http3/server_test.go b/http3/server_test.go index b5e23f775aa..7567f7e50d6 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -261,6 +261,91 @@ var _ = Describe("Server", func() { Eventually(done).Should(BeClosed()) }) + Context("hijacking bidirectional streams", func() { + var conn *mockquic.MockEarlyConnection + testDone := make(chan struct{}) + + BeforeEach(func() { + testDone = make(chan struct{}) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() + conn.EXPECT().LocalAddr().AnyTimes() + }) + + AfterEach(func() { testDone <- struct{}{} }) + + It("hijacks a bidirectional stream of unknown frame type", func() { + frameTypeChan := make(chan FrameType, 1) + s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) { + frameTypeChan <- ft + return true, nil + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("cancels writing when hijacker didn't hijack a bidirectional stream", func() { + frameTypeChan := make(chan FrameType, 1) + s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) { + frameTypeChan <- ft + return false, nil + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)) + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("cancels writing when hijacker returned error", func() { + frameTypeChan := make(chan FrameType, 1) + s.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) { + frameTypeChan <- ft + return false, errors.New("error in hijacker") + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x41) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + unknownStr.EXPECT().CancelWrite(quic.StreamErrorCode(errorRequestIncomplete)) + conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + }) + Context("hijacking unidirectional streams", func() { var conn *mockquic.MockEarlyConnection testDone := make(chan struct{})