From 042effdf1406ae0a81eb72e67c107396b49539fb Mon Sep 17 00:00:00 2001 From: zidaye <44500963+zidaye@users.noreply.github.com> Date: Thu, 1 Dec 2022 02:33:33 +0800 Subject: [PATCH] update on conflict method (#735) --- src/ast/mod.rs | 26 ++++++++++++- src/parser.rs | 12 +++++- tests/sqlparser_postgres.rs | 74 ++++++++++++++++++++++++++++++------- 3 files changed, 94 insertions(+), 18 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 159249c17..f727de27e 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2702,7 +2702,16 @@ pub struct OnConflict { #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum OnConflictAction { DoNothing, - DoUpdate(Vec), + DoUpdate(DoUpdate), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct DoUpdate { + /// Column assignments + pub assignments: Vec, + /// WHERE + pub selection: Option, } impl fmt::Display for OnInsert { @@ -2730,7 +2739,20 @@ impl fmt::Display for OnConflictAction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::DoNothing => write!(f, "DO NOTHING"), - Self::DoUpdate(a) => write!(f, "DO UPDATE SET {}", display_comma_separated(a)), + Self::DoUpdate(do_update) => { + write!(f, "DO UPDATE")?; + if !do_update.assignments.is_empty() { + write!( + f, + " SET {}", + display_comma_separated(&do_update.assignments) + )?; + } + if let Some(selection) = &do_update.selection { + write!(f, " WHERE {}", selection)?; + } + Ok(()) + } } } } diff --git a/src/parser.rs b/src/parser.rs index cb78cf5a3..260dbfbeb 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -5351,8 +5351,16 @@ impl<'a> Parser<'a> { } else { self.expect_keyword(Keyword::UPDATE)?; self.expect_keyword(Keyword::SET)?; - let l = self.parse_comma_separated(Parser::parse_assignment)?; - OnConflictAction::DoUpdate(l) + let assignments = self.parse_comma_separated(Parser::parse_assignment)?; + let selection = if self.parse_keyword(Keyword::WHERE) { + Some(self.parse_expr()?) + } else { + None + }; + OnConflictAction::DoUpdate(DoUpdate { + assignments, + selection, + }) }; Some(OnInsert::OnConflict(OnConflict { diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 2a871a320..e32a7b027 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1117,10 +1117,13 @@ fn parse_pg_on_conflict() { } => { assert_eq!(vec![Ident::from("did")], conflict_target); assert_eq!( - OnConflictAction::DoUpdate(vec![Assignment { - id: vec!["dname".into()], - value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "dname".into()]) - },]), + OnConflictAction::DoUpdate(DoUpdate { + assignments: vec![Assignment { + id: vec!["dname".into()], + value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "dname".into()]) + },], + selection: None + }), action ); } @@ -1147,16 +1150,22 @@ fn parse_pg_on_conflict() { conflict_target ); assert_eq!( - OnConflictAction::DoUpdate(vec![ - Assignment { - id: vec!["dname".into()], - value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "dname".into()]) - }, - Assignment { - id: vec!["area".into()], - value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "area".into()]) - }, - ]), + OnConflictAction::DoUpdate(DoUpdate { + assignments: vec![ + Assignment { + id: vec!["dname".into()], + value: Expr::CompoundIdentifier(vec![ + "EXCLUDED".into(), + "dname".into() + ]) + }, + Assignment { + id: vec!["area".into()], + value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "area".into()]) + }, + ], + selection: None + }), action ); } @@ -1182,6 +1191,43 @@ fn parse_pg_on_conflict() { } _ => unreachable!(), }; + + let stmt = pg_and_generic().verified_stmt( + "INSERT INTO distributors (did, dname, dsize) \ + VALUES (5, 'Gizmo Transglobal', 1000), (6, 'Associated Computing, Inc', 1010) \ + ON CONFLICT(did) \ + DO UPDATE SET dname = $1 WHERE dsize > $2", + ); + match stmt { + Statement::Insert { + on: + Some(OnInsert::OnConflict(OnConflict { + conflict_target, + action, + })), + .. + } => { + assert_eq!(vec![Ident::from("did")], conflict_target); + assert_eq!( + OnConflictAction::DoUpdate(DoUpdate { + assignments: vec![Assignment { + id: vec!["dname".into()], + value: Expr::Value(Value::Placeholder("$1".to_string())) + },], + selection: Some(Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident { + value: "dsize".to_string(), + quote_style: None + })), + op: BinaryOperator::Gt, + right: Box::new(Expr::Value(Value::Placeholder("$2".to_string()))) + }) + }), + action + ); + } + _ => unreachable!(), + }; } #[test]