diff --git a/mongo/client.go b/mongo/client.go index 769549e41c..b2b45c1cf3 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -792,7 +792,10 @@ func (c *Client) Watch(ctx context.Context, pipeline interface{}, // NumberSessionsInProgress returns the number of sessions that have been started for this client but have not been // closed (i.e. EndSession has not been called). func (c *Client) NumberSessionsInProgress() int { - return c.sessionPool.CheckedOut() + // The underlying session pool uses an int64 for checkedOut to allow atomic + // access. We convert to an int here to maintain backward compatibility with + // older versions of the driver that did not atomically access checkedOut. + return int(c.sessionPool.CheckedOut()) } // Timeout returns the timeout set for this client. diff --git a/mongo/integration/sessions_test.go b/mongo/integration/sessions_test.go index e6c483d809..5f80537f2e 100644 --- a/mongo/integration/sessions_test.go +++ b/mongo/integration/sessions_test.go @@ -11,6 +11,7 @@ import ( "context" "fmt" "reflect" + "sync" "testing" "time" @@ -475,6 +476,35 @@ func TestSessions(t *testing.T) { assert.True(mt, limitedSessionUse, limitedSessMsg, len(ops)) }) + + // Regression test for GODRIVER-2533. Note that this test assumes the race + // detector is enabled (GODRIVER-2072). + mt.Run("NumberSessionsInProgress data race", func(mt *mtest.T) { + // Use two goroutines to execute a few simultaneous runs of NumberSessionsInProgress + // and a basic collection operation (CountDocuments). + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + + for i := 0; i < 100; i++ { + time.Sleep(100 * time.Microsecond) + _ = mt.Client.NumberSessionsInProgress() + } + }() + go func() { + defer wg.Done() + + for i := 0; i < 100; i++ { + time.Sleep(100 * time.Microsecond) + _, err := mt.Coll.CountDocuments(context.Background(), bson.D{}) + assert.Nil(mt, err, "CountDocument error: %v", err) + } + }() + + wg.Wait() + }) } func assertCollectionCount(mt *mtest.T, expectedCount int64) { diff --git a/x/mongo/driver/session/session_pool.go b/x/mongo/driver/session/session_pool.go index 27db7c476f..34b863c111 100644 --- a/x/mongo/driver/session/session_pool.go +++ b/x/mongo/driver/session/session_pool.go @@ -8,6 +8,7 @@ package session import ( "sync" + "sync/atomic" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" @@ -29,13 +30,14 @@ type topologyDescription struct { // Pool is a pool of server sessions that can be reused. type Pool struct { + // number of sessions checked out of pool (accessed atomically) + checkedOut int64 + descChan <-chan description.Topology head *Node tail *Node latestTopology topologyDescription mutex sync.Mutex // mutex to protect list and sessionTimeout - - checkedOut int // number of sessions checked out of pool } func (p *Pool) createServerSession() (*Server, error) { @@ -44,7 +46,7 @@ func (p *Pool) createServerSession() (*Server, error) { return nil, err } - p.checkedOut++ + atomic.AddInt64(&p.checkedOut, 1) return s, nil } @@ -100,7 +102,7 @@ func (p *Pool) GetSession() (*Server, error) { p.head = p.head.next } - p.checkedOut++ + atomic.AddInt64(&p.checkedOut, 1) return session, nil } @@ -118,7 +120,7 @@ func (p *Pool) ReturnSession(ss *Server) { p.mutex.Lock() defer p.mutex.Unlock() - p.checkedOut-- + atomic.AddInt64(&p.checkedOut, -1) p.updateTimeout() // check sessions at end of queue for expired // stop checking after hitting the first valid session @@ -185,6 +187,6 @@ func (p *Pool) String() string { } // CheckedOut returns number of sessions checked out from pool. -func (p *Pool) CheckedOut() int { - return p.checkedOut +func (p *Pool) CheckedOut() int64 { + return atomic.LoadInt64(&p.checkedOut) }