From c6edf5f2d7a3dba6be4ae87a32383c23a8d5d64d Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Wed, 10 Aug 2022 04:53:20 +0400 Subject: [PATCH] Support PostgreSQL array subquery constructor --- src/ast/mod.rs | 3 ++ src/parser.rs | 16 +++++++++- tests/sqlparser_common.rs | 2 +- tests/sqlparser_postgres.rs | 63 +++++++++++++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 2 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 312a86813..6f44b2839 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -355,6 +355,8 @@ pub enum Expr { /// A parenthesized subquery `(SELECT ...)`, used in expression like /// `SELECT (subquery) AS x` or `WHERE (subquery) = x` Subquery(Box), + /// An array subquery constructor, e.g. `SELECT ARRAY(SELECT 1 UNION SELECT 2)` + ArraySubquery(Box), /// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)` ListAgg(ListAgg), /// The `GROUPING SETS` expr. @@ -486,6 +488,7 @@ impl fmt::Display for Expr { subquery ), Expr::Subquery(s) => write!(f, "({})", s), + Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s), Expr::ListAgg(listagg) => write!(f, "{}", listagg), Expr::GroupingSets(sets) => { write!(f, "GROUPING SETS (")?; diff --git a/src/parser.rs b/src/parser.rs index 76190feb9..441db524d 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -434,11 +434,18 @@ impl<'a> Parser<'a> { Keyword::TRIM => self.parse_trim_expr(), Keyword::INTERVAL => self.parse_literal_interval(), Keyword::LISTAGG => self.parse_listagg_expr(), - // Treat ARRAY[1,2,3] as an array [1,2,3], otherwise try as function call + // Treat ARRAY[1,2,3] as an array [1,2,3], otherwise try as subquery or a function call Keyword::ARRAY if self.peek_token() == Token::LBracket => { self.expect_token(&Token::LBracket)?; self.parse_array_expr(true) } + Keyword::ARRAY + if dialect_of!(self is PostgreSqlDialect | GenericDialect) + && self.peek_token() == Token::LParen => + { + self.expect_token(&Token::LParen)?; + self.parse_array_subquery() + } Keyword::NOT => self.parse_not(), // Here `w` is a word, check if it's a part of a multi-part // identifier, a function call, or a simple identifier: @@ -910,6 +917,13 @@ impl<'a> Parser<'a> { } } + // Parses an array constructed from a subquery + pub fn parse_array_subquery(&mut self) -> Result { + let query = self.parse_query()?; + self.expect_token(&Token::RParen)?; + Ok(Expr::ArraySubquery(Box::new(query))) + } + /// Parse a SQL LISTAGG expression, e.g. `LISTAGG(...) WITHIN GROUP (ORDER BY ...)`. pub fn parse_listagg_expr(&mut self) -> Result { self.expect_token(&Token::LParen)?; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index ef6011841..fe11eace5 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -2518,7 +2518,7 @@ fn parse_bad_constraint() { #[test] fn parse_scalar_function_in_projection() { - let names = vec!["sqrt", "array", "foo"]; + let names = vec!["sqrt", "foo"]; for function_name in names { // like SELECT sqrt(id) FROM foo diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 632a8bf34..fc20e0254 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1241,6 +1241,69 @@ fn parse_array_index_expr() { ); } +#[test] +fn parse_array_subquery_expr() { + let sql = "SELECT ARRAY(SELECT 1 UNION SELECT 2)"; + let select = pg().verified_only_select(sql); + assert_eq!( + &Expr::ArraySubquery(Box::new(Query { + with: None, + body: Box::new(SetExpr::SetOperation { + op: SetOperator::Union, + all: false, + left: Box::new(SetExpr::Select(Box::new(Select { + distinct: false, + top: None, + projection: vec![SelectItem::UnnamedExpr(Expr::Value(Value::Number( + #[cfg(not(feature = "bigdecimal"))] + "1".to_string(), + #[cfg(feature = "bigdecimal")] + bigdecimal::BigDecimal::from("1"), + false, + )))], + into: None, + from: vec![], + lateral_views: vec![], + selection: None, + group_by: vec![], + cluster_by: vec![], + distribute_by: vec![], + sort_by: vec![], + having: None, + qualify: None, + }))), + right: Box::new(SetExpr::Select(Box::new(Select { + distinct: false, + top: None, + projection: vec![SelectItem::UnnamedExpr(Expr::Value(Value::Number( + #[cfg(not(feature = "bigdecimal"))] + "2".to_string(), + #[cfg(feature = "bigdecimal")] + bigdecimal::BigDecimal::from("2"), + false, + )))], + into: None, + from: vec![], + lateral_views: vec![], + selection: None, + group_by: vec![], + cluster_by: vec![], + distribute_by: vec![], + sort_by: vec![], + having: None, + qualify: None, + }))), + }), + order_by: vec![], + limit: None, + offset: None, + fetch: None, + lock: None, + })), + expr_from_projection(only(&select.projection)), + ); +} + #[test] fn test_transaction_statement() { let statement = pg().verified_stmt("SET TRANSACTION SNAPSHOT '000003A1-1'");