Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for aggregate expressions with filters #585

Merged
merged 9 commits into from Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/ast/mod.rs
Expand Up @@ -375,6 +375,8 @@ pub enum Expr {
MapAccess { column: Box<Expr>, keys: Vec<Expr> },
/// Scalar function call e.g. `LEFT(foo, 5)`
Function(Function),
/// Aggregate function with filter
AggregateExpressionWithFilter { expr: Box<Expr>, filter: Box<Expr> },
/// `CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END`
///
/// Note we only recognize a complete single expression as `<condition>`,
Expand Down Expand Up @@ -571,6 +573,9 @@ impl fmt::Display for Expr {
write!(f, " '{}'", &value::escape_single_quote_string(value))
}
Expr::Function(fun) => write!(f, "{}", fun),
Expr::AggregateExpressionWithFilter { expr, filter } => {
write!(f, "{} FILTER (WHERE {})", expr, filter)
}
Expr::Case {
operand,
conditions,
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/hive.rs
Expand Up @@ -36,4 +36,8 @@ impl Dialect for HiveDialect {
|| ch == '{'
|| ch == '}'
}

fn supports_filter_during_aggregation(&self) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we would want to add this to src/dialect/postgres.rs as well since postgresql supports FILTER also https://www.postgresql.org/docs/current/sql-expressions.html

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

true
}
}
4 changes: 4 additions & 0 deletions src/dialect/mod.rs
Expand Up @@ -67,6 +67,10 @@ pub trait Dialect: Debug + Any {
fn is_identifier_start(&self, ch: char) -> bool;
/// Determine if a character is a valid unquoted identifier character
fn is_identifier_part(&self, ch: char) -> bool;
/// Does the dialect support `FILTER (WHERE expr)` for aggregate queries?
fn supports_filter_during_aggregation(&self) -> bool {
false
}
/// Dialect-specific prefix parser override
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/postgresql.rs
Expand Up @@ -42,6 +42,10 @@ impl Dialect for PostgreSqlDialect {
None
}
}

fn supports_filter_during_aggregation(&self) -> bool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what it is worth, I think the DataFusion actually uses the Generic Dialect, https://github.com/apache/arrow-datafusion/blob/0084aeb686b318cbdb49cab00cb8f15c9f520d1e/datafusion/sql/src/parser.rs#L102

Thus we may want to add return true from supports_filter_during_aggregation in https://github.com/sqlparser-rs/sqlparser-rs/blob/main/src/dialect/generic.rs as well

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Dask SQL has their own dialect where they want to use this, but it makes sense to add to the dialect that DataFusion is using as well.

true
}
}

pub fn parse_comment(parser: &mut Parser) -> Result<Statement, ParserError> {
Expand Down
31 changes: 25 additions & 6 deletions src/parser.rs
Expand Up @@ -4542,12 +4542,31 @@ impl<'a> Parser<'a> {
/// Parse a comma-delimited list of projections after SELECT
pub fn parse_select_item(&mut self) -> Result<SelectItem, ParserError> {
match self.parse_wildcard_expr()? {
WildcardExpr::Expr(expr) => self
.parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS)
.map(|alias| match alias {
Some(alias) => SelectItem::ExprWithAlias { expr, alias },
None => SelectItem::UnnamedExpr(expr),
}),
WildcardExpr::Expr(expr) => {
let expr: Expr = if self.dialect.supports_filter_during_aggregation()
&& self.parse_keyword(Keyword::FILTER)
{
let i = self.index - 1;
if self.consume_token(&Token::LParen) && self.parse_keyword(Keyword::WHERE) {
let filter = self.parse_expr()?;
self.expect_token(&Token::RParen)?;
Expr::AggregateExpressionWithFilter {
expr: Box::new(expr),
filter: Box::new(filter),
}
} else {
self.index = i;
expr
}
} else {
expr
};
self.parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS)
.map(|alias| match alias {
Some(alias) => SelectItem::ExprWithAlias { expr, alias },
None => SelectItem::UnnamedExpr(expr),
})
}
WildcardExpr::QualifiedWildcard(prefix) => Ok(SelectItem::QualifiedWildcard(prefix)),
WildcardExpr::Wildcard => Ok(SelectItem::Wildcard),
}
Expand Down
25 changes: 25 additions & 0 deletions tests/sqlparser_hive.rs
Expand Up @@ -276,6 +276,31 @@ fn parse_create_function() {
);
}

#[test]
fn filtering_during_aggregation() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL), \
array_agg(name) FILTER (WHERE name LIKE 'a%') \
FROM region";
println!("{}", hive().verified_stmt(rename));
}

#[test]
fn filtering_during_aggregation_aliased() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
FROM region";
println!("{}", hive().verified_stmt(rename));
}

#[test]
fn filter_as_alias() {
let sql = "SELECT name filter FROM region";
let expected = "SELECT name AS filter FROM region";
println!("{}", hive().one_statement_parses_to(sql, expected));
}

fn hive() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(HiveDialect {})],
Expand Down