diff --git a/mongo/session.go b/mongo/session.go index 6556b93c6a..37d5b75761 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -206,11 +206,27 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi return res, err } + // Check if callback intentionally aborted and, if so, return immediately + // with no error. err = s.clientSession.CheckAbortTransaction() if err != nil { return res, nil } + // If context has errored, run AbortTransaction and return, as the CommitLoop + // has no chance of succeeding. + // + // Aborting after a failed CommitTransaction is dangerous. Failed transaction + // commits may unpin the session server-side, and subsequent transaction aborts + // may run on a new mongos which could end up with commit and abort being executed + // simultaneously. + if ctx.Err() != nil { + // Wrap the user-provided Context in a new one that behaves like context.Background() for deadlines and + // cancellations, but forwards Value requests to the original one. + _ = s.AbortTransaction(internal.NewBackgroundContext(ctx)) + return nil, ctx.Err() + } + CommitLoop: for { err = s.CommitTransaction(ctx) diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index 356a840fa9..6d3acfff4b 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -348,9 +348,36 @@ func TestConvenientTransactions(t *testing.T) { return nil }) }) - t.Run("commitTransaction timeout does not retry", func(t *testing.T) { + t.Run("context error before commitTransaction does not retry and aborts", func(t *testing.T) { withTransactionTimeout = 2 * time.Second + // Create a special CommandMonitor that only records information about abortTransaction events. + var abortStarted []*event.CommandStartedEvent + var abortSucceeded []*event.CommandSucceededEvent + var abortFailed []*event.CommandFailedEvent + monitor := &event.CommandMonitor{ + Started: func(ctx context.Context, evt *event.CommandStartedEvent) { + if evt.CommandName == "abortTransaction" { + abortStarted = append(abortStarted, evt) + } + }, + Succeeded: func(_ context.Context, evt *event.CommandSucceededEvent) { + if evt.CommandName == "abortTransaction" { + abortSucceeded = append(abortSucceeded, evt) + } + }, + Failed: func(_ context.Context, evt *event.CommandFailedEvent) { + if evt.CommandName == "abortTransaction" { + abortFailed = append(abortFailed, evt) + } + }, + } + + // Set up a new Client using the command monitor defined above get a handle to a collection. The collection + // needs to be explicitly created on the server because implicit collection creation is not allowed in + // transactions for server versions <= 4.2. + client := setupConvenientTransactions(t, options.Client().SetMonitor(monitor)) + db := client.Database("foo") coll := db.Collection("test") // Explicitly create the collection on server because implicit collection creation is not allowed in // transactions for server versions <= 4.2. @@ -377,7 +404,8 @@ func TestConvenientTransactions(t *testing.T) { } }() - // Insert a document within a session and manually cancel context. + // Insert a document within a session and manually cancel context before + // "commitTransaction" can be sent. callback := func(ctx context.Context) { transactionCtx, cancel := context.WithCancel(ctx) @@ -391,6 +419,12 @@ func TestConvenientTransactions(t *testing.T) { // Assert that transaction is canceled within 500ms and not 2 seconds. assert.Soon(t, callback, 500*time.Millisecond) + + // Assert that AbortTransaction was started once and succeeded. + assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted)) + assert.Equal(t, 1, len(abortSucceeded), "expected 1 abortTransaction succeeded event, got %d", + len(abortSucceeded)) + assert.Equal(t, 0, len(abortFailed), "expected 0 abortTransaction failed events, got %d", len(abortFailed)) }) t.Run("wrapped transient transaction error retried", func(t *testing.T) { sess, err := client.StartSession() @@ -416,7 +450,7 @@ func TestConvenientTransactions(t *testing.T) { assert.True(t, ok, "expected result type %T, got %T", false, res) assert.False(t, resBool, "expected result false, got %v", resBool) }) - t.Run("expired context before commitTransaction does not retry", func(t *testing.T) { + t.Run("expired context before callback does not retry", func(t *testing.T) { withTransactionTimeout = 2 * time.Second coll := db.Collection("test") @@ -446,7 +480,7 @@ func TestConvenientTransactions(t *testing.T) { // Assert that transaction fails within 500ms and not 2 seconds. assert.Soon(t, callback, 500*time.Millisecond) }) - t.Run("canceled context before commitTransaction does not retry", func(t *testing.T) { + t.Run("canceled context before callback does not retry", func(t *testing.T) { withTransactionTimeout = 2 * time.Second coll := db.Collection("test") @@ -476,7 +510,7 @@ func TestConvenientTransactions(t *testing.T) { // Assert that transaction fails within 500ms and not 2 seconds. assert.Soon(t, callback, 500*time.Millisecond) }) - t.Run("slow operation before commitTransaction retries", func(t *testing.T) { + t.Run("slow operation in callback retries", func(t *testing.T) { withTransactionTimeout = 2 * time.Second coll := db.Collection("test")