Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-2800 Convert Session Interface to a Struct #1592

Merged
merged 11 commits into from
Apr 30, 2024
80 changes: 44 additions & 36 deletions internal/docexamples/examples.go
Expand Up @@ -1758,30 +1758,32 @@ func UpdateEmployeeInfo(ctx context.Context, client *mongo.Client) error {
employees := client.Database("hr").Collection("employees")
events := client.Database("reporting").Collection("events")

return client.UseSession(ctx, func(sctx mongo.SessionContext) error {
err := sctx.StartTransaction(options.Transaction().
return client.UseSession(ctx, func(ctx context.Context) error {
sess := mongo.SessionFromContext(ctx)

err := sess.StartTransaction(options.Transaction().
SetReadConcern(readconcern.Snapshot()).
SetWriteConcern(writeconcern.Majority()),
)
if err != nil {
return err
}

_, err = employees.UpdateOne(sctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}})
_, err = employees.UpdateOne(ctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}})
if err != nil {
sctx.AbortTransaction(sctx)
sess.AbortTransaction(ctx)
log.Println("caught exception during transaction, aborting.")
return err
}
_, err = events.InsertOne(sctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}})
_, err = events.InsertOne(ctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}})
if err != nil {
sctx.AbortTransaction(sctx)
sess.AbortTransaction(ctx)
log.Println("caught exception during transaction, aborting.")
return err
}

