Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-1879 Apply connectTimeoutMS to TLS handshake #594

Merged
merged 3 commits into from Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 23 additions & 5 deletions x/mongo/driver/topology/connection.go
Expand Up @@ -115,10 +115,28 @@ func (c *connection) connect(ctx context.Context) {
}
defer close(c.connectDone)

// Create separate contexts for dialing a connection and doing the MongoDB/auth handshakes.
//
// handshakeCtx is simply a cancellable version of ctx because there's no default timeout that needs to be applied
// to the full handshake. The cancellation allows consumers to bail out early when dialing a connection if it's no
// longer required. This is done in lock because it accesses the shared cancelConnectContext field.
//
// dialCtx is equal to handshakeCtx if connectTimeoutMS=0. Otherwise, it is derived from handshakeCtx so the
// cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket
// establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid
// holding the lock longer than necessary.
divjotarora marked this conversation as resolved.
Show resolved Hide resolved
c.connectContextMutex.Lock()
ctx, c.cancelConnectContext = context.WithCancel(ctx)
var handshakeCtx context.Context
handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx)
c.connectContextMutex.Unlock()

dialCtx := handshakeCtx
var dialCancel context.CancelFunc
if c.config.connectTimeout != 0 {
dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout)
defer dialCancel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change. But IIUC, since this derived from handshakeCtx, it is not strictly necessary to cancel the dialCtx. Cancelling handshakeCtx would cancel any derived contexts. Is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same understanding. However, AFAIK it's good practice to defer all context.CancelFunc variables and this is happening once per connection handshake, so I don't think it's going to be prohibitive to do that here.

}

defer func() {
var cancelFn context.CancelFunc

Expand All @@ -137,7 +155,7 @@ func (c *connection) connect(ctx context.Context) {
// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
var err error
var tempNc net.Conn
tempNc, err = c.config.dialer.DialContext(ctx, c.addr.Network(), c.addr.String())
tempNc, err = c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
divjotarora marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
c.processInitializationError(err)
return
Expand All @@ -153,7 +171,7 @@ func (c *connection) connect(ctx context.Context) {
Cache: c.config.ocspCache,
DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
}
tlsNc, err := configureTLS(ctx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
if err != nil {
c.processInitializationError(err)
return
Expand All @@ -179,13 +197,13 @@ func (c *connection) connect(ctx context.Context) {
var handshakeInfo driver.HandshakeInformation
handshakeStartTime := time.Now()
handshakeConn := initConnection{c}
handshakeInfo, err = handshaker.GetHandshakeInformation(ctx, c.addr, handshakeConn)
handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn)
if err == nil {
// We only need to retain the Description field as the connection's description. The authentication-related
// fields in handshakeInfo are tracked by the handshaker if necessary.
c.desc = handshakeInfo.Description
c.isMasterRTT = time.Since(handshakeStartTime)
err = handshaker.FinishHandshake(ctx, handshakeConn)
err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
}

// We have a failed handshake here
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/connection_options.go
Expand Up @@ -70,7 +70,7 @@ func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
}

if cfg.dialer == nil {
cfg.dialer = &net.Dialer{Timeout: cfg.connectTimeout}
cfg.dialer = &net.Dialer{}
benjirewis marked this conversation as resolved.
Show resolved Hide resolved
}

return cfg, nil
Expand Down
160 changes: 159 additions & 1 deletion x/mongo/driver/topology/connection_test.go
Expand Up @@ -219,7 +219,7 @@ func TestConnection(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var sentCfg *tls.Config
var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
sentCfg = cfg
return tls.Client(nc, cfg)
}
Expand Down Expand Up @@ -252,6 +252,143 @@ func TestConnection(t *testing.T) {
}
})
})
t.Run("connectTimeout is applied correctly", func(t *testing.T) {
testCases := []struct {
name string
contextTimeout time.Duration
connectTimeout time.Duration
maxConnectTime time.Duration
}{
// The timeout to dial a connection should be min(context timeout, connectTimeoutMS), so 1ms for
// both of the tests declared below. Both tests also specify a 10ms max connect time to provide
// a large buffer for lag and avoid test flakiness.

{"context timeout is lower", 1 * time.Millisecond, 100 * time.Millisecond, 10 * time.Millisecond},
{"connect timeout is lower", 100 * time.Millisecond, 1 * time.Millisecond, 10 * time.Millisecond},
}

for _, tc := range testCases {
t.Run("timeout applied to socket establishment: "+tc.name, func(t *testing.T) {
// Ensure the initial connection dial can be timed out and the connection propagates the error
// from the dialer in this case.

connOpts := []ConnectionOption{
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
<-ctx.Done()
return nil, ctx.Err()
})
}),
WithConnectTimeout(func(time.Duration) time.Duration {
return tc.connectTimeout
}),
}
conn, err := newConnection("", connOpts...)
assert.Nil(t, err, "newConnection error: %v", err)

ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout)
defer cancel()
var connectErr error
callback := func() {
conn.connect(ctx)
connectErr = conn.wait()
}
assert.Soon(t, callback, tc.maxConnectTime)

