Skip to content

Commit

Permalink
Parse ARRAY_AGG for Bigquery and Snowflake
Browse files Browse the repository at this point in the history
  • Loading branch information
SuperBo committed Oct 8, 2022
1 parent a3194dd commit c1af40b
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 5 deletions.
38 changes: 38 additions & 0 deletions src/ast/mod.rs
Expand Up @@ -410,6 +410,8 @@ pub enum Expr {
ArraySubquery(Box<Query>),
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
ListAgg(ListAgg),
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
ArrayAgg(ArrayAgg),
/// The `GROUPING SETS` expr.
GroupingSets(Vec<Vec<Expr>>),
/// The `CUBE` expr.
Expand Down Expand Up @@ -649,6 +651,7 @@ impl fmt::Display for Expr {
Expr::Subquery(s) => write!(f, "({})", s),
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg),
Expr::GroupingSets(sets) => {
write!(f, "GROUPING SETS (")?;
let mut sep = "";
Expand Down Expand Up @@ -2844,6 +2847,41 @@ impl fmt::Display for ListAggOnOverflow {
}
}

/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] <expr> [ORDER BY <expr>] [LIMIT <n>] )
/// [ WITHIN GROUP (ORDER BY <within_group> ]`
/// ORDERY BY position is defined differently for BigQuery, Postgres and Snowflake
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ArrayAgg {
pub distinct: bool,
pub expr: Box<Expr>,
pub order_by: Option<Box<OrderByExpr>>,
pub limit: Option<Box<Expr>>,
pub within_group: Option<Box<OrderByExpr>>,
}

impl fmt::Display for ArrayAgg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"ARRAY_AGG({}{}",
if self.distinct { "DISTINCT " } else { "" },
self.expr
)?;
if let Some(order_by) = &self.order_by {
write!(f, " ORDER BY {}", order_by)?;
}
if let Some(limit) = &self.limit {
write!(f, " LIMIT {}", limit)?;
}
write!(f, ")")?;
if let Some(order_by) = &self.within_group {
write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?;
}
Ok(())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ObjectType {
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/mod.rs
Expand Up @@ -71,6 +71,10 @@ pub trait Dialect: Debug + Any {
fn supports_filter_during_aggregation(&self) -> bool {
false
}
/// Does the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] or ARRAY_AGG([ORDER BY])
fn supports_within_after_array_aggregation(&self) -> bool {
false
}
/// Dialect-specific prefix parser override
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/snowflake.rs
Expand Up @@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect {
|| ch == '$'
|| ch == '_'
}

fn supports_within_after_array_aggregation(&self) -> bool {
true
}
}
52 changes: 52 additions & 0 deletions src/parser.rs
Expand Up @@ -471,6 +471,7 @@ impl<'a> Parser<'a> {
self.expect_token(&Token::LParen)?;
self.parse_array_subquery()
}
Keyword::ARRAY_AGG => self.parse_array_agg_expr(),
Keyword::NOT => self.parse_not(),
// Here `w` is a word, check if it's a part of a multi-part
// identifier, a function call, or a simple identifier:
Expand Down Expand Up @@ -1067,6 +1068,57 @@ impl<'a> Parser<'a> {
}))
}

pub fn parse_array_agg_expr(&mut self) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;
let distinct = self.parse_all_or_distinct()?;
let expr = Box::new(self.parse_expr()?);
// ANSI SQL and BigQuery define ORDER BY inside function.
if !self.dialect.supports_within_after_array_aggregation() {
let order_by = if self.parse_keywords(&[Keyword:: ORDER, Keyword::BY]) {
let order_by_expr = self.parse_order_by_expr()?;
Some(Box::new(order_by_expr))
} else {
None
};
let limit = if self.parse_keyword(Keyword::LIMIT) {
match self.parse_limit()? {
Some(expr) => Some(Box::new(expr)),
None => None
}
} else {
None
};
self.expect_token(&Token::RParen)?;
return Ok(Expr::ArrayAgg(ArrayAgg {
distinct,
expr,
order_by,
limit,
within_group: None
}));
}
// Snowflake defines ORDERY BY in within group instead of inside the function like
// ANSI SQL.
self.expect_token(&Token::RParen)?;
let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) {
self.expect_token(&Token::LParen)?;
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
let order_by_expr = self.parse_order_by_expr()?;
self.expect_token(&Token::RParen)?;
Some(Box::new(order_by_expr))
} else {
None
};

