From 8ab852f0e159e4848ba53e82275a31ca1468d6e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Wed, 30 Nov 2022 14:52:59 +0100 Subject: [PATCH 1/3] Adapt VALUES to MySQL dialect --- src/ast/query.rs | 10 ++++-- src/dialect/mod.rs | 4 +++ src/dialect/mysql.rs | 4 +++ src/parser.rs | 37 ++++++++++++++++------ tests/sqlparser_common.rs | 40 ++++++++++++++++++------ tests/sqlparser_mysql.rs | 62 +++++++++++++++++++++++-------------- tests/sqlparser_postgres.rs | 4 ++- 7 files changed, 115 insertions(+), 46 deletions(-) diff --git a/src/ast/query.rs b/src/ast/query.rs index 4f3d79cdf..da764c39a 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -709,16 +709,20 @@ impl fmt::Display for Top { #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct Values(pub Vec>); +pub struct Values { + pub explicit_row: bool, + pub rows: Vec>, +} impl fmt::Display for Values { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "VALUES ")?; + let prefix = if self.explicit_row { "ROW" } else { "" }; let mut delim = ""; - for row in &self.0 { + for row in &self.rows { write!(f, "{}", delim)?; delim = ", "; - write!(f, "({})", display_comma_separated(row))?; + write!(f, "{prefix}({})", display_comma_separated(row))?; } Ok(()) } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 1eaa41aa7..7772afb87 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -102,6 +102,10 @@ pub trait Dialect: Debug + Any { // return None to fall back to the default behavior None } + /// Returns true if VALUES requires ROW keywords in SELECT. + fn values_require_row_in_select(&self) -> bool { + false + } } impl dyn Dialect { diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index d6095262c..e9458311e 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -35,4 +35,8 @@ impl Dialect for MySqlDialect { fn is_delimited_identifier_start(&self, ch: char) -> bool { ch == '`' } + + fn values_require_row_in_select(&self) -> bool { + true + } } diff --git a/src/parser.rs b/src/parser.rs index e537eefae..a0dcff244 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -4162,6 +4162,10 @@ impl<'a> Parser<'a> { /// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't /// expect the initial keyword to be already consumed pub fn parse_query(&mut self) -> Result { + self.parse_query_impl(false) + } + + pub fn parse_query_impl(&mut self, within_insert: bool) -> Result { let with = if self.parse_keyword(Keyword::WITH) { Some(With { recursive: self.parse_keyword(Keyword::RECURSIVE), @@ -4172,7 +4176,7 @@ impl<'a> Parser<'a> { }; if !self.parse_keyword(Keyword::INSERT) { - let body = Box::new(self.parse_query_body(0)?); + let body = Box::new(self.parse_query_body(0, within_insert)?); let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) { self.parse_comma_separated(Parser::parse_order_by_expr)? @@ -4287,7 +4291,11 @@ impl<'a> Parser<'a> { /// subquery ::= query_body [ order_by_limit ] /// set_operation ::= query_body { 'UNION' | 'EXCEPT' | 'INTERSECT' } [ 'ALL' ] query_body /// ``` - pub fn parse_query_body(&mut self, precedence: u8) -> Result { + pub fn parse_query_body( + &mut self, + precedence: u8, + within_insert: bool, + ) -> Result { // We parse the expression using a Pratt parser, as in `parse_expr()`. // Start by parsing a restricted SELECT or a `(subquery)`: let mut expr = if self.parse_keyword(Keyword::SELECT) { @@ -4298,7 +4306,7 @@ impl<'a> Parser<'a> { self.expect_token(&Token::RParen)?; SetExpr::Query(Box::new(subquery)) } else if self.parse_keyword(Keyword::VALUES) { - SetExpr::Values(self.parse_values()?) + SetExpr::Values(self.parse_values(within_insert)?) } else { return self.expected( "SELECT, VALUES, or a subquery in the query body", @@ -4326,7 +4334,7 @@ impl<'a> Parser<'a> { left: Box::new(expr), op: op.unwrap(), set_quantifier, - right: Box::new(self.parse_query_body(next_precedence)?), + right: Box::new(self.parse_query_body(next_precedence, within_insert)?), }; } @@ -5226,7 +5234,7 @@ impl<'a> Parser<'a> { // Hive allows you to specify columns after partitions as well if you want. let after_columns = self.parse_parenthesized_column_list(Optional)?; - let source = Box::new(self.parse_query()?); + let source = Box::new(self.parse_query_impl(true)?); let on = if self.parse_keyword(Keyword::ON) { if self.parse_keyword(Keyword::CONFLICT) { let conflict_target = @@ -5482,14 +5490,25 @@ impl<'a> Parser<'a> { } } - pub fn parse_values(&mut self) -> Result { - let values = self.parse_comma_separated(|parser| { + pub fn parse_values(&mut self, within_insert: bool) -> Result { + let mut explicit_row = false; + + let rows = self.parse_comma_separated(|parser| { + if parser.parse_keyword(Keyword::ROW) { + explicit_row = true; + } else { + if !within_insert && parser.dialect.values_require_row_in_select() { + parser + .expected(format!("{:?}", &Keyword::ROW).as_str(), parser.peek_token())?; + } + } + parser.expect_token(&Token::LParen)?; let exprs = parser.parse_comma_separated(Parser::parse_expr)?; parser.expect_token(&Token::RParen)?; Ok(exprs) })?; - Ok(Values(values)) + Ok(Values { explicit_row, rows }) } pub fn parse_start_transaction(&mut self) -> Result { @@ -5655,7 +5674,7 @@ impl<'a> Parser<'a> { } let columns = self.parse_parenthesized_column_list(Optional)?; self.expect_keyword(Keyword::VALUES)?; - let values = self.parse_values()?; + let values = self.parse_values(true)?; MergeClause::NotMatched { predicate, columns, diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 6790bbff6..85ef60e83 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -88,7 +88,9 @@ fn parse_insert_values() { assert_eq!(column, &Ident::new(expected_columns[index].clone())); } match &*source.body { - SetExpr::Values(Values(values)) => assert_eq!(values.as_slice(), expected_rows), + SetExpr::Values(Values { rows, .. }) => { + assert_eq!(rows.as_slice(), expected_rows) + } _ => unreachable!(), } } @@ -459,7 +461,7 @@ fn parse_top_level() { verified_stmt("SELECT 1"); verified_stmt("(SELECT 1)"); verified_stmt("((SELECT 1))"); - verified_stmt("VALUES (1)"); + all_but_mysql().verified_stmt("VALUES (1)"); } #[test] @@ -4233,9 +4235,9 @@ fn parse_union_except_intersect() { #[test] fn parse_values() { - verified_stmt("SELECT * FROM (VALUES (1), (2), (3))"); - verified_stmt("SELECT * FROM (VALUES (1), (2), (3)), (VALUES (1, 2, 3))"); - verified_stmt("SELECT * FROM (VALUES (1)) UNION VALUES (1)"); + all_but_mysql().verified_stmt("SELECT * FROM (VALUES (1), (2), (3))"); + all_but_mysql().verified_stmt("SELECT * FROM (VALUES (1), (2), (3)), (VALUES (1, 2, 3))"); + all_but_mysql().verified_stmt("SELECT * FROM (VALUES (1)) UNION VALUES (1)"); } #[test] @@ -5505,11 +5507,14 @@ fn parse_merge() { MergeClause::NotMatched { predicate: None, columns: vec![Ident::new("A"), Ident::new("B"), Ident::new("C")], - values: Values(vec![vec![ - Expr::CompoundIdentifier(vec![Ident::new("stg"), Ident::new("A")]), - Expr::CompoundIdentifier(vec![Ident::new("stg"), Ident::new("B")]), - Expr::CompoundIdentifier(vec![Ident::new("stg"), Ident::new("C")]), - ]]), + values: Values { + explicit_row: false, + rows: vec![vec![ + Expr::CompoundIdentifier(vec![Ident::new("stg"), Ident::new("A")]), + Expr::CompoundIdentifier(vec![Ident::new("stg"), Ident::new("B")]), + Expr::CompoundIdentifier(vec![Ident::new("stg"), Ident::new("C")]), + ]] + }, }, MergeClause::MatchedUpdate { predicate: Some(Expr::BinaryOp { @@ -5680,6 +5685,21 @@ fn verified_expr(query: &str) -> Expr { all_dialects().verified_expr(query) } +fn all_but_mysql() -> TestedDialects { + TestedDialects { + dialects: vec![ + Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(MsSqlDialect {}), + Box::new(AnsiDialect {}), + Box::new(SnowflakeDialect {}), + Box::new(HiveDialect {}), + Box::new(RedshiftSqlDialect {}), + Box::new(BigQueryDialect {}), + ], + } +} + #[test] fn parse_offset_and_limit() { let sql = "SELECT foo FROM bar LIMIT 2 OFFSET 2"; diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index e91cea0ef..947f889e8 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -660,20 +660,25 @@ fn parse_simple_insert() { assert_eq!( Box::new(Query { with: None, - body: Box::new(SetExpr::Values(Values(vec![ - vec![ - Expr::Value(Value::SingleQuotedString("Test Some Inserts".to_string())), - Expr::Value(Value::Number("1".to_string(), false)) - ], - vec![ - Expr::Value(Value::SingleQuotedString("Test Entry 2".to_string())), - Expr::Value(Value::Number("2".to_string(), false)) - ], - vec![ - Expr::Value(Value::SingleQuotedString("Test Entry 3".to_string())), - Expr::Value(Value::Number("3".to_string(), false)) + body: Box::new(SetExpr::Values(Values { + explicit_row: false, + rows: vec![ + vec![ + Expr::Value(Value::SingleQuotedString( + "Test Some Inserts".to_string() + )), + Expr::Value(Value::Number("1".to_string(), false)) + ], + vec![ + Expr::Value(Value::SingleQuotedString("Test Entry 2".to_string())), + Expr::Value(Value::Number("2".to_string(), false)) + ], + vec![ + Expr::Value(Value::SingleQuotedString("Test Entry 3".to_string())), + Expr::Value(Value::Number("3".to_string(), false)) + ] ] - ]))), + })), order_by: vec![], limit: None, offset: None, @@ -717,16 +722,21 @@ fn parse_insert_with_on_duplicate_update() { assert_eq!( Box::new(Query { with: None, - body: Box::new(SetExpr::Values(Values(vec![vec![ - Expr::Value(Value::SingleQuotedString("accounting_manager".to_string())), - Expr::Value(Value::SingleQuotedString( - "Some description about the group".to_string() - )), - Expr::Value(Value::Boolean(true)), - Expr::Value(Value::Boolean(true)), - Expr::Value(Value::Boolean(true)), - Expr::Value(Value::Boolean(true)), - ]]))), + body: Box::new(SetExpr::Values(Values { + explicit_row: false, + rows: vec![vec![ + Expr::Value(Value::SingleQuotedString( + "accounting_manager".to_string() + )), + Expr::Value(Value::SingleQuotedString( + "Some description about the group".to_string() + )), + Expr::Value(Value::Boolean(true)), + Expr::Value(Value::Boolean(true)), + Expr::Value(Value::Boolean(true)), + Expr::Value(Value::Boolean(true)), + ]] + })), order_by: vec![], limit: None, offset: None, @@ -1183,3 +1193,9 @@ fn mysql_and_generic() -> TestedDialects { dialects: vec![Box::new(MySqlDialect {}), Box::new(GenericDialect {})], } } + +#[test] +fn parse_values() { + mysql().verified_stmt("VALUES ROW(1, true, 'a')"); + mysql().verified_stmt("SELECT a, c FROM (VALUES ROW(1, true, 'a'), ROW(2, false, 'b'), ROW(3, false, 'c')) AS t (a, b, c)"); +} diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 5cc333935..05c601f0c 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1067,7 +1067,9 @@ fn parse_prepare() { Expr::Identifier("a3".into()), ]]; match &*source.body { - SetExpr::Values(Values(values)) => assert_eq!(values.as_slice(), &expected_values), + SetExpr::Values(Values { rows, .. }) => { + assert_eq!(rows.as_slice(), &expected_values) + } _ => unreachable!(), } } From 4478808ea725e3846672009d387e3d2a210d4721 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Wed, 30 Nov 2022 19:51:11 +0100 Subject: [PATCH 2/3] Update src/ast/query.rs Co-authored-by: Andrew Lamb --- src/ast/query.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ast/query.rs b/src/ast/query.rs index da764c39a..689bfae57 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -710,6 +710,8 @@ impl fmt::Display for Top { #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Values { + /// Was there an explict ROWs keyword (MySQL)? + /// pub explicit_row: bool, pub rows: Vec>, } From 7fbd21ba56164e35246fe0266c17d95a854e6f60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Thu, 1 Dec 2022 09:12:22 +0100 Subject: [PATCH 3/3] remove *requirement* for ROW --- src/dialect/mod.rs | 4 ---- src/dialect/mysql.rs | 4 ---- src/parser.rs | 27 +++++++-------------------- tests/sqlparser_common.rs | 25 ++++++------------------- 4 files changed, 13 insertions(+), 47 deletions(-) diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 7772afb87..1eaa41aa7 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -102,10 +102,6 @@ pub trait Dialect: Debug + Any { // return None to fall back to the default behavior None } - /// Returns true if VALUES requires ROW keywords in SELECT. - fn values_require_row_in_select(&self) -> bool { - false - } } impl dyn Dialect { diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index e9458311e..d6095262c 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -35,8 +35,4 @@ impl Dialect for MySqlDialect { fn is_delimited_identifier_start(&self, ch: char) -> bool { ch == '`' } - - fn values_require_row_in_select(&self) -> bool { - true - } } diff --git a/src/parser.rs b/src/parser.rs index a0dcff244..52fe57bcf 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -4162,10 +4162,6 @@ impl<'a> Parser<'a> { /// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't /// expect the initial keyword to be already consumed pub fn parse_query(&mut self) -> Result { - self.parse_query_impl(false) - } - - pub fn parse_query_impl(&mut self, within_insert: bool) -> Result { let with = if self.parse_keyword(Keyword::WITH) { Some(With { recursive: self.parse_keyword(Keyword::RECURSIVE), @@ -4176,7 +4172,7 @@ impl<'a> Parser<'a> { }; if !self.parse_keyword(Keyword::INSERT) { - let body = Box::new(self.parse_query_body(0, within_insert)?); + let body = Box::new(self.parse_query_body(0)?); let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) { self.parse_comma_separated(Parser::parse_order_by_expr)? @@ -4291,11 +4287,7 @@ impl<'a> Parser<'a> { /// subquery ::= query_body [ order_by_limit ] /// set_operation ::= query_body { 'UNION' | 'EXCEPT' | 'INTERSECT' } [ 'ALL' ] query_body /// ``` - pub fn parse_query_body( - &mut self, - precedence: u8, - within_insert: bool, - ) -> Result { + pub fn parse_query_body(&mut self, precedence: u8) -> Result { // We parse the expression using a Pratt parser, as in `parse_expr()`. // Start by parsing a restricted SELECT or a `(subquery)`: let mut expr = if self.parse_keyword(Keyword::SELECT) { @@ -4306,7 +4298,7 @@ impl<'a> Parser<'a> { self.expect_token(&Token::RParen)?; SetExpr::Query(Box::new(subquery)) } else if self.parse_keyword(Keyword::VALUES) { - SetExpr::Values(self.parse_values(within_insert)?) + SetExpr::Values(self.parse_values()?) } else { return self.expected( "SELECT, VALUES, or a subquery in the query body", @@ -4334,7 +4326,7 @@ impl<'a> Parser<'a> { left: Box::new(expr), op: op.unwrap(), set_quantifier, - right: Box::new(self.parse_query_body(next_precedence, within_insert)?), + right: Box::new(self.parse_query_body(next_precedence)?), }; } @@ -5234,7 +5226,7 @@ impl<'a> Parser<'a> { // Hive allows you to specify columns after partitions as well if you want. let after_columns = self.parse_parenthesized_column_list(Optional)?; - let source = Box::new(self.parse_query_impl(true)?); + let source = Box::new(self.parse_query()?); let on = if self.parse_keyword(Keyword::ON) { if self.parse_keyword(Keyword::CONFLICT) { let conflict_target = @@ -5490,17 +5482,12 @@ impl<'a> Parser<'a> { } } - pub fn parse_values(&mut self, within_insert: bool) -> Result { + pub fn parse_values(&mut self) -> Result { let mut explicit_row = false; let rows = self.parse_comma_separated(|parser| { if parser.parse_keyword(Keyword::ROW) { explicit_row = true; - } else { - if !within_insert && parser.dialect.values_require_row_in_select() { - parser - .expected(format!("{:?}", &Keyword::ROW).as_str(), parser.peek_token())?; - } } parser.expect_token(&Token::LParen)?; @@ -5674,7 +5661,7 @@ impl<'a> Parser<'a> { } let columns = self.parse_parenthesized_column_list(Optional)?; self.expect_keyword(Keyword::VALUES)?; - let values = self.parse_values(true)?; + let values = self.parse_values()?; MergeClause::NotMatched { predicate, columns, diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 85ef60e83..056c26b97 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -461,7 +461,8 @@ fn parse_top_level() { verified_stmt("SELECT 1"); verified_stmt("(SELECT 1)"); verified_stmt("((SELECT 1))"); - all_but_mysql().verified_stmt("VALUES (1)"); + verified_stmt("VALUES (1)"); + verified_stmt("VALUES ROW(1, true, 'a'), ROW(2, false, 'b')"); } #[test] @@ -4235,9 +4236,10 @@ fn parse_union_except_intersect() { #[test] fn parse_values() { - all_but_mysql().verified_stmt("SELECT * FROM (VALUES (1), (2), (3))"); - all_but_mysql().verified_stmt("SELECT * FROM (VALUES (1), (2), (3)), (VALUES (1, 2, 3))"); - all_but_mysql().verified_stmt("SELECT * FROM (VALUES (1)) UNION VALUES (1)"); + verified_stmt("SELECT * FROM (VALUES (1), (2), (3))"); + verified_stmt("SELECT * FROM (VALUES (1), (2), (3)), (VALUES (1, 2, 3))"); + verified_stmt("SELECT * FROM (VALUES (1)) UNION VALUES (1)"); + verified_stmt("SELECT * FROM (VALUES ROW(1, true, 'a'), ROW(2, false, 'b')) AS t (a, b, c)"); } #[test] @@ -5685,21 +5687,6 @@ fn verified_expr(query: &str) -> Expr { all_dialects().verified_expr(query) } -fn all_but_mysql() -> TestedDialects { - TestedDialects { - dialects: vec![ - Box::new(GenericDialect {}), - Box::new(PostgreSqlDialect {}), - Box::new(MsSqlDialect {}), - Box::new(AnsiDialect {}), - Box::new(SnowflakeDialect {}), - Box::new(HiveDialect {}), - Box::new(RedshiftSqlDialect {}), - Box::new(BigQueryDialect {}), - ], - } -} - #[test] fn parse_offset_and_limit() { let sql = "SELECT foo FROM bar LIMIT 2 OFFSET 2";