Skip to content

Commit

Permalink
Implement ON CONFLICT and RETURNING (#666)
Browse files Browse the repository at this point in the history
* Implement RETURNING on INSERT/UPDATE/DELETE

* Implement INSERT ... ON CONFLICT

* Fix tests

* cargo fmt

* tests: on conflict and returning

Co-authored-by: gamife <gamife9886@gmail.com>
  • Loading branch information
main-- and gamife committed Nov 11, 2022
1 parent ae1c690 commit 814367a
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 8 deletions.
58 changes: 55 additions & 3 deletions src/ast/mod.rs
Expand Up @@ -1049,6 +1049,8 @@ pub enum Statement {
/// whether the insert has the table keyword (Hive)
table: bool,
on: Option<OnInsert>,
/// RETURNING
returning: Option<Vec<SelectItem>>,
},
// TODO: Support ROW FORMAT
Directory {
Expand Down Expand Up @@ -1089,6 +1091,8 @@ pub enum Statement {
from: Option<TableWithJoins>,
/// WHERE
selection: Option<Expr>,
/// RETURNING
returning: Option<Vec<SelectItem>>,
},
/// DELETE
Delete {
Expand All @@ -1098,6 +1102,8 @@ pub enum Statement {
using: Option<TableFactor>,
/// WHERE
selection: Option<Expr>,
/// RETURNING
returning: Option<Vec<SelectItem>>,
},
/// CREATE VIEW
CreateView {
Expand Down Expand Up @@ -1679,6 +1685,7 @@ impl fmt::Display for Statement {
source,
table,
on,
returning,
} => {
if let Some(action) = or {
write!(f, "INSERT OR {} INTO {} ", action, table_name)?;
Expand Down Expand Up @@ -1706,10 +1713,14 @@ impl fmt::Display for Statement {
write!(f, "{}", source)?;

if let Some(on) = on {
write!(f, "{}", on)
} else {
Ok(())
write!(f, "{}", on)?;
}

if let Some(returning) = returning {
write!(f, " RETURNING {}", display_comma_separated(returning))?;
}

Ok(())
}

Statement::Copy {
Expand Down Expand Up @@ -1753,6 +1764,7 @@ impl fmt::Display for Statement {
assignments,
from,
selection,
returning,
} => {
write!(f, "UPDATE {}", table)?;
if !assignments.is_empty() {
Expand All @@ -1764,12 +1776,16 @@ impl fmt::Display for Statement {
if let Some(selection) = selection {
write!(f, " WHERE {}", selection)?;
}
if let Some(returning) = returning {
write!(f, " RETURNING {}", display_comma_separated(returning))?;
}
Ok(())
}
Statement::Delete {
table_name,
using,
selection,
returning,
} => {
write!(f, "DELETE FROM {}", table_name)?;
if let Some(using) = using {
Expand All @@ -1778,6 +1794,9 @@ impl fmt::Display for Statement {
if let Some(selection) = selection {
write!(f, " WHERE {}", selection)?;
}
if let Some(returning) = returning {
write!(f, " RETURNING {}", display_comma_separated(returning))?;
}
Ok(())
}
Statement::Close { cursor } => {
Expand Down Expand Up @@ -2610,6 +2629,21 @@ pub enum MinMaxValue {
pub enum OnInsert {
/// ON DUPLICATE KEY UPDATE (MySQL when the key already exists, then execute an update instead)
DuplicateKeyUpdate(Vec<Assignment>),
/// ON CONFLICT is a PostgreSQL and Sqlite extension
OnConflict(OnConflict),
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct OnConflict {
pub conflict_target: Vec<Ident>,
pub action: OnConflictAction,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum OnConflictAction {
DoNothing,
DoUpdate(Vec<Assignment>),
}

impl fmt::Display for OnInsert {
Expand All @@ -2620,6 +2654,24 @@ impl fmt::Display for OnInsert {
" ON DUPLICATE KEY UPDATE {}",
display_comma_separated(expr)
),
Self::OnConflict(o) => write!(f, " {o}"),
}
}
}
impl fmt::Display for OnConflict {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, " ON CONFLICT")?;
if !self.conflict_target.is_empty() {
write!(f, "({})", display_comma_separated(&self.conflict_target))?;
}
write!(f, " {}", self.action)
}
}
impl fmt::Display for OnConflictAction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::DoNothing => write!(f, "DO NOTHING"),
Self::DoUpdate(a) => write!(f, "DO UPDATE SET {}", display_comma_separated(a)),
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/keywords.rs
Expand Up @@ -144,6 +144,7 @@ define_keywords!(
COMMITTED,
COMPUTE,
CONDITION,
CONFLICT,
CONNECT,
CONNECTION,
CONSTRAINT,
Expand Down Expand Up @@ -200,6 +201,7 @@ define_keywords!(
DISCONNECT,
DISTINCT,
DISTRIBUTE,
DO,
DOUBLE,
DOW,
DOY,
Expand Down Expand Up @@ -370,6 +372,7 @@ define_keywords!(
NOSCAN,
NOSUPERUSER,
NOT,
NOTHING,
NTH_VALUE,
NTILE,
NULL,
Expand Down Expand Up @@ -464,6 +467,7 @@ define_keywords!(
RESTRICT,
RESULT,
RETURN,
RETURNING,
RETURNS,
REVOKE,
RIGHT,
Expand Down
50 changes: 45 additions & 5 deletions src/parser.rs
Expand Up @@ -4070,10 +4070,17 @@ impl<'a> Parser<'a> {
None
};

let returning = if self.parse_keyword(Keyword::RETURNING) {
Some(self.parse_comma_separated(Parser::parse_select_item)?)
} else {
None
};

Ok(Statement::Delete {
table_name,
using,
selection,
returning,
})
}

Expand Down Expand Up @@ -5191,12 +5198,38 @@ impl<'a> Parser<'a> {

let source = Box::new(self.parse_query()?);
let on = if self.parse_keyword(Keyword::ON) {
self.expect_keyword(Keyword::DUPLICATE)?;
self.expect_keyword(Keyword::KEY)?;
self.expect_keyword(Keyword::UPDATE)?;
let l = self.parse_comma_separated(Parser::parse_assignment)?;
if self.parse_keyword(Keyword::CONFLICT) {
let conflict_target =
self.parse_parenthesized_column_list(IsOptional::Optional)?;

Some(OnInsert::DuplicateKeyUpdate(l))
self.expect_keyword(Keyword::DO)?;
let action = if self.parse_keyword(Keyword::NOTHING) {
OnConflictAction::DoNothing
} else {
self.expect_keyword(Keyword::UPDATE)?;
self.expect_keyword(Keyword::SET)?;
let l = self.parse_comma_separated(Parser::parse_assignment)?;
OnConflictAction::DoUpdate(l)
};

Some(OnInsert::OnConflict(OnConflict {
conflict_target,
action,
}))
} else {
self.expect_keyword(Keyword::DUPLICATE)?;
self.expect_keyword(Keyword::KEY)?;
self.expect_keyword(Keyword::UPDATE)?;
let l = self.parse_comma_separated(Parser::parse_assignment)?;

Some(OnInsert::DuplicateKeyUpdate(l))
}
} else {
None
};

let returning = if self.parse_keyword(Keyword::RETURNING) {
Some(self.parse_comma_separated(Parser::parse_select_item)?)
} else {
None
};
Expand All @@ -5212,6 +5245,7 @@ impl<'a> Parser<'a> {
source,
table,
on,
returning,
})
}
}
Expand All @@ -5230,11 +5264,17 @@ impl<'a> Parser<'a> {
} else {
None
};
let returning = if self.parse_keyword(Keyword::RETURNING) {
Some(self.parse_comma_separated(Parser::parse_select_item)?)
} else {
None
};
Ok(Statement::Update {
table,
assignments,
from,
selection,
returning,
})
}

Expand Down
6 changes: 6 additions & 0 deletions tests/sqlparser_common.rs
Expand Up @@ -195,6 +195,7 @@ fn parse_update_with_table_alias() {
assignments,
from: _from,
selection,
returning,
} => {
assert_eq!(
TableWithJoins {
Expand Down Expand Up @@ -231,6 +232,7 @@ fn parse_update_with_table_alias() {
}),
selection
);
assert_eq!(None, returning);
}
_ => unreachable!(),
}
Expand Down Expand Up @@ -278,6 +280,7 @@ fn parse_where_delete_statement() {
table_name,
using,
selection,
returning,
} => {
assert_eq!(
TableFactor::Table {
Expand All @@ -298,6 +301,7 @@ fn parse_where_delete_statement() {
},
selection.unwrap(),
);
assert_eq!(None, returning);
}
_ => unreachable!(),
}
Expand All @@ -313,6 +317,7 @@ fn parse_where_delete_with_alias_statement() {
table_name,
using,
selection,
returning,
} => {
assert_eq!(
TableFactor::Table {
Expand Down Expand Up @@ -353,6 +358,7 @@ fn parse_where_delete_with_alias_statement() {
},
selection.unwrap(),
);
assert_eq!(None, returning);
}
_ => unreachable!(),
}
Expand Down
2 changes: 2 additions & 0 deletions tests/sqlparser_mysql.rs
Expand Up @@ -814,6 +814,7 @@ fn parse_update_with_joins() {
assignments,
from: _from,
selection,
returning,
} => {
assert_eq!(
TableWithJoins {
Expand Down Expand Up @@ -869,6 +870,7 @@ fn parse_update_with_joins() {
}),
selection
);
assert_eq!(None, returning);
}
_ => unreachable!(),
}
Expand Down

0 comments on commit 814367a

Please sign in to comment.