diff --git a/http3/client.go b/http3/client.go index 27c4947b1d3..43d65b327e4 100644 --- a/http3/client.go +++ b/http3/client.go @@ -44,6 +44,7 @@ type roundTripperOpts struct { MaxHeaderBytes int64 AdditionalSettings map[uint64]uint64 StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error) + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool) } // client is a HTTP3 client doing requests @@ -174,7 +175,7 @@ func (c *client) handleUnidirectionalStreams() { return } - go func() { + go func(str quic.ReceiveStream) { streamType, err := quicvarint.Read(quicvarint.NewReader(str)) if err != nil { c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err) @@ -192,6 +193,9 @@ func (c *client) handleUnidirectionalStreams() { c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "") return default: + if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str) { + return + } str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) return } @@ -214,7 +218,7 @@ func (c *client) handleUnidirectionalStreams() { if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams { c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support") } - }() + }(str) } } diff --git a/http3/client_test.go b/http3/client_test.go index b027d8ef522..b13993f5697 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -185,6 +185,89 @@ var _ = Describe("Client", func() { }) }) + Context("hijacking unidirectional streams", func() { + var ( + request *http.Request + conn *mockquic.MockEarlyConnection + settingsFrameWritten chan struct{} + ) + testDone := make(chan struct{}) + + BeforeEach(func() { + testDone = make(chan struct{}) + settingsFrameWritten = make(chan struct{}) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) { + defer GinkgoRecover() + close(settingsFrameWritten) + }) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().HandshakeComplete().Return(handshakeCtx) + conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) { + return conn, nil + } + var err error + request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + testDone <- struct{}{} + Eventually(settingsFrameWritten).Should(BeClosed()) + }) + + It("hijacks an unidirectional stream of unknown stream type", func() { + streamTypeChan := make(chan StreamType, 1) + client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool { + streamTypeChan <- st + return true + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTrip(request) + Expect(err).To(MatchError("done")) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { + streamTypeChan := make(chan StreamType, 1) + client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool { + streamTypeChan <- st + return false + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + _, err := client.RoundTrip(request) + Expect(err).To(MatchError("done")) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + }) + Context("control stream handling", func() { var ( request *http.Request diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 743ff0341af..6ba251cbb43 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -58,6 +58,9 @@ type RoundTripper struct { // Alternatively, callers can take over the QUIC stream (by returning hijacked true). StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error) + // When set, this callback is called for unknown unidirectional stream of unknown stream type. + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool) + // Dial specifies an optional dial function for creating QUIC // connections for requests. // If Dial is nil, quic.DialAddrEarlyContext will be used. @@ -154,6 +157,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr DisableCompression: r.DisableCompression, MaxHeaderBytes: r.MaxResponseHeaderBytes, StreamHijacker: r.StreamHijacker, + UniStreamHijacker: r.UniStreamHijacker, }, r.QuicConfig, r.Dial, diff --git a/http3/server.go b/http3/server.go index e1d818acc67..3897bbb9ffb 100644 --- a/http3/server.go +++ b/http3/server.go @@ -33,6 +33,9 @@ const ( nextProtoH3 = "h3" ) +// StreamType is the stream type of a unidirectional stream. +type StreamType uint64 + const ( streamTypeControlStream = 0 streamTypePushStream = 1 @@ -151,6 +154,9 @@ type Server struct { // Alternatively, callers can take over the QUIC stream (by returning hijacked true). StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error) + // When set, this callback is called for unknown unidirectional stream of unknown stream type. + UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool) + mutex sync.RWMutex listeners map[*quic.EarlyListener]listenerInfo @@ -421,6 +427,9 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) { conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "") return default: + if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str) { + return + } str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) return } diff --git a/http3/server_test.go b/http3/server_test.go index fd6091df687..f64669722b9 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -238,6 +238,72 @@ var _ = Describe("Server", func() { Expect(serr.err).ToNot(HaveOccurred()) }) + Context("hijacking unidirectional streams", func() { + var conn *mockquic.MockEarlyConnection + testDone := make(chan struct{}) + + BeforeEach(func() { + testDone = make(chan struct{}) + conn = mockquic.NewMockEarlyConnection(mockCtrl) + controlStr := mockquic.NewMockStream(mockCtrl) + controlStr.EXPECT().Write(gomock.Any()) + conn.EXPECT().OpenUniStream().Return(controlStr, nil) + conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done")) + conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes() + conn.EXPECT().LocalAddr().AnyTimes() + }) + + AfterEach(func() { testDone <- struct{}{} }) + + It("hijacks an unidirectional stream of unknown stream type", func() { + streamTypeChan := make(chan StreamType, 1) + s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool { + streamTypeChan <- st + return true + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + + It("cancels reading when hijacker didn't hijack an unidirectional stream", func() { + streamTypeChan := make(chan StreamType, 1) + s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool { + streamTypeChan <- st + return false + } + + buf := &bytes.Buffer{} + quicvarint.Write(buf, 0x54) + unknownStr := mockquic.NewMockStream(mockCtrl) + unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError)) + + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + return unknownStr, nil + }) + conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + s.handleConn(conn) + Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54))) + time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError + }) + }) + Context("control stream handling", func() { var conn *mockquic.MockEarlyConnection testDone := make(chan struct{})