Skip to content

Commit

Permalink
GODRIVER-1898 SDAM error handling changes for LB mode (mongodb#611)
Browse files Browse the repository at this point in the history
  • Loading branch information
Divjot Arora authored and Mohammad Fahim Abrar committed Mar 17, 2022
1 parent 36a25fc commit 7735f66
Show file tree
Hide file tree
Showing 11 changed files with 371 additions and 47 deletions.
3 changes: 3 additions & 0 deletions event/monitoring.go
Expand Up @@ -89,6 +89,9 @@ type PoolEvent struct {
ConnectionID uint64 `json:"connectionId"`
PoolOptions *MonitorPoolOptions `json:"options"`
Reason string `json:"reason"`
// ServerID is only set if the Type is PoolCleared and the server is deployed behind a load balancer. This field
// can be used to distinguish between individual servers in a load balanced deployment.
ServerID *primitive.ObjectID `json:"serverId"`
}

// PoolMonitor is a function that allows the user to gain access to events occurring in the pool
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/CMAP_spec_test.go
Expand Up @@ -419,7 +419,7 @@ func runOperationInThread(t *testing.T, operation map[string]interface{}, testIn
}
return c.Close()
case "clear":
s.pool.clear()
s.pool.clear(nil)
case "close":
return s.pool.disconnect(context.Background())
default:
Expand Down
35 changes: 34 additions & 1 deletion x/mongo/driver/topology/connection.go
Expand Up @@ -91,6 +91,11 @@ func newConnection(addr address.Address, opts ...ConnectionOption) (*connection,
cancellationListener: internal.NewCancellationListener(),
poolMonitor: cfg.poolMonitor,
}
// Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered
// at any point during connection establishment can be processed without the connection being considered stale.
if !c.config.loadBalanced {
c.setGenerationNumber()
}
atomic.StoreInt32(&c.connected, initialized)

return c, nil
Expand All @@ -104,8 +109,29 @@ func (c *connection) processInitializationError(err error) {

c.connectErr = ConnectionError{Wrapped: err, init: true}
if c.config.errorHandlingCallback != nil {
c.config.errorHandlingCallback(c.connectErr, c.generation)
c.config.errorHandlingCallback(c.connectErr, c.generation, c.desc.ServerID)
}
}

// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
// configuration.
func (c *connection) setGenerationNumber() {
if c.config.getGenerationFn != nil {
c.generation = c.config.getGenerationFn(c.desc.ServerID)
}
}

// hasGenerationNumber returns true if the connection has set its generation number. If so, this indicates that the
// generationNumberFn provided via the connection options has been called exactly once.
func (c *connection) hasGenerationNumber() bool {
if !c.config.loadBalanced {
// The generation is known for all non-LB clusters once the connection object has been created.
return true
}

// For LB clusters, we set the generation after the initial handshake, so we know it's set if the connection
// description has been updated to reflect that it's behind an LB.
return c.desc.LoadBalanced()
}

// connect handles the I/O for a connection. It will dial, configure TLS, and perform
Expand Down Expand Up @@ -212,6 +238,13 @@ func (c *connection) connect(ctx context.Context) {
}
}
if err == nil {
// For load-balanced connections, the generation number depends on the server ID, which isn't known until the
// initial MongoDB handshake is done. To account for this, we don't attempt to set the connection's generation
// number unless GetHandshakeInformation succeeds.
if c.config.loadBalanced {
c.setGenerationNumber()
}

// If we successfully finished the first part of the handshake and verified LB state, continue with the rest of
// the handshake.
err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
Expand Down
16 changes: 14 additions & 2 deletions x/mongo/driver/topology/connection_options.go
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"time"

"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
Expand Down Expand Up @@ -35,6 +36,9 @@ var DefaultDialer Dialer = &net.Dialer{}
// initialization. Implementations must be goroutine safe.
type Handshaker = driver.Handshaker

// generationNumberFn is a callback type used by a connection to fetch its generation number given its server ID.
type generationNumberFn func(serverID *primitive.ObjectID) uint64

type connectionConfig struct {
appName string
connectTimeout time.Duration
Expand All @@ -51,9 +55,10 @@ type connectionConfig struct {
zstdLevel *int
ocspCache ocsp.Cache
disableOCSPEndpointCheck bool
errorHandlingCallback func(error, uint64)
errorHandlingCallback func(error, uint64, *primitive.ObjectID)
tlsConnectionSource tlsConnectionSource
loadBalanced bool
getGenerationFn generationNumberFn
}

func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
Expand Down Expand Up @@ -87,7 +92,7 @@ func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) C
}
}

