Skip to content

Commit

Permalink
implement HTTP/3 unidirectional stream hijacking (#3389)
Browse files Browse the repository at this point in the history
* implement HTTP/3 unistream hijacking

* Apply suggestions from code review

Fixed name consistency.

Co-authored-by: Marten Seemann <martenseemann@gmail.com>

* rename unistream to unidirectional stream

Co-authored-by: Marten Seemann <martenseemann@gmail.com>
  • Loading branch information
hareku and marten-seemann committed Apr 21, 2022
1 parent 6d4a694 commit 1a0d577
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 2 deletions.
8 changes: 6 additions & 2 deletions http3/client.go
Expand Up @@ -44,6 +44,7 @@ type roundTripperOpts struct {
MaxHeaderBytes int64
AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
}

// client is a HTTP3 client doing requests
Expand Down Expand Up @@ -174,7 +175,7 @@ func (c *client) handleUnidirectionalStreams() {
return
}

go func() {
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
Expand All @@ -192,6 +193,9 @@ func (c *client) handleUnidirectionalStreams() {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
return
default:
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
return
}
Expand All @@ -214,7 +218,7 @@ func (c *client) handleUnidirectionalStreams() {
if c.opts.EnableDatagram && !c.conn.ConnectionState().SupportsDatagrams {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorSettingsError), "missing QUIC Datagram support")
}
}()
}(str)
}
}

Expand Down
83 changes: 83 additions & 0 deletions http3/client_test.go
Expand Up @@ -185,6 +185,89 @@ var _ = Describe("Client", func() {
})
})

Context("hijacking unidirectional streams", func() {
var (
request *http.Request
conn *mockquic.MockEarlyConnection
settingsFrameWritten chan struct{}
)
testDone := make(chan struct{})

BeforeEach(func() {
testDone = make(chan struct{})
settingsFrameWritten = make(chan struct{})
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any()).Do(func(b []byte) {
defer GinkgoRecover()
close(settingsFrameWritten)
})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().HandshakeComplete().Return(handshakeCtx)
conn.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlyConnection, error) {
return conn, nil
}
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
})

AfterEach(func() {
testDone <- struct{}{}
Eventually(settingsFrameWritten).Should(BeClosed())
})

It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return true
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return false
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError))
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTrip(request)
Expect(err).To(MatchError("done"))
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
})

Context("control stream handling", func() {
var (
request *http.Request
Expand Down
4 changes: 4 additions & 0 deletions http3/roundtrip.go
Expand Up @@ -58,6 +58,9 @@ type RoundTripper struct {
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)

// When set, this callback is called for unknown unidirectional stream of unknown stream type.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)

// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddrEarlyContext will be used.
Expand Down Expand Up @@ -154,6 +157,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr
DisableCompression: r.DisableCompression,
MaxHeaderBytes: r.MaxResponseHeaderBytes,
StreamHijacker: r.StreamHijacker,
UniStreamHijacker: r.UniStreamHijacker,
},
r.QuicConfig,
r.Dial,
Expand Down
9 changes: 9 additions & 0 deletions http3/server.go
Expand Up @@ -33,6 +33,9 @@ const (
nextProtoH3 = "h3"
)

// StreamType is the stream type of a unidirectional stream.
type StreamType uint64

const (
streamTypeControlStream = 0
streamTypePushStream = 1
Expand Down Expand Up @@ -151,6 +154,9 @@ type Server struct {
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)

// When set, this callback is called for unknown unidirectional stream of unknown stream type.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)

mutex sync.RWMutex
listeners map[*quic.EarlyListener]listenerInfo

Expand Down Expand Up @@ -421,6 +427,9 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
return
default:
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
return
}
Expand Down
66 changes: 66 additions & 0 deletions http3/server_test.go
Expand Up @@ -238,6 +238,72 @@ var _ = Describe("Server", func() {
Expect(serr.err).ToNot(HaveOccurred())
})

Context("hijacking unidirectional streams", func() {
var conn *mockquic.MockEarlyConnection
testDone := make(chan struct{})

BeforeEach(func() {
testDone = make(chan struct{})
conn = mockquic.NewMockEarlyConnection(mockCtrl)
controlStr := mockquic.NewMockStream(mockCtrl)
controlStr.EXPECT().Write(gomock.Any())
conn.EXPECT().OpenUniStream().Return(controlStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).Return(nil, errors.New("done"))
conn.EXPECT().RemoteAddr().Return(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}).AnyTimes()
conn.EXPECT().LocalAddr().AnyTimes()
})

AfterEach(func() { testDone <- struct{}{} })

It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return true
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(conn)
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
streamTypeChan <- st
return false
}

buf := &bytes.Buffer{}
quicvarint.Write(buf, 0x54)
unknownStr := mockquic.NewMockStream(mockCtrl)
unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
unknownStr.EXPECT().CancelRead(quic.StreamErrorCode(errorStreamCreationError))

conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(conn)
Eventually(streamTypeChan).Should(Receive(BeEquivalentTo(0x54)))
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
})

Context("control stream handling", func() {
var conn *mockquic.MockEarlyConnection
testDone := make(chan struct{})
Expand Down

0 comments on commit 1a0d577

Please sign in to comment.