for {
err = sctx.CommitTransaction(sctx)
err = sess.CommitTransaction(ctx)
switch e := err.(type) {
case nil:
return nil
Expand All @@ -1805,9 +1807,9 @@ func UpdateEmployeeInfo(ctx context.Context, client *mongo.Client) error {
// Start Transactions Retry Example 1

// RunTransactionWithRetry is an example function demonstrating transaction retry logic.
func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.SessionContext) error) error {
func RunTransactionWithRetry(ctx context.Context, txnFn func(context.Context) error) error {
for {
err := txnFn(sctx) // Performs transaction.
err := txnFn(ctx) // Performs transaction.
if err == nil {
return nil
}
Expand All @@ -1828,9 +1830,11 @@ func RunTransactionWithRetry(sctx mongo.SessionContext, txnFn func(mongo.Session
// Start Transactions Retry Example 2

// CommitWithRetry is an example function demonstrating transaction commit with retry logic.
func CommitWithRetry(sctx mongo.SessionContext) error {
func CommitWithRetry(ctx context.Context) error {
sess := mongo.SessionFromContext(ctx)

for {
err := sctx.CommitTransaction(sctx)
err := sess.CommitTransaction(ctx)
switch e := err.(type) {
case nil:
log.Println("Transaction committed.")
Expand Down Expand Up @@ -1872,9 +1876,9 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error {
}
// Start Transactions Retry Example 3

runTransactionWithRetry := func(sctx mongo.SessionContext, txnFn func(mongo.SessionContext) error) error {
runTransactionWithRetry := func(ctx context.Context, txnFn func(context.Context) error) error {
for {
err := txnFn(sctx) // Performs transaction.
err := txnFn(ctx) // Performs transaction.
if err == nil {
return nil
}
Expand All @@ -1890,9 +1894,11 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error {
}
}

commitWithRetry := func(sctx mongo.SessionContext) error {
commitWithRetry := func(ctx context.Context) error {
sess := mongo.SessionFromContext(ctx)

for {
err := sctx.CommitTransaction(sctx)
err := sess.CommitTransaction(ctx)
switch e := err.(type) {
case nil:
log.Println("Transaction committed.")
Expand All @@ -1913,38 +1919,40 @@ func TransactionsExamples(ctx context.Context, client *mongo.Client) error {
}

// Updates two collections in a transaction.
updateEmployeeInfo := func(sctx mongo.SessionContext) error {
updateEmployeeInfo := func(ctx context.Context) error {
employees := client.Database("hr").Collection("employees")
events := client.Database("reporting").Collection("events")

err := sctx.StartTransaction(options.Transaction().
sess := mongo.SessionFromContext(ctx)

err := sess.StartTransaction(options.Transaction().
SetReadConcern(readconcern.Snapshot()).
SetWriteConcern(writeconcern.Majority()),
)
if err != nil {
return err
}

_, err = employees.UpdateOne(sctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}})
_, err = employees.UpdateOne(ctx, bson.D{{"employee", 3}}, bson.D{{"$set", bson.D{{"status", "Inactive"}}}})
if err != nil {
sctx.AbortTransaction(sctx)
sess.AbortTransaction(ctx)
log.Println("caught exception during transaction, aborting.")
return err
}
_, err = events.InsertOne(sctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}})
_, err = events.InsertOne(ctx, bson.D{{"employee", 3}, {"status", bson.D{{"new", "Inactive"}, {"old", "Active"}}}})
if err != nil {
sctx.AbortTransaction(sctx)
sess.AbortTransaction(ctx)
log.Println("caught exception during transaction, aborting.")
return err
}

return commitWithRetry(sctx)
return commitWithRetry(ctx)
}

return client.UseSessionWithOptions(
ctx, options.Session().SetDefaultReadPreference(readpref.Primary()),
func(sctx mongo.SessionContext) error {
return runTransactionWithRetry(sctx, updateEmployeeInfo)
func(ctx context.Context) error {
return runTransactionWithRetry(ctx, updateEmployeeInfo)
},
)
}
Expand Down Expand Up @@ -1976,13 +1984,13 @@ func WithTransactionExample(ctx context.Context) error {
barColl := client.Database("mydb1").Collection("bar", wcMajorityCollectionOpts)

// Step 1: Define the callback that specifies the sequence of operations to perform inside the transaction.
callback := func(sessCtx mongo.SessionContext) (interface{}, error) {
// Important: You must pass sessCtx as the Context parameter to the operations for them to be executed in the
callback := func(sesctx context.Context) (interface{}, error) {
// Important: You must pass sesctx as the Context parameter to the operations for them to be executed in the
// transaction.
if _, err := fooColl.InsertOne(sessCtx, bson.D{{"abc", 1}}); err != nil {
if _, err := fooColl.InsertOne(sesctx, bson.D{{"abc", 1}}); err != nil {
return nil, err
}
if _, err := barColl.InsertOne(sessCtx, bson.D{{"xyz", 999}}); err != nil {
if _, err := barColl.InsertOne(sesctx, bson.D{{"xyz", 999}}); err != nil {
return nil, err
}

Expand Down Expand Up @@ -2560,15 +2568,15 @@ func CausalConsistencyExamples(client *mongo.Client) error {
}
defer session1.EndSession(context.TODO())

err = client.UseSessionWithOptions(context.TODO(), opts, func(sctx mongo.SessionContext) error {
err = client.UseSessionWithOptions(context.TODO(), opts, func(ctx context.Context) error {
// Run an update with our causally-consistent session
_, err = coll.UpdateOne(sctx, bson.D{{"sku", 111}}, bson.D{{"$set", bson.D{{"end", currentDate}}}})
_, err = coll.UpdateOne(ctx, bson.D{{"sku", 111}}, bson.D{{"$set", bson.D{{"end", currentDate}}}})
if err != nil {
return err
}

// Run an insert with our causally-consistent session
_, err = coll.InsertOne(sctx, bson.D{{"sku", "nuts-111"}, {"name", "Pecans"}, {"start", currentDate}})
_, err = coll.InsertOne(ctx, bson.D{{"sku", "nuts-111"}, {"name", "Pecans"}, {"start", currentDate}})
if err != nil {
return err
}
Expand All @@ -2593,7 +2601,7 @@ func CausalConsistencyExamples(client *mongo.Client) error {
}
defer session2.EndSession(context.TODO())

err = client.UseSessionWithOptions(context.TODO(), opts, func(sctx mongo.SessionContext) error {
err = client.UseSessionWithOptions(context.TODO(), opts, func(ctx context.Context) error {
// Set cluster time of session2 to session1's cluster time
clusterTime := session1.ClusterTime()
session2.AdvanceClusterTime(clusterTime)
Expand All @@ -2602,13 +2610,13 @@ func CausalConsistencyExamples(client *mongo.Client) error {
operationTime := session1.OperationTime()
session2.AdvanceOperationTime(operationTime)
// Run a find on session2, which should find all the writes from session1
cursor, err := coll.Find(sctx, bson.D{{"end", nil}})
cursor, err := coll.Find(ctx, bson.D{{"end", nil}})

if err != nil {
return err
}

for cursor.Next(sctx) {
for cursor.Next(ctx) {
doc := cursor.Current
fmt.Printf("Document: %v\n", doc.String())
}
Expand Down Expand Up @@ -2984,7 +2992,7 @@ func snapshotQueryPetExample(mt *mtest.T) error {
defer sess.EndSession(ctx)

var adoptablePetsCount int32
err = mongo.WithSession(ctx, sess, func(ctx mongo.SessionContext) error {
err = mongo.WithSession(ctx, sess, func(ctx context.Context) error {
// Count the adoptable cats
const adoptableCatsOutput = "adoptableCatsCount"
cursor, err := db.Collection("cats").Aggregate(ctx, mongo.Pipeline{
Expand Down Expand Up @@ -3048,7 +3056,7 @@ func snapshotQueryRetailExample(mt *mtest.T) error {
defer sess.EndSession(ctx)

var totalDailySales int32
err = mongo.WithSession(ctx, sess, func(ctx mongo.SessionContext) error {
err = mongo.WithSession(ctx, sess, func(ctx context.Context) error {
// Count the total daily sales
const totalDailySalesOutput = "totalDailySales"
cursor, err := db.Collection("sales").Aggregate(ctx, mongo.Pipeline{
Expand Down
42 changes: 21 additions & 21 deletions internal/integration/causal_consistency_test.go
Expand Up @@ -41,8 +41,8 @@ func TestCausalConsistency_Supported(t *testing.T) {
// first read in a causally consistent session must not send afterClusterTime to the server

ccOpts := options.Session().SetCausalConsistency(true)
_ = mt.Client.UseSessionWithOptions(context.Background(), ccOpts, func(sc mongo.SessionContext) error {
_, _ = mt.Coll.Find(sc, bson.D{})
_ = mt.Client.UseSessionWithOptions(context.Background(), ccOpts, func(ctx context.Context) error {
_, _ = mt.Coll.Find(ctx, bson.D{})
return nil
})

Expand All @@ -57,8 +57,8 @@ func TestCausalConsistency_Supported(t *testing.T) {
assert.Nil(mt, err, "StartSession error: %v", err)
defer sess.EndSession(context.Background())

_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_, _ = mt.Coll.Find(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_, _ = mt.Coll.Find(ctx, bson.D{})
return nil
})

Expand All @@ -85,8 +85,8 @@ func TestCausalConsistency_Supported(t *testing.T) {
assert.Nil(mt, err, "StartSession error: %v", err)
defer sess.EndSession(context.Background())

_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})
currOptime := sess.OperationTime()
Expand Down Expand Up @@ -120,8 +120,8 @@ func TestCausalConsistency_Supported(t *testing.T) {
assert.NotNil(mt, currOptime, "expected session operation time, got nil")

mt.ClearEvents()
_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})
_, sentOptime := getReadConcernFields(mt, mt.GetStartedEvent().Command)
Expand All @@ -134,10 +134,10 @@ func TestCausalConsistency_Supported(t *testing.T) {
// a read operation in a non causally-consistent session should not include afterClusterTime

sessOpts := options.Session().SetCausalConsistency(false)
_ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(sc mongo.SessionContext) error {
_, _ = mt.Coll.Find(sc, bson.D{})
_ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(ctx context.Context) error {
_, _ = mt.Coll.Find(ctx, bson.D{})
mt.ClearEvents()
_, _ = mt.Coll.Find(sc, bson.D{})
_, _ = mt.Coll.Find(ctx, bson.D{})
return nil
})
evt := mt.GetStartedEvent()
Expand All @@ -152,14 +152,14 @@ func TestCausalConsistency_Supported(t *testing.T) {
assert.Nil(mt, err, "StartSession error: %v", err)
defer sess.EndSession(context.Background())

_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})
currOptime := sess.OperationTime()
mt.ClearEvents()
_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})

Expand All @@ -174,14 +174,14 @@ func TestCausalConsistency_Supported(t *testing.T) {
assert.Nil(mt, err, "StartSession error: %v", err)
defer sess.EndSession(context.Background())

_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})
currOptime := sess.OperationTime()
mt.ClearEvents()
_ = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
_ = mt.Coll.FindOne(sc, bson.D{})
_ = mongo.WithSession(context.Background(), sess, func(ctx context.Context) error {
_ = mt.Coll.FindOne(ctx, bson.D{})
return nil
})

Expand Down Expand Up @@ -215,8 +215,8 @@ func TestCausalConsistency_NotSupported(t *testing.T) {
// support cluster times

sessOpts := options.Session().SetCausalConsistency(true)
_ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(sc mongo.SessionContext) error {
_, _ = mt.Coll.Find(sc, bson.D{})
_ = mt.Client.UseSessionWithOptions(context.Background(), sessOpts, func(ctx context.Context) error {
_, _ = mt.Coll.Find(ctx, bson.D{})
return nil
})

Expand Down
3 changes: 1 addition & 2 deletions internal/integration/client_test.go
Expand Up @@ -371,8 +371,7 @@ func TestClient(t *testing.T) {
sess, err := mt.Client.StartSession(tc.opts)
assert.Nil(mt, err, "StartSession error: %v", err)
defer sess.EndSession(context.Background())
xs := sess.(mongo.XSession)
consistent := xs.ClientSession().Consistent
consistent := sess.ClientSession().Consistent
assert.Equal(mt, tc.consistent, consistent, "expected consistent to be %v, got %v", tc.consistent, consistent)
})
}
Expand Down