Skip to content

Commit

Permalink
Merge pull request #129 from SpringQL/fix/aggregation-without-group
Browse files Browse the repository at this point in the history
fix: panic on aggregate without GROUP BY clause
  • Loading branch information
laysakura committed May 11, 2022
2 parents cccad2d + 37b2536 commit 9c5a9cc
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 73 deletions.
Expand Up @@ -3,11 +3,11 @@
pub(crate) mod aggregate;
pub(crate) mod join_parameter;

use self::{aggregate::GroupAggregateParameter, join_parameter::JoinParameter};
use self::{aggregate::AggregateParameter, join_parameter::JoinParameter};

/// Window operation parameters
#[derive(Clone, PartialEq, Debug)]
pub(crate) enum WindowOperationParameter {
GroupAggregation(GroupAggregateParameter),
Aggregate(AggregateParameter),
Join(JoinParameter),
}
Expand Up @@ -2,21 +2,32 @@

use crate::expr_resolver::expr_label::{AggrExprLabel, ValueExprLabel};

/// [GROUP BY c1, c2, c3...]
#[derive(Clone, PartialEq, Debug, Default, new)]
pub(crate) struct GroupByLabels(
/// Empty when GROUP BY clause is not supplied.
Vec<ValueExprLabel>,
);
impl GroupByLabels {
pub(crate) fn as_labels(&self) -> &[ValueExprLabel] {
&self.0
}
}

/// TODO [support complex expression with aggregations](https://gh01.base.toyota-tokyo.tech/SpringQL-internal/SpringQL/issues/152)
///
/// ```sql
/// SELECT group_by, aggr_expr.func(aggr_expr.aggregated)
/// FROM s
/// GROUP BY group_by
/// [GROUP BY group_by]
/// SLIDING WINDOW ...;
/// ```
#[derive(Copy, Clone, PartialEq, Debug, new)]
pub(crate) struct GroupAggregateParameter {
#[derive(Clone, PartialEq, Debug, new)]
pub(crate) struct AggregateParameter {
// TODO multiple aggr_expr
pub(crate) aggr_func: AggregateFunctionParameter,
pub(crate) aggr_expr: AggrExprLabel,

pub(crate) group_by: ValueExprLabel, // TODO multiple group by expression
pub(crate) group_by: GroupByLabels,
}

#[derive(Copy, Clone, Eq, PartialEq, Debug)]
Expand Down
42 changes: 23 additions & 19 deletions springql-core/src/sql_processor/query_planner.rs
Expand Up @@ -63,7 +63,8 @@ use crate::{
pipeline::{
pump_model::{
window_operation_parameter::{
aggregate::GroupAggregateParameter, WindowOperationParameter,
aggregate::{AggregateParameter, GroupByLabels},
WindowOperationParameter,
},
window_parameter::WindowParameter,
},
Expand Down Expand Up @@ -123,7 +124,7 @@ impl QueryPlanner {
match (window_param, group_aggr_param) {
(Some(window_param), Some(group_aggr_param)) => Ok(Some(GroupAggregateWindowOp {
window_param,
op_param: WindowOperationParameter::GroupAggregation(group_aggr_param),
op_param: WindowOperationParameter::Aggregate(group_aggr_param),
})),
_ => Ok(None),
}
Expand All @@ -137,8 +138,8 @@ impl QueryPlanner {
&self,
expr_resolver: &mut ExprResolver,
projection_op: &ProjectionOp,
) -> Result<Option<GroupAggregateParameter>> {
let opt_grouping_elem = self.analyzer.grouping_element();
) -> Result<Option<AggregateParameter>> {
let grouping_elements = self.analyzer.grouping_elements();
let aggr_labels = projection_op
.expr_labels
.iter()
Expand All @@ -151,28 +152,31 @@ impl QueryPlanner {
})
.collect::<Vec<_>>();

match (opt_grouping_elem, aggr_labels.len()) {
(Some(grouping_elem), 1) => {
match aggr_labels.len() {
1 => {
let aggr_label = aggr_labels.get(0).expect("len checked");
let aggr_func = expr_resolver.resolve_aggr_expr(*aggr_label).func;

let group_by_label = match grouping_elem {
GroupingElementSyntax::ValueExpr(expr) => {
expr_resolver.register_value_expr(expr)
}
GroupingElementSyntax::ValueAlias(alias) => {
expr_resolver.resolve_value_alias(alias)?
}
};

Ok(Some(GroupAggregateParameter::new(
let group_by_labels = grouping_elements
.iter()
.map(|grouping_elem| match grouping_elem {
GroupingElementSyntax::ValueExpr(expr) => {
Ok(expr_resolver.register_value_expr(expr.clone()))
}
GroupingElementSyntax::ValueAlias(alias) => {
expr_resolver.resolve_value_alias(alias.clone())
}
})
.collect::<Result<Vec<_>>>()?;

Ok(Some(AggregateParameter::new(
aggr_func,
*aggr_label,
group_by_label,
GroupByLabels::new(group_by_labels),
)))
}
(None, 0) => Ok(None),
_ => unimplemented!(),
0 => Ok(None),
_ => unimplemented!("2 or more aggregate expressions"),
}
}

Expand Down
Expand Up @@ -4,8 +4,7 @@ use super::SelectSyntaxAnalyzer;
use crate::sql_processor::sql_parser::syntax::GroupingElementSyntax;

impl SelectSyntaxAnalyzer {
/// TODO multiple GROUP BY
pub(in super::super) fn grouping_element(&self) -> Option<GroupingElementSyntax> {
self.select_syntax.grouping_element.clone()
pub(in super::super) fn grouping_elements(&self) -> Vec<GroupingElementSyntax> {
self.select_syntax.grouping_elements.clone()
}
}
Expand Up @@ -479,7 +479,7 @@ select_stream_command = {
^"SELECT" ~ "STREAM"
~ select_field ~ ("," ~ select_field)*
~ (^"FROM" ~ from_item)
~ (^"GROUP" ~ "BY" ~ grouping_element)? // TODO multiple grouping elements
~ group_by_clause?
~ window_clause?
}

Expand All @@ -500,6 +500,10 @@ join_type = {
^"LEFT" ~ ^"OUTER" ~ ^"JOIN"
}

group_by_clause = {
^"GROUP" ~ "BY" ~ grouping_element ~ ("," ~ grouping_element)*
}

grouping_element = {
value_expr
| value_alias
Expand Down
29 changes: 19 additions & 10 deletions springql-core/src/sql_processor/sql_parser/pest_parser_impl.rs
Expand Up @@ -210,12 +210,12 @@ impl PestParserImpl {
)?;

let event_duration = match duration_function {
DurationFunction::Millis => {
Ok(SpringEventDuration::from_millis(integer_constant.to_i64()? as u64))
}
DurationFunction::Secs => {
Ok(SpringEventDuration::from_secs(integer_constant.to_i64()? as u64))
}
DurationFunction::Millis => Ok(SpringEventDuration::from_millis(
integer_constant.to_i64()? as u64,
)),
DurationFunction::Secs => Ok(SpringEventDuration::from_secs(
integer_constant.to_i64()? as u64,
)),
}?;

Ok(SqlValue::NotNull(NnSqlValue::Duration(event_duration)))
Expand Down Expand Up @@ -549,10 +549,10 @@ impl PestParserImpl {
Self::parse_from_item,
identity,
)?;
let grouping_element = try_parse_child(
let grouping_elements = try_parse_child(
&mut params,
Rule::grouping_element,
Self::parse_grouping_element,
Rule::group_by_clause,
Self::parse_group_by_clause,
identity,
)?;
let window_clause = try_parse_child(
Expand All @@ -565,7 +565,7 @@ impl PestParserImpl {
Ok(SelectStreamSyntax {
fields,
from_item,
grouping_element,
grouping_elements: grouping_elements.unwrap_or_default(),
window_clause,
})
}
Expand Down Expand Up @@ -680,6 +680,15 @@ impl PestParserImpl {
}
}

fn parse_group_by_clause(mut params: FnParseParams) -> Result<Vec<GroupingElementSyntax>> {
parse_child_seq(
&mut params,
Rule::grouping_element,
&Self::parse_grouping_element,
&identity,
)
}

fn parse_grouping_element(mut params: FnParseParams) -> Result<GroupingElementSyntax> {
try_parse_child(
&mut params,
Expand Down
5 changes: 4 additions & 1 deletion springql-core/src/sql_processor/sql_parser/syntax.rs
Expand Up @@ -26,7 +26,10 @@ pub(in crate::sql_processor) struct OptionSyntax {
pub(in crate::sql_processor) struct SelectStreamSyntax {
pub(in crate::sql_processor) fields: Vec<SelectFieldSyntax>,
pub(in crate::sql_processor) from_item: FromItemSyntax,
pub(in crate::sql_processor) grouping_element: Option<GroupingElementSyntax>,

/// Empty when no GROUP BY clause is supplied.
pub(in crate::sql_processor) grouping_elements: Vec<GroupingElementSyntax>,

pub(in crate::sql_processor) window_clause: Option<WindowParameter>,
}

Expand Down
Expand Up @@ -123,7 +123,7 @@ mod tests {
pipeline::{
name::{AggrAlias, ColumnName, StreamName},
pump_model::window_operation_parameter::aggregate::{
AggregateFunctionParameter, GroupAggregateParameter,
AggregateFunctionParameter, AggregateParameter, GroupByLabels,
},
},
sql_processor::sql_parser::syntax::SelectFieldSyntax,
Expand Down Expand Up @@ -218,10 +218,10 @@ mod tests {
period: SpringEventDuration::from_secs(5),
allowed_delay: SpringEventDuration::from_secs(1),
},
WindowOperationParameter::GroupAggregation(GroupAggregateParameter {
WindowOperationParameter::Aggregate(AggregateParameter {
aggr_func: AggregateFunctionParameter::Avg,
aggr_expr: aggr_label,
group_by: group_by_label,
group_by: GroupByLabels::new(vec![group_by_label]),
}),
);

Expand Down Expand Up @@ -449,10 +449,10 @@ mod tests {
length: SpringEventDuration::from_secs(10),
allowed_delay: SpringEventDuration::from_secs(1),
},
WindowOperationParameter::GroupAggregation(GroupAggregateParameter {
WindowOperationParameter::Aggregate(AggregateParameter {
aggr_func: AggregateFunctionParameter::Avg,
aggr_expr: aggr_label,
group_by: group_by_label,
group_by: GroupByLabels::new(vec![group_by_label]),
}),
);

Expand Down
Expand Up @@ -143,7 +143,7 @@ mod tests {
expr_resolver::{expr_label::ExprLabel, ExprResolver},
expression::{AggrExpr, ValueExpr},
pipeline::pump_model::window_operation_parameter::aggregate::{
AggregateFunctionParameter, GroupAggregateParameter,
AggregateFunctionParameter, AggregateParameter, GroupByLabels,
},
sql_processor::sql_parser::syntax::SelectFieldSyntax,
stream_engine::{
Expand All @@ -168,16 +168,17 @@ mod tests {
}];
let (mut expr_resolver, labels) = ExprResolver::new(select_list);

let group_by_label = expr_resolver.register_value_expr(group_by_expr);
let group_by_labels =
GroupByLabels::new(vec![expr_resolver.register_value_expr(group_by_expr)]);

WindowOperationParameter::GroupAggregation(GroupAggregateParameter {
WindowOperationParameter::Aggregate(AggregateParameter {
aggr_func: AggregateFunctionParameter::Avg,
aggr_expr: if let ExprLabel::Aggr(l) = labels[0] {
l
} else {
unreachable!()
},
group_by: group_by_label,
group_by: group_by_labels,
})
}

Expand Down

0 comments on commit 9c5a9cc

Please sign in to comment.