diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 4f5fdb2eb..6c279518e 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -375,6 +375,8 @@ pub enum Expr { MapAccess { column: Box, keys: Vec }, /// Scalar function call e.g. `LEFT(foo, 5)` Function(Function), + /// Aggregate function with filter + AggregateExpressionWithFilter { expr: Box, filter: Box }, /// `CASE [] WHEN THEN ... [ELSE ] END` /// /// Note we only recognize a complete single expression as ``, @@ -571,6 +573,9 @@ impl fmt::Display for Expr { write!(f, " '{}'", &value::escape_single_quote_string(value)) } Expr::Function(fun) => write!(f, "{}", fun), + Expr::AggregateExpressionWithFilter { expr, filter } => { + write!(f, "{} FILTER (WHERE {})", expr, filter) + } Expr::Case { operand, conditions, diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index 9b42857ec..ceb5488ef 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -36,4 +36,8 @@ impl Dialect for HiveDialect { || ch == '{' || ch == '}' } + + fn supports_filter_during_aggregation(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 46e8dda2c..1d3c9cf5f 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -67,6 +67,10 @@ pub trait Dialect: Debug + Any { fn is_identifier_start(&self, ch: char) -> bool; /// Determine if a character is a valid unquoted identifier character fn is_identifier_part(&self, ch: char) -> bool; + /// Does the dialect support `FILTER (WHERE expr)` for aggregate queries? + fn supports_filter_during_aggregation(&self) -> bool { + false + } /// Dialect-specific prefix parser override fn parse_prefix(&self, _parser: &mut Parser) -> Option> { // return None to fall back to the default behavior diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 04d64b9bf..b1f261b2e 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -42,6 +42,10 @@ impl Dialect for PostgreSqlDialect { None } } + + fn supports_filter_during_aggregation(&self) -> bool { + true + } } pub fn parse_comment(parser: &mut Parser) -> Result { diff --git a/src/parser.rs b/src/parser.rs index 877c47303..894bb84f1 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -4542,12 +4542,31 @@ impl<'a> Parser<'a> { /// Parse a comma-delimited list of projections after SELECT pub fn parse_select_item(&mut self) -> Result { match self.parse_wildcard_expr()? { - WildcardExpr::Expr(expr) => self - .parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS) - .map(|alias| match alias { - Some(alias) => SelectItem::ExprWithAlias { expr, alias }, - None => SelectItem::UnnamedExpr(expr), - }), + WildcardExpr::Expr(expr) => { + let expr: Expr = if self.dialect.supports_filter_during_aggregation() + && self.parse_keyword(Keyword::FILTER) + { + let i = self.index - 1; + if self.consume_token(&Token::LParen) && self.parse_keyword(Keyword::WHERE) { + let filter = self.parse_expr()?; + self.expect_token(&Token::RParen)?; + Expr::AggregateExpressionWithFilter { + expr: Box::new(expr), + filter: Box::new(filter), + } + } else { + self.index = i; + expr + } + } else { + expr + }; + self.parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS) + .map(|alias| match alias { + Some(alias) => SelectItem::ExprWithAlias { expr, alias }, + None => SelectItem::UnnamedExpr(expr), + }) + } WildcardExpr::QualifiedWildcard(prefix) => Ok(SelectItem::QualifiedWildcard(prefix)), WildcardExpr::Wildcard => Ok(SelectItem::Wildcard), } diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 4223ad5fa..8839cea2b 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -276,6 +276,31 @@ fn parse_create_function() { ); } +#[test] +fn filtering_during_aggregation() { + let rename = "SELECT \ + array_agg(name) FILTER (WHERE name IS NOT NULL), \ + array_agg(name) FILTER (WHERE name LIKE 'a%') \ + FROM region"; + println!("{}", hive().verified_stmt(rename)); +} + +#[test] +fn filtering_during_aggregation_aliased() { + let rename = "SELECT \ + array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \ + array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \ + FROM region"; + println!("{}", hive().verified_stmt(rename)); +} + +#[test] +fn filter_as_alias() { + let sql = "SELECT name filter FROM region"; + let expected = "SELECT name AS filter FROM region"; + println!("{}", hive().one_statement_parses_to(sql, expected)); +} + fn hive() -> TestedDialects { TestedDialects { dialects: vec![Box::new(HiveDialect {})],