Skip to content

Commit

Permalink
GODRIVER-1879 Apply connectTimeoutMS to TLS handshake (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
Divjot Arora committed Mar 8, 2021
1 parent 47f87bd commit 3cf67b9
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 11 deletions.
28 changes: 23 additions & 5 deletions x/mongo/driver/topology/connection.go
Expand Up @@ -104,10 +104,28 @@ func (c *connection) connect(ctx context.Context) {
}
defer close(c.connectDone)

// Create separate contexts for dialing a connection and doing the MongoDB/auth handshakes.
//
// handshakeCtx is simply a cancellable version of ctx because there's no default timeout that needs to be applied
// to the full handshake. The cancellation allows consumers to bail out early when dialing a connection if it's no
// longer required. This is done in lock because it accesses the shared cancelConnectContext field.
//
// dialCtx is equal to handshakeCtx if connectTimeoutMS=0. Otherwise, it is derived from handshakeCtx so the
// cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket
// establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid
// holding the lock longer than necessary.
c.connectContextMutex.Lock()
ctx, c.cancelConnectContext = context.WithCancel(ctx)
var handshakeCtx context.Context
handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx)
c.connectContextMutex.Unlock()

dialCtx := handshakeCtx
var dialCancel context.CancelFunc
if c.config.connectTimeout != 0 {
dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout)
defer dialCancel()
}

defer func() {
var cancelFn context.CancelFunc

Expand All @@ -126,7 +144,7 @@ func (c *connection) connect(ctx context.Context) {
// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
var err error
var tempNc net.Conn
tempNc, err = c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String())
tempNc, err = c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
if err != nil {
c.processInitializationError(err)
return
Expand All @@ -142,7 +160,7 @@ func (c *connection) connect(ctx context.Context) {
Cache: c.config.ocspCache,
DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
}
tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
if err != nil {
c.processInitializationError(err)
return
Expand All @@ -160,10 +178,10 @@ func (c *connection) connect(ctx context.Context) {

handshakeStartTime := time.Now()
handshakeConn := initConnection{c}
c.desc, err = handshaker.GetDescription(ctx, c.addr, handshakeConn)
c.desc, err = handshaker.GetDescription(handshakeCtx, c.addr, handshakeConn)
if err == nil {
c.isMasterRTT = time.Since(handshakeStartTime)
err = handshaker.FinishHandshake(ctx, handshakeConn)
err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
}
if err != nil {
c.processInitializationError(err)
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/connection_options.go
Expand Up @@ -69,7 +69,7 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
}

if cfg.dialer == nil {
cfg.dialer = &net.Dialer{Timeout: cfg.connectTimeout}
cfg.dialer = &net.Dialer{}
}

return cfg, nil
Expand Down
160 changes: 159 additions & 1 deletion x/mongo/driver/topology/connection_test.go
Expand Up @@ -195,7 +195,7 @@ func TestConnection(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var sentCfg *tls.Config
var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
sentCfg = cfg
return tls.Client(nc, cfg)
}
Expand Down Expand Up @@ -228,6 +228,143 @@ func TestConnection(t *testing.T) {
}
})
})
t.Run("connectTimeout is applied correctly", func(t *testing.T) {
testCases := []struct {
name string
contextTimeout time.Duration
connectTimeout time.Duration
maxConnectTime time.Duration
}{
// The timeout to dial a connection should be min(context timeout, connectTimeoutMS), so 1ms for
// both of the tests declared below. Both tests also specify a 10ms max connect time to provide
// a large buffer for lag and avoid test flakiness.

{"context timeout is lower", 1 * time.Millisecond, 100 * time.Millisecond, 10 * time.Millisecond},
{"connect timeout is lower", 100 * time.Millisecond, 1 * time.Millisecond, 10 * time.Millisecond},
}

for _, tc := range testCases {
t.Run("timeout applied to socket establishment: "+tc.name, func(t *testing.T) {
// Ensure the initial connection dial can be timed out and the connection propagates the error
// from the dialer in this case.

connOpts := []ConnectionOption{
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
<-ctx.Done()
return nil, ctx.Err()
})
}),
WithConnectTimeout(func(time.Duration) time.Duration {
return tc.connectTimeout
}),
}
conn, err := newConnection("", connOpts...)
assert.Nil(t, err, "newConnection error: %v", err)

ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout)
defer cancel()
var connectErr error
callback := func() {
conn.connect(ctx)
connectErr = conn.wait()
}
assert.Soon(t, callback, tc.maxConnectTime)

