diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 6a718702d..cf936d3c9 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2728,11 +2728,17 @@ pub enum OnInsert { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct OnConflict { - pub conflict_target: Vec, + pub conflict_target: Option, pub action: OnConflictAction, } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum ConflictTarget { + Columns(Vec), + OnConstraint(ObjectName), +} +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum OnConflictAction { DoNothing, DoUpdate(DoUpdate), @@ -2762,12 +2768,20 @@ impl fmt::Display for OnInsert { impl fmt::Display for OnConflict { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, " ON CONFLICT")?; - if !self.conflict_target.is_empty() { - write!(f, "({})", display_comma_separated(&self.conflict_target))?; + if let Some(target) = &self.conflict_target { + write!(f, "{}", target)?; } write!(f, " {}", self.action) } } +impl fmt::Display for ConflictTarget { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ConflictTarget::Columns(cols) => write!(f, "({})", display_comma_separated(cols)), + ConflictTarget::OnConstraint(name) => write!(f, " ON CONSTRAINT {}", name), + } + } +} impl fmt::Display for OnConflictAction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { diff --git a/src/parser.rs b/src/parser.rs index 26aa499e5..528b8c6fe 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -5630,7 +5630,15 @@ impl<'a> Parser<'a> { let on = if self.parse_keyword(Keyword::ON) { if self.parse_keyword(Keyword::CONFLICT) { let conflict_target = - self.parse_parenthesized_column_list(IsOptional::Optional)?; + if self.parse_keywords(&[Keyword::ON, Keyword::CONSTRAINT]) { + Some(ConflictTarget::OnConstraint(self.parse_object_name()?)) + } else if self.peek_token() == Token::LParen { + Some(ConflictTarget::Columns( + self.parse_parenthesized_column_list(IsOptional::Mandatory)?, + )) + } else { + None + }; self.expect_keyword(Keyword::DO)?; let action = if self.parse_keyword(Keyword::NOTHING) { diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 5f178a91e..6b26fe48b 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1112,12 +1112,12 @@ fn parse_pg_on_conflict() { Statement::Insert { on: Some(OnInsert::OnConflict(OnConflict { - conflict_target, + conflict_target: Some(ConflictTarget::Columns(cols)), action, })), .. } => { - assert_eq!(vec![Ident::from("did")], conflict_target); + assert_eq!(vec![Ident::from("did")], cols); assert_eq!( OnConflictAction::DoUpdate(DoUpdate { assignments: vec![Assignment { @@ -1142,15 +1142,12 @@ fn parse_pg_on_conflict() { Statement::Insert { on: Some(OnInsert::OnConflict(OnConflict { - conflict_target, + conflict_target: Some(ConflictTarget::Columns(cols)), action, })), .. } => { - assert_eq!( - vec![Ident::from("did"), Ident::from("area"),], - conflict_target - ); + assert_eq!(vec![Ident::from("did"), Ident::from("area"),], cols); assert_eq!( OnConflictAction::DoUpdate(DoUpdate { assignments: vec![ @@ -1183,12 +1180,11 @@ fn parse_pg_on_conflict() { Statement::Insert { on: Some(OnInsert::OnConflict(OnConflict { - conflict_target, + conflict_target: None, action, })), .. } => { - assert_eq!(Vec::::new(), conflict_target); assert_eq!(OnConflictAction::DoNothing, action); } _ => unreachable!(), @@ -1204,12 +1200,49 @@ fn parse_pg_on_conflict() { Statement::Insert { on: Some(OnInsert::OnConflict(OnConflict { - conflict_target, + conflict_target: Some(ConflictTarget::Columns(cols)), + action, + })), + .. + } => { + assert_eq!(vec![Ident::from("did")], cols); + assert_eq!( + OnConflictAction::DoUpdate(DoUpdate { + assignments: vec![Assignment { + id: vec!["dname".into()], + value: Expr::Value(Value::Placeholder("$1".to_string())) + },], + selection: Some(Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident { + value: "dsize".to_string(), + quote_style: None + })), + op: BinaryOperator::Gt, + right: Box::new(Expr::Value(Value::Placeholder("$2".to_string()))) + }) + }), + action + ); + } + _ => unreachable!(), + }; + + let stmt = pg_and_generic().verified_stmt( + "INSERT INTO distributors (did, dname, dsize) \ + VALUES (5, 'Gizmo Transglobal', 1000), (6, 'Associated Computing, Inc', 1010) \ + ON CONFLICT ON CONSTRAINT distributors_did_pkey \ + DO UPDATE SET dname = $1 WHERE dsize > $2", + ); + match stmt { + Statement::Insert { + on: + Some(OnInsert::OnConflict(OnConflict { + conflict_target: Some(ConflictTarget::OnConstraint(cname)), action, })), .. } => { - assert_eq!(vec![Ident::from("did")], conflict_target); + assert_eq!(vec![Ident::from("distributors_did_pkey")], cname.0); assert_eq!( OnConflictAction::DoUpdate(DoUpdate { assignments: vec![Assignment {