diff --git a/driver-async/src/main/com/mongodb/async/client/AggregateIterableImpl.java b/driver-async/src/main/com/mongodb/async/client/AggregateIterableImpl.java index 59f140acee3..83c446337fc 100644 --- a/driver-async/src/main/com/mongodb/async/client/AggregateIterableImpl.java +++ b/driver-async/src/main/com/mongodb/async/client/AggregateIterableImpl.java @@ -186,13 +186,19 @@ private MongoNamespace getOutNamespace() { if (lastStageDocument.containsKey("$out")) { return new MongoNamespace(namespace.getDatabaseName(), lastStageDocument.getString("$out").getValue()); } else if (lastStageDocument.containsKey("$merge")) { - BsonDocument mergeDocument = lastStageDocument.getDocument("$merge"); - if (mergeDocument.isDocument("into")) { - BsonDocument intoDocument = mergeDocument.getDocument("into"); - return new MongoNamespace(intoDocument.getString("db", new BsonString(namespace.getDatabaseName())).getValue(), - intoDocument.getString("coll").getValue()); - } else if (mergeDocument.isString("into")) { - return new MongoNamespace(namespace.getDatabaseName(), mergeDocument.getString("into").getValue()); + if (lastStageDocument.isString("$merge")) { + return new MongoNamespace(namespace.getDatabaseName(), lastStageDocument.getString("$merge").getValue()); + } else if (lastStageDocument.isDocument("$merge")) { + BsonDocument mergeDocument = lastStageDocument.getDocument("$merge"); + if (mergeDocument.isDocument("into")) { + BsonDocument intoDocument = mergeDocument.getDocument("into"); + return new MongoNamespace(intoDocument.getString("db", new BsonString(namespace.getDatabaseName())).getValue(), + intoDocument.getString("coll").getValue()); + } else if (mergeDocument.isString("into")) { + return new MongoNamespace(namespace.getDatabaseName(), mergeDocument.getString("into").getValue()); + } + } else { + throw new IllegalStateException("Cannot return a cursor when the value for $merge stage is not a string or a document"); } } diff --git a/driver-async/src/test/unit/com/mongodb/async/client/AggregateIterableSpecification.groovy b/driver-async/src/test/unit/com/mongodb/async/client/AggregateIterableSpecification.groovy index 5e51916a453..deb99e1a3a2 100644 --- a/driver-async/src/test/unit/com/mongodb/async/client/AggregateIterableSpecification.groovy +++ b/driver-async/src/test/unit/com/mongodb/async/client/AggregateIterableSpecification.groovy @@ -209,7 +209,7 @@ class AggregateIterableSpecification extends Specification { .comment('this is a comment')) } - def 'should build the expected AggregateToCollectionOperation for $merge'() { + def 'should build the expected AggregateToCollectionOperation for $merge document'() { given: def cursor = Stub(AsyncBatchCursor) { next(_) >> { @@ -358,6 +358,38 @@ class AggregateIterableSpecification extends Specification { .comment('this is a comment')) } + def 'should build the expected AggregateToCollectionOperation for $merge string'() { + given: + def cursor = Stub(AsyncBatchCursor) { + next(_) >> { + it[0].onResult(null, null) + } + } + def executor = new TestOperationExecutor([cursor, cursor, cursor, cursor, cursor, cursor, cursor]); + def collectionName = 'collectionName' + def collectionNamespace = new MongoNamespace(namespace.getDatabaseName(), collectionName) + def pipeline = [new Document('$match', 1), new Document('$merge', new Document('into', collectionName))] + + when: 'aggregation includes $merge' + new AggregateIterableImpl(null, namespace, Document, Document, codecRegistry, readPreference, readConcern, writeConcern, executor, + pipeline, AggregationLevel.COLLECTION, true) + .into([]) { result, t -> } + + def operation = executor.getReadOperation() as WriteOperationThenCursorReadOperation + + then: + expect operation.getAggregateToCollectionOperation(), isTheSameAs(new AggregateToCollectionOperation(namespace, + [new BsonDocument('$match', new BsonInt32(1)), + new BsonDocument('$merge', new BsonDocument('into', new BsonString(collectionName)))], + readConcern, writeConcern)) + + when: + operation = operation.getReadOperation() as FindOperation + + then: + operation.getNamespace() == collectionNamespace + } + def 'should handle exceptions correctly'() { given: def codecRegistry = fromProviders([new ValueCodecProvider(), new BsonValueCodecProvider()]) diff --git a/driver-sync/src/main/com/mongodb/client/internal/AggregateIterableImpl.java b/driver-sync/src/main/com/mongodb/client/internal/AggregateIterableImpl.java index 56086cda801..629128eded8 100644 --- a/driver-sync/src/main/com/mongodb/client/internal/AggregateIterableImpl.java +++ b/driver-sync/src/main/com/mongodb/client/internal/AggregateIterableImpl.java @@ -187,13 +187,19 @@ private MongoNamespace getOutNamespace() { if (lastStageDocument.containsKey("$out")) { return new MongoNamespace(namespace.getDatabaseName(), lastStageDocument.getString("$out").getValue()); } else if (lastStageDocument.containsKey("$merge")) { - BsonDocument mergeDocument = lastStageDocument.getDocument("$merge"); - if (mergeDocument.isDocument("into")) { - BsonDocument intoDocument = mergeDocument.getDocument("into"); - return new MongoNamespace(intoDocument.getString("db", new BsonString(namespace.getDatabaseName())).getValue(), - intoDocument.getString("coll").getValue()); - } else if (mergeDocument.isString("into")) { - return new MongoNamespace(namespace.getDatabaseName(), mergeDocument.getString("into").getValue()); + if (lastStageDocument.isString("$merge")) { + return new MongoNamespace(namespace.getDatabaseName(), lastStageDocument.getString("$merge").getValue()); + } else if (lastStageDocument.isDocument("$merge")) { + BsonDocument mergeDocument = lastStageDocument.getDocument("$merge"); + if (mergeDocument.isDocument("into")) { + BsonDocument intoDocument = mergeDocument.getDocument("into"); + return new MongoNamespace(intoDocument.getString("db", new BsonString(namespace.getDatabaseName())).getValue(), + intoDocument.getString("coll").getValue()); + } else if (mergeDocument.isString("into")) { + return new MongoNamespace(namespace.getDatabaseName(), mergeDocument.getString("into").getValue()); + } + } else { + throw new IllegalStateException("Cannot return a cursor when the value for $merge stage is not a string or a document"); } } diff --git a/driver-sync/src/test/unit/com/mongodb/client/internal/AggregateIterableSpecification.groovy b/driver-sync/src/test/unit/com/mongodb/client/internal/AggregateIterableSpecification.groovy index d9d9adaa289..ca3030bfa7d 100644 --- a/driver-sync/src/test/unit/com/mongodb/client/internal/AggregateIterableSpecification.groovy +++ b/driver-sync/src/test/unit/com/mongodb/client/internal/AggregateIterableSpecification.groovy @@ -195,7 +195,7 @@ class AggregateIterableSpecification extends Specification { .comment('this is a comment')) } - def 'should build the expected AggregateToCollectionOperation for $merge'() { + def 'should build the expected AggregateToCollectionOperation for $merge document'() { given: def executor = new TestOperationExecutor([null, null, null, null, null, null, null]) def collectionName = 'collectionName' @@ -335,6 +335,30 @@ class AggregateIterableSpecification extends Specification { .comment('this is a comment')) } + def 'should build the expected AggregateToCollectionOperation for $merge string'() { + given: + def executor = new TestOperationExecutor([null, null, null, null, null, null, null]) + def collectionName = 'collectionName' + def collectionNamespace = new MongoNamespace(namespace.getDatabaseName(), collectionName) + def pipeline = [new BsonDocument('$match', new BsonDocument()), new BsonDocument('$merge', new BsonString(collectionName))] + + when: + new AggregateIterableImpl(null, namespace, Document, Document, codecRegistry, readPreference, readConcern, writeConcern, executor, + pipeline, AggregationLevel.COLLECTION, false) + .iterator() + + def operation = executor.getWriteOperation() as AggregateToCollectionOperation + + then: + expect operation, isTheSameAs(new AggregateToCollectionOperation(namespace, pipeline, readConcern, writeConcern, + AggregationLevel.COLLECTION)) + + when: + operation = executor.getReadOperation() as FindOperation + + then: + operation.getNamespace() == collectionNamespace + } def 'should use ClientSession for AggregationOperation'() { given: