From 7735f66548aaa38dd9789051de695fc7ec45fa74 Mon Sep 17 00:00:00 2001 From: Divjot Arora Date: Wed, 24 Mar 2021 15:59:33 -0400 Subject: [PATCH] GODRIVER-1898 SDAM error handling changes for LB mode (#611) --- event/monitoring.go | 3 + x/mongo/driver/topology/CMAP_spec_test.go | 2 +- x/mongo/driver/topology/connection.go | 35 ++++- x/mongo/driver/topology/connection_options.go | 16 ++- x/mongo/driver/topology/connection_test.go | 3 +- x/mongo/driver/topology/pool.go | 46 +++--- .../topology/pool_generation_counter.go | 133 ++++++++++++++++++ x/mongo/driver/topology/pool_test.go | 16 ++- x/mongo/driver/topology/sdam_spec_test.go | 7 +- x/mongo/driver/topology/server.go | 25 ++-- x/mongo/driver/topology/server_test.go | 132 ++++++++++++++++- 11 files changed, 371 insertions(+), 47 deletions(-) create mode 100644 x/mongo/driver/topology/pool_generation_counter.go diff --git a/event/monitoring.go b/event/monitoring.go index a8a3f94391..acebc0f950 100644 --- a/event/monitoring.go +++ b/event/monitoring.go @@ -89,6 +89,9 @@ type PoolEvent struct { ConnectionID uint64 `json:"connectionId"` PoolOptions *MonitorPoolOptions `json:"options"` Reason string `json:"reason"` + // ServerID is only set if the Type is PoolCleared and the server is deployed behind a load balancer. This field + // can be used to distinguish between individual servers in a load balanced deployment. + ServerID *primitive.ObjectID `json:"serverId"` } // PoolMonitor is a function that allows the user to gain access to events occurring in the pool diff --git a/x/mongo/driver/topology/CMAP_spec_test.go b/x/mongo/driver/topology/CMAP_spec_test.go index 59045dcaaa..36a7de82b3 100644 --- a/x/mongo/driver/topology/CMAP_spec_test.go +++ b/x/mongo/driver/topology/CMAP_spec_test.go @@ -419,7 +419,7 @@ func runOperationInThread(t *testing.T, operation map[string]interface{}, testIn } return c.Close() case "clear": - s.pool.clear() + s.pool.clear(nil) case "close": return s.pool.disconnect(context.Background()) default: diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index d1da3f570e..f4348bf541 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -91,6 +91,11 @@ func newConnection(addr address.Address, opts ...ConnectionOption) (*connection, cancellationListener: internal.NewCancellationListener(), poolMonitor: cfg.poolMonitor, } + // Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered + // at any point during connection establishment can be processed without the connection being considered stale. + if !c.config.loadBalanced { + c.setGenerationNumber() + } atomic.StoreInt32(&c.connected, initialized) return c, nil @@ -104,8 +109,29 @@ func (c *connection) processInitializationError(err error) { c.connectErr = ConnectionError{Wrapped: err, init: true} if c.config.errorHandlingCallback != nil { - c.config.errorHandlingCallback(c.connectErr, c.generation) + c.config.errorHandlingCallback(c.connectErr, c.generation, c.desc.ServerID) + } +} + +// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection +// configuration. +func (c *connection) setGenerationNumber() { + if c.config.getGenerationFn != nil { + c.generation = c.config.getGenerationFn(c.desc.ServerID) + } +} + +// hasGenerationNumber returns true if the connection has set its generation number. If so, this indicates that the +// generationNumberFn provided via the connection options has been called exactly once. +func (c *connection) hasGenerationNumber() bool { + if !c.config.loadBalanced { + // The generation is known for all non-LB clusters once the connection object has been created. + return true } + + // For LB clusters, we set the generation after the initial handshake, so we know it's set if the connection + // description has been updated to reflect that it's behind an LB. + return c.desc.LoadBalanced() } // connect handles the I/O for a connection. It will dial, configure TLS, and perform @@ -212,6 +238,13 @@ func (c *connection) connect(ctx context.Context) { } } if err == nil { + // For load-balanced connections, the generation number depends on the server ID, which isn't known until the + // initial MongoDB handshake is done. To account for this, we don't attempt to set the connection's generation + // number unless GetHandshakeInformation succeeds. + if c.config.loadBalanced { + c.setGenerationNumber() + } + // If we successfully finished the first part of the handshake and verified LB state, continue with the rest of // the handshake. err = handshaker.FinishHandshake(handshakeCtx, handshakeConn) diff --git a/x/mongo/driver/topology/connection_options.go b/x/mongo/driver/topology/connection_options.go index 0895e56508..34aee5424b 100644 --- a/x/mongo/driver/topology/connection_options.go +++ b/x/mongo/driver/topology/connection_options.go @@ -6,6 +6,7 @@ import ( "net" "time" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/ocsp" @@ -35,6 +36,9 @@ var DefaultDialer Dialer = &net.Dialer{} // initialization. Implementations must be goroutine safe. type Handshaker = driver.Handshaker +// generationNumberFn is a callback type used by a connection to fetch its generation number given its server ID. +type generationNumberFn func(serverID *primitive.ObjectID) uint64 + type connectionConfig struct { appName string connectTimeout time.Duration @@ -51,9 +55,10 @@ type connectionConfig struct { zstdLevel *int ocspCache ocsp.Cache disableOCSPEndpointCheck bool - errorHandlingCallback func(error, uint64) + errorHandlingCallback func(error, uint64, *primitive.ObjectID) tlsConnectionSource tlsConnectionSource loadBalanced bool + getGenerationFn generationNumberFn } func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) { @@ -87,7 +92,7 @@ func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) C } } -func withErrorHandlingCallback(fn func(error, uint64)) ConnectionOption { +func withErrorHandlingCallback(fn func(error, uint64, *primitive.ObjectID)) ConnectionOption { return func(c *connectionConfig) error { c.errorHandlingCallback = fn return nil @@ -217,3 +222,10 @@ func WithConnectionLoadBalanced(fn func(bool) bool) ConnectionOption { return nil } } + +func withGenerationNumberFn(fn func(generationNumberFn) generationNumberFn) ConnectionOption { + return func(c *connectionConfig) error { + c.getGenerationFn = fn(c.getGenerationFn) + return nil + } +} diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 3684bd748b..c1a0b55a5d 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/internal" "go.mongodb.org/mongo-driver/internal/testutil/assert" "go.mongodb.org/mongo-driver/mongo/address" @@ -128,7 +129,7 @@ func TestConnection(t *testing.T) { return &net.TCPConn{}, nil }) }), - withErrorHandlingCallback(func(err error, _ uint64) { + withErrorHandlingCallback(func(err error, _ uint64, _ *primitive.ObjectID) { got = err }), ) diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 7ce544f46c..c2abb7046f 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -13,6 +13,7 @@ import ( "sync/atomic" "time" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/mongo/address" "golang.org/x/sync/semaphore" @@ -60,7 +61,7 @@ type pool struct { address address.Address opts []ConnectionOption conns *resourcePool // pool for non-checked out connections - generation uint64 // must be accessed using atomic package + generation *poolGenerationMap monitor *event.PoolMonitor connected int32 // Must be accessed using the sync/atomic package. @@ -148,13 +149,15 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) { } pool := &pool{ - address: config.Address, - monitor: config.PoolMonitor, - connected: disconnected, - opened: make(map[uint64]*connection), - opts: opts, - sem: semaphore.NewWeighted(int64(maxConns)), + address: config.Address, + monitor: config.PoolMonitor, + connected: disconnected, + opened: make(map[uint64]*connection), + opts: opts, + sem: semaphore.NewWeighted(int64(maxConns)), + generation: newPoolGenerationMap(), } + pool.opts = append(pool.opts, withGenerationNumberFn(func(_ generationNumberFn) generationNumberFn { return pool.getGenerationForNewConnection })) // we do not pass in config.MaxPoolSize because we manage the max size at this level rather than the resource pool level rpc := resourcePoolConfig{ @@ -189,7 +192,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) { // stale checks if a given connection's generation is below the generation of the pool func (p *pool) stale(c *connection) bool { - return c == nil || c.generation < atomic.LoadUint64(&p.generation) + return c == nil || p.generation.stale(c.desc.ServerID, c.generation) } // connect puts the pool into the connected state, allowing it to be used and will allow items to begin being processed from the wait queue @@ -197,6 +200,7 @@ func (p *pool) connect() error { if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) { return ErrPoolConnected } + p.generation.connect() p.conns.initialize() return nil } @@ -212,7 +216,7 @@ func (p *pool) disconnect(ctx context.Context) error { } p.conns.Close() - atomic.AddUint64(&p.generation, 1) + p.generation.disconnect() var err error if dl, ok := ctx.Deadline(); ok { @@ -277,7 +281,6 @@ func (p *pool) makeNewConnection() (*connection, string, error) { c.pool = p c.poolID = atomic.AddUint64(&p.nextid, 1) - c.generation = atomic.LoadUint64(&p.generation) if p.monitor != nil { p.monitor.Event(&event.PoolEvent{ @@ -310,10 +313,6 @@ func (p *pool) makeNewConnection() (*connection, string, error) { } -func (p *pool) getGeneration() uint64 { - return atomic.LoadUint64(&p.generation) -} - // Checkout returns a connection from the pool func (p *pool) get(ctx context.Context) (*connection, error) { if ctx == nil { @@ -487,6 +486,10 @@ func (p *pool) closeConnection(c *connection) error { return nil } +func (p *pool) getGenerationForNewConnection(serverID *primitive.ObjectID) uint64 { + return p.generation.addConnection(serverID) +} + // removeConnection removes a connection from the pool. func (p *pool) removeConnection(c *connection, reason string) error { if c.pool != p { @@ -501,6 +504,12 @@ func (p *pool) removeConnection(c *connection, reason string) error { } p.Unlock() + // Only update the generation numbers map if the connection has retrieved its generation number. Otherwise, we'd + // decrement the count for the generation even though it had never been incremented. + if c.hasGenerationNumber() { + p.generation.removeConnection(c.desc.ServerID) + } + if publishEvent && p.monitor != nil { c.pool.monitor.Event(&event.PoolEvent{ Type: event.ConnectionClosed, @@ -545,12 +554,13 @@ func (p *pool) put(c *connection) error { } // clear clears the pool by incrementing the generation -func (p *pool) clear() { +func (p *pool) clear(serverID *primitive.ObjectID) { if p.monitor != nil { p.monitor.Event(&event.PoolEvent{ - Type: event.PoolCleared, - Address: p.address.String(), + Type: event.PoolCleared, + Address: p.address.String(), + ServerID: serverID, }) } - atomic.AddUint64(&p.generation, 1) + p.generation.clear(serverID) } diff --git a/x/mongo/driver/topology/pool_generation_counter.go b/x/mongo/driver/topology/pool_generation_counter.go new file mode 100644 index 0000000000..a141cd31a8 --- /dev/null +++ b/x/mongo/driver/topology/pool_generation_counter.go @@ -0,0 +1,133 @@ +// Copyright (C) MongoDB, Inc. 2017-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package topology + +import ( + "sync" + "sync/atomic" + + "go.mongodb.org/mongo-driver/bson/primitive" +) + +// generationStats represents the version of a pool. It tracks the generation number as well as the number of +// connections that have been created in the generation. +type generationStats struct { + generation uint64 + numConns uint64 +} + +// poolGenerationMap tracks the version for each server ID present in a pool. For deployments that are not behind a load +// balancer, there is only one server ID: primitive.NilObjectID. For load-balanced deployments, each server behind the +// load balancer will have a unique server ID. +type poolGenerationMap struct { + // state must be accessed using the atomic package. + state int32 + generationMap map[primitive.ObjectID]*generationStats + + sync.Mutex +} + +func newPoolGenerationMap() *poolGenerationMap { + pgm := &poolGenerationMap{ + generationMap: make(map[primitive.ObjectID]*generationStats), + } + pgm.generationMap[primitive.NilObjectID] = &generationStats{} + return pgm +} + +func (p *poolGenerationMap) connect() { + atomic.StoreInt32(&p.state, connected) +} + +func (p *poolGenerationMap) disconnect() { + atomic.StoreInt32(&p.state, disconnected) +} + +// addConnection increments the connection count for the generation associated with the given server ID and returns the +// generation number for the connection. +func (p *poolGenerationMap) addConnection(serverIDPtr *primitive.ObjectID) uint64 { + serverID := getServerID(serverIDPtr) + p.Lock() + defer p.Unlock() + + stats, ok := p.generationMap[serverID] + if ok { + // If the serverID is already being tracked, we only need to increment the connection count. + stats.numConns++ + return stats.generation + } + + // If the serverID is untracked, create a new entry with a starting generation number of 0. + stats = &generationStats{ + numConns: 1, + } + p.generationMap[serverID] = stats + return 0 +} + +func (p *poolGenerationMap) removeConnection(serverIDPtr *primitive.ObjectID) { + serverID := getServerID(serverIDPtr) + p.Lock() + defer p.Unlock() + + stats, ok := p.generationMap[serverID] + if !ok { + return + } + + // If the serverID is being tracked, decrement the connection count and delete this serverID to prevent the map + // from growing unboundedly. This case would happen if a server behind a load-balancer was permanently removed + // and its connections were pruned after a network error or idle timeout. + stats.numConns-- + if stats.numConns == 0 { + delete(p.generationMap, serverID) + } +} + +func (p *poolGenerationMap) clear(serverIDPtr *primitive.ObjectID) { + serverID := getServerID(serverIDPtr) + p.Lock() + defer p.Unlock() + + if stats, ok := p.generationMap[serverID]; ok { + stats.generation++ + } +} + +func (p *poolGenerationMap) stale(serverIDPtr *primitive.ObjectID, knownGeneration uint64) bool { + // If the map has been disconnected, all connections should be considered stale to ensure that they're closed. + if atomic.LoadInt32(&p.state) == disconnected { + return true + } + + serverID := getServerID(serverIDPtr) + p.Lock() + defer p.Unlock() + + if stats, ok := p.generationMap[serverID]; ok { + return knownGeneration < stats.generation + } + return false +} + +func (p *poolGenerationMap) getGeneration(serverIDPtr *primitive.ObjectID) uint64 { + serverID := getServerID(serverIDPtr) + p.Lock() + defer p.Unlock() + + if stats, ok := p.generationMap[serverID]; ok { + return stats.generation + } + return 0 +} + +func getServerID(oid *primitive.ObjectID) primitive.ObjectID { + if oid == nil { + return primitive.NilObjectID + } + return *oid +} diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index 1e45100022..06911dbedb 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -250,6 +250,13 @@ func TestPool(t *testing.T) { }) t.Run("connect", func(t *testing.T) { t.Run("can reconnect a disconnected pool", func(t *testing.T) { + assertGenerationMapState := func(t *testing.T, p *pool, expectedState int32) { + t.Helper() + + actualState := atomic.LoadInt32(&p.generation.state) + assert.Equal(t, expectedState, actualState, "expected generation map state %d, got %d", expectedState, actualState) + } + cleanup := make(chan struct{}) addr := bootstrapConnections(t, 3, func(nc net.Conn) { <-cleanup @@ -263,6 +270,7 @@ func TestPool(t *testing.T) { noerr(t, err) err = p.connect() noerr(t, err) + assertGenerationMapState(t, p, connected) c, err := p.get(context.Background()) noerr(t, err) gen := c.generation @@ -281,6 +289,7 @@ func TestPool(t *testing.T) { defer cancel() err = p.disconnect(ctx) noerr(t, err) + assertGenerationMapState(t, p, disconnected) assertConnectionsClosed(t, d, 1) if p.conns.totalSize != 0 { @@ -293,13 +302,10 @@ func TestPool(t *testing.T) { } err = p.connect() noerr(t, err) + assertGenerationMapState(t, p, connected) c, err = p.get(context.Background()) noerr(t, err) - gen = atomic.LoadUint64(&c.generation) - if gen != 1 { - t.Errorf("Connection should have a newer generation. got %d; want %d", gen, 1) - } err = p.put(c) noerr(t, err) if d.lenopened() != 2 { @@ -687,7 +693,7 @@ func TestPool(t *testing.T) { // Increment the pool's generation number so the connection will be considered stale and will be closed by // get(). - p.clear() + p.clear(nil) _, err = p.get(context.Background()) noerr(t, err) }) diff --git a/x/mongo/driver/topology/sdam_spec_test.go b/x/mongo/driver/topology/sdam_spec_test.go index b74fc54f41..729d84d749 100644 --- a/x/mongo/driver/topology/sdam_spec_test.go +++ b/x/mongo/driver/topology/sdam_spec_test.go @@ -13,7 +13,6 @@ import ( "net" "path" "sync" - "sync/atomic" "testing" "time" @@ -465,7 +464,7 @@ func applyErrors(t *testing.T, topo *Topology, errors []applicationError) { versionRange := description.NewVersionRange(0, *appErr.MaxWireVersion) desc.WireVersion = &versionRange - generation := atomic.LoadUint64(&server.pool.generation) + generation := server.pool.generation.getGeneration(nil) if appErr.Generation != nil { generation = uint64(*appErr.Generation) } @@ -479,7 +478,7 @@ func applyErrors(t *testing.T, topo *Topology, errors []applicationError) { switch appErr.When { case "beforeHandshakeCompletes": - server.ProcessHandshakeError(currError, generation) + server.ProcessHandshakeError(currError, generation, nil) case "afterHandshakeCompletes": _ = server.ProcessError(currError, &conn) default: @@ -693,7 +692,7 @@ func runTest(t *testing.T, directory string, filename string) { topo.serversLock.Lock() actualServer := topo.servers[address.Address(addr)] topo.serversLock.Unlock() - actualGeneration := atomic.LoadUint64(&actualServer.pool.generation) + actualGeneration := actualServer.pool.generation.getGeneration(nil) assert.Equal(t, server.Pool.Generation, actualGeneration, "expected server pool generation to be %v, got %v", server.Pool.Generation, actualGeneration) } diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index 1d8f981dba..947d9ffedf 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -273,9 +273,15 @@ func (s *Server) Connection(ctx context.Context) (driver.Connection, error) { } // ProcessHandshakeError implements SDAM error handling for errors that occur before a connection finishes handshaking. -func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint64) { - // ignore nil or stale error - if err == nil || startingGenerationNumber < atomic.LoadUint64(&s.pool.generation) { +func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint64, serverID *primitive.ObjectID) { + // Ignore the error if the server is behind a load balancer but the server ID is unknown. This indicates that the + // error happened when dialing the connection or during the MongoDB handshake, so we don't know the server ID to use + // for clearing the pool. + if err == nil || s.cfg.loadBalanced && serverID == nil { + return + } + // Ignore the error if the connection is stale. + if startingGenerationNumber < s.pool.generation.getGeneration(serverID) { return } @@ -288,7 +294,7 @@ func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint6 // the description.Server appropriately. The description should not have a TopologyVersion because the staleness // checking logic above has already determined that this description is not stale. s.updateDescription(description.NewServerFromError(s.address, wrappedConnErr, nil)) - s.pool.clear() + s.pool.clear(serverID) s.cancelCheck() } @@ -392,7 +398,7 @@ func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessE // If the node is shutting down or is older than 4.2, we synchronously clear the pool if cerr.NodeIsShuttingDown() || desc.WireVersion == nil || desc.WireVersion.Max < 8 { res = driver.ConnectionPoolCleared - s.pool.clear() + s.pool.clear(desc.ServerID) } return res } @@ -410,7 +416,7 @@ func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessE // If the node is shutting down or is older than 4.2, we synchronously clear the pool if wcerr.NodeIsShuttingDown() || desc.WireVersion == nil || desc.WireVersion.Max < 8 { res = driver.ConnectionPoolCleared - s.pool.clear() + s.pool.clear(desc.ServerID) } return res } @@ -432,7 +438,7 @@ func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessE // monitoring check. The check is cancelled last to avoid a post-cancellation reconnect racing with // updateDescription. s.updateDescription(description.NewServerFromError(s.address, err, nil)) - s.pool.clear() + s.pool.clear(desc.ServerID) s.cancelCheck() return driver.ConnectionPoolCleared } @@ -522,8 +528,9 @@ func (s *Server) update() { s.updateDescription(desc) if desc.LastError != nil { - // Clear the pool once the description has been updated to Unknown. - s.pool.clear() + // Clear the pool once the description has been updated to Unknown. Pass in a nil server ID to clear because + // the monitoring routine only runs for non-load balanced deployments in which servers don't return IDs. + s.pool.clear(nil) } // If the server supports streaming or we're already streaming, we want to move to streaming the next response diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index 4b23a27049..392e2342a8 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -71,6 +71,7 @@ func TestServer(t *testing.T) { netErr := ConnectionError{Wrapped: &net.AddrError{}, init: true} for _, tt := range serverTestTable { t.Run(tt.name, func(t *testing.T) { + var returnConnectionError bool s, err := NewServer( address.Address("localhost"), primitive.NewObjectID(), @@ -80,7 +81,7 @@ func TestServer(t *testing.T) { return &testHandshaker{ finishHandshake: func(context.Context, driver.Connection) error { var err error - if tt.connectionError { + if tt.connectionError && returnConnectionError { err = authErr.Wrapped } return err @@ -90,7 +91,7 @@ func TestServer(t *testing.T) { WithDialer(func(Dialer) Dialer { return DialerFunc(func(context.Context, string, string) (net.Conn, error) { var err error - if tt.networkError { + if tt.networkError && returnConnectionError { err = netErr.Wrapped } return &net.TCPConn{}, err @@ -111,6 +112,13 @@ func TestServer(t *testing.T) { require.NoError(t, err, "unable to connect to pool") s.connectionstate = connected + // The internal connection pool resets the generation number once the number of connections in a generation + // reaches zero, which will cause some of these tests to fail because they assert that the generation + // number after a connection failure is 1. To workaround this, we call Connection() twice: once to + // successfully establish a connection and once to actually do the behavior described in the test case. + _, err = s.Connection(context.Background()) + assert.Nil(t, err, "error getting initial connection: %v", err) + returnConnectionError = true _, err = s.Connection(context.Background()) switch { @@ -127,12 +135,123 @@ func TestServer(t *testing.T) { require.NotNil(t, s.Description().LastError) } - if (tt.connectionError || tt.networkError) && atomic.LoadUint64(&s.pool.generation) != 1 { - t.Errorf("Expected pool to be drained once on connection or network error. got %d; want %d", s.pool.generation, 1) + generation := s.pool.generation.getGeneration(nil) + if (tt.connectionError || tt.networkError) && generation != 1 { + t.Errorf("Expected pool to be drained once on connection or network error. got %d; want %d", generation, 1) } }) } + t.Run("multiple connection initialization errors are processed correctly", func(t *testing.T) { + assertGenerationStats := func(t *testing.T, server *Server, serverID primitive.ObjectID, generation, numConns uint64) { + t.Helper() + + stats, ok := server.pool.generation.generationMap[serverID] + assert.True(t, ok, "entry for serverID not found") + assert.Equal(t, generation, stats.generation, "expected generation number %d, got %d", generation, stats.generation) + assert.Equal(t, numConns, stats.numConns, "expected connection count %d, got %d", numConns, stats.numConns) + } + + testCases := []struct { + name string + loadBalanced bool + dialErr error + getInfoErr error + finishHandshakeErr error + finalGeneration uint64 + numNewConns uint64 + }{ + // For LB clusters, errors for dialing and the initial handshake are ignored. + {"dial errors are ignored for load balancers", true, netErr.Wrapped, nil, nil, 0, 0}, + {"initial handshake errors are ignored for load balancers", true, nil, netErr.Wrapped, nil, 0, 0}, + {"post-handshake errors are not ignored for load balancers", true, nil, nil, netErr.Wrapped, 2, 0}, + + // For non-LB clusters, all errors are processed. + {"dial errors are not ignored for non-lb clusters", false, netErr.Wrapped, nil, nil, 2, 0}, + {"initial handshake errors are not ignored for non-lb clusters", false, nil, netErr.Wrapped, nil, 2, 0}, + {"post-handshake errors are not ignored for non-lb clusters", false, nil, nil, netErr.Wrapped, 2, 0}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var returnConnectionError bool + var serverID primitive.ObjectID + if tc.loadBalanced { + serverID = primitive.NewObjectID() + } + + handshaker := &testHandshaker{ + getHandshakeInformation: func(_ context.Context, addr address.Address, _ driver.Connection) (driver.HandshakeInformation, error) { + if tc.getInfoErr != nil && returnConnectionError { + return driver.HandshakeInformation{}, tc.getInfoErr + } + + desc := description.NewDefaultServer(addr) + if tc.loadBalanced { + desc.ServerID = &serverID + } + return driver.HandshakeInformation{Description: desc}, nil + }, + finishHandshake: func(context.Context, driver.Connection) error { + if tc.finishHandshakeErr != nil && returnConnectionError { + return tc.finishHandshakeErr + } + return nil + }, + } + connOpts := []ConnectionOption{ + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + var err error + if returnConnectionError && tc.dialErr != nil { + err = tc.dialErr + } + return &net.TCPConn{}, err + }) + }), + WithHandshaker(func(Handshaker) Handshaker { + return handshaker + }), + WithConnectionLoadBalanced(func(bool) bool { + return tc.loadBalanced + }), + } + serverOpts := []ServerOption{ + WithServerLoadBalanced(func(bool) bool { + return tc.loadBalanced + }), + WithConnectionOptions(func(...ConnectionOption) []ConnectionOption { + return connOpts + }), + // Disable the monitoring routine because we're only testing pooled connections and we don't want + // errors in monitoring to clear the pool and make this test flaky. + withMonitoringDisabled(func(bool) bool { + return true + }), + } + + server, err := ConnectServer(address.Address("localhost:27017"), nil, primitive.NewObjectID(), serverOpts...) + assert.Nil(t, err, "ConnectServer error: %v", err) + + _, err = server.Connection(context.Background()) + assert.Nil(t, err, "Connection error: %v", err) + assertGenerationStats(t, server, serverID, 0, 1) + + returnConnectionError = true + for i := 0; i < 2; i++ { + _, err = server.Connection(context.Background()) + switch { + case tc.dialErr != nil || tc.getInfoErr != nil || tc.finishHandshakeErr != nil: + assert.NotNil(t, err, "expected Connection error at iteration %d, got nil", i) + default: + assert.Nil(t, err, "Connection error at iteration %d: %v", i, err) + } + } + // The final number of connections should be numNewConns+1 to account for the extra one we create above. + assertGenerationStats(t, server, serverID, tc.finalGeneration, tc.numNewConns+1) + }) + } + }) + t.Run("Cannot starve connection request", func(t *testing.T) { cleanup := make(chan struct{}) addr := bootstrapConnections(t, 3, func(nc net.Conn) { @@ -322,8 +441,9 @@ func TestServer(t *testing.T) { "expected server kind %q, got %q", expectedKind, desc.Kind) assert.Equal(t, expectedError, desc.LastError, "expected last error %v, got %v", expectedError, desc.LastError) - assert.Equal(t, expectedPoolGeneration, server.pool.generation, - "expected pool generation %d, got %d", expectedPoolGeneration, server.pool.generation) + generation := server.pool.generation.getGeneration(nil) + assert.Equal(t, expectedPoolGeneration, generation, + "expected pool generation %d, got %d", expectedPoolGeneration, generation) }) } })