ce, ok := connectErr.(ConnectionError)
assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v",
context.DeadlineExceeded, ce.Unwrap())
})
t.Run("timeout applied to TLS handshake: "+tc.name, func(t *testing.T) {
// Ensure the TLS handshake can be timed out and the connection propagates the error from the
// tlsConn in this case.

var hangingTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
tlsConn := tls.Client(nc, cfg)
return newHangingTLSConn(tlsConn, tc.maxConnectTime)
}

connOpts := []ConnectionOption{
WithConnectTimeout(func(time.Duration) time.Duration {
return tc.connectTimeout
}),
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return &net.TCPConn{}, nil
})
}),
WithTLSConfig(func(*tls.Config) *tls.Config {
return &tls.Config{}
}),
withTLSConnectionSource(func(tlsConnectionSource) tlsConnectionSource {
return hangingTLSConnectionSource
}),
}
conn, err := newConnection("", connOpts...)
assert.Nil(t, err, "newConnection error: %v", err)

ctx, cancel := context.WithTimeout(context.Background(), tc.contextTimeout)
defer cancel()
var connectErr error
callback := func() {
conn.connect(ctx)
connectErr = conn.wait()
}
assert.Soon(t, callback, tc.maxConnectTime)

ce, ok := connectErr.(ConnectionError)
assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v",
context.DeadlineExceeded, ce.Unwrap())
})
t.Run("timeout is not applied to handshaker: "+tc.name, func(t *testing.T) {
iwysiu marked this conversation as resolved.
Show resolved Hide resolved
// Ensure that no additional timeout is applied to the handshake after the connection has been
// established.

var getInfoCtx, finishCtx context.Context
handshaker := &testHandshaker{
getHandshakeInformation: func(ctx context.Context, _ address.Address, _ driver.Connection) (driver.HandshakeInformation, error) {
getInfoCtx = ctx
return driver.HandshakeInformation{}, nil
},
finishHandshake: func(ctx context.Context, _ driver.Connection) error {
finishCtx = ctx
return nil
},
}

connOpts := []ConnectionOption{
WithConnectTimeout(func(time.Duration) time.Duration {
return tc.connectTimeout
}),
WithDialer(func(Dialer) Dialer {
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
return &net.TCPConn{}, nil
})
}),
WithHandshaker(func(Handshaker) Handshaker {
return handshaker
}),
}
conn, err := newConnection("", connOpts...)
assert.Nil(t, err, "newConnection error: %v", err)

bgCtx := context.Background()
conn.connect(bgCtx)
err = conn.wait()
assert.Nil(t, err, "connect error: %v", err)

assertNoContextTimeout := func(t *testing.T, ctx context.Context) {
t.Helper()
dl, ok := ctx.Deadline()
assert.False(t, ok, "expected context to have no deadline, but got deadline %v", dl)
}
assertNoContextTimeout(t, getInfoCtx)
assertNoContextTimeout(t, finishCtx)
})
}
})
})
t.Run("writeWireMessage", func(t *testing.T) {
t.Run("closed connection", func(t *testing.T) {
Expand Down Expand Up @@ -993,3 +1130,24 @@ func (t *testCancellationListener) assertMethodsCalled(testingT *testing.T, numL
assert.Equal(testingT, numStopListening, t.numStopListening, "expected StopListening to be called %d times, got %d",
numListen, t.numListen)
}

// hangingTLSConn is an implementation of tlsConn that wraps the tls.Conn type and overrides the Handshake function to
divjotarora marked this conversation as resolved.
Show resolved Hide resolved
// sleep for a fixed amount of time.
type hangingTLSConn struct {
*tls.Conn
sleepTime time.Duration
}

var _ tlsConn = (*hangingTLSConn)(nil)
benjirewis marked this conversation as resolved.
Show resolved Hide resolved

func newHangingTLSConn(conn *tls.Conn, sleepTime time.Duration) *hangingTLSConn {
return &hangingTLSConn{
Conn: conn,
sleepTime: sleepTime,
}
}

func (h *hangingTLSConn) Handshake() error {
time.Sleep(h.sleepTime)
return h.Conn.Handshake()
}
18 changes: 14 additions & 4 deletions x/mongo/driver/topology/tls_connection_source.go
Expand Up @@ -11,16 +11,26 @@ import (
"net"
)

type tlsConn interface {
net.Conn
Handshake() error
ConnectionState() tls.ConnectionState
}

var _ tlsConn = (*tls.Conn)(nil)

type tlsConnectionSource interface {
Client(net.Conn, *tls.Config) *tls.Conn
Client(net.Conn, *tls.Config) tlsConn
}

type tlsConnectionSourceFn func(net.Conn, *tls.Config) *tls.Conn
type tlsConnectionSourceFn func(net.Conn, *tls.Config) tlsConn

var _ tlsConnectionSource = (tlsConnectionSourceFn)(nil)

func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) *tls.Conn {
func (t tlsConnectionSourceFn) Client(nc net.Conn, cfg *tls.Config) tlsConn {
return t(nc, cfg)
}

var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) *tls.Conn {
var defaultTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
return tls.Client(nc, cfg)
}