Skip to content

Commit

Permalink
GODRIVER-1897 Pin transactions to a connection in LB mode (#608)
Browse files Browse the repository at this point in the history
  • Loading branch information
Divjot Arora committed Mar 11, 2021
1 parent 32aee05 commit 0af4510
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 16 deletions.
6 changes: 6 additions & 0 deletions x/mongo/driver/driver.go
Expand Up @@ -6,6 +6,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)

// Deployment is implemented by types that can select a server from a deployment.
Expand Down Expand Up @@ -73,6 +74,11 @@ type PinnedConnection interface {
UnpinFromTransaction() error
}

// The session.LoadBalancedTransactionConnection type is a copy of PinnedConnection that was introduced to avoid
// import cycles. This compile-time assertion ensures that these types remain in sync if the PinnedConnection interface
// is changed in the future.
var _ PinnedConnection = (session.LoadBalancedTransactionConnection)(nil)

// LocalAddresser is a type that is able to supply its local address
type LocalAddresser interface {
LocalAddress() address.Address
Expand Down
77 changes: 61 additions & 16 deletions x/mongo/driver/operation.go
Expand Up @@ -224,6 +224,44 @@ func (op Operation) selectServer(ctx context.Context) (Server, error) {
return op.Deployment.SelectServer(ctx, selector)
}

// getServerAndConnection should be used to retrieve a Server and Connection to execute an operation.
func (op Operation) getServerAndConnection(ctx context.Context) (Server, Connection, error) {
server, err := op.selectServer(ctx)
if err != nil {
return nil, nil, err
}

// If the provided client session has a pinned connection, it should be used for the operation because this
// indicates that we're in a transaction and the target server is behind a load balancer.
if op.Client != nil && op.Client.PinnedConnection != nil {
return server, op.Client.PinnedConnection, nil
}

// Otherwise, default to checking out a connection from the server's pool.
conn, err := server.Connection(ctx)
if err != nil {
return nil, nil, err
}

// If we're in load balanced mode and this is the first operation in a transaction, pin the session to a connection.
if conn.Description().LoadBalanced() && op.Client != nil && op.Client.TransactionStarting() {
pinnedConn, ok := conn.(PinnedConnection)
if !ok {
// Close the original connection to avoid a leak.
_ = conn.Close()
return nil, nil, fmt.Errorf("expected Connection used to start a transaction to be a PinnedConnection, but got %T", conn)
}
if err := pinnedConn.PinToTransaction(); err != nil {
// Close the original connection to avoid a leak.
_ = conn.Close()
return nil, nil, fmt.Errorf("error incrementing connection reference count when starting a transaction: %v", err)
}
op.Client.PinnedConnection = pinnedConn
}

return server, conn, nil
}

// Validate validates this operation, ensuring the fields are set properly.
func (op Operation) Validate() error {
if op.CommandFn == nil {
Expand All @@ -249,12 +287,7 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error {
return err
}

srvr, err := op.selectServer(ctx)
if err != nil {
return err
}

conn, err := srvr.Connection(ctx)
srvr, conn, err := op.getServerAndConnection(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -386,6 +419,26 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error {
_ = ep.ProcessError(err, conn)
}

// If we're executing a load-balanced transaction and encounter a network error, the pinned connection should
// be unpinned. We call ExpirePinnedConnection to ensure that the connection is closed and returned to the
// pool for bookkeeping. Future AbortTransaction calls will check out a new connection, which is desired. We
// do this before any other checks to make sure we release the invalidated connection even though other
// resources may be holding references to it.
if op.Client != nil && op.Client.PinnedConnection != nil {
if driverErr, ok := err.(Error); ok && driverErr.NetworkError() {
_ = op.Client.ExpirePinnedConnection()
}
}

// If we're executing a load-balanced transaction and are committing/aborting, unpin the session's connection.
// This has to be done before entering the retryability logic because commit/abort attempts are retryable on a
// different mongos, so we want to allow for the possibility of checking out a new connection for the retry.
if op.Client != nil && (op.Client.Committing || op.Client.Aborting) && op.Client.PinnedConnection != nil {
if err := op.Client.UnpinConnection(); err != nil {
return err
}
}

finishedInfo.response = res
finishedInfo.cmdErr = err
op.publishFinishedEvent(ctx, finishedInfo)
Expand All @@ -412,11 +465,7 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error {
retries--
original, err = err, nil
conn.Close() // Avoid leaking the connection.
srvr, err = op.selectServer(ctx)
if err != nil {
return original
}
conn, err = srvr.Connection(ctx)
srvr, conn, err = op.getServerAndConnection(ctx)
if err != nil || conn == nil || !op.retryable(conn.Description()) {
if conn != nil {
conn.Close()
Expand Down Expand Up @@ -505,11 +554,7 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error {
retries--
original, err = err, nil
conn.Close() // Avoid leaking the connection.
srvr, err = op.selectServer(ctx)
if err != nil {
return original
}
conn, err = srvr.Connection(ctx)
srvr, conn, err = op.getServerAndConnection(ctx)
if err != nil || conn == nil || !op.retryable(conn.Description()) {
if conn != nil {
conn.Close()
Expand Down
55 changes: 55 additions & 0 deletions x/mongo/driver/session/client_session.go
Expand Up @@ -7,11 +7,13 @@
package session // import "go.mongodb.org/mongo-driver/x/mongo/driver/session"

import (
"context"
"errors"
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/mongo/readconcern"
"go.mongodb.org/mongo-driver/mongo/readpref"
Expand Down Expand Up @@ -79,6 +81,30 @@ func (s TransactionState) String() string {
}
}

// LoadBalancedTransactionConnection represents a connection that's pinned by a ClientSession because it's being used
// to execute a transaction when running against a load balancer. This interface is a copy of driver.PinnedConnection
// and exists to be able to pin transactions to a connection without causing an import cycle.
type LoadBalancedTransactionConnection interface {
// Functions copied over from driver.Connection.
WriteWireMessage(context.Context, []byte) error
ReadWireMessage(ctx context.Context, dst []byte) ([]byte, error)
Description() description.Server
Close() error
ID() string
Address() address.Address
Stale() bool

// Functions copied over from driver.Expirable.
Alive() bool
Expire() error

// Functions copied over from driver.PinnedConnection that are not part of Connection or Expirable.
PinToCursor() error
PinToTransaction() error
UnpinFromCursor() error
UnpinFromTransaction() error
}

// Client is a session for clients to run commands.
type Client struct {
*Server
Expand Down Expand Up @@ -111,6 +137,7 @@ type Client struct {
TransactionState TransactionState
PinnedServer *description.Server
RecoveryToken bson.Raw
PinnedConnection LoadBalancedTransactionConnection
}

func getClusterTime(clusterTime bson.Raw) (uint32, uint32) {
Expand Down Expand Up @@ -246,6 +273,34 @@ func (c *Client) ClearPinnedServer() {
}
}

// UnpinConnection gracefully unpins the connection associated with the session if there is one. This is done via
// the pinned connection's UnpinFromTransaction function.
func (c *Client) UnpinConnection() error {
if c == nil || c.PinnedConnection == nil {
return nil
}

err := c.PinnedConnection.UnpinFromTransaction()
closeErr := c.PinnedConnection.Close()
if err == nil && closeErr != nil {
err = closeErr
}
c.PinnedConnection = nil
return err
}

// ExpirePinnedConnection forcefully unpins the connection assocated with the session if there is one. This is done via
// the pinned connection's Expire function.
func (c *Client) ExpirePinnedConnection() error {
if c == nil || c.PinnedConnection == nil {
return nil
}

err := c.PinnedConnection.Expire()
c.PinnedConnection = nil
return err
}

// EndSession ends the session.
func (c *Client) EndSession() {
if c.Terminated {
Expand Down

0 comments on commit 0af4510

Please sign in to comment.