Skip to content

Commit

Permalink
http3: use the connection, not the stream context, on the server side
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed May 13, 2024
1 parent 60d4e96 commit 7a90dd7
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 19 deletions.
8 changes: 7 additions & 1 deletion http3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ func (c *SingleDestinationRoundTripper) Start() Connection {
func (c *SingleDestinationRoundTripper) init() {
c.decoder = qpack.NewDecoder(func(hf qpack.HeaderField) {})
c.requestWriter = newRequestWriter()
c.hconn = newConnection(c.Connection, c.EnableDatagrams, protocol.PerspectiveClient, c.Logger)
c.hconn = newConnection(
c.Connection.Context(),
c.Connection,
c.EnableDatagrams,
protocol.PerspectiveClient,
c.Logger,
)
// send the SETTINGs frame, using 0-RTT data, if possible
go func() {
if err := c.setupConn(c.hconn); err != nil {
Expand Down
6 changes: 6 additions & 0 deletions http3/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Connection interface {

type connection struct {
quic.Connection
ctx context.Context

perspective protocol.Perspective
logger *slog.Logger
Expand All @@ -53,20 +54,23 @@ type connection struct {
}

func newConnection(
ctx context.Context,
quicConn quic.Connection,
enableDatagrams bool,
perspective protocol.Perspective,
logger *slog.Logger,
) *connection {
c := &connection{
Connection: quicConn,
ctx: ctx,
perspective: perspective,
logger: logger,
enableDatagrams: enableDatagrams,
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
receivedSettings: make(chan struct{}),
streams: make(map[protocol.StreamID]*datagrammer),
}

return c
}

Expand Down Expand Up @@ -280,3 +284,5 @@ func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSetti
// Settings returns the settings received on this connection.
// It is only valid to call this function after the channel returned by ReceivedSettings was closed.
func (c *connection) Settings() *Settings { return c.settings }

func (c *connection) Context() context.Context { return c.ctx }
31 changes: 17 additions & 14 deletions http3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,8 @@ type Server struct {
// In that case, the stream type will not be set.
UniStreamHijacker func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)

// ConnContext optionally specifies a function that modifies
// the context used for a new connection c. The provided ctx
// has a ServerContextKey value.
// ConnContext optionally specifies a function that modifies the context used for a new connection c.
// The provided ctx has a ServerContextKey value.
ConnContext func(ctx context.Context, c quic.Connection) context.Context

Logger *slog.Logger
Expand Down Expand Up @@ -436,7 +435,19 @@ func (s *Server) handleConn(conn quic.Connection) error {
}).Append(b)
str.Write(b)

ctx := conn.Context()
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
if s.ConnContext != nil {
ctx = s.ConnContext(ctx, conn)
if ctx == nil {
panic("http3: ConnContext returned nil")
}
}

hconn := newConnection(
ctx,
conn,
s.EnableDatagrams,
protocol.PerspectiveServer,
Expand Down Expand Up @@ -533,17 +544,9 @@ func (s *Server) handleRequest(conn *connection, str quic.Stream, datagrams *dat
s.Logger.Debug("handling request", "method", req.Method, "host", req.Host, "uri", req.RequestURI)
}

ctx := str.Context()
ctx = context.WithValue(ctx, ServerContextKey, s)
ctx = context.WithValue(ctx, http.LocalAddrContextKey, conn.LocalAddr())
ctx = context.WithValue(ctx, RemoteAddrContextKey, conn.RemoteAddr())
if s.ConnContext != nil {
ctx = s.ConnContext(ctx, conn.Connection)
if ctx == nil {
panic("http3: ConnContext returned nil")
}
}
req = req.WithContext(ctx)
// TODO(4508): this context needs to be cancelled when the client cancels the request
req = req.WithContext(conn.Context())

r := newResponseWriter(hstr, conn, req.Method == http.MethodHead, s.Logger)
handler := s.Handler
if handler == nil {
Expand Down
44 changes: 40 additions & 4 deletions integrationtests/self/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ var _ = Describe("HTTP tests", func() {
mux.HandleFunc("/cancel", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
defer close(handlerCalled)
// TODO(4508): check for request context cancellations
for {
if _, err := w.Write([]byte("foobar")); err != nil {
Expect(r.Context().Done()).To(BeClosed())
var http3Err *http3.Error
Expect(errors.As(err, &http3Err)).To(BeTrue())
Expect(http3Err.ErrorCode).To(Equal(http3.ErrCode(0x10c)))
Expand Down Expand Up @@ -570,7 +570,7 @@ var _ = Describe("HTTP tests", func() {
tracingID = c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
return ctx
}
mux.HandleFunc("/conn-context", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/http3-conn-context", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
v, ok := r.Context().Value(ctxKey(0)).(string)
Expect(ok).To(BeTrue())
Expand All @@ -589,9 +589,45 @@ var _ = Describe("HTTP tests", func() {
Expect(id).To(Equal(tracingID))
})

resp, err := client.Get(fmt.Sprintf("https://localhost:%d/conn-context", port))
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/http3-conn-context", port))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(200))
Expect(resp.StatusCode).To(Equal(http.StatusOK))
})

It("uses the QUIC connection context", func() {
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{
Conn: conn,
ConnContext: func() context.Context {
//nolint:staticcheck
return context.WithValue(context.Background(), "foo", "bar")
},
}
defer tr.Close()
tlsConf := getTLSConfig()
tlsConf.NextProtos = []string{http3.NextProtoH3}
ln, err := tr.Listen(tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer ln.Close()

mux.HandleFunc("/quic-conn-context", func(w http.ResponseWriter, r *http.Request) {
defer GinkgoRecover()
v, ok := r.Context().Value("foo").(string)
Expect(ok).To(BeTrue())
Expect(v).To(Equal("bar"))
})
go func() {
defer GinkgoRecover()
c, err := ln.Accept(context.Background())
Expect(err).ToNot(HaveOccurred())
server.ServeQUICConn(c)
}()

resp, err := client.Get(fmt.Sprintf("https://localhost:%d/quic-conn-context", conn.LocalAddr().(*net.UDPAddr).Port))
Expect(err).ToNot(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
})

It("checks the server's settings", func() {
Expand Down

0 comments on commit 7a90dd7

Please sign in to comment.