diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index aca5a77e42..a015a64755 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -115,10 +115,28 @@ func (c *connection) connect(ctx context.Context) { } defer close(c.connectDone) + // Create separate contexts for dialing a connection and doing the MongoDB/auth handshakes. + // + // handshakeCtx is simply a cancellable version of ctx because there's no default timeout that needs to be applied + // to the full handshake. The cancellation allows consumers to bail out early when dialing a connection if it's no + // longer required. This is done in lock because it accesses the shared cancelConnectContext field. + // + // dialCtx is equal to handshakeCtx if connectTimeoutMS=0. Otherwise, it is derived from handshakeCtx so the + // cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket + // establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid + // holding the lock longer than necessary. c.connectContextMutex.Lock() - ctx, c.cancelConnectContext = context.WithCancel(ctx) + var handshakeCtx context.Context + handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx) c.connectContextMutex.Unlock() + dialCtx := handshakeCtx + var dialCancel context.CancelFunc + if c.config.connectTimeout != 0 { + dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout) + defer dialCancel() + } + defer func() { var cancelFn context.CancelFunc @@ -137,7 +155,7 @@ func (c *connection) connect(ctx context.Context) { // Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case. var err error var tempNc net.Conn - tempNc, err = c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String()) + tempNc, err = c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String()) if err != nil { c.processInitializationError(err) return @@ -153,7 +171,7 @@ func (c *connection) connect(ctx context.Context) { Cache: c.config.ocspCache, DisableEndpointChecking: c.config.disableOCSPEndpointCheck, } - tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts) + tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts) if err != nil { c.processInitializationError(err) return @@ -179,13 +197,13 @@ func (c *connection) connect(ctx context.Context) { var handshakeInfo driver.HandshakeInformation handshakeStartTime := time.Now() handshakeConn := initConnection{c} - handshakeInfo, err = handshaker.GetHandshakeInformation(ctx, c.addr, handshakeConn) + handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn) if err == nil { // We only need to retain the Description field as the connection's description. The authentication-related // fields in handshakeInfo are tracked by the handshaker if necessary. c.desc = handshakeInfo.Description c.isMasterRTT = time.Since(handshakeStartTime) - err = handshaker.FinishHandshake(ctx, handshakeConn) + err = handshaker.FinishHandshake(handshakeCtx, handshakeConn) } // We have a failed handshake here diff --git a/x/mongo/driver/topology/connection_options.go b/x/mongo/driver/topology/connection_options.go index 69b2f5e4c6..0895e56508 100644 --- a/x/mongo/driver/topology/connection_options.go +++ b/x/mongo/driver/topology/connection_options.go @@ -71,7 +71,7 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) { } if cfg.dialer == nil { - cfg.dialer = &net.Dialer{Timeout: cfg.connectTimeout} + cfg.dialer = &net.Dialer{} } return cfg, nil diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 1e194a0cb9..3684bd748b 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -219,7 +219,7 @@ func TestConnection(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var sentCfg *tls.Config - var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn { + var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn { sentCfg = cfg return tls.Client(nc, cfg) } @@ -252,6 +252,143 @@ func TestConnection(t *testing.T) { } }) }) + t.Run("connectTimeout is applied correctly", func(t *testing.T) { + testCases := []struct { + name string + contextTimeout time.Duration + connectTimeout time.Duration + maxConnectTime time.Duration + }{ + // The timeout to dial a connection should be min(context timeout, connectTimeoutMS), so 1ms for + // both of the tests declared below. Both tests also specify a 10ms max connect time to provide + // a large buffer for lag and avoid test flakiness. + + {"context timeout is lower", 1 * time.Millisecond, 100 * time.Millisecond, 10 * time.Millisecond}, + {"connect timeout is lower", 100 * time.Millisecond, 1 * time.Millisecond, 10 * time.Millisecond}, + } + + for _, tc := range testCases { + t.Run("timeout applied to socket establishment: "+tc.name, func(t *testing.T) { + // Ensure the initial connection dial can be timed out and the connection propagates the error + // from the dialer in this case. + + connOpts := []ConnectionOption{ + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) { + <-ctx.Done() + return nil, ctx.Err() + }) + }), + WithConnectTimeout(func(time.Duration) time.Duration { + return tc.connectTimeout + }), + } + conn, err := newConnection("", connOpts...) + assert.Nil(t, err, "newConnection error: %v", err) + + ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout) + defer cancel() + var connectErr error + callback := func() { + conn.connect(ctx) + connectErr = conn.wait() + } + assert.Soon(t, callback, tc.maxConnectTime) + + ce, ok := connectErr.(ConnectionError) + assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{}) + assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v", + context.DeadlineExceeded, ce.Unwrap()) + }) + t.Run("timeout applied to TLS handshake: "+tc.name, func(t *testing.T) { + // Ensure the TLS handshake can be timed out and the connection propagates the error from the + // tlsConn in this case. + + var hangingTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn { + tlsConn := tls.Client(nc, cfg) + return newHangingTLSConn(tlsConn, tc.maxConnectTime) + } + + connOpts := []ConnectionOption{ + WithConnectTimeout(func(time.Duration) time.Duration { + return tc.connectTimeout + }), + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + return &net.TCPConn{}, nil + }) + }), + WithTLSConfig(func(*tls.Config) *tls.Config { + return &tls.Config{} + }), + withTLSConnectionSource(func(tlsConnectionSource) tlsConnectionSource { + return hangingTLSConnectionSource + }), + } + conn, err := newConnection("", connOpts...) + assert.Nil(t, err, "newConnection error: %v", err) + + ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout) + defer cancel() + var connectErr error + callback := func() { + conn.connect(ctx) + connectErr = conn.wait() + } + assert.Soon(t, callback, tc.maxConnectTime) + + ce, ok := connectErr.(ConnectionError) + assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{}) + assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v", + context.DeadlineExceeded, ce.Unwrap()) + }) + t.Run("timeout is not applied to handshaker: "+tc.name, func(t *testing.T) { + // Ensure that no additional timeout is applied to the handshake after the connection has been + // established. + + var getInfoCtx, finishCtx context.Context + handshaker := &testHandshaker{ + getHandshakeInformation: func(ctx context.Context, _ address.Address, _ driver.Connection) (driver.HandshakeInformation, error) { + getInfoCtx = ctx + return driver.HandshakeInformation{}, nil + }, + finishHandshake: func(ctx context.Context, _ driver.Connection) error { + finishCtx = ctx + return nil + }, + } + + connOpts := []ConnectionOption{ + WithConnectTimeout(func(time.Duration) time.Duration { + return tc.connectTimeout + }), + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + return &net.TCPConn{}, nil + }) + }), + WithHandshaker(func(Handshaker) Handshaker { + return handshaker + }), + } + conn, err := newConnection("", connOpts...) + assert.Nil(t, err, "newConnection error: %v", err) + + bgCtx := context.Background() + conn.connect(bgCtx) + err = conn.wait() + assert.Nil(t, err, "connect error: %v", err) + + assertNoContextTimeout := func(t *testing.T, ctx context.Context) { + t.Helper() + dl, ok := ctx.Deadline() + assert.False(t, ok, "expected context to have no deadline, but got deadline %v", dl) + } + assertNoContextTimeout(t, getInfoCtx) + assertNoContextTimeout(t, finishCtx) + }) + } + }) }) t.Run("writeWireMessage", func(t *testing.T) { t.Run("closed connection", func(t *testing.T) { @@ -993,3 +1130,24 @@ func (t *testCancellationListener) assertMethodsCalled(testingT *testing.T, numL assert.Equal(testingT, numStopListening, t.numStopListening, "expected StopListening to be called %d times, got %d", numListen, t.numListen) } + +// hangingTLSConn is an implementation of tlsConn that wraps the tls.Conn type and overrides the Handshake function to +// sleep for a fixed amount of time. +type hangingTLSConn struct { + *tls.Conn + sleepTime time.Duration +} + +var _ tlsConn = (*hangingTLSConn)(nil) + +func newHangingTLSConn(conn *tls.Conn, sleepTime time.Duration) *hangingTLSConn { + return &hangingTLSConn{ + Conn: conn, + sleepTime: sleepTime, + } +} + +func (h *hangingTLSConn) Handshake() error { + time.Sleep(h.sleepTime) + return h.Conn.Handshake() +} diff --git a/x/mongo/driver/topology/tls_connection_source.go b/x/mongo/driver/topology/tls_connection_source.go index e67a049307..718a9abbde 100644 --- a/x/mongo/driver/topology/tls_connection_source.go +++ b/x/mongo/driver/topology/tls_connection_source.go @@ -11,16 +11,26 @@ import ( "net" ) +type tlsConn interface { + net.Conn + Handshake() error + ConnectionState() tls.ConnectionState +} + +var _ tlsConn = (*tls.Conn)(nil) + type tlsConnectionSource interface { - Client(net.Conn, *tls.Config) *tls.Conn + Client(net.Conn, *tls.Config) tlsConn } -type tlsConnectionSourceFn func(net.Conn, *tls.Config) *tls.Conn +type tlsConnectionSourceFn func(net.Conn, *tls.Config) tlsConn + +var _ tlsConnectionSource = (tlsConnectionSourceFn)(nil) -func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) *tls.Conn { +func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) tlsConn { return t(nc, cfg) } -var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn { +var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn { return tls.Client(nc, cfg) }