Skip to content

Commit

Permalink
rename FuncArgs back to AggArgs
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Chien <stdrc@outlook.com>
  • Loading branch information
stdrc committed Apr 18, 2023
1 parent bf69e62 commit be50b16
Show file tree
Hide file tree
Showing 13 changed files with 66 additions and 66 deletions.
8 changes: 4 additions & 4 deletions src/expr/src/function/aggregate/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use risingwave_common::types::DataType;

/// An aggregation function may accept 0, 1 or 2 arguments.
#[derive(Clone, Debug)]
pub enum FuncArgs {
pub enum AggArgs {
/// `None` is used for function calls that accept 0 argument, e.g. `count(*)`.
None,
/// `Unary` is used for function calls that accept 1 argument, e.g. `sum(x)`.
Expand All @@ -27,10 +27,10 @@ pub enum FuncArgs {
Binary([DataType; 2], [usize; 2]),
}

impl FuncArgs {
impl AggArgs {
/// return the types of arguments.
pub fn arg_types(&self) -> &[DataType] {
use FuncArgs::*;
use AggArgs::*;
match self {
None => Default::default(),
Unary(typ, _) => slice::from_ref(typ),
Expand All @@ -40,7 +40,7 @@ impl FuncArgs {

/// return the indices of the arguments in [`risingwave_common::array::StreamChunk`].
pub fn val_indices(&self) -> &[usize] {
use FuncArgs::*;
use AggArgs::*;
match self {
None => Default::default(),
Unary(_, val_idx) => slice::from_ref(val_idx),
Expand Down
10 changes: 5 additions & 5 deletions src/expr/src/function/aggregate/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use risingwave_pb::expr::PbAggCall;

use super::AggKind;
use crate::expr::{build_from_prost, ExpressionRef};
use crate::function::aggregate::FuncArgs;
use crate::function::aggregate::AggArgs;
use crate::Result;

/// Represents an aggregation function.
Expand All @@ -30,7 +30,7 @@ pub struct AggCall {
/// Aggregation kind for constructing agg state.
pub kind: AggKind,
/// Arguments of aggregation function input.
pub args: FuncArgs,
pub args: AggArgs,
/// The return type of aggregation function.
pub return_type: DataType,

Expand All @@ -48,11 +48,11 @@ impl AggCall {
pub fn from_protobuf(agg_call: &PbAggCall) -> Result<Self> {
let agg_kind = AggKind::from_protobuf(agg_call.get_type()?)?;
let args = match &agg_call.get_args()[..] {
[] => FuncArgs::None,
[] => AggArgs::None,
[arg] if agg_kind != AggKind::StringAgg => {
FuncArgs::Unary(DataType::from(arg.get_type()?), arg.get_index() as usize)
AggArgs::Unary(DataType::from(arg.get_type()?), arg.get_index() as usize)
}
[agg_arg, extra_arg] if agg_kind == AggKind::StringAgg => FuncArgs::Binary(
[agg_arg, extra_arg] if agg_kind == AggKind::StringAgg => AggArgs::Binary(
[
DataType::from(agg_arg.get_type()?),
DataType::from(extra_arg.get_type()?),
Expand Down
4 changes: 2 additions & 2 deletions src/expr/src/function/window/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use risingwave_common::types::DataType;

use super::WindowFuncKind;
use crate::function::aggregate::FuncArgs;
use crate::function::aggregate::AggArgs;

#[derive(Clone)]
pub enum Frame {
Expand All @@ -29,7 +29,7 @@ pub enum Frame {
#[derive(Clone)]
pub struct WindowFuncCall {
pub kind: WindowFuncKind,
pub args: FuncArgs,
pub args: AggArgs,
pub return_type: DataType,
pub frame: Frame,
}
14 changes: 7 additions & 7 deletions src/expr/src/vector_op/agg/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use risingwave_common::array::*;
use risingwave_common::bail;
use risingwave_common::types::*;

use crate::function::aggregate::{AggCall, AggKind, FuncArgs};
use crate::function::aggregate::{AggArgs, AggCall, AggKind};
use crate::vector_op::agg::approx_count_distinct::ApproxCountDistinct;
use crate::vector_op::agg::array_agg::create_array_agg_state;
use crate::vector_op::agg::count_star::CountStar;
Expand Down Expand Up @@ -66,21 +66,21 @@ impl AggStateFactory {
// NOTE: The function signature is checked by `AggCall::infer_return_type` in the frontend.

let initial_agg_state: BoxedAggState = match (agg_call.kind, agg_call.args) {
(AggKind::Count, FuncArgs::None) => {
(AggKind::Count, AggArgs::None) => {
Box::new(CountStar::new(agg_call.return_type.clone()))
}
(AggKind::ApproxCountDistinct, FuncArgs::Unary(_, arg_idx)) => Box::new(
(AggKind::ApproxCountDistinct, AggArgs::Unary(_, arg_idx)) => Box::new(
ApproxCountDistinct::new(agg_call.return_type.clone(), arg_idx),
),
(
AggKind::StringAgg,
FuncArgs::Binary([value_type, delim_type], [value_idx, delim_idx]),
AggArgs::Binary([value_type, delim_type], [value_idx, delim_idx]),
) => {
assert_eq!(value_type, DataType::Varchar);
assert_eq!(delim_type, DataType::Varchar);
create_string_agg_state(value_idx, delim_idx, agg_call.column_orders.clone())
}
(AggKind::Sum, FuncArgs::Unary(arg_type, arg_idx))
(AggKind::Sum, AggArgs::Unary(arg_type, arg_idx))
if matches!(arg_type, DataType::Int256) =>
{
// Special handling of the `sum` function for `Int256`, when the
Expand All @@ -92,12 +92,12 @@ impl AggStateFactory {
// when refactoring the code related to aggregation in the future.
Box::new(Int256Sum::new(arg_idx, agg_call.distinct))
}
(AggKind::ArrayAgg, FuncArgs::Unary(_, arg_idx)) => create_array_agg_state(
(AggKind::ArrayAgg, AggArgs::Unary(_, arg_idx)) => create_array_agg_state(
agg_call.return_type.clone(),
arg_idx,
agg_call.column_orders.clone(),
),
(agg_kind, FuncArgs::Unary(arg_type, arg_idx)) => {
(agg_kind, AggArgs::Unary(arg_type, arg_idx)) => {
// other unary agg call
create_agg_state_unary(
arg_type,
Expand Down
16 changes: 8 additions & 8 deletions src/stream/benches/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use futures::StreamExt;
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::types::DataType;
use risingwave_expr::expr::*;
use risingwave_expr::function::aggregate::{AggCall, AggKind, FuncArgs};
use risingwave_expr::function::aggregate::{AggArgs, AggCall, AggKind};
use risingwave_storage::memory::MemoryStateStore;
use risingwave_storage::StateStore;
use risingwave_stream::executor::test_utils::agg_executor::new_boxed_hash_agg_executor;
Expand Down Expand Up @@ -84,31 +84,31 @@ fn setup_bench_hash_agg<S: StateStore>(store: S) -> BoxedExecutor {
let agg_calls = vec![
AggCall {
kind: AggKind::Count,
args: FuncArgs::None,
args: AggArgs::None,
return_type: DataType::Int64,
column_orders: vec![],
filter: None,
distinct: false,
},
AggCall {
kind: AggKind::Count,
args: FuncArgs::None,
args: AggArgs::None,
return_type: DataType::Int64,
column_orders: vec![],
filter: Some(build_from_pretty("(less_than:boolean $2:int8 10000:int8)").into()),
distinct: false,
},
AggCall {
kind: AggKind::Count,
args: FuncArgs::None,
args: AggArgs::None,
return_type: DataType::Int64,
column_orders: vec![],
filter: Some(build_from_pretty("(and:boolean (greater_than_or_equal:boolean $2:int8 10000:int8) (less_than:boolean $2:int8 100000:int8))").into()),
distinct: false,
},
AggCall {
kind: AggKind::Count,
args: FuncArgs::None,
args: AggArgs::None,
return_type: DataType::Int64,
column_orders: vec![],
filter: Some(build_from_pretty("(greater_than_or_equal:boolean $2:int8 100000:int8)").into()),
Expand Down Expand Up @@ -144,7 +144,7 @@ fn setup_bench_hash_agg<S: StateStore>(store: S) -> BoxedExecutor {
// avg (sum)
AggCall {
kind: AggKind::Sum,
args: FuncArgs::Unary(DataType::Int64, 2),
args: AggArgs::Unary(DataType::Int64, 2),
return_type: DataType::Int64,
column_orders: vec![],
filter: None,
Expand All @@ -153,15 +153,15 @@ fn setup_bench_hash_agg<S: StateStore>(store: S) -> BoxedExecutor {
// avg (count)
AggCall {
kind: AggKind::Count,
args: FuncArgs::Unary(DataType::Int64, 2),
args: AggArgs::Unary(DataType::Int64, 2),
return_type: DataType::Int64,
column_orders: vec![],
filter: None,
distinct: false,
},
AggCall {
kind: AggKind::Sum,
args: FuncArgs::Unary(DataType::Int64, 2),
args: AggArgs::Unary(DataType::Int64, 2),
return_type: DataType::Int64,
column_orders: vec![],
filter: None,
Expand Down
4 changes: 2 additions & 2 deletions src/stream/src/executor/aggregation/distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,15 @@ mod tests {
use risingwave_common::types::DataType;
use risingwave_common::util::epoch::EpochPair;
use risingwave_common::util::sort_util::OrderType;
use risingwave_expr::function::aggregate::{AggKind, FuncArgs};
use risingwave_expr::function::aggregate::{AggArgs, AggKind};
use risingwave_storage::memory::MemoryStateStore;

use super::*;

fn count_agg_call(kind: AggKind, col_idx: usize, distinct: bool) -> AggCall {
AggCall {
kind,
args: FuncArgs::Unary(DataType::Int64, col_idx),
args: AggArgs::Unary(DataType::Int64, col_idx),
return_type: DataType::Int64,
distinct,

Expand Down
8 changes: 4 additions & 4 deletions src/stream/src/executor/aggregation/minput.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ mod tests {
use risingwave_common::util::epoch::EpochPair;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_expr::function::aggregate::{AggCall, AggKind, FuncArgs};
use risingwave_expr::function::aggregate::{AggArgs, AggCall, AggKind};
use risingwave_storage::memory::MemoryStateStore;
use risingwave_storage::StateStore;

Expand Down Expand Up @@ -287,7 +287,7 @@ mod tests {
fn create_extreme_agg_call(kind: AggKind, arg_type: DataType, arg_idx: usize) -> AggCall {
AggCall {
kind,
args: FuncArgs::Unary(arg_type.clone(), arg_idx),
args: AggArgs::Unary(arg_type.clone(), arg_idx),
return_type: arg_type,
column_orders: vec![],
filter: None,
Expand Down Expand Up @@ -980,7 +980,7 @@ mod tests {

let agg_call = AggCall {
kind: AggKind::StringAgg,
args: FuncArgs::Binary([DataType::Varchar, DataType::Varchar], [0, 1]),
args: AggArgs::Binary([DataType::Varchar, DataType::Varchar], [0, 1]),
return_type: DataType::Varchar,
column_orders: vec![
ColumnOrder::new(2, OrderType::ascending()), // b ASC
Expand Down Expand Up @@ -1081,7 +1081,7 @@ mod tests {

let agg_call = AggCall {
kind: AggKind::ArrayAgg,
args: FuncArgs::Unary(DataType::Int32, 1), // array_agg(b)
args: AggArgs::Unary(DataType::Int32, 1), // array_agg(b)
return_type: DataType::Int32,
column_orders: vec![
ColumnOrder::new(2, OrderType::ascending()), // c ASC
Expand Down
6 changes: 3 additions & 3 deletions src/stream/src/executor/aggregation/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ impl ValueState {
mod tests {
use risingwave_common::array::{I64Array, Op};
use risingwave_common::types::{DataType, ScalarImpl};
use risingwave_expr::function::aggregate::{AggKind, FuncArgs};
use risingwave_expr::function::aggregate::{AggArgs, AggKind};

use super::*;

fn create_test_count_agg() -> AggCall {
AggCall {
kind: AggKind::Count,
args: FuncArgs::Unary(DataType::Int64, 0),
args: AggArgs::Unary(DataType::Int64, 0),
return_type: DataType::Int64,
column_orders: vec![],
filter: None,
Expand Down Expand Up @@ -130,7 +130,7 @@ mod tests {
fn create_test_max_agg_append_only() -> AggCall {
AggCall {
kind: AggKind::Max,
args: FuncArgs::Unary(DataType::Int64, 0),
args: AggArgs::Unary(DataType::Int64, 0),
return_type: DataType::Int64,
column_orders: vec![],
filter: None,
Expand Down
10 changes: 5 additions & 5 deletions src/stream/src/executor/global_simple_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ mod tests {
use risingwave_common::array::stream_chunk::StreamChunkTestExt;
use risingwave_common::catalog::Field;
use risingwave_common::types::*;
use risingwave_expr::function::aggregate::{AggCall, AggKind, FuncArgs};
use risingwave_expr::function::aggregate::{AggArgs, AggCall, AggKind};
use risingwave_storage::memory::MemoryStateStore;
use risingwave_storage::StateStore;

Expand Down Expand Up @@ -381,31 +381,31 @@ mod tests {
let agg_calls = vec![
AggCall {
kind: AggKind::Count, // as row count, index: 0
args: FuncArgs::None,
args: AggArgs::None,
return_type: DataType::Int64,
column_orders: vec![],
filter: None,
distinct: false,
},
AggCall {
kind: AggKind::Sum,
args: FuncArgs::Unary(DataType::Int64, 0),
args: AggArgs::Unary(DataType::Int64, 0),
return_type: DataType::Int64,
column_orders: vec![],
filter: None,
distinct: false,
},
AggCall {
kind: AggKind::Sum,
args: FuncArgs::Unary(DataType::Int64, 1),
args: AggArgs::Unary(DataType::Int64, 1),
return_type: DataType::Int64,
column_orders: vec![],
filter: None,
distinct: false,
},
AggCall {
kind: AggKind::Min,
args: FuncArgs::Unary(DataType::Int64, 0),
args: AggArgs::Unary(DataType::Int64, 0),
return_type: DataType::Int64,
column_orders: vec![],
filter: None,
Expand Down

0 comments on commit be50b16

Please sign in to comment.