diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index f4348bf541..0f016307d3 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -605,7 +605,8 @@ func (c initConnection) SupportsStreaming() bool { // messages and the driver.Expirable interface to allow expiring. type Connection struct { *connection - refCount int + refCount int + cleanupPoolFn func() mu sync.RWMutex } @@ -687,9 +688,7 @@ func (c *Connection) Close() error { return nil } - err := c.pool.put(c.connection) - c.connection = nil - return err + return c.cleanupReferences() } // Expire closes this connection and will closeConnection the underlying socket. @@ -701,7 +700,15 @@ func (c *Connection) Expire() error { } _ = c.close() + return c.cleanupReferences() +} + +func (c *Connection) cleanupReferences() error { err := c.pool.put(c.connection) + if c.cleanupPoolFn != nil { + c.cleanupPoolFn() + c.cleanupPoolFn = nil + } c.connection = nil return err } @@ -750,21 +757,27 @@ func (c *Connection) LocalAddress() address.Address { // PinToCursor updates this connection to reflect that it is pinned to a cursor. func (c *Connection) PinToCursor() error { - return c.pin("cursor") + return c.pin("cursor", c.pool.pinConnectionToCursor, c.pool.unpinConnectionFromCursor) } // PinToTransaction updates this connection to reflect that it is pinned to a transaction. func (c *Connection) PinToTransaction() error { - return c.pin("transaction") + return c.pin("transaction", c.pool.pinConnectionToTransaction, c.pool.unpinConnectionFromTransaction) } -func (c *Connection) pin(reason string) error { +func (c *Connection) pin(reason string, updatePoolFn, cleanupPoolFn func()) error { c.mu.Lock() defer c.mu.Unlock() if c.connection == nil { return fmt.Errorf("attempted to pin a connection for a %s, but the connection has already been returned to the pool", reason) } + // Only use the provided callbacks for the first reference to avoid double-counting pinned connection statistics + // in the pool. + if c.refCount == 0 { + updatePoolFn() + c.cleanupPoolFn = cleanupPoolFn + } c.refCount++ return nil } diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index c1a0b55a5d..6249322d04 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -834,6 +834,134 @@ func TestConnection(t *testing.T) { t.Errorf("LocalAddresses do not match. got %v; want %v", got, want) } }) + + t.Run("pinning", func(t *testing.T) { + makeMultipleConnections := func(t *testing.T, numConns int) (*pool, []*Connection) { + t.Helper() + + addr := address.Address("") + pool, err := newPool(poolConfig{Address: addr}) + assert.Nil(t, err, "newPool error: %v", err) + + err = pool.sem.Acquire(context.Background(), int64(numConns)) + assert.Nil(t, err, "error acquiring semaphore: %v", err) + + conns := make([]*Connection, 0, numConns) + for i := 0; i < numConns; i++ { + conn, err := newConnection(addr) + assert.Nil(t, err, "newConnection error: %v", err) + conn.pool = pool + conns = append(conns, &Connection{connection: conn}) + } + return pool, conns + } + makeOneConnection := func(t *testing.T) (*pool, *Connection) { + t.Helper() + + pool, conns := makeMultipleConnections(t, 1) + return pool, conns[0] + } + + assertPoolPinnedStats := func(t *testing.T, p *pool, cursorConns, txnConns uint64) { + t.Helper() + + assert.Equal(t, cursorConns, p.pinnedCursorConnections, "expected %d connections to be pinned to cursors, got %d", + cursorConns, p.pinnedCursorConnections) + assert.Equal(t, txnConns, p.pinnedTransactionConnections, "expected %d connections to be pinned to transactions, got %d", + txnConns, p.pinnedTransactionConnections) + } + + t.Run("cursors", func(t *testing.T) { + pool, conn := makeOneConnection(t) + err := conn.PinToCursor() + assert.Nil(t, err, "PinToCursor error: %v", err) + assertPoolPinnedStats(t, pool, 1, 0) + + err = conn.UnpinFromCursor() + assert.Nil(t, err, "UnpinFromCursor error: %v", err) + + err = conn.Close() + assert.Nil(t, err, "Close error: %v", err) + assertPoolPinnedStats(t, pool, 0, 0) + }) + t.Run("transactions", func(t *testing.T) { + pool, conn := makeOneConnection(t) + err := conn.PinToTransaction() + assert.Nil(t, err, "PinToTransaction error: %v", err) + assertPoolPinnedStats(t, pool, 0, 1) + + err = conn.UnpinFromTransaction() + assert.Nil(t, err, "UnpinFromTransaction error: %v", err) + + err = conn.Close() + assert.Nil(t, err, "Close error: %v", err) + assertPoolPinnedStats(t, pool, 0, 0) + }) + t.Run("pool is only updated for first reference", func(t *testing.T) { + pool, conn := makeOneConnection(t) + err := conn.PinToTransaction() + assert.Nil(t, err, "PinToTransaction error: %v", err) + assertPoolPinnedStats(t, pool, 0, 1) + + err = conn.PinToCursor() + assert.Nil(t, err, "PinToCursor error: %v", err) + assertPoolPinnedStats(t, pool, 0, 1) + + err = conn.UnpinFromCursor() + assert.Nil(t, err, "UnpinFromCursor error: %v", err) + assertPoolPinnedStats(t, pool, 0, 1) + + err = conn.UnpinFromTransaction() + assert.Nil(t, err, "UnpinFromTransaction error: %v", err) + assertPoolPinnedStats(t, pool, 0, 1) + + err = conn.Close() + assert.Nil(t, err, "Close error: %v", err) + assertPoolPinnedStats(t, pool, 0, 0) + }) + t.Run("multiple connections from a pool", func(t *testing.T) { + pool, conns := makeMultipleConnections(t, 2) + first, second := conns[0], conns[1] + + err := first.PinToTransaction() + assert.Nil(t, err, "PinToTransaction error: %v", err) + err = second.PinToCursor() + assert.Nil(t, err, "PinToCursor error: %v", err) + assertPoolPinnedStats(t, pool, 1, 1) + + err = first.UnpinFromTransaction() + assert.Nil(t, err, "UnpinFromTransaction error: %v", err) + err = first.Close() + assert.Nil(t, err, "Close error: %v", err) + assertPoolPinnedStats(t, pool, 1, 0) + + err = second.UnpinFromCursor() + assert.Nil(t, err, "UnpinFromCursor error: %v", err) + err = second.Close() + assert.Nil(t, err, "Close error: %v", err) + assertPoolPinnedStats(t, pool, 0, 0) + }) + t.Run("close is ignored if connection is pinned", func(t *testing.T) { + pool, conn := makeOneConnection(t) + err := conn.PinToCursor() + assert.Nil(t, err, "PinToCursor error: %v", err) + + err = conn.Close() + assert.Nil(t, err, "Close error") + assert.NotNil(t, conn.connection, "expected connection to be pinned but it was released to the pool") + assertPoolPinnedStats(t, pool, 1, 0) + }) + t.Run("expire forcefully returns connection to pool", func(t *testing.T) { + pool, conn := makeOneConnection(t) + err := conn.PinToCursor() + assert.Nil(t, err, "PinToCursor error: %v", err) + + err = conn.Expire() + assert.Nil(t, err, "Expire error") + assert.Nil(t, conn.connection, "expected connection to be released to the pool but was not") + assertPoolPinnedStats(t, pool, 0, 0) + }) + }) }) } diff --git a/x/mongo/driver/topology/errors.go b/x/mongo/driver/topology/errors.go index 30274ee91e..f3e2e3a7cb 100644 --- a/x/mongo/driver/topology/errors.go +++ b/x/mongo/driver/topology/errors.go @@ -62,16 +62,23 @@ func (e ServerSelectionError) Unwrap() error { // WaitQueueTimeoutError represents a timeout when requesting a connection from the pool type WaitQueueTimeoutError struct { - Wrapped error + Wrapped error + PinnedCursorConnections uint64 + PinnedTransactionConnections uint64 + maxPoolSize uint64 } // Error implements the error interface. func (w WaitQueueTimeoutError) Error() string { errorMsg := "timed out while checking out a connection from connection pool" if w.Wrapped != nil { - return fmt.Sprintf("%s: %s", errorMsg, w.Wrapped.Error()) + errorMsg = fmt.Sprintf("%s: %s", errorMsg, w.Wrapped.Error()) } - return errorMsg + + errorMsg = fmt.Sprintf("%s; maxPoolSize: %d, connections in use by cursors: %d, connections in use by transactions: %d", + errorMsg, w.maxPoolSize, w.PinnedCursorConnections, w.PinnedTransactionConnections) + return fmt.Sprintf("%s, connections in use by other operations: %d", errorMsg, + w.maxPoolSize-(w.PinnedCursorConnections+w.PinnedTransactionConnections)) } // Unwrap returns the underlying error. diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index c2abb7046f..5775ec3cc3 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -64,10 +64,14 @@ type pool struct { generation *poolGenerationMap monitor *event.PoolMonitor - connected int32 // Must be accessed using the sync/atomic package. - nextid uint64 - opened map[uint64]*connection // opened holds all of the currently open connections. - sem *semaphore.Weighted + // Must be accessed using the atomic package. + connected int32 + pinnedCursorConnections uint64 + pinnedTransactionConnections uint64 + + nextid uint64 + opened map[uint64]*connection // opened holds all of the currently open connections. + sem *semaphore.Weighted sync.Mutex } @@ -313,6 +317,24 @@ func (p *pool) makeNewConnection() (*connection, string, error) { } +func (p *pool) pinConnectionToCursor() { + atomic.AddUint64(&p.pinnedCursorConnections, 1) +} + +func (p *pool) unpinConnectionFromCursor() { + // See https://golang.org/pkg/sync/atomic/#AddUint64 for an explanation of the ^uint64(0) syntax. + atomic.AddUint64(&p.pinnedCursorConnections, ^uint64(0)) +} + +func (p *pool) pinConnectionToTransaction() { + atomic.AddUint64(&p.pinnedTransactionConnections, 1) +} + +func (p *pool) unpinConnectionFromTransaction() { + // See https://golang.org/pkg/sync/atomic/#AddUint64 for an explanation of the ^uint64(0) syntax. + atomic.AddUint64(&p.pinnedTransactionConnections, ^uint64(0)) +} + // Checkout returns a connection from the pool func (p *pool) get(ctx context.Context) (*connection, error) { if ctx == nil { @@ -340,7 +362,10 @@ func (p *pool) get(ctx context.Context) (*connection, error) { }) } errWaitQueueTimeout := WaitQueueTimeoutError{ - Wrapped: ctx.Err(), + Wrapped: ctx.Err(), + PinnedCursorConnections: atomic.LoadUint64(&p.pinnedCursorConnections), + PinnedTransactionConnections: atomic.LoadUint64(&p.pinnedTransactionConnections), + maxPoolSize: p.conns.maxSize, } return nil, errWaitQueueTimeout }