Skip to content

Commit

Permalink
add simplify method for aggregate function
Browse files Browse the repository at this point in the history
  • Loading branch information
milenkovicm committed May 9, 2024
1 parent 96487ea commit de51434
Show file tree
Hide file tree
Showing 3 changed files with 329 additions and 1 deletion.
181 changes: 181 additions & 0 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_common::tree_node::Transformed;
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};

use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
use datafusion::error::Result;
use datafusion::{assert_batches_eq, prelude::*};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::{
expr::{AggregateFunction, AggregateFunctionDefinition},
function::AccumulatorArgs,
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
/// defined aggregate function with a different expression which is defined in the `simplify` method.

#[derive(Debug, Clone)]
struct BetterAvgUdaf {
signature: Signature,
}

impl BetterAvgUdaf {
/// Create a new instance of the GeoMeanUdaf struct
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for BetterAvgUdaf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"better_avg"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
unimplemented!("should not be invoked")
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
unimplemented!("should not be invoked")
}

fn groups_accumulator_supported(&self) -> bool {
true
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("should not get here");
}
// we override method, to return new expression which would substitute
// user defined function call
fn simplify(
&self,
aggregate_function: AggregateFunction,
_info: &dyn SimplifyInfo,
) -> Result<Transformed<Expr>> {
// as an example for this functionality we replace UDF function
// with build-in aggregate function to illustrate the use
let expr = Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
datafusion_expr::aggregate_function::AggregateFunction::Avg,
),
args: aggregate_function.args,
distinct: aggregate_function.distinct,
filter: aggregate_function.filter,
order_by: aggregate_function.order_by,
null_treatment: aggregate_function.null_treatment,
});

Ok(Transformed::yes(expr))
}
}

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));

// define data in two partitions
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
],
)?;
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![16.0])),
Arc::new(Float32Array::from(vec![2.0])),
],
)?;

let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
Ok(ctx)
}

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;

let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
ctx.register_udaf(better_avg.clone());

let result = ctx
.sql("SELECT better_avg(a) FROM t group by b")
.await?
.collect()
.await?;

let expected = [
"+-----------------+",
"| better_avg(t.a) |",
"+-----------------+",
"| 7.5 |",
"+-----------------+",
];

assert_batches_eq!(expected, &result);

let df = ctx.table("t").await?;
let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;

let results = df.collect().await?;
let result = as_float64_array(results[0].column(0))?;

assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
println!("The average of [2,4,8,16] is {}", result.value(0));

Ok(())
}
42 changes: 42 additions & 0 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use crate::expr::AggregateFunction;
use crate::function::AccumulatorArgs;
use crate::groups_accumulator::GroupsAccumulator;
use crate::simplify::SimplifyInfo;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::tree_node::Transformed;
use datafusion_common::{not_impl_err, Result};
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
Expand Down Expand Up @@ -195,6 +198,16 @@ impl AggregateUDF {
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator()
}
/// Do the function rewrite
///
/// See [`AggregateUDFImpl::simplify`] for more details.
pub fn simplify(
&self,
aggregate_function: AggregateFunction,
info: &dyn SimplifyInfo,
) -> Result<Transformed<Expr>> {
self.inner.simplify(aggregate_function, info)
}
}

impl<F> From<F> for AggregateUDF
Expand Down Expand Up @@ -354,6 +367,35 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn aliases(&self) -> &[String] {
&[]
}

/// Optionally apply per-UDF simplification / rewrite rules.
///
/// This can be used to apply function specific simplification rules during
/// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
/// implementation does nothing.
///
/// Note that DataFusion handles simplifying arguments and "constant
/// folding" (replacing a function call with constant arguments such as
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
/// optimizations manually for specific UDFs.
///
/// # Arguments
/// * 'aggregate_function': Aggregate function to be simplified
/// * 'info': Simplification information
///
/// # Returns
/// [`Transformed`] indicating the result of the simplification NOTE
/// if the function cannot be simplified, [Expr::AggregateFunction] with unmodified [AggregateFunction]
/// should be returned
fn simplify(
&self,
aggregate_function: AggregateFunction,
_info: &dyn SimplifyInfo,
) -> Result<Transformed<Expr>> {
Ok(Transformed::yes(Expr::AggregateFunction(
aggregate_function,
)))
}
}

/// AggregateUDF that adds an alias to the underlying function. It is better to
Expand Down
107 changes: 106 additions & 1 deletion datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{InList, InSubquery};
use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
Expand Down Expand Up @@ -1382,6 +1382,18 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
}
}

Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction {
func_def: AggregateFunctionDefinition::UDF(ref udaf),
..
}) => {
let udaf = udaf.clone();
if let Expr::AggregateFunction(aggregate_function) = expr {
udaf.simplify(aggregate_function, info)?
} else {
unreachable!("this branch should be unreachable")
}
}

//
// Rules for Between
//
Expand Down Expand Up @@ -3698,4 +3710,97 @@ mod tests {
assert_eq!(expr, expected);
assert_eq!(num_iter, 2);
}
#[test]
fn test_simplify_udaf() {
let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify());
let aggregate_function_expr =
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
udaf.into(),
vec![],
false,
None,
None,
None,
));

let expected = col("result_column");
assert_eq!(simplify(aggregate_function_expr), expected);

let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify());
let aggregate_function_expr =
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
udaf.into(),
vec![],
false,
None,
None,
None,
));

let expected = aggregate_function_expr.clone();
assert_eq!(simplify(aggregate_function_expr), expected);
}

/// A Mock UDAF which defines `simplify` to be used in tests
/// related to UDAF simplification
#[derive(Debug, Clone)]
struct SimplifyMockUdaf {
simplify: bool,
}

impl SimplifyMockUdaf {
/// make simplify method return new expression
fn new_with_simplify() -> Self {
Self { simplify: true }
}
/// make simplify method return no change
fn new_without_simplify() -> Self {
Self { simplify: false }
}
}

impl AggregateUDFImpl for SimplifyMockUdaf {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"mock_simplify"
}

fn signature(&self) -> &Signature {
unimplemented!()
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
unimplemented!("not needed for tests")
}

fn accumulator(
&self,
_acc_args: function::AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
unimplemented!("not needed for tests")
}

fn groups_accumulator_supported(&self) -> bool {
unimplemented!("not needed for testing")
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("not needed for testing")
}

fn simplify(
&self,
aggregate_function: datafusion_expr::expr::AggregateFunction,
_info: &dyn SimplifyInfo,
) -> Result<Transformed<Expr>> {
if self.simplify {
Ok(Transformed::yes(col("result_column")))
} else {
Ok(Transformed::no(Expr::AggregateFunction(aggregate_function)))
}
}
}
}

0 comments on commit de51434

Please sign in to comment.