Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement HTTP/3 unidirectional stream hijacking #3389

Merged
merged 3 commits into from Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions http3/client.go
Expand Up @@ -44,6 +44,7 @@ type roundTripperOpts struct {
MaxHeaderBytes int64
AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
}

// client is a HTTP3 client doing requests
Expand Down Expand Up @@ -174,7 +175,7 @@ func (c *client) handleUnidirectionalStreams() {
return
}

go func() {
go func(str quic.ReceiveStream) {
marten-seemann marked this conversation as resolved.
Show resolved Hide resolved
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
Expand All @@ -192,6 +193,9 @@ func (c *client) handleUnidirectionalStreams() {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
return
default:
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
return
}
Expand All @@ -214,7 +218,7 @@ func (c *client) handleUnidirectionalStreams() {
if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
}
}()
}(str)
}
}

Expand Down
83 changes: 83 additions & 0 deletions http3/client_test.go
Expand Up @@ -185,6 +185,89 @@ var _ = Describe("Client", func() {
})
})

Context("hijacking unistreams", func() {
hareku marked this conversation as resolved.
Show resolved Hide resolved
var (
request *http.Request
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
)
testDone := make(chan struct{})

BeforeEach(func() {
testDone = make(chan struct{})
settingsFrameWritten = make(chan struct{})
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
defer GinkgoRecover()
close(settingsFrameWritten)
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})

AfterEach(func() {
testDone <- struct{}{}
Eventually(settingsFrameWritten).Should(BeClosed())
})

It("hijacks an unistream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return true
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels reading when hijacker didn't hijack an unistream", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return false
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
})

Context("control stream handling", func() {
var (
request *http.Request
Expand Down
4 changes: 4 additions & 0 deletions http3/roundtrip.go
Expand Up @@ -58,6 +58,9 @@ type RoundTripper struct {
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)

// When set, this callback is called for the first unknown stream type parsed on a unidirectional receive stream.
hareku marked this conversation as resolved.
Show resolved Hide resolved
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)

// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddrEarlyContext will be used.
Expand Down Expand Up @@ -154,6 +157,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
DisableCompression: r.DisableCompression,
MaxHeaderBytes: r.MaxResponseHeaderBytes,
StreamHijacker: r.StreamHijacker,
UniStreamHijacker: r.UniStreamHijacker,
},
r.QuicConfig,
r.Dial,
Expand Down
8 changes: 8 additions & 0 deletions http3/server.go
Expand Up @@ -33,6 +33,8 @@ const (
nextProtoH3 = "h3"
)

type StreamType uint64
hareku marked this conversation as resolved.
Show resolved Hide resolved

const (
streamTypeControlStream = 0
streamTypePushStream = 1
Expand Down Expand Up @@ -151,6 +153,9 @@ type Server struct {
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)

// When set, this callback is called for the first unknown stream type parsed on a unidirectional receive stream.
hareku marked this conversation as resolved.
Show resolved Hide resolved
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)

mutex sync.RWMutex
listeners map[*quic.EarlyListener]listenerInfo

Expand Down Expand Up @@ -421,6 +426,9 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
return
default:
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
return
}
Expand Down
66 changes: 66 additions & 0 deletions http3/server_test.go
Expand Up @@ -238,6 +238,72 @@ var _ = Describe("Server", func() {
Expect(serr.err).ToNot(HaveOccurred())
})

Context("hijacking unistreams", func() {
hareku marked this conversation as resolved.
Show resolved Hide resolved
var conn *mockquic.MockEarlyConnection
testDone := make(chan struct{})

BeforeEach(func() {
testDone = make(chan struct{})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any())
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes()
})

AfterEach(func() { testDone <- struct{}{} })

It("hijacks an unistream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return true
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(conn)
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels reading when hijacker didn't hijack an unistream", func() {
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return false
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError))

conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(conn)
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
})

Context("control stream handling", func() {
var conn *mockquic.MockEarlyConnection
testDone := make(chan struct{})
Expand Down