Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop treating context errors as network errors where possible. #1045

Merged
merged 4 commits into from Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 3 additions & 8 deletions mongo/session.go
Expand Up @@ -200,10 +200,6 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi
default:
}

// End if context has timed out or been canceled, as retrying has no chance of success.
if ctx.Err() != nil {
return res, err
}
if errorHasLabel(err, driver.TransientTransactionError) {
continue
}
Expand All @@ -218,10 +214,9 @@ func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(sessCtx Sessi
CommitLoop:
for {
err = s.CommitTransaction(ctx)
// End when error is nil (transaction has been committed), or when context has timed out or been
// canceled, as retrying has no chance of success.
if err == nil || ctx.Err() != nil {
return res, err
// End when error is nil, as transaction has been committed.
if err == nil {
return res, nil
}

select {
Expand Down
12 changes: 7 additions & 5 deletions x/mongo/driver/operation.go
Expand Up @@ -531,15 +531,17 @@ func (op Operation) Execute(ctx context.Context, scratch []byte) error {
serviceID: startedInfo.serviceID,
}

// Check if there's enough time to perform a round trip before the Context deadline. If ctx is
// a Timeout Context, use the 90th percentile RTT as a threshold. Otherwise, use the minimum observed
// RTT.
if deadline, ok := ctx.Deadline(); ok {
// Check for possible context error. If no context error, check if there's enough time to perform a
// round trip before the Context deadline. If ctx is a Timeout Context, use the 90th percentile RTT
// as a threshold. Otherwise, use the minimum observed RTT.
if ctx.Err() != nil {
err = ctx.Err()
} else if deadline, ok := ctx.Deadline(); ok {
if internal.IsTimeoutContext(ctx) && time.Now().Add(srvr.RTTMonitor().P90()).After(deadline) {
err = internal.WrapErrorf(ErrDeadlineWouldBeExceeded,
"remaining time %v until context deadline is less than 90th percentile RTT\n%v", time.Until(deadline), srvr.RTTMonitor().Stats())
} else if time.Now().Add(srvr.RTTMonitor().Min()).After(deadline) {
err = op.networkError(context.DeadlineExceeded)
err = context.DeadlineExceeded
}
}
benjirewis marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
13 changes: 0 additions & 13 deletions x/mongo/driver/topology/connection.go
Expand Up @@ -331,11 +331,6 @@ func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
if atomic.LoadInt64(&c.state) != connConnected {
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}
select {
case <-ctx.Done():
return ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to write"}
default:
}

var deadline time.Time
if c.writeTimeout != 0 {
Expand Down Expand Up @@ -388,14 +383,6 @@ func (c *connection) readWireMessage(ctx context.Context, dst []byte) ([]byte, e
return dst, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}

select {
case <-ctx.Done():
// We closeConnection the connection because we don't know if there is an unread message on the wire.
c.close()
return nil, ConnectionError{ConnectionID: c.id, Wrapped: ctx.Err(), message: "failed to read"}
default:
}

var deadline time.Time
if c.readTimeout != 0 {
deadline = time.Now().Add(c.readTimeout)
Expand Down
14 changes: 0 additions & 14 deletions x/mongo/driver/topology/connection_errors_test.go
Expand Up @@ -50,19 +50,5 @@ func TestConnectionErrors(t *testing.T) {
err := conn.connect(ctx)
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
})
t.Run("write error", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
err := conn.writeWireMessage(ctx, []byte{})
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
})
t.Run("read error", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
_, err := conn.readWireMessage(ctx, []byte{})
assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err)
})
})
}
20 changes: 0 additions & 20 deletions x/mongo/driver/topology/connection_test.go
Expand Up @@ -351,16 +351,6 @@ func TestConnection(t *testing.T) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
})
t.Run("completed context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to write"}
got := conn.writeWireMessage(ctx, []byte{})
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
})
t.Run("deadlines", func(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -490,16 +480,6 @@ func TestConnection(t *testing.T) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
})
t.Run("completed context", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
conn := &connection{id: "foobar", nc: &net.TCPConn{}, state: connConnected}
want := ConnectionError{ConnectionID: "foobar", Wrapped: ctx.Err(), message: "failed to read"}
_, got := conn.readWireMessage(ctx, []byte{})
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
})
t.Run("deadlines", func(t *testing.T) {
testCases := []struct {
name string
Expand Down