diff --git a/src/ast/mod.rs b/src/ast/mod.rs index b91aae43c..d551737d7 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -533,8 +533,10 @@ impl fmt::Display for Expr { Expr::UnaryOp { op, expr } => { if op == &UnaryOperator::PGPostfixFactorial { write!(f, "{}{}", expr, op) - } else { + } else if op == &UnaryOperator::Not { write!(f, "{} {}", op, expr) + } else { + write!(f, "{}{}", op, expr) } } Expr::Cast { expr, data_type } => write!(f, "CAST({} AS {})", expr, data_type), @@ -1088,7 +1090,7 @@ pub enum Statement { local: bool, hivevar: bool, variable: ObjectName, - value: Vec, + value: Vec, }, /// SET NAMES 'charset_name' [COLLATE 'collation_name'] /// @@ -2733,23 +2735,6 @@ impl fmt::Display for ShowStatementFilter { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub enum SetVariableValue { - Ident(Ident), - Literal(Value), -} - -impl fmt::Display for SetVariableValue { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use SetVariableValue::*; - match self { - Ident(ident) => write!(f, "{}", ident), - Literal(literal) => write!(f, "{}", literal), - } - } -} - /// Sqlite specific syntax /// /// https://sqlite.org/lang_conflict.html diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index 6581195b8..d6095262c 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -24,6 +24,7 @@ impl Dialect for MySqlDialect { || ('A'..='Z').contains(&ch) || ch == '_' || ch == '$' + || ch == '@' || ('\u{0080}'..='\u{ffff}').contains(&ch) } diff --git a/src/parser.rs b/src/parser.rs index 3a90b3ccb..596cf888c 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -3737,22 +3737,12 @@ impl<'a> Parser<'a> { } else if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) { let mut values = vec![]; loop { - let token = self.peek_token(); - let value = match (self.parse_value(), token) { - (Ok(value), _) => SetVariableValue::Literal(value), - (Err(_), Token::Word(ident)) => SetVariableValue::Ident(ident.to_ident()), - (Err(_), Token::Minus) => { - let next_token = self.next_token(); - match next_token { - Token::Word(ident) => SetVariableValue::Ident(Ident { - quote_style: ident.quote_style, - value: format!("-{}", ident.value), - }), - _ => self.expected("word", next_token)?, - } - } - (Err(_), unexpected) => self.expected("variable value", unexpected)?, + let value = if let Ok(expr) = self.parse_expr() { + expr + } else { + self.expected("variable value", self.peek_token())? }; + values.push(value); if self.consume_token(&Token::Comma) { continue; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index c87b0c5e8..4e3887f0e 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -580,7 +580,7 @@ fn parse_select_count_wildcard() { #[test] fn parse_select_count_distinct() { - let sql = "SELECT COUNT(DISTINCT + x) FROM customer"; + let sql = "SELECT COUNT(DISTINCT +x) FROM customer"; let select = verified_only_select(sql); assert_eq!( &Expr::Function(Function { @@ -597,8 +597,8 @@ fn parse_select_count_distinct() { ); one_statement_parses_to( - "SELECT COUNT(ALL + x) FROM customer", - "SELECT COUNT(+ x) FROM customer", + "SELECT COUNT(ALL +x) FROM customer", + "SELECT COUNT(+x) FROM customer", ); let sql = "SELECT COUNT(ALL DISTINCT + x) FROM customer"; @@ -754,7 +754,7 @@ fn parse_compound_expr_2() { #[test] fn parse_unary_math() { use self::Expr::*; - let sql = "- a + - b"; + let sql = "-a + -b"; assert_eq!( BinaryOp { left: Box::new(UnaryOp { diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index fa2486120..4223ad5fa 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -15,7 +15,7 @@ //! Test SQL syntax specific to Hive. The parser based on the generic dialect //! is also tested (on the inputs it can handle). -use sqlparser::ast::{CreateFunctionUsing, Ident, ObjectName, SetVariableValue, Statement}; +use sqlparser::ast::{CreateFunctionUsing, Expr, Ident, ObjectName, Statement, UnaryOperator}; use sqlparser::dialect::{GenericDialect, HiveDialect}; use sqlparser::parser::ParserError; use sqlparser::test_utils::*; @@ -220,14 +220,17 @@ fn set_statement_with_minus() { Ident::new("java"), Ident::new("opts") ]), - value: vec![SetVariableValue::Ident("-Xmx4g".into())], + value: vec![Expr::UnaryOp { + op: UnaryOperator::Minus, + expr: Box::new(Expr::Identifier(Ident::new("Xmx4g"))) + }], } ); assert_eq!( hive().parse_sql_statements("SET hive.tez.java.opts = -"), Err(ParserError::ParserError( - "Expected word, found: EOF".to_string() + "Expected variable value, found: EOF".to_string() )) ) } diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index c1cfa2876..f46d5d23e 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -251,6 +251,26 @@ fn parse_use() { ); } +#[test] +fn parse_set_variables() { + mysql_and_generic().verified_stmt("SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')"); + assert_eq!( + mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"), + Statement::SetVariable { + local: true, + hivevar: false, + variable: ObjectName(vec!["autocommit".into()]), + value: vec![Expr::Value(Value::Number( + #[cfg(not(feature = "bigdecimal"))] + "1".to_string(), + #[cfg(feature = "bigdecimal")] + bigdecimal::BigDecimal::from(1), + false + ))], + } + ); +} + #[test] fn parse_create_table_auto_increment() { let sql = "CREATE TABLE foo (bar INT PRIMARY KEY AUTO_INCREMENT)"; diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index e3c9332dc..3aaabc9e3 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -18,7 +18,6 @@ mod test_utils; use test_utils::*; -use sqlparser::ast::Value::Boolean; use sqlparser::ast::*; use sqlparser::dialect::{GenericDialect, PostgreSqlDialect}; use sqlparser::parser::ParserError; @@ -782,7 +781,10 @@ fn parse_set() { local: false, hivevar: false, variable: ObjectName(vec![Ident::new("a")]), - value: vec![SetVariableValue::Ident("b".into())], + value: vec![Expr::Identifier(Ident { + value: "b".into(), + quote_style: None + })], } ); @@ -793,9 +795,7 @@ fn parse_set() { local: false, hivevar: false, variable: ObjectName(vec![Ident::new("a")]), - value: vec![SetVariableValue::Literal(Value::SingleQuotedString( - "b".into() - ))], + value: vec![Expr::Value(Value::SingleQuotedString("b".into()))], } ); @@ -806,7 +806,13 @@ fn parse_set() { local: false, hivevar: false, variable: ObjectName(vec![Ident::new("a")]), - value: vec![SetVariableValue::Literal(number("0"))], + value: vec![Expr::Value(Value::Number( + #[cfg(not(feature = "bigdecimal"))] + "0".to_string(), + #[cfg(feature = "bigdecimal")] + bigdecimal::BigDecimal::from(0), + false, + ))], } ); @@ -817,7 +823,10 @@ fn parse_set() { local: false, hivevar: false, variable: ObjectName(vec![Ident::new("a")]), - value: vec![SetVariableValue::Ident("DEFAULT".into())], + value: vec![Expr::Identifier(Ident { + value: "DEFAULT".into(), + quote_style: None + })], } ); @@ -828,7 +837,7 @@ fn parse_set() { local: true, hivevar: false, variable: ObjectName(vec![Ident::new("a")]), - value: vec![SetVariableValue::Ident("b".into())], + value: vec![Expr::Identifier("b".into())], } ); @@ -839,7 +848,10 @@ fn parse_set() { local: false, hivevar: false, variable: ObjectName(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]), - value: vec![SetVariableValue::Ident("b".into())], + value: vec![Expr::Identifier(Ident { + value: "b".into(), + quote_style: None + })], } ); @@ -859,7 +871,7 @@ fn parse_set() { Ident::new("reducer"), Ident::new("parallelism") ]), - value: vec![SetVariableValue::Literal(Boolean(false))], + value: vec![Expr::Value(Value::Boolean(false))], } ); @@ -1107,7 +1119,7 @@ fn parse_pg_unary_ops() { ]; for (str_op, op) in pg_unary_ops { - let select = pg().verified_only_select(&format!("SELECT {} a", &str_op)); + let select = pg().verified_only_select(&format!("SELECT {}a", &str_op)); assert_eq!( SelectItem::UnnamedExpr(Expr::UnaryOp { op: op.clone(),