diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 7f9d42f05..62e7c15aa 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1037,6 +1037,8 @@ pub enum Statement { /// whether the insert has the table keyword (Hive) table: bool, on: Option, + /// RETURNING + returning: Option>, }, // TODO: Support ROW FORMAT Directory { @@ -1077,6 +1079,8 @@ pub enum Statement { from: Option, /// WHERE selection: Option, + /// RETURNING + returning: Option>, }, /// DELETE Delete { @@ -1086,6 +1090,8 @@ pub enum Statement { using: Option, /// WHERE selection: Option, + /// RETURNING + returning: Option>, }, /// CREATE VIEW CreateView { @@ -1633,6 +1639,7 @@ impl fmt::Display for Statement { source, table, on, + returning, } => { if let Some(action) = or { write!(f, "INSERT OR {} INTO {} ", action, table_name)?; @@ -1660,10 +1667,14 @@ impl fmt::Display for Statement { write!(f, "{}", source)?; if let Some(on) = on { - write!(f, "{}", on) - } else { - Ok(()) + write!(f, "{}", on)?; + } + + if let Some(returning) = returning { + write!(f, " RETURNING {}", display_comma_separated(returning))?; } + + Ok(()) } Statement::Copy { @@ -1707,6 +1718,7 @@ impl fmt::Display for Statement { assignments, from, selection, + returning, } => { write!(f, "UPDATE {}", table)?; if !assignments.is_empty() { @@ -1718,12 +1730,16 @@ impl fmt::Display for Statement { if let Some(selection) = selection { write!(f, " WHERE {}", selection)?; } + if let Some(returning) = returning { + write!(f, " RETURNING {}", display_comma_separated(returning))?; + } Ok(()) } Statement::Delete { table_name, using, selection, + returning, } => { write!(f, "DELETE FROM {}", table_name)?; if let Some(using) = using { @@ -1732,6 +1748,9 @@ impl fmt::Display for Statement { if let Some(selection) = selection { write!(f, " WHERE {}", selection)?; } + if let Some(returning) = returning { + write!(f, " RETURNING {}", display_comma_separated(returning))?; + } Ok(()) } Statement::Close { cursor } => { @@ -2416,6 +2435,21 @@ impl fmt::Display for Statement { pub enum OnInsert { /// ON DUPLICATE KEY UPDATE (MySQL when the key already exists, then execute an update instead) DuplicateKeyUpdate(Vec), + /// ON CONFLICT is a PostgreSQL and Sqlite extension + OnConflict(OnConflict), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct OnConflict { + pub conflict_target: Vec, + pub action: OnConflictAction, +} +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum OnConflictAction { + DoNothing, + DoUpdate(Vec), } impl fmt::Display for OnInsert { @@ -2426,6 +2460,24 @@ impl fmt::Display for OnInsert { " ON DUPLICATE KEY UPDATE {}", display_comma_separated(expr) ), + Self::OnConflict(o) => write!(f, " {o}"), + } + } +} +impl fmt::Display for OnConflict { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, " ON CONFLICT")?; + if !self.conflict_target.is_empty() { + write!(f, "({})", display_comma_separated(&self.conflict_target))?; + } + write!(f, " {}", self.action) + } +} +impl fmt::Display for OnConflictAction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::DoNothing => write!(f, "DO NOTHING"), + Self::DoUpdate(a) => write!(f, "DO UPDATE SET {}", display_comma_separated(a)), } } } diff --git a/src/keywords.rs b/src/keywords.rs index b84b4cf9d..3a3b6011a 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -142,6 +142,7 @@ define_keywords!( COMMITTED, COMPUTE, CONDITION, + CONFLICT, CONNECT, CONNECTION, CONSTRAINT, @@ -198,6 +199,7 @@ define_keywords!( DISCONNECT, DISTINCT, DISTRIBUTE, + DO, DOUBLE, DOW, DOY, @@ -363,6 +365,7 @@ define_keywords!( NOSCAN, NOSUPERUSER, NOT, + NOTHING, NTH_VALUE, NTILE, NULL, @@ -454,6 +457,7 @@ define_keywords!( RESTRICT, RESULT, RETURN, + RETURNING, RETURNS, REVOKE, RIGHT, diff --git a/src/parser.rs b/src/parser.rs index cb261e183..241fe016f 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -3733,10 +3733,17 @@ impl<'a> Parser<'a> { None }; + let returning = if self.parse_keyword(Keyword::RETURNING) { + Some(self.parse_comma_separated(Parser::parse_select_item)?) + } else { + None + }; + Ok(Statement::Delete { table_name, using, selection, + returning, }) } @@ -4824,12 +4831,38 @@ impl<'a> Parser<'a> { let source = Box::new(self.parse_query()?); let on = if self.parse_keyword(Keyword::ON) { - self.expect_keyword(Keyword::DUPLICATE)?; - self.expect_keyword(Keyword::KEY)?; - self.expect_keyword(Keyword::UPDATE)?; - let l = self.parse_comma_separated(Parser::parse_assignment)?; + if self.parse_keyword(Keyword::CONFLICT) { + let conflict_target = + self.parse_parenthesized_column_list(IsOptional::Optional)?; - Some(OnInsert::DuplicateKeyUpdate(l)) + self.expect_keyword(Keyword::DO)?; + let action = if self.parse_keyword(Keyword::NOTHING) { + OnConflictAction::DoNothing + } else { + self.expect_keyword(Keyword::UPDATE)?; + self.expect_keyword(Keyword::SET)?; + let l = self.parse_comma_separated(Parser::parse_assignment)?; + OnConflictAction::DoUpdate(l) + }; + + Some(OnInsert::OnConflict(OnConflict { + conflict_target, + action, + })) + } else { + self.expect_keyword(Keyword::DUPLICATE)?; + self.expect_keyword(Keyword::KEY)?; + self.expect_keyword(Keyword::UPDATE)?; + let l = self.parse_comma_separated(Parser::parse_assignment)?; + + Some(OnInsert::DuplicateKeyUpdate(l)) + } + } else { + None + }; + + let returning = if self.parse_keyword(Keyword::RETURNING) { + Some(self.parse_comma_separated(Parser::parse_select_item)?) } else { None }; @@ -4845,6 +4878,7 @@ impl<'a> Parser<'a> { source, table, on, + returning, }) } } @@ -4863,11 +4897,17 @@ impl<'a> Parser<'a> { } else { None }; + let returning = if self.parse_keyword(Keyword::RETURNING) { + Some(self.parse_comma_separated(Parser::parse_select_item)?) + } else { + None + }; Ok(Statement::Update { table, assignments, from, selection, + returning, }) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 7654d677e..7847051b2 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -195,6 +195,7 @@ fn parse_update_with_table_alias() { assignments, from: _from, selection, + returning, } => { assert_eq!( TableWithJoins { @@ -231,6 +232,7 @@ fn parse_update_with_table_alias() { }), selection ); + assert_eq!(None, returning); } _ => unreachable!(), } @@ -278,6 +280,7 @@ fn parse_where_delete_statement() { table_name, using, selection, + returning, } => { assert_eq!( TableFactor::Table { @@ -298,6 +301,7 @@ fn parse_where_delete_statement() { }, selection.unwrap(), ); + assert_eq!(None, returning); } _ => unreachable!(), } @@ -313,6 +317,7 @@ fn parse_where_delete_with_alias_statement() { table_name, using, selection, + returning, } => { assert_eq!( TableFactor::Table { @@ -353,6 +358,7 @@ fn parse_where_delete_with_alias_statement() { }, selection.unwrap(), ); + assert_eq!(None, returning); } _ => unreachable!(), } diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 8b8754db4..c9944f440 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -815,6 +815,7 @@ fn parse_update_with_joins() { assignments, from: _from, selection, + returning, } => { assert_eq!( TableWithJoins { @@ -870,6 +871,7 @@ fn parse_update_with_joins() { }), selection ); + assert_eq!(None, returning); } _ => unreachable!(), } diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index c589feec5..796b8d45f 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -476,6 +476,7 @@ fn parse_update_set_from() { Ident::new("id") ])), }), + returning: None, } ); } @@ -1086,6 +1087,143 @@ fn parse_prepare() { ); } +#[test] +fn parse_pg_on_conflict() { + let stmt = pg_and_generic().verified_stmt( + "INSERT INTO distributors (did, dname) \ + VALUES (5, 'Gizmo Transglobal'), (6, 'Associated Computing, Inc') \ + ON CONFLICT(did) \ + DO UPDATE SET dname = EXCLUDED.dname", + ); + match stmt { + Statement::Insert { + on: + Some(OnInsert::OnConflict(OnConflict { + conflict_target, + action, + })), + .. + } => { + assert_eq!(vec![Ident::from("did")], conflict_target); + assert_eq!( + OnConflictAction::DoUpdate(vec![Assignment { + id: vec!["dname".into()], + value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "dname".into()]) + },]), + action + ); + } + _ => unreachable!(), + }; + + let stmt = pg_and_generic().verified_stmt( + "INSERT INTO distributors (did, dname, area) \ + VALUES (5, 'Gizmo Transglobal', 'Mars'), (6, 'Associated Computing, Inc', 'Venus') \ + ON CONFLICT(did, area) \ + DO UPDATE SET dname = EXCLUDED.dname, area = EXCLUDED.area", + ); + match stmt { + Statement::Insert { + on: + Some(OnInsert::OnConflict(OnConflict { + conflict_target, + action, + })), + .. + } => { + assert_eq!( + vec![Ident::from("did"), Ident::from("area"),], + conflict_target + ); + assert_eq!( + OnConflictAction::DoUpdate(vec![ + Assignment { + id: vec!["dname".into()], + value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "dname".into()]) + }, + Assignment { + id: vec!["area".into()], + value: Expr::CompoundIdentifier(vec!["EXCLUDED".into(), "area".into()]) + }, + ]), + action + ); + } + _ => unreachable!(), + }; + + let stmt = pg_and_generic().verified_stmt( + "INSERT INTO distributors (did, dname) \ + VALUES (5, 'Gizmo Transglobal'), (6, 'Associated Computing, Inc') \ + ON CONFLICT DO NOTHING", + ); + match stmt { + Statement::Insert { + on: + Some(OnInsert::OnConflict(OnConflict { + conflict_target, + action, + })), + .. + } => { + assert_eq!(Vec::::new(), conflict_target); + assert_eq!(OnConflictAction::DoNothing, action); + } + _ => unreachable!(), + }; +} + +#[test] +fn parse_pg_returning() { + let stmt = pg_and_generic().verified_stmt( + "INSERT INTO distributors (did, dname) VALUES (DEFAULT, 'XYZ Widgets') RETURNING did", + ); + match stmt { + Statement::Insert { returning, .. } => { + assert_eq!( + Some(vec![SelectItem::UnnamedExpr(Expr::Identifier( + "did".into() + )),]), + returning + ); + } + _ => unreachable!(), + }; + + let stmt = pg_and_generic().verified_stmt( + "UPDATE weather SET temp_lo = temp_lo + 1, temp_hi = temp_lo + 15, prcp = DEFAULT \ + WHERE city = 'San Francisco' AND date = '2003-07-03' \ + RETURNING temp_lo AS lo, temp_hi AS hi, prcp", + ); + match stmt { + Statement::Update { returning, .. } => { + assert_eq!( + Some(vec![ + SelectItem::ExprWithAlias { + expr: Expr::Identifier("temp_lo".into()), + alias: "lo".into() + }, + SelectItem::ExprWithAlias { + expr: Expr::Identifier("temp_hi".into()), + alias: "hi".into() + }, + SelectItem::UnnamedExpr(Expr::Identifier("prcp".into())), + ]), + returning + ); + } + _ => unreachable!(), + }; + let stmt = + pg_and_generic().verified_stmt("DELETE FROM tasks WHERE status = 'DONE' RETURNING *"); + match stmt { + Statement::Delete { returning, .. } => { + assert_eq!(Some(vec![SelectItem::Wildcard,]), returning); + } + _ => unreachable!(), + }; +} + #[test] fn parse_pg_bitwise_binary_ops() { let bitwise_ops = &[