diff --git a/mongo/session.go b/mongo/session.go index a4f18baf01..6556b93c6a 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -200,10 +200,6 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi default: } - // End if context has timed out or been canceled, as retrying has no chance of success. - if ctx.Err() != nil { - return res, err - } if errorHasLabel(err, driver.TransientTransactionError) { continue } @@ -218,10 +214,9 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi CommitLoop: for { err = s.CommitTransaction(ctx) - // End when error is nil (transaction has been committed), or when context has timed out or been - // canceled, as retrying has no chance of success. - if err == nil || ctx.Err() != nil { - return res, err + // End when error is nil, as transaction has been committed. + if err == nil { + return res, nil } select { diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 699c63cfb9..2b80fc8d52 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -531,15 +531,17 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error { serviceID: startedInfo.serviceID, } - // Check if there's enough time to perform a round trip before the Context deadline. If ctx is - // a Timeout Context, use the 90th percentile RTT as a threshold. Otherwise, use the minimum observed - // RTT. - if deadline, ok := ctx.Deadline(); ok { + // Check for possible context error. If no context error, check if there's enough time to perform a + // round trip before the Context deadline. If ctx is a Timeout Context, use the 90th percentile RTT + // as a threshold. Otherwise, use the minimum observed RTT. + if ctx.Err() != nil { + err = ctx.Err() + } else if deadline, ok := ctx.Deadline(); ok { if internal.IsTimeoutContext(ctx) && time.Now().Add(srvr.RTTMonitor().P90()).After(deadline) { err = internal.WrapErrorf(ErrDeadlineWouldBeExceeded, "remaining time %v until context deadline is less than 90th percentile RTT\n%v", time.Until(deadline), srvr.RTTMonitor().Stats()) } else if time.Now().Add(srvr.RTTMonitor().Min()).After(deadline) { - err = op.networkError(context.DeadlineExceeded) + err = context.DeadlineExceeded } } diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 2734888328..9b8609af9a 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -675,6 +675,48 @@ func TestOperation(t *testing.T) { assert.Nil(t, err, "ExecuteExhaust error: %v", err) assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true") }) + t.Run("context deadline exceeded not marked as TransientTransactionError", func(t *testing.T) { + conn := new(mockConnection) + // Create a context that's already timed out. + ctx, cancel := context.WithDeadline(context.Background(), time.Unix(893934480, 0)) + defer cancel() + + op := Operation{ + Database: "foobar", + Deployment: SingleConnectionDeployment{C: conn}, + CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { + dst = bsoncore.AppendInt32Element(dst, "ping", 1) + return dst, nil + }, + } + + err := op.Execute(ctx, nil) + assert.NotNil(t, err, "expected an error from Execute(), got nil") + // Assert that error is just context deadline exceeded and is therefore not a driver.Error marked + // with the TransientTransactionError label. + assert.Equal(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded error, got %v", err) + }) + t.Run("canceled context not marked as TransientTransactionError", func(t *testing.T) { + conn := new(mockConnection) + // Create a context and cancel it immediately. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + op := Operation{ + Database: "foobar", + Deployment: SingleConnectionDeployment{C: conn}, + CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { + dst = bsoncore.AppendInt32Element(dst, "ping", 1) + return dst, nil + }, + } + + err := op.Execute(ctx, nil) + assert.NotNil(t, err, "expected an error from Execute(), got nil") + // Assert that error is just context canceled and is therefore not a driver.Error marked with + // the TransientTransactionError label. + assert.Equal(t, err, context.Canceled, "expected context.Canceled error, got %v", err) + }) } func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte { diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 02d22c7504..34cb6c9579 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -331,11 +331,6 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { if atomic.LoadInt64(&c.state) != connConnected { return ConnectionError{ConnectionID: c.id, message: "connection is closed"} } - select { - case <-ctx.Done(): - return ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to write"} - default: - } var deadline time.Time if c.writeTimeout != 0 { @@ -388,14 +383,6 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"} } - select { - case <-ctx.Done(): - // We closeConnection the connection because we don't know if there is an unread message on the wire. - c.close() - return nil, ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to read"} - default: - } - var deadline time.Time if c.readTimeout != 0 { deadline = time.Now().Add(c.readTimeout) diff --git a/x/mongo/driver/topology/connection_errors_test.go b/x/mongo/driver/topology/connection_errors_test.go index 27b4597518..66ebbcb8e1 100644 --- a/x/mongo/driver/topology/connection_errors_test.go +++ b/x/mongo/driver/topology/connection_errors_test.go @@ -50,19 +50,5 @@ func TestConnectionErrors(t *testing.T) { err := conn.connect(ctx) assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err) }) - t.Run("write error", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected} - err := conn.writeWireMessage(ctx, []byte{}) - assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err) - }) - t.Run("read error", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected} - _, err := conn.readWireMessage(ctx, []byte{}) - assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err) - }) }) } diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 43c486a0ce..0989c6af0c 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -351,16 +351,6 @@ func TestConnection(t *testing.T) { t.Errorf("errors do not match. got %v; want %v", got, want) } }) - t.Run("completed context", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected} - want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to write"} - got := conn.writeWireMessage(ctx, []byte{}) - if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { - t.Errorf("errors do not match. got %v; want %v", got, want) - } - }) t.Run("deadlines", func(t *testing.T) { testCases := []struct { name string @@ -490,16 +480,6 @@ func TestConnection(t *testing.T) { t.Errorf("errors do not match. got %v; want %v", got, want) } }) - t.Run("completed context", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected} - want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to read"} - _, got := conn.readWireMessage(ctx, []byte{}) - if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { - t.Errorf("errors do not match. got %v; want %v", got, want) - } - }) t.Run("deadlines", func(t *testing.T) { testCases := []struct { name string