From 6dcd185afbb61edfb94fa7566f8928a52c7981e8 Mon Sep 17 00:00:00 2001 From: hareku Date: Wed, 20 Apr 2022 13:00:33 +0900 Subject: [PATCH] implement HTTP/3 unistream hijacking --- http3/client.go | 13 ++++++++++++- http3/roundtrip.go | 4 ++++ http3/server.go | 15 ++++++++++++++- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/http3/client.go b/http3/client.go index 27c4947b1d3..34bb1287fa6 100644 --- a/http3/client.go +++ b/http3/client.go @@ -44,6 +44,7 @@ type roundTripperOpts struct { MaxHeaderBytes int64 AdditionalSettings map[uint64]uint64 StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error) + UniStreamHijacker func(FrameType, quic.Connection, quic.ReceiveStream) (hijacked bool, err error) } // client is a HTTP3 client doing requests @@ -195,8 +196,18 @@ func (c *client) handleUnidirectionalStreams() { str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) return } - f, err := parseNextFrame(str, nil) + + var ufh unknownFrameHandlerFunc + if c.opts.UniStreamHijacker != nil { + ufh = func(ft FrameType) (processed bool, err error) { + return c.opts.UniStreamHijacker(ft, c.conn, str) + } + } + f, err := parseNextFrame(str, ufh) if err != nil { + if err == errHijacked { + return + } c.conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") return } diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 743ff0341af..09072182634 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -58,6 +58,9 @@ type RoundTripper struct { // Alternatively, callers can take over the QUIC stream (by returning hijacked true). StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error) + // When set, this callback is called for the first unknown frame parsed on a unidirectional stream. + UniStreamHijacker func(FrameType, quic.Connection, quic.ReceiveStream) (hijacked bool, err error) + // Dial specifies an optional dial function for creating QUIC // connections for requests. // If Dial is nil, quic.DialAddrEarlyContext will be used. @@ -154,6 +157,7 @@ func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTr DisableCompression: r.DisableCompression, MaxHeaderBytes: r.MaxResponseHeaderBytes, StreamHijacker: r.StreamHijacker, + UniStreamHijacker: r.UniStreamHijacker, }, r.QuicConfig, r.Dial, diff --git a/http3/server.go b/http3/server.go index e1d818acc67..47f7e1af98c 100644 --- a/http3/server.go +++ b/http3/server.go @@ -151,6 +151,9 @@ type Server struct { // Alternatively, callers can take over the QUIC stream (by returning hijacked true). StreamHijacker func(FrameType, quic.Connection, quic.Stream) (hijacked bool, err error) + // When set, this callback is called for the first unknown frame parsed on a unidirectional receive stream. + UniStreamHijacker func(FrameType, quic.Connection, quic.ReceiveStream) (hijacked bool, err error) + mutex sync.RWMutex listeners map[*quic.EarlyListener]listenerInfo @@ -424,8 +427,18 @@ func (s *Server) handleUnidirectionalStreams(conn quic.EarlyConnection) { str.CancelRead(quic.StreamErrorCode(errorStreamCreationError)) return } - f, err := parseNextFrame(str, nil) + + var ufh unknownFrameHandlerFunc + if s.UniStreamHijacker != nil { + ufh = func(ft FrameType) (processed bool, err error) { + return s.UniStreamHijacker(ft, conn, str) + } + } + f, err := parseNextFrame(str, ufh) if err != nil { + if err == errHijacked { + return + } conn.CloseWithError(quic.ApplicationErrorCode(errorFrameError), "") return }