From 4b1dc1abf7138cb1576707c194ec89ad89049d27 Mon Sep 17 00:00:00 2001 From: unvalley <38400669+unvalley@users.noreply.github.com> Date: Sat, 12 Nov 2022 06:37:09 +0900 Subject: [PATCH] Support `UPDATE ... FROM ( subquery )` in some dialects (#694) * Apply UPDATE SET FROM statement for some dialects * Add GenericDialect to support * Test SnowflakeDialect Co-authored-by: Andrew Lamb --- src/parser.rs | 4 +- tests/sqlparser_common.rs | 92 ++++++++++++++++++++++++++++++++++++- tests/sqlparser_postgres.rs | 80 -------------------------------- 3 files changed, 94 insertions(+), 82 deletions(-) diff --git a/src/parser.rs b/src/parser.rs index 0dbc4aa36..f91223c2f 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -5254,7 +5254,9 @@ impl<'a> Parser<'a> { let table = self.parse_table_and_joins()?; self.expect_keyword(Keyword::SET)?; let assignments = self.parse_comma_separated(Parser::parse_assignment)?; - let from = if self.parse_keyword(Keyword::FROM) && dialect_of!(self is PostgreSqlDialect) { + let from = if self.parse_keyword(Keyword::FROM) + && dialect_of!(self is GenericDialect | PostgreSqlDialect | BigQueryDialect | SnowflakeDialect | RedshiftSqlDialect | MsSqlDialect) + { Some(self.parse_table_and_joins()?) } else { None diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 93439cafe..83a7a4ca6 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -24,7 +24,7 @@ use sqlparser::ast::SelectItem::UnnamedExpr; use sqlparser::ast::*; use sqlparser::dialect::{ AnsiDialect, BigQueryDialect, ClickHouseDialect, GenericDialect, HiveDialect, MsSqlDialect, - MySqlDialect, PostgreSqlDialect, SQLiteDialect, SnowflakeDialect, + MySqlDialect, PostgreSqlDialect, RedshiftSqlDialect, SQLiteDialect, SnowflakeDialect, }; use sqlparser::keywords::ALL_KEYWORDS; use sqlparser::parser::{Parser, ParserError}; @@ -186,6 +186,96 @@ fn parse_update() { ); } +#[test] +fn parse_update_set_from() { + let sql = "UPDATE t1 SET name = t2.name FROM (SELECT name, id FROM t1 GROUP BY id) AS t2 WHERE t1.id = t2.id"; + let dialects = TestedDialects { + dialects: vec![ + Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(BigQueryDialect {}), + Box::new(SnowflakeDialect {}), + Box::new(RedshiftSqlDialect {}), + Box::new(MsSqlDialect {}), + ], + }; + let stmt = dialects.verified_stmt(sql); + assert_eq!( + stmt, + Statement::Update { + table: TableWithJoins { + relation: TableFactor::Table { + name: ObjectName(vec![Ident::new("t1")]), + alias: None, + args: None, + with_hints: vec![], + }, + joins: vec![], + }, + assignments: vec![Assignment { + id: vec![Ident::new("name")], + value: Expr::CompoundIdentifier(vec![Ident::new("t2"), Ident::new("name")]) + }], + from: Some(TableWithJoins { + relation: TableFactor::Derived { + lateral: false, + subquery: Box::new(Query { + with: None, + body: Box::new(SetExpr::Select(Box::new(Select { + distinct: false, + top: None, + projection: vec![ + SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("name"))), + SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("id"))), + ], + into: None, + from: vec![TableWithJoins { + relation: TableFactor::Table { + name: ObjectName(vec![Ident::new("t1")]), + alias: None, + args: None, + with_hints: vec![], + }, + joins: vec![], + }], + lateral_views: vec![], + selection: None, + group_by: vec![Expr::Identifier(Ident::new("id"))], + cluster_by: vec![], + distribute_by: vec![], + sort_by: vec![], + having: None, + qualify: None + }))), + order_by: vec![], + limit: None, + offset: None, + fetch: None, + lock: None, + }), + alias: Some(TableAlias { + name: Ident::new("t2"), + columns: vec![], + }) + }, + joins: vec![], + }), + selection: Some(Expr::BinaryOp { + left: Box::new(Expr::CompoundIdentifier(vec![ + Ident::new("t1"), + Ident::new("id") + ])), + op: BinaryOperator::Eq, + right: Box::new(Expr::CompoundIdentifier(vec![ + Ident::new("t2"), + Ident::new("id") + ])), + }), + returning: None, + } + ); +} + #[test] fn parse_update_with_table_alias() { let sql = "UPDATE users AS u SET u.username = 'new_user' WHERE u.username = 'old_user'"; diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 6cce4fdb9..5cc333935 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -489,86 +489,6 @@ PHP ₱ USD $ //assert_eq!(sql, ast.to_string()); } -#[test] -fn parse_update_set_from() { - let sql = "UPDATE t1 SET name = t2.name FROM (SELECT name, id FROM t1 GROUP BY id) AS t2 WHERE t1.id = t2.id"; - let stmt = pg().verified_stmt(sql); - assert_eq!( - stmt, - Statement::Update { - table: TableWithJoins { - relation: TableFactor::Table { - name: ObjectName(vec![Ident::new("t1")]), - alias: None, - args: None, - with_hints: vec![], - }, - joins: vec![], - }, - assignments: vec![Assignment { - id: vec![Ident::new("name")], - value: Expr::CompoundIdentifier(vec![Ident::new("t2"), Ident::new("name")]) - }], - from: Some(TableWithJoins { - relation: TableFactor::Derived { - lateral: false, - subquery: Box::new(Query { - with: None, - body: Box::new(SetExpr::Select(Box::new(Select { - distinct: false, - top: None, - projection: vec![ - SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("name"))), - SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("id"))), - ], - into: None, - from: vec![TableWithJoins { - relation: TableFactor::Table { - name: ObjectName(vec![Ident::new("t1")]), - alias: None, - args: None, - with_hints: vec![], - }, - joins: vec![], - }], - lateral_views: vec![], - selection: None, - group_by: vec![Expr::Identifier(Ident::new("id"))], - cluster_by: vec![], - distribute_by: vec![], - sort_by: vec![], - having: None, - qualify: None - }))), - order_by: vec![], - limit: None, - offset: None, - fetch: None, - lock: None, - }), - alias: Some(TableAlias { - name: Ident::new("t2"), - columns: vec![], - }) - }, - joins: vec![], - }), - selection: Some(Expr::BinaryOp { - left: Box::new(Expr::CompoundIdentifier(vec![ - Ident::new("t1"), - Ident::new("id") - ])), - op: BinaryOperator::Eq, - right: Box::new(Expr::CompoundIdentifier(vec![ - Ident::new("t2"), - Ident::new("id") - ])), - }), - returning: None, - } - ); -} - #[test] fn test_copy_from() { let stmt = pg().verified_stmt("COPY users FROM 'data.csv'");