Skip to content

Commit

Permalink
ESQL: Add aggregates node level reduction (#107876)
Browse files Browse the repository at this point in the history
* Add aggregation intermediate reduction level and estimatedRowSize
computed value
  • Loading branch information
astefan committed May 9, 2024
1 parent 0a8c6d2 commit 1b7cad1
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 32 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/107876.yaml
@@ -0,0 +1,5 @@
pr: 107876
summary: "ESQL: Add aggregates node level reduction"
area: ES|QL
type: enhancement
issues: []
Expand Up @@ -81,6 +81,7 @@ public class EsqlActionTaskIT extends AbstractPausableIntegTestCase {
@Before
public void setup() {
assumeTrue("requires query pragmas", canUseQueryPragmas());
nodeLevelReduction = randomBoolean();
READ_DESCRIPTION = """
\\_LuceneSourceOperator[dataPartitioning = SHARD, maxPageSize = pageSize(), limit = 2147483647]
\\_ValuesSourceReaderOperator[fields = [pause_me]]
Expand All @@ -92,10 +93,10 @@ public void setup() {
\\_ProjectOperator[projection = [0]]
\\_LimitOperator[limit = 1000]
\\_OutputOperator[columns = [sum(pause_me)]]""";
REDUCE_DESCRIPTION = """
\\_ExchangeSourceOperator[]
\\_ExchangeSinkOperator""";
nodeLevelReduction = randomBoolean();
REDUCE_DESCRIPTION = "\\_ExchangeSourceOperator[]\n"
+ (nodeLevelReduction ? "\\_AggregationOperator[mode = INTERMEDIATE, aggs = sum of longs]\n" : "")
+ "\\_ExchangeSinkOperator";

}

public void testTaskContents() throws Exception {
Expand Down Expand Up @@ -480,6 +481,37 @@ public void testTaskContentsForLimitQuery() throws Exception {
}
}

public void testTaskContentsForGroupingStatsQuery() throws Exception {
READ_DESCRIPTION = """
\\_LuceneSourceOperator[dataPartitioning = SHARD, maxPageSize = pageSize(), limit = 2147483647]
\\_ValuesSourceReaderOperator[fields = [foo]]
\\_OrdinalsGroupingOperator(aggs = max of longs)
\\_ExchangeSinkOperator""".replace("pageSize()", Integer.toString(pageSize()));
MERGE_DESCRIPTION = """
\\_ExchangeSourceOperator[]
\\_HashAggregationOperator[mode = <not-needed>, aggs = max of longs]
\\_ProjectOperator[projection = [1, 0]]
\\_LimitOperator[limit = 1000]
\\_OutputOperator[columns = [max(foo), pause_me]]""";
REDUCE_DESCRIPTION = "\\_ExchangeSourceOperator[]\n"
+ (nodeLevelReduction ? "\\_HashAggregationOperator[mode = <not-needed>, aggs = max of longs]\n" : "")
+ "\\_ExchangeSinkOperator";

ActionFuture<EsqlQueryResponse> response = startEsql("from test | stats max(foo) by pause_me");
try {
getTasksStarting();
scriptPermits.release(pageSize());
getTasksRunning();
} finally {
scriptPermits.release(numberOfDocs());
try (EsqlQueryResponse esqlResponse = response.get()) {
var it = Iterators.flatMap(esqlResponse.values(), i -> i);
assertThat(it.next(), equalTo(numberOfDocs() - 1L)); // max of numberOfDocs() generated int values
assertThat(it.next(), equalTo(1L)); // pause_me always emits 1
}
}
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return CollectionUtils.appendToCopy(super.nodePlugins(), MockTransportService.TestPlugin.class);
Expand Down
Expand Up @@ -68,6 +68,10 @@ public List<? extends NamedExpression> aggregates() {
return aggregates;
}

public AggregateExec withMode(Mode newMode) {
return new AggregateExec(source(), child(), groupings, aggregates, newMode, estimatedRowSize);
}

/**
* Estimate of the number of bytes that'll be loaded per position before
* the stream of pages is consumed.
Expand Down
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.esql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.PhysicalOperation;
import org.elasticsearch.xpack.ql.InvalidArgumentException;
Expand Down Expand Up @@ -54,6 +55,20 @@ public final PhysicalOperation groupingPhysicalOperation(
var aggregates = aggregateExec.aggregates();

var sourceLayout = source.layout;
AggregatorMode aggregatorMode;

if (mode == AggregateExec.Mode.FINAL) {
aggregatorMode = AggregatorMode.FINAL;
} else if (mode == AggregateExec.Mode.PARTIAL) {
if (aggregateExec.child() instanceof ExchangeSourceExec) {// the reducer step at data node (local) level
aggregatorMode = AggregatorMode.INTERMEDIATE;
} else {
aggregatorMode = AggregatorMode.INITIAL;
}
} else {
assert false : "Invalid aggregator mode [" + mode + "]";
aggregatorMode = AggregatorMode.SINGLE;
}

if (aggregateExec.groupings().isEmpty()) {
// not grouping
Expand All @@ -65,20 +80,18 @@ public final PhysicalOperation groupingPhysicalOperation(
} else {
layout.append(aggregateMapper.mapNonGrouping(aggregates));
}

// create the agg factories
aggregatesToFactory(
aggregates,
mode,
aggregatorMode,
sourceLayout,
false, // non-grouping
s -> aggregatorFactories.add(s.supplier.aggregatorFactory(s.mode))
);

if (aggregatorFactories.isEmpty() == false) {
operatorFactory = new AggregationOperator.AggregationOperatorFactory(
aggregatorFactories,
mode == AggregateExec.Mode.FINAL ? AggregatorMode.FINAL : AggregatorMode.INITIAL
);
operatorFactory = new AggregationOperator.AggregationOperatorFactory(aggregatorFactories, aggregatorMode);
}
} else {
// grouping
Expand Down Expand Up @@ -136,7 +149,7 @@ else if (mode == AggregateExec.Mode.PARTIAL) {
// create the agg factories
aggregatesToFactory(
aggregates,
mode,
aggregatorMode,
sourceLayout,
true, // grouping
s -> aggregatorFactories.add(s.supplier.groupingAggregatorFactory(s.mode))
Expand Down Expand Up @@ -219,7 +232,7 @@ private record AggFunctionSupplierContext(AggregatorFunctionSupplier supplier, A

private void aggregatesToFactory(
List<? extends NamedExpression> aggregates,
AggregateExec.Mode mode,
AggregatorMode mode,
Layout layout,
boolean grouping,
Consumer<AggFunctionSupplierContext> consumer
Expand All @@ -228,11 +241,9 @@ private void aggregatesToFactory(
if (ne instanceof Alias alias) {
var child = alias.child();
if (child instanceof AggregateFunction aggregateFunction) {
AggregatorMode aggMode = null;
List<? extends NamedExpression> sourceAttr;

if (mode == AggregateExec.Mode.PARTIAL) {
aggMode = AggregatorMode.INITIAL;
if (mode == AggregatorMode.INITIAL) {
// TODO: this needs to be made more reliable - use casting to blow up when dealing with expressions (e+1)
Expression field = aggregateFunction.field();
// Only count can now support literals - all the other aggs should be optimized away
Expand All @@ -257,9 +268,7 @@ private void aggregatesToFactory(
}
sourceAttr = List.of(attr);
}

} else if (mode == AggregateExec.Mode.FINAL) {
aggMode = AggregatorMode.FINAL;
} else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) {
if (grouping) {
sourceAttr = aggregateMapper.mapGrouping(aggregateFunction);
} else {
Expand All @@ -279,7 +288,7 @@ private void aggregatesToFactory(
assert inputChannels.size() > 0 && inputChannels.stream().allMatch(i -> i >= 0);
}
if (aggregateFunction instanceof ToAggregator agg) {
consumer.accept(new AggFunctionSupplierContext(agg.supplier(inputChannels), aggMode));
consumer.accept(new AggFunctionSupplierContext(agg.supplier(inputChannels), mode));
} else {
throw new EsqlIllegalArgumentException("aggregate functions must extend ToAggregator");
}
Expand Down
Expand Up @@ -54,7 +54,7 @@
public class Mapper {

private final FunctionRegistry functionRegistry;
private final boolean localMode;
private final boolean localMode; // non-coordinator (data node) mode

public Mapper(FunctionRegistry functionRegistry) {
this.functionRegistry = functionRegistry;
Expand Down
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.xpack.esql.plan.logical.Aggregate;
import org.elasticsearch.xpack.esql.plan.logical.EsRelation;
import org.elasticsearch.xpack.esql.plan.logical.TopN;
import org.elasticsearch.xpack.esql.plan.physical.AggregateExec;
import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec;
import org.elasticsearch.xpack.esql.plan.physical.EsSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize;
Expand Down Expand Up @@ -87,23 +88,19 @@ public static PhysicalPlan dataNodeReductionPlan(LogicalPlan plan, PhysicalPlan

if (pipelineBreakers.isEmpty() == false) {
UnaryPlan pipelineBreaker = (UnaryPlan) pipelineBreakers.get(0);
if (pipelineBreaker instanceof TopN topN) {
return new TopNExec(topN.source(), unused, topN.order(), topN.limit(), 2000);
if (pipelineBreaker instanceof TopN) {
Mapper mapper = new Mapper(true);
var physicalPlan = EstimatesRowSize.estimateRowSize(0, mapper.map(plan));
return physicalPlan.collectFirstChildren(TopNExec.class::isInstance).get(0);
} else if (pipelineBreaker instanceof Limit limit) {
return new LimitExec(limit.source(), unused, limit.limit());
} else if (pipelineBreaker instanceof OrderBy order) {
return new OrderExec(order.source(), unused, order.order());
} else if (pipelineBreaker instanceof Aggregate aggregate) {
// TODO handle this as a special PARTIAL step (intermediate)
/*return new AggregateExec(
aggregate.source(),
unused,
aggregate.groupings(),
aggregate.aggregates(),
AggregateExec.Mode.PARTIAL,
0
);*/
return null;
} else if (pipelineBreaker instanceof Aggregate) {
Mapper mapper = new Mapper(true);
var physicalPlan = EstimatesRowSize.estimateRowSize(0, mapper.map(plan));
var aggregate = (AggregateExec) physicalPlan.collectFirstChildren(AggregateExec.class::isInstance).get(0);
return aggregate.withMode(AggregateExec.Mode.PARTIAL);
} else {
throw new EsqlIllegalArgumentException("unsupported unary physical plan node [" + pipelineBreaker.nodeName() + "]");
}
Expand Down

0 comments on commit 1b7cad1

Please sign in to comment.