diff --git a/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java b/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java index af302afde06..8a76b8eb732 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/ChangeStreamBatchCursor.java @@ -64,7 +64,11 @@ public boolean hasNext() { return resumeableOperation(new Function, Boolean>() { @Override public Boolean apply(final AggregateResponseBatchCursor queryBatchCursor) { - return queryBatchCursor.hasNext(); + try { + return queryBatchCursor.hasNext(); + } finally { + cachePostBatchResumeToken(queryBatchCursor); + } } }); } @@ -74,9 +78,11 @@ public List next() { return resumeableOperation(new Function, List>() { @Override public List apply(final AggregateResponseBatchCursor queryBatchCursor) { - List results = convertResults(queryBatchCursor.next()); - cachePostBatchResumeToken(queryBatchCursor); - return results; + try { + return convertResults(queryBatchCursor.next()); + } finally { + cachePostBatchResumeToken(queryBatchCursor); + } } }); } @@ -86,9 +92,11 @@ public List tryNext() { return resumeableOperation(new Function, List>() { @Override public List apply(final AggregateResponseBatchCursor queryBatchCursor) { - List results = convertResults(queryBatchCursor.tryNext()); - cachePostBatchResumeToken(queryBatchCursor); - return results; + try { + return convertResults(queryBatchCursor.tryNext()); + } finally { + cachePostBatchResumeToken(queryBatchCursor); + } } }); } diff --git a/driver-core/src/main/com/mongodb/internal/operation/QueryBatchCursor.java b/driver-core/src/main/com/mongodb/internal/operation/QueryBatchCursor.java index 451364b70c2..23379e4d579 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/QueryBatchCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/QueryBatchCursor.java @@ -100,6 +100,7 @@ class QueryBatchCursor implements AggregateResponseBatchCursor { this.decoder = notNull("decoder", decoder); if (result != null) { this.operationTime = result.getTimestamp(OPERATION_TIME, null); + this.postBatchResumeToken = getPostBatchResumeTokenFromResponse(result); } if (firstQueryResult.getCursor() != null) { notNull("connectionSource", connectionSource); diff --git a/driver-core/src/test/functional/com/mongodb/OperationFunctionalSpecification.groovy b/driver-core/src/test/functional/com/mongodb/OperationFunctionalSpecification.groovy index e1fd5373b60..a79f81b0ff7 100644 --- a/driver-core/src/test/functional/com/mongodb/OperationFunctionalSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/OperationFunctionalSpecification.groovy @@ -157,21 +157,32 @@ class OperationFunctionalSpecification extends Specification { } def next(cursor, boolean async, int minimumCount) { + next(cursor, async, false, minimumCount) + } + + def next(cursor, boolean async, boolean callHasNextBeforeNext, int minimumCount) { List retVal = [] while (retVal.size() < minimumCount) { - retVal.addAll(next(cursor, async)) + retVal.addAll(doNext(cursor, async, callHasNextBeforeNext)) } retVal } def next(cursor, boolean async) { + doNext(cursor, async, false) + } + + def doNext(cursor, boolean async, boolean callHasNextBeforeNext) { if (async) { def futureResultCallback = new FutureResultCallback>() cursor.next(futureResultCallback) futureResultCallback.get(TIMEOUT, TimeUnit.SECONDS) } else { + if (callHasNextBeforeNext) { + cursor.hasNext() + } cursor.next() } } diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/ChangeStreamOperationProseTestSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/operation/ChangeStreamOperationProseTestSpecification.groovy index dac2d6633a9..778d3b4e46c 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/ChangeStreamOperationProseTestSpecification.groovy +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/ChangeStreamOperationProseTestSpecification.groovy @@ -96,7 +96,7 @@ class ChangeStreamOperationProseTestSpecification extends OperationFunctionalSpe setFailPoint(failPointDocument) then: - def result = next(cursor, async, 2) + def result = next(cursor, async, callHasNext, 2) then: result.size() == 2 @@ -107,7 +107,10 @@ class ChangeStreamOperationProseTestSpecification extends OperationFunctionalSpe waitForLastRelease(async ? getAsyncCluster() : getCluster()) where: - async << [true, false] + async | callHasNext + true | false + false | false + false | true } // diff --git a/driver-sync/src/test/functional/com/mongodb/client/ChangeStreamProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/ChangeStreamProseTest.java index c372ba556e5..dca1cdf019b 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/ChangeStreamProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/ChangeStreamProseTest.java @@ -370,8 +370,10 @@ public void testGetResumeTokenReturnsPostBatchResumeTokenAfterGetMore() // use reflection to access the postBatchResumeToken AggregateResponseBatchCursor batchCursor = getBatchCursor(cursor); - // check equality in the case where the batch has not been iterated at all - assertEquals(cursor.getResumeToken(), batchCursor.getPostBatchResumeToken()); + assertNotNull(batchCursor.getPostBatchResumeToken()); + + // resume token should be null before iteration + assertNull(cursor.getResumeToken()); cursor.next(); assertEquals(cursor.getResumeToken(), batchCursor.getPostBatchResumeToken()); diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestValidator.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestValidator.java index 3ecd6ef79e9..c01d18c8af0 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestValidator.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTestValidator.java @@ -33,7 +33,6 @@ import java.util.Collection; import java.util.List; -import static org.junit.Assume.assumeTrue; import static util.JsonPoweredTestHelper.getTestDocument; import static util.JsonPoweredTestHelper.getTestFiles; @@ -51,8 +50,6 @@ public UnifiedTestValidator(final String fileDescription, final String testDescr @Before public void setUp() { - // TODO: remove after https://jira.mongodb.org/browse/JAVA-3871 is fixed - assumeTrue(!(fileDescription.equals("poc-change-streams") && testDescription.equals("Test consecutive resume"))); super.setUp(); }