Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regression in timeout handling. #1373

Merged
merged 4 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 26 additions & 21 deletions driver-core/src/main/com/mongodb/internal/TimeoutContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -185,25 +185,27 @@ public long getMaxAwaitTimeMS() {
return timeoutSettings.getMaxAwaitTimeMS();
}

public void runMaxTimeMSTimeout(final Runnable onInfinite, final LongConsumer onRemaining,
final Runnable onExpired) {
public void runMaxTimeMS(final LongConsumer onRemaining) {
if (maxTimeSupplier != null) {
runWithFixedTimout(maxTimeSupplier.get(), onInfinite, onRemaining);
runWithFixedTimeout(maxTimeSupplier.get(), onRemaining);
return;
}

if (timeout != null) {
timeout.shortenBy(minRoundTripTimeMS, MILLISECONDS)
.run(MILLISECONDS, onInfinite, onRemaining, onExpired);
} else {
runWithFixedTimout(timeoutSettings.getMaxTimeMS(), onInfinite, onRemaining);
if (timeout == null) {
runWithFixedTimeout(timeoutSettings.getMaxTimeMS(), onRemaining);
return;
}
timeout.shortenBy(minRoundTripTimeMS, MILLISECONDS)
.run(MILLISECONDS,
() -> {},
onRemaining,
() -> {
throw createMongoRoundTripTimeoutException();
});

}

private static void runWithFixedTimout(final long ms, final Runnable onInfinite, final LongConsumer onRemaining) {
if (ms == 0) {
onInfinite.run();
} else {
private static void runWithFixedTimeout(final long ms, final LongConsumer onRemaining) {
if (ms != 0) {
onRemaining.accept(ms);
}
}
Expand All @@ -214,15 +216,18 @@ public void resetToDefaultMaxTime() {

/**
* The override will be provided as the remaining value in
* {@link #runMaxTimeMSTimeout}, where 0 will invoke the onExpired path
* {@link #runMaxTimeMS}, where 0 is ignored.
* <p>
* NOTE: Suitable for static user-defined values only (i.e MaxAwaitTimeMS),
* not for running timeouts that adjust dynamically.
*/
public void setMaxTimeOverride(final long maxTimeMS) {
this.maxTimeSupplier = () -> maxTimeMS;
}

/**
* The override will be provided as the remaining value in
* {@link #runMaxTimeMSTimeout}, where 0 will invoke the onExpired path
* {@link #runMaxTimeMS}, where 0 is ignored.
*/
public void setMaxTimeOverrideToMaxCommitTime() {
this.maxTimeSupplier = () -> getMaxCommitTimeMS();
Expand All @@ -242,12 +247,12 @@ public long getWriteTimeoutMS() {
return timeoutOrAlternative(0);
}

public Timeout createConnectTimeoutMs() {
// null timeout treated as infinite will be later than the other

return Timeout.earliest(
Timeout.expiresIn(getTimeoutSettings().getConnectTimeoutMS(), MILLISECONDS, ZERO_DURATION_MEANS_INFINITE),
Timeout.nullAsInfinite(timeout));
public int getConnectTimeoutMs() {
final long connectTimeoutMS = getTimeoutSettings().getConnectTimeoutMS();
return Math.toIntExact(Timeout.nullAsInfinite(timeout).call(MILLISECONDS,
() -> connectTimeoutMS,
(ms) -> connectTimeoutMS == 0 ? ms : Math.min(ms, connectTimeoutMS),
() -> throwMongoTimeoutException("The operation timeout has expired.")));
}

public void resetTimeout() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,9 @@ private List<BsonElement> getExtraElements(final OperationContext operationConte

List<BsonElement> extraElements = new ArrayList<>();
if (!getSettings().isCryptd()) {
timeoutContext.runMaxTimeMSTimeout(
() -> {},
(ms) -> extraElements.add(new BsonElement("maxTimeMS", new BsonInt64(ms))),
() -> {
throw TimeoutContext.createMongoRoundTripTimeoutException();
});
timeoutContext.runMaxTimeMS(maxTimeMS ->
extraElements.add(new BsonElement("maxTimeMS", new BsonInt64(maxTimeMS)))
);
}
extraElements.add(new BsonElement("$db", new BsonString(new MongoNamespace(getCollectionName()).getDatabaseName())));
if (sessionContext.getClusterTime() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import static com.mongodb.internal.connection.SocketStreamHelper.configureSocket;
import static com.mongodb.internal.connection.SslHelper.configureSslSocket;
import static com.mongodb.internal.thread.InterruptionUtil.translateInterruptedException;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

/**
* <p>This class is not part of the public API and may be removed or changed at any time</p>
Expand Down Expand Up @@ -122,10 +121,7 @@ private SSLSocket initializeSslSocketOverSocksProxy(final OperationContext opera
SocksSocket socksProxy = new SocksSocket(settings.getProxySettings());
configureSocket(socksProxy, operationContext, settings);
InetSocketAddress inetSocketAddress = toSocketAddress(serverHost, serverPort);
operationContext.getTimeoutContext().createConnectTimeoutMs().checkedRun(MILLISECONDS,
() -> socksProxy.connect(inetSocketAddress, 0),
(ms) -> socksProxy.connect(inetSocketAddress, Math.toIntExact(ms)),
() -> throwMongoTimeoutException("The operation timeout has expired."));
socksProxy.connect(inetSocketAddress, operationContext.getTimeoutContext().getConnectTimeoutMs());

SSLSocket sslSocket = (SSLSocket) sslSocketFactory.createSocket(socksProxy, serverHost, serverPort, true);
//Even though Socks proxy connection is already established, TLS handshake has not been performed yet.
Expand Down Expand Up @@ -153,11 +149,8 @@ private Socket initializeSocketOverSocksProxy(final OperationContext operationCo
*/
SocksSocket socksProxy = new SocksSocket(createdSocket, settings.getProxySettings());

InetSocketAddress inetSocketAddress = toSocketAddress(address.getHost(), address.getPort());
operationContext.getTimeoutContext().createConnectTimeoutMs().checkedRun(MILLISECONDS,
() -> socksProxy.connect(inetSocketAddress, 0),
(ms) -> socksProxy.connect(inetSocketAddress, Math.toIntExact(ms)),
() -> throwMongoTimeoutException("The operation timeout has expired."));
socksProxy.connect(toSocketAddress(address.getHost(), address.getPort()),
operationContext.getTimeoutContext().getConnectTimeoutMs());
return socksProxy;
}

Expand Down Expand Up @@ -185,9 +178,7 @@ public ByteBuf read(final int numBytes, final OperationContext operationContext)
byte[] bytes = buffer.array();
while (totalBytesRead < buffer.limit()) {
int readTimeoutMS = (int) operationContext.getTimeoutContext().getReadTimeoutMS();
if (readTimeoutMS > 0) {
socket.setSoTimeout(readTimeoutMS);
}
socket.setSoTimeout(readTimeoutMS);
int bytesRead = inputStream.read(bytes, totalBytesRead, buffer.limit() - totalBytesRead);
if (bytesRead == -1) {
throw new MongoSocketReadException("Prematurely reached end of stream", getAddress());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
import java.net.SocketException;
import java.net.SocketOption;

import static com.mongodb.internal.TimeoutContext.throwMongoTimeoutException;
import static com.mongodb.internal.connection.SslHelper.configureSslSocket;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

@SuppressWarnings({"unchecked", "rawtypes"})
final class SocketStreamHelper {
Expand Down Expand Up @@ -75,10 +73,7 @@ static void initialize(final OperationContext operationContext, final Socket soc
final SslSettings sslSettings) throws IOException {
configureSocket(socket, operationContext, settings);
configureSslSocket(socket, sslSettings, inetSocketAddress);
operationContext.getTimeoutContext().createConnectTimeoutMs().checkedRun(MILLISECONDS,
() -> socket.connect(inetSocketAddress, 0),
(ms) -> socket.connect(inetSocketAddress, Math.toIntExact(ms)),
() -> throwMongoTimeoutException("The operation timeout has expired."));
socket.connect(inetSocketAddress, operationContext.getTimeoutContext().getConnectTimeoutMs());
}

static void configureSocket(final Socket socket, final OperationContext operationContext, final SocketSettings settings) throws SocketException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@

import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.internal.Locks.withLock;
import static com.mongodb.internal.TimeoutContext.throwMongoTimeoutException;
import static com.mongodb.internal.connection.ServerAddressHelper.getSocketAddresses;
import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification;
import static com.mongodb.internal.connection.SslHelper.enableSni;
Expand Down Expand Up @@ -192,10 +191,8 @@ private void initializeChannel(final OperationContext operationContext, final As
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(workerGroup);
bootstrap.channel(socketChannelClass);
operationContext.getTimeoutContext().createConnectTimeoutMs().checkedRun(MILLISECONDS,
() -> bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 0),
(ms) -> bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Math.toIntExact(ms)),
() -> throwMongoTimeoutException("The operation timeout has expired."));
bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS,
operationContext.getTimeoutContext().getConnectTimeoutMs());
bootstrap.option(ChannelOption.TCP_NODELAY, true);
bootstrap.option(ChannelOption.SO_KEEPALIVE, true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@
import com.mongodb.session.ClientSession;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;

import java.util.function.Supplier;
import java.util.stream.Stream;

import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS;
import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS_WITH_INFINITE_TIMEOUT;
Expand All @@ -42,12 +46,7 @@ final class TimeoutContextTest {

public static long getMaxTimeMS(final TimeoutContext timeoutContext) {
long[] result = {0L};
timeoutContext.runMaxTimeMSTimeout(
() -> {},
(ms) -> result[0] = ms,
() -> {
throw TimeoutContext.createMongoRoundTripTimeoutException();
});
timeoutContext.runMaxTimeMS((ms) -> result[0] = ms);
return result[0];
}

Expand Down Expand Up @@ -198,16 +197,19 @@ void testThrowsWhenExpired() {

assertThrows(MongoOperationTimeoutException.class, smallTimeout::getReadTimeoutMS);
assertThrows(MongoOperationTimeoutException.class, smallTimeout::getWriteTimeoutMS);
assertThrows(MongoOperationTimeoutException.class, smallTimeout::getConnectTimeoutMs);
assertThrows(MongoOperationTimeoutException.class, () -> getMaxTimeMS(smallTimeout));
assertThrows(MongoOperationTimeoutException.class, smallTimeout::getMaxCommitTimeMS);
assertThrows(MongoOperationTimeoutException.class, () -> smallTimeout.timeoutOrAlternative(1));
assertDoesNotThrow(longTimeout::getReadTimeoutMS);
assertDoesNotThrow(longTimeout::getWriteTimeoutMS);
assertDoesNotThrow(longTimeout::getConnectTimeoutMs);
assertDoesNotThrow(() -> getMaxTimeMS(longTimeout));
assertDoesNotThrow(longTimeout::getMaxCommitTimeMS);
assertDoesNotThrow(() -> longTimeout.timeoutOrAlternative(1));
assertDoesNotThrow(noTimeout::getReadTimeoutMS);
assertDoesNotThrow(noTimeout::getWriteTimeoutMS);
assertDoesNotThrow(noTimeout::getConnectTimeoutMs);
assertDoesNotThrow(() -> getMaxTimeMS(noTimeout));
assertDoesNotThrow(noTimeout::getMaxCommitTimeMS);
assertDoesNotThrow(() -> noTimeout.timeoutOrAlternative(1));
Expand Down Expand Up @@ -284,6 +286,61 @@ void shouldResetMaximeMS() {
assertTrue(getMaxTimeMS(timeoutContext) > 1);
}

static Stream<Arguments> shouldChooseConnectTimeoutWhenItIsLessThenTimeoutMs() {
return Stream.of(
//connectTimeoutMS, timeoutMS, expected
Arguments.of(500L, 1000L, 500L),
Arguments.of(0L, null, 0L),
Arguments.of(1000L, null, 1000L),
Arguments.of(1000L, 0L, 1000L),
Arguments.of(0L, 0L, 0L)
);
}

@ParameterizedTest
@MethodSource
@DisplayName("should choose connectTimeoutMS when connectTimeoutMS is less than timeoutMS")
void shouldChooseConnectTimeoutWhenItIsLessThenTimeoutMs(final Long connectTimeoutMS,
final Long timeoutMS,
final long expected) {
TimeoutContext timeoutContext = new TimeoutContext(
new TimeoutSettings(0,
connectTimeoutMS,
0,
timeoutMS,
0));

long calculatedTimeoutMS = timeoutContext.getConnectTimeoutMs();
assertEquals(expected, calculatedTimeoutMS);
}


static Stream<Arguments> shouldChooseTimeoutMsWhenItIsLessThenConnectTimeoutMS() {
return Stream.of(
//connectTimeoutMS, timeoutMS, expected
Arguments.of(1000L, 1000L, 999),
Arguments.of(1000L, 500L, 499L),
Arguments.of(0L, 1000L, 999L)
);
}

@ParameterizedTest
@MethodSource
@DisplayName("should choose timeoutMS when timeoutMS is less than connectTimeoutMS")
void shouldChooseTimeoutMsWhenItIsLessThenConnectTimeoutMS(final Long connectTimeoutMS,
final Long timeoutMS,
final long expected) {
TimeoutContext timeoutContext = new TimeoutContext(
new TimeoutSettings(0,
connectTimeoutMS,
0,
timeoutMS,
0));

long calculatedTimeoutMS = timeoutContext.getConnectTimeoutMs();
assertTrue(expected - calculatedTimeoutMS <= 1);
}

private TimeoutContextTest() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void encodeShouldThrowTimeoutExceptionWhenTimeoutContextIsCalled() {
BasicOutputBuffer bsonOutput = new BasicOutputBuffer();
SessionContext sessionContext = mock(SessionContext.class);
TimeoutContext timeoutContext = mock(TimeoutContext.class, mock -> {
doThrow(new MongoOperationTimeoutException("test")).when(mock).runMaxTimeMSTimeout(any(), any(), any());
doThrow(new MongoOperationTimeoutException("test")).when(mock).runMaxTimeMS(any());
});
OperationContext operationContext = mock(OperationContext.class, mock -> {
when(mock.getSessionContext()).thenReturn(sessionContext);
Expand Down