diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java index f5795bd164..f84e427bd4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AbstractAggregationExpression.java @@ -19,6 +19,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -76,11 +77,11 @@ private Object unpack(Object value, AggregationOperationContext context) { return context.getReference(field).toString(); } - if(value instanceof Fields fields) { + if (value instanceof Fields fields) { return fields.asList().stream().map(it -> unpack(it, context)).collect(Collectors.toList()); } - if(value instanceof Sort sort) { + if (value instanceof Sort sort) { Document sortDoc = new Document(); for (Order order : sort) { @@ -154,9 +155,40 @@ protected Map append(String key, Object value) { Assert.isInstanceOf(Map.class, this.value, "Value must be a type of Map"); - Map clone = new LinkedHashMap<>((java.util.Map) this.value); + return append((Map) this.value, key, value); + } + + private Map append(Map existing, String key, Object value) { + + Map clone = new LinkedHashMap<>(existing); clone.put(key, value); return clone; + } + + protected Map appendTo(String key, Object value) { + + Assert.isInstanceOf(Map.class, this.value, "Value must be a type of Map"); + + if (this.value instanceof Map map) { + + Map target = new HashMap<>(map); + if (!target.containsKey(key)) { + target.put(key, value); + return target; + } + target.computeIfPresent(key, (k, v) -> { + + if (v instanceof List list) { + List targetList = new ArrayList<>(list); + targetList.add(value); + return targetList; + } + return Arrays.asList(v, value); + }); + return target; + } + throw new IllegalStateException( + String.format("Cannot append value to %s type", ObjectUtils.nullSafeClassName(this.value))); } @@ -247,6 +279,10 @@ protected T get(Object key) { return (T) ((Map) this.value).get(key); } + protected boolean isArgumentMap() { + return this.value instanceof Map; + } + /** * Get the argument map. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java index 51fa0459fd..042ff90326 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java @@ -112,6 +112,17 @@ public Max max() { return usesFieldRef() ? Max.maxOf(fieldReference) : Max.maxOf(expression); } + /** + * Creates new {@link AggregationExpression} that takes the associated numeric value expression and returns the + * requested number of maximum values. + * + * @return new instance of {@link Max}. + * @since 4.0 + */ + public Max max(int numberOfResults) { + return max().limit(numberOfResults); + } + /** * Creates new {@link AggregationExpression} that takes the associated numeric value expression and returns the * minimum value. @@ -441,7 +452,7 @@ private Max(Object value) { @Override protected String getMongoMethod() { - return "$max"; + return contains("n") ? "$maxN" : "$max"; } /** @@ -453,7 +464,7 @@ protected String getMongoMethod() { public static Max maxOf(String fieldReference) { Assert.notNull(fieldReference, "FieldReference must not be null"); - return new Max(asFields(fieldReference)); + return new Max(Collections.singletonMap("input", Fields.field(fieldReference))); } /** @@ -465,7 +476,7 @@ public static Max maxOf(String fieldReference) { public static Max maxOf(AggregationExpression expression) { Assert.notNull(expression, "Expression must not be null"); - return new Max(Collections.singletonList(expression)); + return new Max(Collections.singletonMap("input", expression)); } /** @@ -478,7 +489,7 @@ public static Max maxOf(AggregationExpression expression) { public Max and(String fieldReference) { Assert.notNull(fieldReference, "FieldReference must not be null"); - return new Max(append(Fields.field(fieldReference))); + return new Max(appendTo("input", Fields.field(fieldReference))); } /** @@ -491,7 +502,26 @@ public Max and(String fieldReference) { public Max and(AggregationExpression expression) { Assert.notNull(expression, "Expression must not be null"); - return new Max(append(expression)); + return new Max(appendTo("input", expression)); + } + + /** + * Creates new {@link Max} that returns the given number of maxmimum values ({@literal $maxN}). + * NOTE: Cannot be used with more than one {@literal input} value. + * + * @param numberOfResults + * @return new instance of {@link Max}. + */ + public Max limit(int numberOfResults) { + return new Max(append("n", numberOfResults)); + } + + @Override + public Document toDocument(AggregationOperationContext context) { + if (get("n") == null) { + return toDocument(get("input"), context); + } + return super.toDocument(context); } @Override diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java index 2f99aafb59..1f33d18cde 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java @@ -224,6 +224,8 @@ public class MethodReferenceNode extends ExpressionNode { .mappingParametersTo("output", "sortBy")); map.put("topN", mapArgRef().forOperator("$topN") // .mappingParametersTo("n", "output", "sortBy")); + map.put("maxN", mapArgRef().forOperator("$maxN") // + .mappingParametersTo("n", "input")); // CONVERT OPERATORS map.put("convert", mapArgRef().forOperator("$convert") // diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java index 32a772950b..871f60db48 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java @@ -80,6 +80,20 @@ void rendersExpMovingAvgWithAlpha() { .isEqualTo(Document.parse("{ $expMovingAvg: { input: \"$price\", alpha: 0.75 } }")); } + @Test // GH-4139 + void rendersMax() { + + assertThat(valueOf("price").max().toDocument(Aggregation.DEFAULT_CONTEXT)) + .isEqualTo(Document.parse("{ $max: \"$price\" }")); + } + + @Test // GH-4139 + void rendersMaxN() { + + assertThat(valueOf("price").max(3).toDocument(Aggregation.DEFAULT_CONTEXT)) + .isEqualTo(Document.parse("{ $maxN: { n: 3, input : \"$price\" } }")); + } + static class Jedi { String name; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java index 1c2c4b5725..5474616977 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java @@ -1204,6 +1204,11 @@ void shouldRenderLastN() { assertThat(transform("lastN(3, \"$score\")")).isEqualTo("{ $lastN : { n : 3, input : \"$score\" }}"); } + @Test // GH-4139 + void shouldRenderMaxN() { + assertThat(transform("maxN(3, \"$score\")")).isEqualTo("{ $maxN : { n : 3, input : \"$score\" }}"); + } + private Document transform(String expression, Object... params) { return (Document) transformer.transform(expression, Aggregation.DEFAULT_CONTEXT, params); }