diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 342bd28cf..14f298e5a 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -410,6 +410,8 @@ pub enum Expr { ArraySubquery(Box), /// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)` ListAgg(ListAgg), + /// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)` + ArrayAgg(ArrayAgg), /// The `GROUPING SETS` expr. GroupingSets(Vec>), /// The `CUBE` expr. @@ -649,6 +651,7 @@ impl fmt::Display for Expr { Expr::Subquery(s) => write!(f, "({})", s), Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s), Expr::ListAgg(listagg) => write!(f, "{}", listagg), + Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg), Expr::GroupingSets(sets) => { write!(f, "GROUPING SETS (")?; let mut sep = ""; @@ -2844,6 +2847,41 @@ impl fmt::Display for ListAggOnOverflow { } } +/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] [ORDER BY ] [LIMIT ] ) +/// [ WITHIN GROUP (ORDER BY ]` +/// ORDERY BY position is defined differently for BigQuery, Postgres and Snowflake +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ArrayAgg { + pub distinct: bool, + pub expr: Box, + pub order_by: Option>, + pub limit: Option>, + pub within_group: Option>, +} + +impl fmt::Display for ArrayAgg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "ARRAY_AGG({}{}", + if self.distinct { "DISTINCT " } else { "" }, + self.expr + )?; + if let Some(order_by) = &self.order_by { + write!(f, " ORDER BY {}", order_by)?; + } + if let Some(limit) = &self.limit { + write!(f, " LIMIT {}", limit)?; + } + write!(f, ")")?; + if let Some(order_by) = &self.within_group { + write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?; + } + Ok(()) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum ObjectType { diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 1d3c9cf5f..2a0621b0f 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -71,6 +71,10 @@ pub trait Dialect: Debug + Any { fn supports_filter_during_aggregation(&self) -> bool { false } + /// Does the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] or ARRAY_AGG([ORDER BY]) + fn supports_within_after_array_aggregation(&self) -> bool { + false + } /// Dialect-specific prefix parser override fn parse_prefix(&self, _parser: &mut Parser) -> Option> { // return None to fall back to the default behavior diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index 93db95692..11108e973 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect { || ch == '$' || ch == '_' } + + fn supports_within_after_array_aggregation(&self) -> bool { + true + } } diff --git a/src/parser.rs b/src/parser.rs index fb473b74f..68cd3704d 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -471,6 +471,7 @@ impl<'a> Parser<'a> { self.expect_token(&Token::LParen)?; self.parse_array_subquery() } + Keyword::ARRAY_AGG => self.parse_array_agg_expr(), Keyword::NOT => self.parse_not(), // Here `w` is a word, check if it's a part of a multi-part // identifier, a function call, or a simple identifier: @@ -1067,6 +1068,57 @@ impl<'a> Parser<'a> { })) } + pub fn parse_array_agg_expr(&mut self) -> Result { + self.expect_token(&Token::LParen)?; + let distinct = self.parse_all_or_distinct()?; + let expr = Box::new(self.parse_expr()?); + // ANSI SQL and BigQuery define ORDER BY inside function. + if !self.dialect.supports_within_after_array_aggregation() { + let order_by = if self.parse_keywords(&[Keyword:: ORDER, Keyword::BY]) { + let order_by_expr = self.parse_order_by_expr()?; + Some(Box::new(order_by_expr)) + } else { + None + }; + let limit = if self.parse_keyword(Keyword::LIMIT) { + match self.parse_limit()? { + Some(expr) => Some(Box::new(expr)), + None => None + } + } else { + None + }; + self.expect_token(&Token::RParen)?; + return Ok(Expr::ArrayAgg(ArrayAgg { + distinct, + expr, + order_by, + limit, + within_group: None + })); + } + // Snowflake defines ORDERY BY in within group instead of inside the function like + // ANSI SQL. + self.expect_token(&Token::RParen)?; + let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) { + self.expect_token(&Token::LParen)?; + self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?; + let order_by_expr = self.parse_order_by_expr()?; + self.expect_token(&Token::RParen)?; + Some(Box::new(order_by_expr)) + } else { + None + }; + + Ok(Expr::ArrayAgg(ArrayAgg { + distinct, + expr, + order_by: None, + limit: None, + within_group, + })) + } + // This function parses date/time fields for the EXTRACT function-like // operator, interval qualifiers, and the ceil/floor operations. // EXTRACT supports a wider set of date/time fields than interval qualifiers, diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index 0a606c3ec..5d43f342d 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -115,6 +115,26 @@ fn parse_cast_type() { bigquery().verified_only_select(sql); } +#[test] +fn parse_array_agg_func() { + for (sql, canonical) in [ + ( + "select array_agg(x order by x) as a from T", + "SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T" + ), + ( + "select array_agg(x order by x LIMIT 2) from tbl", + "SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl" + ), + ( + "select array_agg(distinct x order by x LIMIT 2) from tbl", + "SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl" + ), + ] { + bigquery().one_statement_parses_to(sql, canonical); + } +} + fn bigquery() -> TestedDialects { TestedDialects { dialects: vec![Box::new(BigQueryDialect {})], diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 7654d677e..59b839b59 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -21,9 +21,11 @@ #[macro_use] mod test_utils; +use std::ops::Deref; + use matches::assert_matches; use sqlparser::ast::SelectItem::UnnamedExpr; -use sqlparser::ast::*; +use sqlparser::{ast::*, dialect}; use sqlparser::dialect::{ AnsiDialect, BigQueryDialect, ClickHouseDialect, GenericDialect, HiveDialect, MsSqlDialect, PostgreSqlDialect, SQLiteDialect, SnowflakeDialect, @@ -1899,6 +1901,36 @@ fn parse_listagg() { ); } +#[test] +fn parse_array_agg_func() { + let supported_dialects = TestedDialects { + dialects: vec![ + Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiDialect {}), + Box::new(HiveDialect {}), + ] + }; + + for (sql, canonical) in [ + ( + "select array_agg(x order by x) as a from T", + "SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T" + ), + ( + "select array_agg(x order by x LIMIT 2) from tbl", + "SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl" + ), + ( + "select array_agg(distinct x order by x LIMIT 2) from tbl", + "SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl" + ), + ] { + supported_dialects.one_statement_parses_to(sql, canonical); + } +} + #[test] fn parse_create_table() { let sql = "CREATE TABLE uk_cities (\ diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 8839cea2b..8d93d7827 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -279,8 +279,8 @@ fn parse_create_function() { #[test] fn filtering_during_aggregation() { let rename = "SELECT \ - array_agg(name) FILTER (WHERE name IS NOT NULL), \ - array_agg(name) FILTER (WHERE name LIKE 'a%') \ + ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL), \ + ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') \ FROM region"; println!("{}", hive().verified_stmt(rename)); } @@ -288,8 +288,8 @@ fn filtering_during_aggregation() { #[test] fn filtering_during_aggregation_aliased() { let rename = "SELECT \ - array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \ - array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \ + ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL) AS agg1, \ + ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') AS agg2 \ FROM region"; println!("{}", hive().verified_stmt(rename)); } diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index 7c089a935..3dbbd2f3d 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -143,6 +143,26 @@ fn test_single_table_in_parenthesis_with_alias() { ); } +#[test] +fn test_array_agg_func() { + for (sql, canonical) in [ + ( + "select array_agg(x) within group (order by x) as a from T", + "SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T" + ), + ( + "select array_agg(distinct x) within group (order by x asc) from tbl", + "SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl" + ), + ] { + snowflake().one_statement_parses_to(sql, canonical); + } + + let sql = "select array_agg(x order by x) as a from T"; + let result = snowflake().parse_sql_statements(&sql); + assert_eq!(result, Err(ParserError::ParserError(String::from("Expected ), found: order")))) +} + fn snowflake() -> TestedDialects { TestedDialects { dialects: vec![Box::new(SnowflakeDialect {})],