Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass frame / stream type parsing errors to the hijacker callbacks #3429

Merged
merged 2 commits into from May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 15 additions & 14 deletions http3/client.go
Expand Up @@ -43,8 +43,8 @@ type roundTripperOpts struct {
EnableDatagram bool
MaxHeaderBytes int64
AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
}

// client is a HTTP3 client doing requests
Expand Down Expand Up @@ -151,18 +151,16 @@ func (c *client) handleBidirectionalStreams() {
return
}
go func(str quic.Stream) {
for {
_, err := parseNextFrame(str, func(ft FrameType) (processed bool, err error) {
return c.opts.StreamHijacker(ft, c.conn, str)
})
if err == errHijacked {
return
}
if err != nil {
c.logger.Debugf("error handling stream: %s", err)
}
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
return c.opts.StreamHijacker(ft, c.conn, str, e)
})
if err == errHijacked {
return
}
if err != nil {
c.logger.Debugf("error handling stream: %s", err)
}
c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), "received HTTP/3 frame on bidirectional stream")
}(str)
}
}
Expand All @@ -178,6 +176,9 @@ func (c *client) handleUnidirectionalStreams() {
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) {
return
}
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
return
}
Expand All @@ -193,7 +194,7 @@ func (c *client) handleUnidirectionalStreams() {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
return
default:
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str) {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
Expand Down
64 changes: 59 additions & 5 deletions http3/client_test.go
Expand Up @@ -221,7 +221,8 @@ var _ = Describe("Client", func() {

It("hijacks a bidirectional stream of unknown frame type", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return true, nil
}
Expand All @@ -243,7 +244,8 @@ var _ = Describe("Client", func() {

It("closes the connection when hijacker didn't hijack a bidirectional stream", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, nil
}
Expand All @@ -265,7 +267,8 @@ var _ = Describe("Client", func() {

It("closes the connection when hijacker returned error", func() {
frameTypeChan := make(chan FrameType, 1)
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream) (hijacked bool, err error) {
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, s quic.Stream, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
frameTypeChan <- ft
return false, errors.New("error in hijacker")
}
Expand All @@ -284,6 +287,31 @@ var _ = Describe("Client", func() {
Expect(err).To(MatchError("done"))
Eventually(frameTypeChan).Should(Receive(BeEquivalentTo(0x41)))
})

It("handles errors that occur when reading the frame type", func() {
testErr := errors.New("test error")
unknownStr := mockquic.NewMockStream(mockCtrl)
done := make(chan struct{})
client.opts.StreamHijacker = func(ft FrameType, c quic.Connection, str quic.Stream, e error) (hijacked bool, err error) {
defer close(done)
Expect(e).To(MatchError(testErr))
Expect(ft).To(BeZero())
Expect(str).To(Equal(unknownStr))
return false, nil
}

unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr).AnyTimes()
conn.EXPECT().AcceptStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.Stream, error) {
<-testDone
return nil, errors.New("test done")
})
conn.EXPECT().CloseWithError(quic.ApplicationErrorCode(errorFrameUnexpected), gomock.Any()).Return(nil).AnyTimes()
_, err := client.RoundTripOpt(request, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})
})

Context("hijacking unidirectional streams", func() {
Expand Down Expand Up @@ -321,7 +349,8 @@ var _ = Describe("Client", func() {

It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return true
}
Expand All @@ -343,9 +372,34 @@ var _ = Describe("Client", func() {
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("handles errors that occur when reading the stream type", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
defer close(done)
Expect(st).To(BeZero())
Expect(str).To(Equal(unknownStr))
Expect(err).To(MatchError(testErr))
return true
}

unknownStr.EXPECT().Read(gomock.Any()).Return(0, testErr)
conn.EXPECT().AcceptUniStream(gomock.Any()).Return(unknownStr, nil)
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return false
}
Expand Down
16 changes: 12 additions & 4 deletions http3/frames.go
Expand Up @@ -14,7 +14,7 @@ import (
// FrameType is the frame type of a HTTP/3 frame
type FrameType uint64

type unknownFrameHandlerFunc func(FrameType) (processed bool, err error)
type unknownFrameHandlerFunc func(FrameType, error) (processed bool, err error)

type frame interface{}

Expand All @@ -25,19 +25,27 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f
for {
t, err := quicvarint.Read(qr)
if err != nil {
if unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(0, err)
if err != nil {
return nil, err
}
if hijacked {
return nil, errHijacked
}
}
return nil, err
}
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
if t > 0xd && unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(FrameType(t))
hijacked, err := unknownFrameHandler(FrameType(t), nil)
if err != nil {
return nil, err
}
// If the unknownFrameHandler didn't process the frame, it is our responsibility to skip it.
if hijacked {
return nil, errHijacked
}
continue
// If the unknownFrameHandler didn't process the frame, it is our responsibility to skip it.
}
l, err := quicvarint.Read(qr)
if err != nil {
Expand Down
29 changes: 23 additions & 6 deletions http3/frames_test.go
Expand Up @@ -2,6 +2,7 @@ package http3

import (
"bytes"
"errors"
"fmt"
"io"

Expand All @@ -11,6 +12,10 @@ import (
. "github.com/onsi/gomega"
)

type errReader struct{ err error }

func (e errReader) Read([]byte) (int, error) { return 0, e.err }

var _ = Describe("Frames", func() {
appendVarInt := func(b []byte, val uint64) []byte {
buf := &bytes.Buffer{}
Expand Down Expand Up @@ -189,7 +194,8 @@ var _ = Describe("Frames", func() {
buf.Write(customFrameContents)

var called bool
_, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) {
_, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
Expect(ft).To(BeEquivalentTo(1337))
called = true
b := make([]byte, 3)
Expand All @@ -202,22 +208,33 @@ var _ = Describe("Frames", func() {
Expect(called).To(BeTrue())
})

It("passes on errors that occur when reading the frame type", func() {
testErr := errors.New("test error")
var called bool
_, err := parseNextFrame(errReader{err: testErr}, func(ft FrameType, e error) (hijacked bool, err error) {
Expect(e).To(MatchError(testErr))
Expect(ft).To(BeZero())
called = true
return true, nil
})
Expect(err).To(MatchError(errHijacked))
Expect(called).To(BeTrue())
})

It("reads a frame without hijacking the stream", func() {
buf := &bytes.Buffer{}
quicvarint.Write(buf, 1337)
customFrameContents := []byte("custom frame")
quicvarint.Write(buf, uint64(len(customFrameContents)))
buf.Write(customFrameContents)
(&dataFrame{Length: 6}).Write(buf)
buf.WriteString("foobar")

var called bool
frame, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) {
frame, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
Expect(ft).To(BeEquivalentTo(1337))
called = true
b := make([]byte, len(customFrameContents))
_, err = io.ReadFull(buf, b)
Expect(err).ToNot(HaveOccurred())
Expect(string(b)).To(Equal(string(customFrameContents)))
return false, nil
})
Expect(err).ToNot(HaveOccurred())
Expand Down
10 changes: 7 additions & 3 deletions http3/roundtrip.go
Expand Up @@ -53,13 +53,17 @@ type RoundTripper struct {

// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
// It is called right after parsing the frame type.
// Callers can either process the frame and return control of the stream back to HTTP/3
// If parsing the frame type fails, the error is passed to the callback.
// In that case, the frame type will not be set.
// Callers can either ignore the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)

// When set, this callback is called for unknown unidirectional stream of unknown stream type.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
// If parsing the stream type fails, the error is passed to the callback.
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)

// Dial specifies an optional dial function for creating QUIC
// connections for requests.
Expand Down
19 changes: 12 additions & 7 deletions http3/server.go
Expand Up @@ -178,13 +178,17 @@ type Server struct {

// StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream.
// It is called right after parsing the frame type.
// Callers can either process the frame and return control of the stream back to HTTP/3
// If parsing the frame type fails, the error is passed to the callback.
// In that case, the frame type will not be set.
// Callers can either ignore the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)

// UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
// If parsing the stream type fails, the error is passed to the callback.
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)

mutex sync.RWMutex
listeners map[*quic.EarlyListener]listenerInfo
Expand Down Expand Up @@ -457,6 +461,9 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, err) {
return
}
s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
return
}
Expand All @@ -471,7 +478,7 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
return
default:
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str) {
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
Expand Down Expand Up @@ -510,9 +517,7 @@ func (s *Server) maxHeaderBytes() uint64 {
func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
var ufh unknownFrameHandlerFunc
if s.StreamHijacker != nil {
ufh = func(ft FrameType) (processed bool, err error) {
return s.StreamHijacker(ft, conn, str)
}
ufh = func(ft FrameType, e error) (processed bool, err error) { return s.StreamHijacker(ft, conn, str, e) }
}
frame, err := parseNextFrame(str, ufh)
if err != nil {
Expand Down