diff --git a/http3/client.go b/http3/client.go index 861eaf0ab70..ee2db7ab2e8 100644 --- a/http3/client.go +++ b/http3/client.go @@ -34,7 +34,9 @@ var defaultQuicConfig = &quic.Config{ Versions: []protocol.VersionNumber{protocol.VersionTLS}, } -var dialAddr = quic.DialAddrEarly +type dialFunc func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) + +var dialAddr = quic.DialAddrEarlyContext type roundTripperOpts struct { DisableCompression bool @@ -49,7 +51,7 @@ type client struct { opts *roundTripperOpts dialOnce sync.Once - dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) + dialer dialFunc handshakeErr error requestWriter *requestWriter @@ -62,24 +64,18 @@ type client struct { logger utils.Logger } -func newClient( - hostname string, - tlsConf *tls.Config, - opts *roundTripperOpts, - quicConfig *quic.Config, - dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error), -) (*client, error) { - if quicConfig == nil { - quicConfig = defaultQuicConfig.Clone() - } else if len(quicConfig.Versions) == 0 { - quicConfig = quicConfig.Clone() - quicConfig.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} +func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) { + if conf == nil { + conf = defaultQuicConfig.Clone() + } else if len(conf.Versions) == 0 { + conf = conf.Clone() + conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]} } - if len(quicConfig.Versions) != 1 { + if len(conf.Versions) != 1 { return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection") } - quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams - quicConfig.EnableDatagrams = opts.EnableDatagram + conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams + conf.EnableDatagrams = opts.EnableDatagram logger := utils.DefaultLogger.WithPrefix("h3 client") if tlsConf == nil { @@ -88,26 +84,26 @@ func newClient( tlsConf = tlsConf.Clone() } // Replace existing ALPNs by H3 - tlsConf.NextProtos = []string{versionToALPN(quicConfig.Versions[0])} + tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])} return &client{ hostname: authorityAddr("https", hostname), tlsConf: tlsConf, requestWriter: newRequestWriter(logger), decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}), - config: quicConfig, + config: conf, opts: opts, dialer: dialer, logger: logger, }, nil } -func (c *client) dial() error { +func (c *client) dial(ctx context.Context) error { var err error if c.dialer != nil { - c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config) + c.session, err = c.dialer(ctx, "udp", c.hostname, c.tlsConf, c.config) } else { - c.session, err = dialAddr(c.hostname, c.tlsConf, c.config) + c.session, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config) } if err != nil { return err @@ -212,7 +208,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { } c.dialOnce.Do(func() { - c.handshakeErr = c.dial() + c.handshakeErr = c.dial(req.Context()) }) if c.handshakeErr != nil { diff --git a/http3/client_test.go b/http3/client_test.go index 3f3b4a4b1f9..ed4ae6670b0 100644 --- a/http3/client_test.go +++ b/http3/client_test.go @@ -12,13 +12,13 @@ import ( "net/http" "time" - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go" mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic" - "github.com/lucas-clemente/quic-go/quicvarint" - "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" + "github.com/lucas-clemente/quic-go/quicvarint" + + "github.com/golang/mock/gomock" "github.com/marten-seemann/qpack" . "github.com/onsi/ginkgo" @@ -65,7 +65,7 @@ var _ = Describe("Client", func() { client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool - dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) { + dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) { Expect(quicConf).To(Equal(defaultQuicConfig)) Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3})) Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1})) @@ -80,7 +80,7 @@ var _ = Describe("Client", func() { client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { + dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { Expect(hostname).To(Equal("quic.clemente.io:443")) dialAddrCalled = true return nil, errors.New("test done") @@ -100,12 +100,8 @@ var _ = Describe("Client", func() { client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil) Expect(err).ToNot(HaveOccurred()) var dialAddrCalled bool - dialAddr = func( - hostname string, - tlsConfP *tls.Config, - quicConfP *quic.Config, - ) (quic.EarlySession, error) { - Expect(hostname).To(Equal("localhost:1337")) + dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) { + Expect(host).To(Equal("localhost:1337")) Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName)) Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3})) Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout)) @@ -122,8 +118,11 @@ var _ = Describe("Client", func() { testErr := errors.New("test done") tlsConf := &tls.Config{ServerName: "foo.bar"} quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second} + ctx, cancel := context.WithTimeout(context.Background(), time.Hour) + defer cancel() var dialerCalled bool - dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) { + dialer := func(ctxP context.Context, network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) { + Expect(ctxP).To(Equal(ctx)) Expect(network).To(Equal("udp")) Expect(address).To(Equal("localhost:1337")) Expect(tlsConfP.ServerName).To(Equal("foo.bar")) @@ -133,7 +132,7 @@ var _ = Describe("Client", func() { } client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer) Expect(err).ToNot(HaveOccurred()) - _, err = client.RoundTrip(req) + _, err = client.RoundTrip(req.WithContext(ctx)) Expect(err).To(MatchError(testErr)) Expect(dialerCalled).To(BeTrue()) }) @@ -142,7 +141,7 @@ var _ = Describe("Client", func() { testErr := errors.New("handshake error") client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil) Expect(err).ToNot(HaveOccurred()) - dialAddr = func(hostname string, _ *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) { + dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) { Expect(quicConf.EnableDatagrams).To(BeTrue()) return nil, testErr } @@ -154,7 +153,7 @@ var _ = Describe("Client", func() { testErr := errors.New("handshake error") client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil) Expect(err).ToNot(HaveOccurred()) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return nil, testErr } _, err = client.RoundTrip(req) @@ -179,7 +178,7 @@ var _ = Describe("Client", func() { testErr := errors.New("handshake error") req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil) Expect(err).ToNot(HaveOccurred()) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return nil, testErr } _, err = client.RoundTrip(req) @@ -206,7 +205,7 @@ var _ = Describe("Client", func() { sess.EXPECT().OpenUniStream().Return(controlStr, nil) sess.EXPECT().HandshakeComplete().Return(handshakeCtx) sess.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done")) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil } + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return sess, nil } var err error request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) @@ -453,7 +452,7 @@ var _ = Describe("Client", func() { <-testDone return nil, errors.New("test done") }) - dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil } + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return sess, nil } var err error request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil) Expect(err).ToNot(HaveOccurred()) diff --git a/http3/roundtrip.go b/http3/roundtrip.go index 8e6f943e93b..d301045a3c0 100644 --- a/http3/roundtrip.go +++ b/http3/roundtrip.go @@ -1,6 +1,7 @@ package http3 import ( + "context" "crypto/tls" "errors" "fmt" @@ -9,7 +10,7 @@ import ( "strings" "sync" - quic "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" "golang.org/x/net/http/httpguts" ) @@ -48,8 +49,8 @@ type RoundTripper struct { // Dial specifies an optional dial function for creating QUIC // connections for requests. - // If Dial is nil, quic.DialAddrEarly will be used. - Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) + // If Dial is nil, quic.DialAddrEarlyContext will be used. + Dial func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) // MaxResponseHeaderBytes specifies a limit on how many response bytes are // allowed in the server's response header. diff --git a/http3/roundtrip_test.go b/http3/roundtrip_test.go index 184889f1c35..ff5874c5fd2 100644 --- a/http3/roundtrip_test.go +++ b/http3/roundtrip_test.go @@ -82,7 +82,7 @@ var _ = Describe("RoundTripper", func() { BeforeEach(func() { session = mockquic.NewMockEarlySession(mockCtrl) origDialAddr = dialAddr - dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) { + dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { // return an error when trying to open a stream // we don't want to test all the dial logic here, just that dialing happens at all return session, nil @@ -115,7 +115,7 @@ var _ = Describe("RoundTripper", func() { It("uses the quic.Config, if provided", func() { config := &quic.Config{HandshakeIdleTimeout: time.Millisecond} var receivedConfig *quic.Config - dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) { + dialAddr = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlySession, error) { receivedConfig = config return nil, errors.New("handshake error") } @@ -127,7 +127,7 @@ var _ = Describe("RoundTripper", func() { It("uses the custom dialer, if provided", func() { var dialed bool - dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) { + dialer := func(_ context.Context, _, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) { dialed = true return nil, errors.New("handshake error") }