Skip to content

Commit

Permalink
pass frame / stream type parsing errors to the hijacker callbacks
Browse files Browse the repository at this point in the history
When a stream is reset, we might not have received the frame / stream
type yet. The callback might be able to identify if it was a stream
intended for that application by analyzing the stream reset error.
  • Loading branch information
marten-seemann committed May 23, 2022
1 parent 8185d1b commit 9afa9fb
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 20 deletions.
13 changes: 8 additions & 5 deletions http3/client.go
Expand Up @@ -43,8 +43,8 @@ type roundTripperOpts struct {
EnableDatagram bool
MaxHeaderBytes int64
AdditionalSettings map[uint64]uint64
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)
}

// client is a HTTP3 client doing requests
Expand Down Expand Up @@ -152,8 +152,8 @@ func (c *client) handleBidirectionalStreams() {
}
go func(str quic.Stream) {
for {
_, err := parseNextFrame(str, func(ft FrameType) (processed bool, err error) {
return c.opts.StreamHijacker(ft, c.conn, str)
_, err := parseNextFrame(str, func(ft FrameType, e error) (processed bool, err error) {
return c.opts.StreamHijacker(ft, c.conn, str, e)
})
if err == errHijacked {
return
Expand All @@ -178,6 +178,9 @@ func (c *client) handleUnidirectionalStreams() {
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, err) {
return
}
c.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
return
}
Expand All @@ -193,7 +196,7 @@ func (c *client) handleUnidirectionalStreams() {
c.conn.CloseWithError(quic.ApplicationErrorCode(errorIDError), "")
return
default:
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str) {
if c.opts.UniStreamHijacker != nil && c.opts.UniStreamHijacker(StreamType(streamType), c.conn, str, nil) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
Expand Down
32 changes: 30 additions & 2 deletions http3/client_test.go
Expand Up @@ -220,7 +220,8 @@ var _ = Describe("Client", func() {

It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return true
}
Expand All @@ -242,9 +243,36 @@ var _ = Describe("Client", func() {
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("handles errors that occur when reading the stream type", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
defer close(done)
Expect(st).To(BeZero())
Expect(str).To(Equal(unknownStr))
Expect(err).To(MatchError(testErr))
return true
}

unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr })
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
_, err := client.RoundTripOpt(req, RoundTripOpt{})
Expect(err).To(MatchError("done"))
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
client.opts.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
client.opts.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return false
}
Expand Down
13 changes: 11 additions & 2 deletions http3/frames.go
Expand Up @@ -14,7 +14,7 @@ import (
// FrameType is the frame type of a HTTP/3 frame
type FrameType uint64

type unknownFrameHandlerFunc func(FrameType) (processed bool, err error)
type unknownFrameHandlerFunc func(FrameType, error) (processed bool, err error)

type frame interface{}

Expand All @@ -25,11 +25,20 @@ func parseNextFrame(r io.Reader, unknownFrameHandler unknownFrameHandlerFunc) (f
for {
t, err := quicvarint.Read(qr)
if err != nil {
if unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(0, err)
if err != nil {
return nil, err
}
if hijacked {
return nil, errHijacked
}
}
return nil, err
}
// Call the unknownFrameHandler for frames not defined in the HTTP/3 spec
if t > 0xd && unknownFrameHandler != nil {
hijacked, err := unknownFrameHandler(FrameType(t))
hijacked, err := unknownFrameHandler(FrameType(t), nil)
if err != nil {
return nil, err
}
Expand Down
24 changes: 22 additions & 2 deletions http3/frames_test.go
Expand Up @@ -2,6 +2,7 @@ package http3

import (
"bytes"
"errors"
"fmt"
"io"

Expand All @@ -11,6 +12,10 @@ import (
. "github.com/onsi/gomega"
)

type errReader struct{ err error }

func (e errReader) Read([]byte) (int, error) { return 0, e.err }

var _ = Describe("Frames", func() {
appendVarInt := func(b []byte, val uint64) []byte {
buf := &bytes.Buffer{}
Expand Down Expand Up @@ -189,7 +194,8 @@ var _ = Describe("Frames", func() {
buf.Write(customFrameContents)

var called bool
_, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) {
_, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
Expect(ft).To(BeEquivalentTo(1337))
called = true
b := make([]byte, 3)
Expand All @@ -202,6 +208,19 @@ var _ = Describe("Frames", func() {
Expect(called).To(BeTrue())
})

It("passes on errors that occur when reading the frame type", func() {
testErr := errors.New("test error")
var called bool
_, err := parseNextFrame(errReader{err: testErr}, func(ft FrameType, e error) (hijacked bool, err error) {
Expect(e).To(MatchError(testErr))
Expect(ft).To(BeZero())
called = true
return true, nil
})
Expect(err).To(MatchError(errHijacked))
Expect(called).To(BeTrue())
})

It("reads a frame without hijacking the stream", func() {
buf := &bytes.Buffer{}
quicvarint.Write(buf, 1337)
Expand All @@ -211,7 +230,8 @@ var _ = Describe("Frames", func() {
buf.WriteString("foobar")

var called bool
frame, err := parseNextFrame(buf, func(ft FrameType) (hijacked bool, err error) {
frame, err := parseNextFrame(buf, func(ft FrameType, e error) (hijacked bool, err error) {
Expect(e).ToNot(HaveOccurred())
Expect(ft).To(BeEquivalentTo(1337))
called = true
b := make([]byte, len(customFrameContents))
Expand Down
8 changes: 6 additions & 2 deletions http3/roundtrip.go
Expand Up @@ -53,13 +53,17 @@ type RoundTripper struct {

// When set, this callback is called for the first unknown frame parsed on a bidirectional stream.
// It is called right after parsing the frame type.
// If parsing the frame type fails, the error is passed to the callback.
// In that case, the frame type will not be set.
// Callers can either process the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)

// When set, this callback is called for unknown unidirectional stream of unknown stream type.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
// If parsing the stream type fails, the error is passed to the callback.
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)

// Dial specifies an optional dial function for creating QUIC
// connections for requests.
Expand Down
17 changes: 12 additions & 5 deletions http3/server.go
Expand Up @@ -178,13 +178,17 @@ type Server struct {

// StreamHijacker, when set, is called for the first unknown frame parsed on a bidirectional stream.
// It is called right after parsing the frame type.
// If parsing the frame type fails, the error is passed to the callback.
// In that case, the frame type will not be set.
// Callers can either process the frame and return control of the stream back to HTTP/3
// (by returning hijacked false).
// Alternatively, callers can take over the QUIC stream (by returning hijacked true).
StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error)
StreamHijacker func(FrameType, quic.Connection, quic.Stream, error) (hijacked bool, err error)

// UniStreamHijacker, when set, is called for unknown unidirectional stream of unknown stream type.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream) (hijacked bool)
// If parsing the stream type fails, the error is passed to the callback.
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.Connection, quic.ReceiveStream, error) (hijacked bool)

mutex sync.RWMutex
listeners map[*quic.EarlyListener]listenerInfo
Expand Down Expand Up @@ -457,6 +461,9 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
go func(str quic.ReceiveStream) {
streamType, err := quicvarint.Read(quicvarint.NewReader(str))
if err != nil {
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, err) {
return
}
s.logger.Debugf("reading stream type on stream %d failed: %s", str.StreamID(), err)
return
}
Expand All @@ -471,7 +478,7 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) {
conn.CloseWithError(quic.ApplicationErrorCode(errorStreamCreationError), "")
return
default:
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str) {
if s.UniStreamHijacker != nil && s.UniStreamHijacker(StreamType(streamType), conn, str, nil) {
return
}
str.CancelRead(quic.StreamErrorCode(errorStreamCreationError))
Expand Down Expand Up @@ -510,8 +517,8 @@ func (s *Server) maxHeaderBytes() uint64 {
func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *qpack.Decoder, onFrameError func()) requestError {
var ufh unknownFrameHandlerFunc
if s.StreamHijacker != nil {
ufh = func(ft FrameType) (processed bool, err error) {
return s.StreamHijacker(ft, conn, str)
ufh = func(ft FrameType, e error) (processed bool, err error) {
return s.StreamHijacker(ft, conn, str, e)
}
}
frame, err := parseNextFrame(str, ufh)
Expand Down
31 changes: 29 additions & 2 deletions http3/server_test.go
Expand Up @@ -280,7 +280,8 @@ var _ = Describe("Server", func() {

It("hijacks an unidirectional stream of unknown stream type", func() {
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return true
}
Expand All @@ -301,9 +302,35 @@ var _ = Describe("Server", func() {
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("handles errors that occur when reading the stream type", func() {
testErr := errors.New("test error")
done := make(chan struct{})
unknownStr := mockquic.NewMockStream(mockCtrl)
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, str quic.ReceiveStream, err error) bool {
defer close(done)
Expect(st).To(BeZero())
Expect(str).To(Equal(unknownStr))
Expect(err).To(MatchError(testErr))
return true
}

unknownStr.EXPECT().Read(gomock.Any()).DoAndReturn(func([]byte) (int, error) { return 0, testErr })
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
return unknownStr, nil
})
conn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) {
<-testDone
return nil, errors.New("test done")
})
s.handleConn(conn)
Eventually(done).Should(BeClosed())
time.Sleep(scaleDuration(20 * time.Millisecond)) // don't EXPECT any calls to conn.CloseWithError
})

It("cancels reading when hijacker didn't hijack an unidirectional stream", func() {
streamTypeChan := make(chan StreamType, 1)
s.UniStreamHijacker = func(st StreamType, c quic.Connection, rs quic.ReceiveStream) bool {
s.UniStreamHijacker = func(st StreamType, _ quic.Connection, _ quic.ReceiveStream, err error) bool {
Expect(err).ToNot(HaveOccurred())
streamTypeChan <- st
return false
}
Expand Down

0 comments on commit 9afa9fb

Please sign in to comment.