Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement HTTP/3 unidirectional stream hijacking #3389

Merged
merged 3 commits into from Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions http3/client.go
Expand Up @@ -44,6 +44,7 @@ type roundTripperOpts struct {
MaxHeaderBytes int64
AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
}

// client is a HTTP3 client doing requests
Expand Down Expand Up @@ -174,7 +175,7 @@ func (c *client) handleUnidirectionalStreams() {
return
}

go func() {
go func(str quic.ReceiveStream) {
marten-seemann marked this conversation as resolved.
Show resolved Hide resolved
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
Expand All @@ -192,6 +193,9 @@ func (c *client) handleUnidirectionalStreams() {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
return
default:
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
return
}
Expand All @@ -214,7 +218,7 @@ func (c *client) handleUnidirectionalStreams() {
if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
}
}()
}(str)
}
}

Expand Down
83 changes: 83 additions & 0 deletions http3/client_test.go
Expand Up @@ -185,6 +185,89 @@ var _ = Describe("Client", func() {
})
})

Context("hijacking unidirectional streams", func() {
var (
request *http.Request
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
)
testDone := make(chan struct{})

BeforeEach(func() {
testDone = make(chan struct{})
settingsFrameWritten = make(chan struct{})
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
defer GinkgoRecover()
close(settingsFrameWritten)
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})

AfterEach(func() {
testDone <- struct{}{}
Eventually(settingsFrameWritten).Should(BeClosed())
})

It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return true
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return false
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
})

Context("control stream handling", func() {
var (
request *http.Request
Expand Down
4 changes: 4 additions & 0 deletions http3/roundtrip.go
Expand Up @@ -58,6 +58,9 @@ type RoundTripper struct {
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)

// When set, this callback is called for unknown unidirectional stream of unknown stream type.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)

// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddrEarlyContext will be used.
Expand Down Expand Up @@ -154,6 +157,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
DisableCompression: r.DisableCompression,
MaxHeaderBytes: r.MaxResponseHeaderBytes,
StreamHijacker: r.StreamHijacker,
UniStreamHijacker: r.UniStreamHijacker,
},
r.QuicConfig,
r.Dial,
Expand Down
9 changes: 9 additions & 0 deletions http3/server.go
Expand Up @@ -33,6 +33,9 @@ const (
nextProtoH3 = "h3"
)

// StreamType is the stream type of a unidirectional stream.
type StreamType uint64
hareku marked this conversation as resolved.
Show resolved Hide resolved

const (
streamTypeControlStream = 0
streamTypePushStream = 1
Expand Down Expand Up @@ -151,6 +154,9 @@ type Server struct {
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)

// When set, this callback is called for unknown unidirectional stream of unknown stream type.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)

mutex sync.RWMutex
listeners map[*quic.EarlyListener]listenerInfo

Expand Down Expand Up @@ -421,6 +427,9 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
return
default:
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
return
}
Expand Down
66 changes: 66 additions & 0 deletions http3/server_test.go
Expand Up @@ -238,6 +238,72 @@ var _ = Describe("Server", func() {
Expect(serr.err).ToNot(HaveOccurred())
})

Context("hijacking unidirectional streams", func() {
var conn *mockquic.MockEarlyConnection
testDone := make(chan struct{})

BeforeEach(func() {
testDone = make(chan struct{})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any())
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes()
})

AfterEach(func() { testDone <- struct{}{} })

It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return true
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(conn)
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return false
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError))

conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(conn)
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
})

Context("control stream handling", func() {
var conn *mockquic.MockEarlyConnection
testDone := make(chan struct{})
Expand Down