ce, ok := connectErr.(ConnectionError)
assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v",
context.DeadlineExceeded, ce.Unwrap())
})
t.Run("timeout applied to TLS handshake: "+tc.name, func(t *testing.T) {
// Ensure the TLS handshake can be timed out and the connection propagates the error from the
// tlsConn in this case.

var hangingTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
tlsConn := tls.Client(nc, cfg)
return newHangingTLSConn(tlsConn, tc.maxConnectTime)
}

connOpts := []ConnectionOption{
WithConnectTimeout(func(time.Duration) time.Duration {
return tc.connectTimeout
}),
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return &net.TCPConn{}, nil
})
}),
WithTLSConfig(func(*tls.Config) *tls.Config {
return &tls.Config{}
}),
withTLSConnectionSource(func(tlsConnectionSource) tlsConnectionSource {
return hangingTLSConnectionSource
}),
}
conn, err := newConnection("", connOpts...)
assert.Nil(t, err, "newConnection error: %v", err)

ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout)
defer cancel()
var connectErr error
callback := func() {
conn.connect(ctx)
connectErr = conn.wait()
}
assert.Soon(t, callback, tc.maxConnectTime)

ce, ok := connectErr.(ConnectionError)
assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v",
context.DeadlineExceeded, ce.Unwrap())
})
t.Run("timeout is not applied to handshaker: "+tc.name, func(t *testing.T) {
// Ensure that no additional timeout is applied to the handshake after the connection has been
// established.

var getInfoCtx, finishCtx context.Context
handshaker := &testHandshaker{
getDescription: func(ctx context.Context, _ address.Address, _ driver.Connection) (description.Server, error) {
getInfoCtx = ctx
return description.Server{}, nil
},
finishHandshake: func(ctx context.Context, _ driver.Connection) error {
finishCtx = ctx
return nil
},
}

connOpts := []ConnectionOption{
WithConnectTimeout(func(time.Duration) time.Duration {
return tc.connectTimeout
}),
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return &net.TCPConn{}, nil
})
}),
WithHandshaker(func(Handshaker) Handshaker {
return handshaker
}),
}
conn, err := newConnection("", connOpts...)
assert.Nil(t, err, "newConnection error: %v", err)

bgCtx := context.Background()
conn.connect(bgCtx)
err = conn.wait()
assert.Nil(t, err, "connect error: %v", err)

assertNoContextTimeout := func(t *testing.T, ctx context.Context) {
t.Helper()
dl, ok := ctx.Deadline()
assert.False(t, ok, "expected context to have no deadline, but got deadline %v", dl)
}
assertNoContextTimeout(t, getInfoCtx)
assertNoContextTimeout(t, finishCtx)
})
}
})
})
t.Run("writeWireMessage", func(t *testing.T) {
t.Run("closed connection", func(t *testing.T) {
Expand Down Expand Up @@ -689,3 +826,24 @@ func (d *dialer) lenclosed() int {
defer d.Unlock()
return len(d.closed)
}

// hangingTLSConn is an implementation of tlsConn that wraps the tls.Conn type and overrides the Handshake function to
// sleep for a fixed amount of time.
type hangingTLSConn struct {
*tls.Conn
sleepTime time.Duration
}

var _ tlsConn = (*hangingTLSConn)(nil)

func newHangingTLSConn(conn *tls.Conn, sleepTime time.Duration) *hangingTLSConn {
return &hangingTLSConn{
Conn: conn,
sleepTime: sleepTime,
}
}

func (h *hangingTLSConn) Handshake() error {
time.Sleep(h.sleepTime)
return h.Conn.Handshake()
}
18 changes: 14 additions & 4 deletions x/mongo/driver/topology/tls_connection_source.go
Expand Up @@ -11,16 +11,26 @@ import (
"net"
)

type tlsConn interface {
net.Conn
Handshake() error
ConnectionState() tls.ConnectionState
}

var _ tlsConn = (*tls.Conn)(nil)

type tlsConnectionSource interface {
Client(net.Conn, *tls.Config) *tls.Conn
Client(net.Conn, *tls.Config) tlsConn
}

type tlsConnectionSourceFn func(net.Conn, *tls.Config) *tls.Conn
type tlsConnectionSourceFn func(net.Conn, *tls.Config) tlsConn

var _ tlsConnectionSource = (tlsConnectionSourceFn)(nil)

func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) *tls.Conn {
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) tlsConn {
return t(nc, cfg)
}

var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
return tls.Client(nc, cfg)
}

0 comments on commit 3cf67b9

Please sign in to comment.