func withErrorHandlingCallback(fn func(error, uint64)) ConnectionOption {
func withErrorHandlingCallback(fn func(error, uint64, *primitive.ObjectID)) ConnectionOption {
return func(c *connectionConfig) error {
c.errorHandlingCallback = fn
return nil
Expand Down Expand Up @@ -217,3 +222,10 @@ func WithConnectionLoadBalanced(fn func(bool) bool) ConnectionOption {
return nil
}
}

func withGenerationNumberFn(fn func(generationNumberFn) generationNumberFn) ConnectionOption {
return func(c *connectionConfig) error {
c.getGenerationFn = fn(c.getGenerationFn)
return nil
}
}
3 changes: 2 additions & 1 deletion x/mongo/driver/topology/connection_test.go
Expand Up @@ -17,6 +17,7 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/mongo/address"
Expand Down Expand Up @@ -128,7 +129,7 @@ func TestConnection(t *testing.T) {
return &net.TCPConn{}, nil
})
}),
withErrorHandlingCallback(func(err error, _ uint64) {
withErrorHandlingCallback(func(err error, _ uint64, _ *primitive.ObjectID) {
got = err
}),
)
Expand Down
46 changes: 28 additions & 18 deletions x/mongo/driver/topology/pool.go
Expand Up @@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"

"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/mongo/address"
"golang.org/x/sync/semaphore"
Expand Down Expand Up @@ -60,7 +61,7 @@ type pool struct {
address address.Address
opts []ConnectionOption
conns *resourcePool // pool for non-checked out connections
generation uint64 // must be accessed using atomic package
generation *poolGenerationMap
monitor *event.PoolMonitor

connected int32 // Must be accessed using the sync/atomic package.
Expand Down Expand Up @@ -148,13 +149,15 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {
}

pool := &pool{
address: config.Address,
monitor: config.PoolMonitor,
connected: disconnected,
opened: make(map[uint64]*connection),
opts: opts,
sem: semaphore.NewWeighted(int64(maxConns)),
address: config.Address,
monitor: config.PoolMonitor,
connected: disconnected,
opened: make(map[uint64]*connection),
opts: opts,
sem: semaphore.NewWeighted(int64(maxConns)),
generation: newPoolGenerationMap(),
}
pool.opts = append(pool.opts, withGenerationNumberFn(func(_ generationNumberFn) generationNumberFn { return pool.getGenerationForNewConnection }))

// we do not pass in config.MaxPoolSize because we manage the max size at this level rather than the resource pool level
rpc := resourcePoolConfig{
Expand Down Expand Up @@ -189,14 +192,15 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) (*pool, error) {

// stale checks if a given connection's generation is below the generation of the pool
func (p *pool) stale(c *connection) bool {
return c == nil || c.generation < atomic.LoadUint64(&p.generation)
return c == nil || p.generation.stale(c.desc.ServerID, c.generation)
}

// connect puts the pool into the connected state, allowing it to be used and will allow items to begin being processed from the wait queue
func (p *pool) connect() error {
if !atomic.CompareAndSwapInt32(&p.connected, disconnected, connected) {
return ErrPoolConnected
}
p.generation.connect()
p.conns.initialize()
return nil
}
Expand All @@ -212,7 +216,7 @@ func (p *pool) disconnect(ctx context.Context) error {
}

p.conns.Close()
atomic.AddUint64(&p.generation, 1)
p.generation.disconnect()

var err error
if dl, ok := ctx.Deadline(); ok {
Expand Down Expand Up @@ -277,7 +281,6 @@ func (p *pool) makeNewConnection() (*connection, string, error) {

c.pool = p
c.poolID = atomic.AddUint64(&p.nextid, 1)
c.generation = atomic.LoadUint64(&p.generation)

if p.monitor != nil {
p.monitor.Event(&event.PoolEvent{
Expand Down Expand Up @@ -310,10 +313,6 @@ func (p *pool) makeNewConnection() (*connection, string, error) {

}

func (p *pool) getGeneration() uint64 {
return atomic.LoadUint64(&p.generation)
}

// Checkout returns a connection from the pool
func (p *pool) get(ctx context.Context) (*connection, error) {
if ctx == nil {
Expand Down Expand Up @@ -487,6 +486,10 @@ func (p *pool) closeConnection(c *connection) error {
return nil
}

func (p *pool) getGenerationForNewConnection(serverID *primitive.ObjectID) uint64 {
return p.generation.addConnection(serverID)
}

// removeConnection removes a connection from the pool.
func (p *pool) removeConnection(c *connection, reason string) error {
if c.pool != p {
Expand All @@ -501,6 +504,12 @@ func (p *pool) removeConnection(c *connection, reason string) error {
}
p.Unlock()

// Only update the generation numbers map if the connection has retrieved its generation number. Otherwise, we'd
// decrement the count for the generation even though it had never been incremented.
if c.hasGenerationNumber() {
p.generation.removeConnection(c.desc.ServerID)
}

if publishEvent && p.monitor != nil {
c.pool.monitor.Event(&event.PoolEvent{
Type: event.ConnectionClosed,
Expand Down Expand Up @@ -545,12 +554,13 @@ func (p *pool) put(c *connection) error {
}

// clear clears the pool by incrementing the generation
func (p *pool) clear() {
func (p *pool) clear(serverID *primitive.ObjectID) {
if p.monitor != nil {
p.monitor.Event(&event.PoolEvent{
Type: event.PoolCleared,
Address: p.address.String(),
Type: event.PoolCleared,
Address: p.address.String(),
ServerID: serverID,
})
}
atomic.AddUint64(&p.generation, 1)
p.generation.clear(serverID)
}

0 comments on commit 7735f66

Please sign in to comment.