diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 29839b630..1bf8c9cd0 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -32,9 +32,9 @@ pub use self::ddl::{ pub use self::operator::{BinaryOperator, UnaryOperator}; pub use self::query::{ Cte, ExceptSelectItem, ExcludeSelectItem, Fetch, Join, JoinConstraint, JoinOperator, - LateralView, LockType, Offset, OffsetRows, OrderByExpr, Query, Select, SelectInto, SelectItem, - SetExpr, SetOperator, SetQuantifier, Table, TableAlias, TableFactor, TableWithJoins, Top, - Values, WildcardAdditionalOptions, With, + LateralView, LockClause, LockType, NonBlock, Offset, OffsetRows, OrderByExpr, Query, Select, + SelectInto, SelectItem, SetExpr, SetOperator, SetQuantifier, Table, TableAlias, TableFactor, + TableWithJoins, Top, Values, WildcardAdditionalOptions, With, }; pub use self::value::{escape_quoted_string, DateTimeField, TrimWhereField, Value}; diff --git a/src/ast/query.rs b/src/ast/query.rs index f813f44dd..9d23b1ae7 100644 --- a/src/ast/query.rs +++ b/src/ast/query.rs @@ -35,8 +35,8 @@ pub struct Query { pub offset: Option, /// `FETCH { FIRST | NEXT } [ PERCENT ] { ROW | ROWS } | { ONLY | WITH TIES }` pub fetch: Option, - /// `FOR { UPDATE | SHARE }` - pub lock: Option, + /// `FOR { UPDATE | SHARE } [ OF table_name ] [ SKIP LOCKED | NOWAIT ]` + pub locks: Vec, } impl fmt::Display for Query { @@ -57,8 +57,8 @@ impl fmt::Display for Query { if let Some(ref fetch) = self.fetch { write!(f, " {}", fetch)?; } - if let Some(ref lock) = self.lock { - write!(f, " {}", lock)?; + if !self.locks.is_empty() { + write!(f, " {}", display_separated(&self.locks, " "))?; } Ok(()) } @@ -833,6 +833,27 @@ impl fmt::Display for Fetch { } } +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct LockClause { + pub lock_type: LockType, + pub of: Option, + pub nonblock: Option, +} + +impl fmt::Display for LockClause { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "FOR {}", &self.lock_type)?; + if let Some(ref of) = self.of { + write!(f, " OF {}", of)?; + } + if let Some(ref nb) = self.nonblock { + write!(f, " {}", nb)?; + } + Ok(()) + } +} + #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum LockType { @@ -843,13 +864,30 @@ pub enum LockType { impl fmt::Display for LockType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let select_lock = match self { - LockType::Share => "FOR SHARE", - LockType::Update => "FOR UPDATE", + LockType::Share => "SHARE", + LockType::Update => "UPDATE", }; write!(f, "{}", select_lock) } } +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum NonBlock { + Nowait, + SkipLocked, +} + +impl fmt::Display for NonBlock { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let nonblock = match self { + NonBlock::Nowait => "NOWAIT", + NonBlock::SkipLocked => "SKIP LOCKED", + }; + write!(f, "{}", nonblock) + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Top { diff --git a/src/keywords.rs b/src/keywords.rs index e87062846..d41463e9e 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -331,6 +331,7 @@ define_keywords!( LOCALTIME, LOCALTIMESTAMP, LOCATION, + LOCKED, LOGIN, LOWER, MANAGEDLOCATION, @@ -382,6 +383,7 @@ define_keywords!( NOSUPERUSER, NOT, NOTHING, + NOWAIT, NTH_VALUE, NTILE, NULL, @@ -509,6 +511,7 @@ define_keywords!( SHARE, SHOW, SIMILAR, + SKIP, SMALLINT, SNAPSHOT, SOME, diff --git a/src/parser.rs b/src/parser.rs index 551961b8f..1e3d71b34 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -4505,11 +4505,10 @@ impl<'a> Parser<'a> { None }; - let lock = if self.parse_keyword(Keyword::FOR) { - Some(self.parse_lock()?) - } else { - None - }; + let mut locks = Vec::new(); + while self.parse_keyword(Keyword::FOR) { + locks.push(self.parse_lock()?); + } Ok(Query { with, @@ -4518,7 +4517,7 @@ impl<'a> Parser<'a> { limit, offset, fetch, - lock, + locks, }) } else { let insert = self.parse_insert()?; @@ -4530,7 +4529,7 @@ impl<'a> Parser<'a> { order_by: vec![], offset: None, fetch: None, - lock: None, + locks: vec![], }) } } @@ -5945,12 +5944,29 @@ impl<'a> Parser<'a> { } /// Parse a FOR UPDATE/FOR SHARE clause - pub fn parse_lock(&mut self) -> Result { - match self.expect_one_of_keywords(&[Keyword::UPDATE, Keyword::SHARE])? { - Keyword::UPDATE => Ok(LockType::Update), - Keyword::SHARE => Ok(LockType::Share), + pub fn parse_lock(&mut self) -> Result { + let lock_type = match self.expect_one_of_keywords(&[Keyword::UPDATE, Keyword::SHARE])? { + Keyword::UPDATE => LockType::Update, + Keyword::SHARE => LockType::Share, _ => unreachable!(), - } + }; + let of = if self.parse_keyword(Keyword::OF) { + Some(self.parse_object_name()?) + } else { + None + }; + let nonblock = if self.parse_keyword(Keyword::NOWAIT) { + Some(NonBlock::Nowait) + } else if self.parse_keywords(&[Keyword::SKIP, Keyword::LOCKED]) { + Some(NonBlock::SkipLocked) + } else { + None + }; + Ok(LockClause { + lock_type, + of, + nonblock, + }) } pub fn parse_values(&mut self) -> Result { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 66557f6a6..e5ed0bb80 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -253,7 +253,7 @@ fn parse_update_set_from() { limit: None, offset: None, fetch: None, - lock: None, + locks: vec![], }), alias: Some(TableAlias { name: Ident::new("t2"), @@ -2296,7 +2296,7 @@ fn parse_create_table_as_table() { limit: None, offset: None, fetch: None, - lock: None, + locks: vec![], }); match verified_stmt(sql1) { @@ -2319,7 +2319,7 @@ fn parse_create_table_as_table() { limit: None, offset: None, fetch: None, - lock: None, + locks: vec![], }); match verified_stmt(sql2) { @@ -3456,7 +3456,7 @@ fn parse_interval_and_or_xor() { limit: None, offset: None, fetch: None, - lock: None, + locks: vec![], }))]; assert_eq!(actual_ast, expected_ast); @@ -5604,7 +5604,7 @@ fn parse_merge() { limit: None, offset: None, fetch: None, - lock: None, + locks: vec![], }), alias: Some(TableAlias { name: Ident { @@ -5729,12 +5729,106 @@ fn test_merge_with_delimiter() { #[test] fn test_lock() { let sql = "SELECT * FROM student WHERE id = '1' FOR UPDATE"; - let ast = verified_query(sql); - assert_eq!(ast.lock.unwrap(), LockType::Update); + let mut ast = verified_query(sql); + assert_eq!(ast.locks.len(), 1); + let lock = ast.locks.pop().unwrap(); + assert_eq!(lock.lock_type, LockType::Update); + assert!(lock.of.is_none()); + assert!(lock.nonblock.is_none()); let sql = "SELECT * FROM student WHERE id = '1' FOR SHARE"; - let ast = verified_query(sql); - assert_eq!(ast.lock.unwrap(), LockType::Share); + let mut ast = verified_query(sql); + assert_eq!(ast.locks.len(), 1); + let lock = ast.locks.pop().unwrap(); + assert_eq!(lock.lock_type, LockType::Share); + assert!(lock.of.is_none()); + assert!(lock.nonblock.is_none()); +} + +#[test] +fn test_lock_table() { + let sql = "SELECT * FROM student WHERE id = '1' FOR UPDATE OF school"; + let mut ast = verified_query(sql); + assert_eq!(ast.locks.len(), 1); + let lock = ast.locks.pop().unwrap(); + assert_eq!(lock.lock_type, LockType::Update); + assert_eq!( + lock.of.unwrap().0, + vec![Ident { + value: "school".to_string(), + quote_style: None + }] + ); + assert!(lock.nonblock.is_none()); + + let sql = "SELECT * FROM student WHERE id = '1' FOR SHARE OF school"; + let mut ast = verified_query(sql); + assert_eq!(ast.locks.len(), 1); + let lock = ast.locks.pop().unwrap(); + assert_eq!(lock.lock_type, LockType::Share); + assert_eq!( + lock.of.unwrap().0, + vec![Ident { + value: "school".to_string(), + quote_style: None + }] + ); + assert!(lock.nonblock.is_none()); + + let sql = "SELECT * FROM student WHERE id = '1' FOR SHARE OF school FOR UPDATE OF student"; + let mut ast = verified_query(sql); + assert_eq!(ast.locks.len(), 2); + let lock = ast.locks.remove(0); + assert_eq!(lock.lock_type, LockType::Share); + assert_eq!( + lock.of.unwrap().0, + vec![Ident { + value: "school".to_string(), + quote_style: None + }] + ); + assert!(lock.nonblock.is_none()); + let lock = ast.locks.remove(0); + assert_eq!(lock.lock_type, LockType::Update); + assert_eq!( + lock.of.unwrap().0, + vec![Ident { + value: "student".to_string(), + quote_style: None + }] + ); + assert!(lock.nonblock.is_none()); +} + +#[test] +fn test_lock_nonblock() { + let sql = "SELECT * FROM student WHERE id = '1' FOR UPDATE OF school SKIP LOCKED"; + let mut ast = verified_query(sql); + assert_eq!(ast.locks.len(), 1); + let lock = ast.locks.pop().unwrap(); + assert_eq!(lock.lock_type, LockType::Update); + assert_eq!( + lock.of.unwrap().0, + vec![Ident { + value: "school".to_string(), + quote_style: None + }] + ); + assert_eq!(lock.nonblock.unwrap(), NonBlock::SkipLocked); + + let sql = "SELECT * FROM student WHERE id = '1' FOR SHARE OF school NOWAIT"; + let mut ast = verified_query(sql); + assert_eq!(ast.locks.len(), 1); + let lock = ast.locks.pop().unwrap(); + assert_eq!(lock.lock_type, LockType::Share); + assert_eq!( + lock.of.unwrap().0, + vec![Ident { + value: "school".to_string(), + quote_style: None + }] + ); + assert_eq!(lock.nonblock.unwrap(), NonBlock::Nowait); } #[test] diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 18f554f9e..67850b1b9 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -466,7 +466,7 @@ fn parse_quote_identifiers_2() { limit: None, offset: None, fetch: None, - lock: None, + locks: vec![], })) ); } @@ -500,7 +500,7 @@ fn parse_quote_identifiers_3() { limit: None, offset: None, fetch: None, - lock: None, + locks: vec![], })) ); } @@ -683,7 +683,7 @@ fn parse_simple_insert() { limit: None, offset: None, fetch: None, - lock: None, + locks: vec![], }), source ); @@ -741,7 +741,7 @@ fn parse_insert_with_on_duplicate_update() { limit: None, offset: None, fetch: None, - lock: None, + locks: vec![], }), source ); @@ -983,7 +983,7 @@ fn parse_substring_in_select() { limit: None, offset: None, fetch: None, - lock: None, + locks: vec![], }), query ); diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 53be5e465..d8144a716 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1512,7 +1512,7 @@ fn parse_array_subquery_expr() { limit: None, offset: None, fetch: None, - lock: None, + locks: vec![], })), expr_from_projection(only(&select.projection)), );