Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

respect the request context when dialing #3359

Merged
merged 2 commits into from Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 19 additions & 23 deletions http3/client.go
Expand Up @@ -34,7 +34,9 @@ var defaultQuicConfig = &quic.Config{
Versions: []protocol.VersionNumber{protocol.VersionTLS},
}

var dialAddr = quic.DialAddrEarly
type dialFunc func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)

var dialAddr = quic.DialAddrEarlyContext

type roundTripperOpts struct {
DisableCompression bool
Expand All @@ -49,7 +51,7 @@ type client struct {
opts *roundTripperOpts

dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
dialer dialFunc
handshakeErr error

requestWriter *requestWriter
Expand All @@ -62,24 +64,18 @@ type client struct {
logger utils.Logger
}

func newClient(
hostname string,
tlsConf *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error),
) (*client, error) {
if quicConfig == nil {
quicConfig = defaultQuicConfig.Clone()
} else if len(quicConfig.Versions) == 0 {
quicConfig = quicConfig.Clone()
quicConfig.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]}
func newClient(hostname string, tlsConf *tls.Config, opts *roundTripperOpts, conf *quic.Config, dialer dialFunc) (*client, error) {
if conf == nil {
conf = defaultQuicConfig.Clone()
} else if len(conf.Versions) == 0 {
conf = conf.Clone()
conf.Versions = []quic.VersionNumber{defaultQuicConfig.Versions[0]}
}
if len(quicConfig.Versions) != 1 {
if len(conf.Versions) != 1 {
return nil, errors.New("can only use a single QUIC version for dialing a HTTP/3 connection")
}
quicConfig.MaxIncomingStreams = -1 // don't allow any bidirectional streams
quicConfig.EnableDatagrams = opts.EnableDatagram
conf.MaxIncomingStreams = -1 // don't allow any bidirectional streams
conf.EnableDatagrams = opts.EnableDatagram
logger := utils.DefaultLogger.WithPrefix("h3 client")

if tlsConf == nil {
Expand All @@ -88,26 +84,26 @@ func newClient(
tlsConf = tlsConf.Clone()
}
// Replace existing ALPNs by H3
tlsConf.NextProtos = []string{versionToALPN(quicConfig.Versions[0])}
tlsConf.NextProtos = []string{versionToALPN(conf.Versions[0])}

return &client{
hostname: authorityAddr("https", hostname),
tlsConf: tlsConf,
requestWriter: newRequestWriter(logger),
decoder: qpack.NewDecoder(func(hf qpack.HeaderField) {}),
config: quicConfig,
config: conf,
opts: opts,
dialer: dialer,
logger: logger,
}, nil
}

func (c *client) dial() error {
func (c *client) dial(ctx context.Context) error {
var err error
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
c.session, err = c.dialer(ctx, "udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
c.session, err = dialAddr(ctx, c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
Expand Down Expand Up @@ -212,7 +208,7 @@ func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
}

c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
c.handshakeErr = c.dial(req.Context())
})

if c.handshakeErr != nil {
Expand Down
35 changes: 17 additions & 18 deletions http3/client_test.go
Expand Up @@ -12,13 +12,13 @@ import (
"net/http"
"time"

"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go"
mockquic "github.com/lucas-clemente/quic-go/internal/mocks/quic"
"github.com/lucas-clemente/quic-go/quicvarint"

"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/quicvarint"

"github.com/golang/mock/gomock"
"github.com/marten-seemann/qpack"

. "github.com/onsi/ginkgo"
Expand Down Expand Up @@ -65,7 +65,7 @@ var _ = Describe("Client", func() {
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
var dialAddrCalled bool
dialAddr = func(_ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
dialAddr = func(_ context.Context, _ string, tlsConf *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
Expect(quicConf).To(Equal(defaultQuicConfig))
Expect(tlsConf.NextProtos).To(Equal([]string{nextProtoH3}))
Expect(quicConf.Versions).To(Equal([]protocol.VersionNumber{protocol.Version1}))
Expand All @@ -80,7 +80,7 @@ var _ = Describe("Client", func() {
client, err := newClient("quic.clemente.io", nil, &roundTripperOpts{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
var dialAddrCalled bool
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
dialAddr = func(_ context.Context, hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
Expect(hostname).To(Equal("quic.clemente.io:443"))
dialAddrCalled = true
return nil, errors.New("test done")
Expand All @@ -100,12 +100,8 @@ var _ = Describe("Client", func() {
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, nil)
Expect(err).ToNot(HaveOccurred())
var dialAddrCalled bool
dialAddr = func(
hostname string,
tlsConfP *tls.Config,
quicConfP *quic.Config,
) (quic.EarlySession, error) {
Expect(hostname).To(Equal("localhost:1337"))
dialAddr = func(_ context.Context, host string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
Expect(host).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal(tlsConf.ServerName))
Expect(tlsConfP.NextProtos).To(Equal([]string{nextProtoH3}))
Expect(quicConfP.MaxIdleTimeout).To(Equal(quicConf.MaxIdleTimeout))
Expand All @@ -122,8 +118,11 @@ var _ = Describe("Client", func() {
testErr := errors.New("test done")
tlsConf := &tls.Config{ServerName: "foo.bar"}
quicConf := &quic.Config{MaxIdleTimeout: 1337 * time.Second}
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel()
var dialerCalled bool
dialer := func(network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
dialer := func(ctxP context.Context, network, address string, tlsConfP *tls.Config, quicConfP *quic.Config) (quic.EarlySession, error) {
Expect(ctxP).To(Equal(ctx))
Expect(network).To(Equal("udp"))
Expect(address).To(Equal("localhost:1337"))
Expect(tlsConfP.ServerName).To(Equal("foo.bar"))
Expand All @@ -133,7 +132,7 @@ var _ = Describe("Client", func() {
}
client, err := newClient("localhost:1337", tlsConf, &roundTripperOpts{}, quicConf, dialer)
Expect(err).ToNot(HaveOccurred())
_, err = client.RoundTrip(req)
_, err = client.RoundTrip(req.WithContext(ctx))
Expect(err).To(MatchError(testErr))
Expect(dialerCalled).To(BeTrue())
})
Expand All @@ -142,7 +141,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("handshake error")
client, err := newClient("localhost:1337", nil, &roundTripperOpts{EnableDatagram: true}, nil, nil)
Expect(err).ToNot(HaveOccurred())
dialAddr = func(hostname string, _ *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
dialAddr = func(_ context.Context, _ string, _ *tls.Config, quicConf *quic.Config) (quic.EarlySession, error) {
Expect(quicConf.EnableDatagrams).To(BeTrue())
return nil, testErr
}
Expand All @@ -154,7 +153,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("handshake error")
client, err := newClient("localhost:1337", nil, &roundTripperOpts{}, nil, nil)
Expect(err).ToNot(HaveOccurred())
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
return nil, testErr
}
_, err = client.RoundTrip(req)
Expand All @@ -179,7 +178,7 @@ var _ = Describe("Client", func() {
testErr := errors.New("handshake error")
req, err := http.NewRequest("masque", "masque://quic.clemente.io:1337/foobar.html", nil)
Expect(err).ToNot(HaveOccurred())
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) {
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
return nil, testErr
}
_, err = client.RoundTrip(req)
Expand All @@ -206,7 +205,7 @@ var _ = Describe("Client", func() {
sess.EXPECT().OpenUniStream().Return(controlStr, nil)
sess.EXPECT().HandshakeComplete().Return(handshakeCtx)
sess.EXPECT().OpenStreamSync(gomock.Any()).Return(nil, errors.New("done"))
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil }
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return sess, nil }
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
Expand Down Expand Up @@ -453,7 +452,7 @@ var _ = Describe("Client", func() {
<-testDone
return nil, errors.New("test done")
})
dialAddr = func(hostname string, _ *tls.Config, _ *quic.Config) (quic.EarlySession, error) { return sess, nil }
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) { return sess, nil }
var err error
request, err = http.NewRequest("GET", "https://quic.clemente.io:1337/file1.dat", nil)
Expect(err).ToNot(HaveOccurred())
Expand Down
7 changes: 4 additions & 3 deletions http3/roundtrip.go
@@ -1,6 +1,7 @@
package http3

import (
"context"
"crypto/tls"
"errors"
"fmt"
Expand All @@ -9,7 +10,7 @@ import (
"strings"
"sync"

quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go"

"golang.org/x/net/http/httpguts"
)
Expand Down Expand Up @@ -48,8 +49,8 @@ type RoundTripper struct {

// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddrEarly will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)
// If Dial is nil, quic.DialAddrEarlyContext will be used.
Dial func(ctx context.Context, network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error)

// MaxResponseHeaderBytes specifies a limit on how many response bytes are
// allowed in the server's response header.
Expand Down
6 changes: 3 additions & 3 deletions http3/roundtrip_test.go
Expand Up @@ -82,7 +82,7 @@ var _ = Describe("RoundTripper", func() {
BeforeEach(func() {
session = mockquic.NewMockEarlySession(mockCtrl)
origDialAddr = dialAddr
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
dialAddr = func(context.Context, string, *tls.Config, *quic.Config) (quic.EarlySession, error) {
// return an error when trying to open a stream
// we don't want to test all the dial logic here, just that dialing happens at all
return session, nil
Expand Down Expand Up @@ -115,7 +115,7 @@ var _ = Describe("RoundTripper", func() {
It("uses the quic.Config, if provided", func() {
config := &quic.Config{HandshakeIdleTimeout: time.Millisecond}
var receivedConfig *quic.Config
dialAddr = func(addr string, tlsConf *tls.Config, config *quic.Config) (quic.EarlySession, error) {
dialAddr = func(_ context.Context, _ string, _ *tls.Config, config *quic.Config) (quic.EarlySession, error) {
receivedConfig = config
return nil, errors.New("handshake error")
}
Expand All @@ -127,7 +127,7 @@ var _ = Describe("RoundTripper", func() {

It("uses the custom dialer, if provided", func() {
var dialed bool
dialer := func(_, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
dialer := func(_ context.Context, _, _ string, tlsCfgP *tls.Config, cfg *quic.Config) (quic.EarlySession, error) {
dialed = true
return nil, errors.New("handshake error")
}
Expand Down