Skip to content

Commit

Permalink
GODRIVER-2565 Abort transaction before CommitLoop if context errored. (
Browse files Browse the repository at this point in the history
  • Loading branch information
benjirewis committed Oct 31, 2022
1 parent e4853fb commit edfc51c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
16 changes: 16 additions & 0 deletions mongo/session.go
Expand Up @@ -206,11 +206,27 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi
return res, err
}

// Check if callback intentionally aborted and, if so, return immediately
// with no error.
err = s.clientSession.CheckAbortTransaction()
if err != nil {
return res, nil
}

// If context has errored, run AbortTransaction and return, as the CommitLoop
// has no chance of succeeding.
//
// Aborting after a failed CommitTransaction is dangerous. Failed transaction
// commits may unpin the session server-side, and subsequent transaction aborts
// may run on a new mongos which could end up with commit and abort being executed
// simultaneously.
if ctx.Err() != nil {
// Wrap the user-provided Context in a new one that behaves like context.Background() for deadlines and
// cancellations, but forwards Value requests to the original one.
_ = s.AbortTransaction(internal.NewBackgroundContext(ctx))
return nil, ctx.Err()
}

CommitLoop:
for {
err = s.CommitTransaction(ctx)
Expand Down
44 changes: 39 additions & 5 deletions mongo/with_transactions_test.go
Expand Up @@ -348,9 +348,36 @@ func TestConvenientTransactions(t *testing.T) {
return nil
})
})
t.Run("commitTransaction timeout does not retry", func(t *testing.T) {
t.Run("context error before commitTransaction does not retry and aborts", func(t *testing.T) {
withTransactionTimeout = 2 * time.Second

// Create a special CommandMonitor that only records information about abortTransaction events.
var abortStarted []*event.CommandStartedEvent
var abortSucceeded []*event.CommandSucceededEvent
var abortFailed []*event.CommandFailedEvent
monitor := &event.CommandMonitor{
Started: func(ctx context.Context, evt *event.CommandStartedEvent) {
if evt.CommandName == "abortTransaction" {
abortStarted = append(abortStarted, evt)
}
},
Succeeded: func(_ context.Context, evt *event.CommandSucceededEvent) {
if evt.CommandName == "abortTransaction" {
abortSucceeded = append(abortSucceeded, evt)
}
},
Failed: func(_ context.Context, evt *event.CommandFailedEvent) {
if evt.CommandName == "abortTransaction" {
abortFailed = append(abortFailed, evt)
}
},
}

// Set up a new Client using the command monitor defined above get a handle to a collection. The collection
// needs to be explicitly created on the server because implicit collection creation is not allowed in
// transactions for server versions <= 4.2.
client := setupConvenientTransactions(t, options.Client().SetMonitor(monitor))
db := client.Database("foo")
coll := db.Collection("test")
// Explicitly create the collection on server because implicit collection creation is not allowed in
// transactions for server versions <= 4.2.
Expand All @@ -377,7 +404,8 @@ func TestConvenientTransactions(t *testing.T) {
}
}()

// Insert a document within a session and manually cancel context.
// Insert a document within a session and manually cancel context before
// "commitTransaction" can be sent.
callback := func(ctx context.Context) {
transactionCtx, cancel := context.WithCancel(ctx)

Expand All @@ -391,6 +419,12 @@ func TestConvenientTransactions(t *testing.T) {

// Assert that transaction is canceled within 500ms and not 2 seconds.
assert.Soon(t, callback, 500*time.Millisecond)

// Assert that AbortTransaction was started once and succeeded.
assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted))
assert.Equal(t, 1, len(abortSucceeded), "expected 1 abortTransaction succeeded event, got %d",
len(abortSucceeded))
assert.Equal(t, 0, len(abortFailed), "expected 0 abortTransaction failed events, got %d", len(abortFailed))
})
t.Run("wrapped transient transaction error retried", func(t *testing.T) {
sess, err := client.StartSession()
Expand All @@ -416,7 +450,7 @@ func TestConvenientTransactions(t *testing.T) {
assert.True(t, ok, "expected result type %T, got %T", false, res)
assert.False(t, resBool, "expected result false, got %v", resBool)
})
t.Run("expired context before commitTransaction does not retry", func(t *testing.T) {
t.Run("expired context before callback does not retry", func(t *testing.T) {
withTransactionTimeout = 2 * time.Second

coll := db.Collection("test")
Expand Down Expand Up @@ -446,7 +480,7 @@ func TestConvenientTransactions(t *testing.T) {
// Assert that transaction fails within 500ms and not 2 seconds.
assert.Soon(t, callback, 500*time.Millisecond)
})
t.Run("canceled context before commitTransaction does not retry", func(t *testing.T) {
t.Run("canceled context before callback does not retry", func(t *testing.T) {
withTransactionTimeout = 2 * time.Second

coll := db.Collection("test")
Expand Down Expand Up @@ -476,7 +510,7 @@ func TestConvenientTransactions(t *testing.T) {
// Assert that transaction fails within 500ms and not 2 seconds.
assert.Soon(t, callback, 500*time.Millisecond)
})
t.Run("slow operation before commitTransaction retries", func(t *testing.T) {
t.Run("slow operation in callback retries", func(t *testing.T) {
withTransactionTimeout = 2 * time.Second

coll := db.Collection("test")
Expand Down

0 comments on commit edfc51c

Please sign in to comment.