Ok(Expr::ArrayAgg(ArrayAgg {
distinct,
expr,
order_by: None,
limit: None,
within_group,
}))
}

// This function parses date/time fields for the EXTRACT function-like
// operator, interval qualifiers, and the ceil/floor operations.
// EXTRACT supports a wider set of date/time fields than interval qualifiers,
Expand Down
20 changes: 20 additions & 0 deletions tests/sqlparser_bigquery.rs
Expand Up @@ -115,6 +115,26 @@ fn parse_cast_type() {
bigquery().verified_only_select(sql);
}

#[test]
fn parse_array_agg_func() {
for (sql, canonical) in [
(
"select array_agg(x order by x) as a from T",
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T"
),
(
"select array_agg(x order by x LIMIT 2) from tbl",
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl"
),
(
"select array_agg(distinct x order by x LIMIT 2) from tbl",
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl"
),
] {
bigquery().one_statement_parses_to(sql, canonical);
}
}

fn bigquery() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(BigQueryDialect {})],
Expand Down
34 changes: 33 additions & 1 deletion tests/sqlparser_common.rs
Expand Up @@ -21,9 +21,11 @@
#[macro_use]
mod test_utils;

use std::ops::Deref;

use matches::assert_matches;
use sqlparser::ast::SelectItem::UnnamedExpr;
use sqlparser::ast::*;
use sqlparser::{ast::*, dialect};
use sqlparser::dialect::{
AnsiDialect, BigQueryDialect, ClickHouseDialect, GenericDialect, HiveDialect, MsSqlDialect,
PostgreSqlDialect, SQLiteDialect, SnowflakeDialect,
Expand Down Expand Up @@ -1899,6 +1901,36 @@ fn parse_listagg() {
);
}

#[test]
fn parse_array_agg_func() {
let supported_dialects = TestedDialects {
dialects: vec![
Box::new(GenericDialect {}),
Box::new(PostgreSqlDialect {}),
Box::new(MsSqlDialect {}),
Box::new(AnsiDialect {}),
Box::new(HiveDialect {}),
]
};

for (sql, canonical) in [
(
"select array_agg(x order by x) as a from T",
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T"
),
(
"select array_agg(x order by x LIMIT 2) from tbl",
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl"
),
(
"select array_agg(distinct x order by x LIMIT 2) from tbl",
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl"
),
] {
supported_dialects.one_statement_parses_to(sql, canonical);
}
}

#[test]
fn parse_create_table() {
let sql = "CREATE TABLE uk_cities (\
Expand Down
8 changes: 4 additions & 4 deletions tests/sqlparser_hive.rs
Expand Up @@ -279,17 +279,17 @@ fn parse_create_function() {
#[test]
fn filtering_during_aggregation() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL), \
array_agg(name) FILTER (WHERE name LIKE 'a%') \
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL), \
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') \
FROM region";
println!("{}", hive().verified_stmt(rename));
}

#[test]
fn filtering_during_aggregation_aliased() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
FROM region";
println!("{}", hive().verified_stmt(rename));
}
Expand Down
20 changes: 20 additions & 0 deletions tests/sqlparser_snowflake.rs
Expand Up @@ -143,6 +143,26 @@ fn test_single_table_in_parenthesis_with_alias() {
);
}

#[test]
fn test_array_agg_func() {
for (sql, canonical) in [
(
"select array_agg(x) within group (order by x) as a from T",
"SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T"
),
(
"select array_agg(distinct x) within group (order by x asc) from tbl",
"SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl"
),
] {
snowflake().one_statement_parses_to(sql, canonical);
}

let sql = "select array_agg(x order by x) as a from T";
let result = snowflake().parse_sql_statements(&sql);
assert_eq!(result, Err(ParserError::ParserError(String::from("Expected ), found: order"))))
}

fn snowflake() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(SnowflakeDialect {})],
Expand Down

0 comments on commit c1af40b

Please sign in to comment.