From 69f4b244212346954612af56e51a3284c5622ff1 Mon Sep 17 00:00:00 2001 From: Benjamin Rewis Date: Wed, 10 Aug 2022 14:50:53 -0400 Subject: [PATCH 1/4] Pre-write context expiration is not network error --- mongo/session.go | 11 +++------- x/mongo/driver/operation.go | 12 ++++++----- x/mongo/driver/topology/connection.go | 13 ------------ .../driver/topology/connection_errors_test.go | 14 ------------- x/mongo/driver/topology/connection_test.go | 20 ------------------- 5 files changed, 10 insertions(+), 60 deletions(-) diff --git a/mongo/session.go b/mongo/session.go index a4f18baf01..6556b93c6a 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -200,10 +200,6 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi default: } - // End if context has timed out or been canceled, as retrying has no chance of success. - if ctx.Err() != nil { - return res, err - } if errorHasLabel(err, driver.TransientTransactionError) { continue } @@ -218,10 +214,9 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi CommitLoop: for { err = s.CommitTransaction(ctx) - // End when error is nil (transaction has been committed), or when context has timed out or been - // canceled, as retrying has no chance of success. - if err == nil || ctx.Err() != nil { - return res, err + // End when error is nil, as transaction has been committed. + if err == nil { + return res, nil } select { diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 699c63cfb9..2b80fc8d52 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -531,15 +531,17 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error { serviceID: startedInfo.serviceID, } - // Check if there's enough time to perform a round trip before the Context deadline. If ctx is - // a Timeout Context, use the 90th percentile RTT as a threshold. Otherwise, use the minimum observed - // RTT. - if deadline, ok := ctx.Deadline(); ok { + // Check for possible context error. If no context error, check if there's enough time to perform a + // round trip before the Context deadline. If ctx is a Timeout Context, use the 90th percentile RTT + // as a threshold. Otherwise, use the minimum observed RTT. + if ctx.Err() != nil { + err = ctx.Err() + } else if deadline, ok := ctx.Deadline(); ok { if internal.IsTimeoutContext(ctx) && time.Now().Add(srvr.RTTMonitor().P90()).After(deadline) { err = internal.WrapErrorf(ErrDeadlineWouldBeExceeded, "remaining time %v until context deadline is less than 90th percentile RTT\n%v", time.Until(deadline), srvr.RTTMonitor().Stats()) } else if time.Now().Add(srvr.RTTMonitor().Min()).After(deadline) { - err = op.networkError(context.DeadlineExceeded) + err = context.DeadlineExceeded } } diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 02d22c7504..34cb6c9579 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -331,11 +331,6 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { if atomic.LoadInt64(&c.state) != connConnected { return ConnectionError{ConnectionID: c.id, message: "connection is closed"} } - select { - case <-ctx.Done(): - return ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to write"} - default: - } var deadline time.Time if c.writeTimeout != 0 { @@ -388,14 +383,6 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"} } - select { - case <-ctx.Done(): - // We closeConnection the connection because we don't know if there is an unread message on the wire. - c.close() - return nil, ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to read"} - default: - } - var deadline time.Time if c.readTimeout != 0 { deadline = time.Now().Add(c.readTimeout) diff --git a/x/mongo/driver/topology/connection_errors_test.go b/x/mongo/driver/topology/connection_errors_test.go index 27b4597518..66ebbcb8e1 100644 --- a/x/mongo/driver/topology/connection_errors_test.go +++ b/x/mongo/driver/topology/connection_errors_test.go @@ -50,19 +50,5 @@ func TestConnectionErrors(t *testing.T) { err := conn.connect(ctx) assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err) }) - t.Run("write error", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected} - err := conn.writeWireMessage(ctx, []byte{}) - assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err) - }) - t.Run("read error", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected} - _, err := conn.readWireMessage(ctx, []byte{}) - assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err) - }) }) } diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 43c486a0ce..0989c6af0c 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -351,16 +351,6 @@ func TestConnection(t *testing.T) { t.Errorf("errors do not match. got %v; want %v", got, want) } }) - t.Run("completed context", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected} - want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to write"} - got := conn.writeWireMessage(ctx, []byte{}) - if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { - t.Errorf("errors do not match. got %v; want %v", got, want) - } - }) t.Run("deadlines", func(t *testing.T) { testCases := []struct { name string @@ -490,16 +480,6 @@ func TestConnection(t *testing.T) { t.Errorf("errors do not match. got %v; want %v", got, want) } }) - t.Run("completed context", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected} - want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to read"} - _, got := conn.readWireMessage(ctx, []byte{}) - if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { - t.Errorf("errors do not match. got %v; want %v", got, want) - } - }) t.Run("deadlines", func(t *testing.T) { testCases := []struct { name string From 818236757bde935c21167ec31c64665cc0fd015a Mon Sep 17 00:00:00 2001 From: Benjamin Rewis Date: Wed, 10 Aug 2022 15:44:54 -0400 Subject: [PATCH 2/4] Add Execute tests --- x/mongo/driver/operation_test.go | 46 ++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 2734888328..cb6690a780 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -675,6 +675,52 @@ func TestOperation(t *testing.T) { assert.Nil(t, err, "ExecuteExhaust error: %v", err) assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true") }) + t.Run("context deadline exceeded not marked as TransientTransactionError", func(t *testing.T) { + conn := new(mockConnection) + // Create a context with a very short timeout. + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + + op := Operation{ + Database: "foobar", + Deployment: SingleConnectionDeployment{C: conn}, + CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { + dst = bsoncore.AppendInt32Element(dst, "ping", 1) + return dst, nil + }, + } + + err := op.Execute(ctx, nil) + assert.NotNil(t, err, "expected an error from Execute(), got nil") + // Assert that error is just context deadline exceeded and is therefore not a driver.Error marked + // with the TransientTransactionError label. + if !cmp.Equal(err, context.DeadlineExceeded, cmp.Comparer(compareErrors)) { + t.Errorf("err is not equal to expected error. got %v; want %v", err, context.DeadlineExceeded) + } + }) + t.Run("canceled context not marked as TransientTransactionError", func(t *testing.T) { + conn := new(mockConnection) + // Create a context and cancel it immediately. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + op := Operation{ + Database: "foobar", + Deployment: SingleConnectionDeployment{C: conn}, + CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) { + dst = bsoncore.AppendInt32Element(dst, "ping", 1) + return dst, nil + }, + } + + err := op.Execute(ctx, nil) + assert.NotNil(t, err, "expected an error from Execute(), got nil") + // Assert that error is just context canceled and is therefore not a driver.Error marked with + // the TransientTransactionError label. + if !cmp.Equal(err, context.Canceled, cmp.Comparer(compareErrors)) { + t.Errorf("err is not equal to expected error. got %v; want %v", err, context.Canceled) + } + }) } func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte { From 144772d0711fccf52426f0a2ac121dd3298d93e2 Mon Sep 17 00:00:00 2001 From: Benjamin Rewis Date: Thu, 11 Aug 2022 10:57:25 -0400 Subject: [PATCH 3/4] Use already timed out context instead of short timeout --- x/mongo/driver/operation_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index cb6690a780..632c20ff80 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -677,8 +677,8 @@ func TestOperation(t *testing.T) { }) t.Run("context deadline exceeded not marked as TransientTransactionError", func(t *testing.T) { conn := new(mockConnection) - // Create a context with a very short timeout. - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + // Create a context that's already timed out. + ctx, cancel := context.WithDeadline(context.Background(), time.Unix(893934480, 0)) defer cancel() op := Operation{ From 65fc24a2184c9f1fb082d632451d1999bd22261c Mon Sep 17 00:00:00 2001 From: Benjamin Rewis Date: Wed, 17 Aug 2022 15:13:48 -0400 Subject: [PATCH 4/4] Use assert.Equal --- x/mongo/driver/operation_test.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 632c20ff80..9b8609af9a 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -694,9 +694,7 @@ func TestOperation(t *testing.T) { assert.NotNil(t, err, "expected an error from Execute(), got nil") // Assert that error is just context deadline exceeded and is therefore not a driver.Error marked // with the TransientTransactionError label. - if !cmp.Equal(err, context.DeadlineExceeded, cmp.Comparer(compareErrors)) { - t.Errorf("err is not equal to expected error. got %v; want %v", err, context.DeadlineExceeded) - } + assert.Equal(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded error, got %v", err) }) t.Run("canceled context not marked as TransientTransactionError", func(t *testing.T) { conn := new(mockConnection) @@ -717,9 +715,7 @@ func TestOperation(t *testing.T) { assert.NotNil(t, err, "expected an error from Execute(), got nil") // Assert that error is just context canceled and is therefore not a driver.Error marked with // the TransientTransactionError label. - if !cmp.Equal(err, context.Canceled, cmp.Comparer(compareErrors)) { - t.Errorf("err is not equal to expected error. got %v; want %v", err, context.Canceled) - } + assert.Equal(t, err, context.Canceled, "expected context.Canceled error, got %v", err) }) }