Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support generic expressions in SET statement #574

Merged
merged 1 commit into from Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 4 additions & 19 deletions src/ast/mod.rs
Expand Up @@ -533,8 +533,10 @@ impl fmt::Display for Expr {
Expr::UnaryOp { op, expr } => {
if op == &UnaryOperator::PGPostfixFactorial {
write!(f, "{}{}", expr, op)
} else {
} else if op == &UnaryOperator::Not {
write!(f, "{} {}", op, expr)
} else {
write!(f, "{}{}", op, expr)
}
}
Expr::Cast { expr, data_type } => write!(f, "CAST({} AS {})", expr, data_type),
Expand Down Expand Up @@ -1088,7 +1090,7 @@ pub enum Statement {
local: bool,
hivevar: bool,
variable: ObjectName,
value: Vec<SetVariableValue>,
value: Vec<Expr>,
},
/// SET NAMES 'charset_name' [COLLATE 'collation_name']
///
Expand Down Expand Up @@ -2733,23 +2735,6 @@ impl fmt::Display for ShowStatementFilter {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum SetVariableValue {
Ident(Ident),
Literal(Value),
}

impl fmt::Display for SetVariableValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use SetVariableValue::*;
match self {
Ident(ident) => write!(f, "{}", ident),
Literal(literal) => write!(f, "{}", literal),
}
}
}

/// Sqlite specific syntax
///
/// https://sqlite.org/lang_conflict.html
Expand Down
1 change: 1 addition & 0 deletions src/dialect/mysql.rs
Expand Up @@ -24,6 +24,7 @@ impl Dialect for MySqlDialect {
|| ('A'..='Z').contains(&ch)
|| ch == '_'
|| ch == '$'
|| ch == '@'
|| ('\u{0080}'..='\u{ffff}').contains(&ch)
}

Expand Down
20 changes: 5 additions & 15 deletions src/parser.rs
Expand Up @@ -3737,22 +3737,12 @@ impl<'a> Parser<'a> {
} else if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
let mut values = vec![];
loop {
let token = self.peek_token();
let value = match (self.parse_value(), token) {
(Ok(value), _) => SetVariableValue::Literal(value),
(Err(_), Token::Word(ident)) => SetVariableValue::Ident(ident.to_ident()),
(Err(_), Token::Minus) => {
let next_token = self.next_token();
match next_token {
Token::Word(ident) => SetVariableValue::Ident(Ident {
quote_style: ident.quote_style,
value: format!("-{}", ident.value),
}),
_ => self.expected("word", next_token)?,
}
}
(Err(_), unexpected) => self.expected("variable value", unexpected)?,
let value = if let Ok(expr) = self.parse_expr() {
expr
} else {
self.expected("variable value", self.peek_token())?
};

values.push(value);
if self.consume_token(&Token::Comma) {
continue;
Expand Down
8 changes: 4 additions & 4 deletions tests/sqlparser_common.rs
Expand Up @@ -580,7 +580,7 @@ fn parse_select_count_wildcard() {

#[test]
fn parse_select_count_distinct() {
let sql = "SELECT COUNT(DISTINCT + x) FROM customer";
let sql = "SELECT COUNT(DISTINCT +x) FROM customer";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I was confused about this for a while, but then I see that +x is actually a unary op, so it makes sense to display it as +x rather than + x 👍

let select = verified_only_select(sql);
assert_eq!(
&Expr::Function(Function {
Expand All @@ -597,8 +597,8 @@ fn parse_select_count_distinct() {
);

one_statement_parses_to(
"SELECT COUNT(ALL + x) FROM customer",
"SELECT COUNT(+ x) FROM customer",
"SELECT COUNT(ALL +x) FROM customer",
"SELECT COUNT(+x) FROM customer",
);

let sql = "SELECT COUNT(ALL DISTINCT + x) FROM customer";
Expand Down Expand Up @@ -754,7 +754,7 @@ fn parse_compound_expr_2() {
#[test]
fn parse_unary_math() {
use self::Expr::*;
let sql = "- a + - b";
let sql = "-a + -b";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

assert_eq!(
BinaryOp {
left: Box::new(UnaryOp {
Expand Down
9 changes: 6 additions & 3 deletions tests/sqlparser_hive.rs
Expand Up @@ -15,7 +15,7 @@
//! Test SQL syntax specific to Hive. The parser based on the generic dialect
//! is also tested (on the inputs it can handle).

use sqlparser::ast::{CreateFunctionUsing, Ident, ObjectName, SetVariableValue, Statement};
use sqlparser::ast::{CreateFunctionUsing, Expr, Ident, ObjectName, Statement, UnaryOperator};
use sqlparser::dialect::{GenericDialect, HiveDialect};
use sqlparser::parser::ParserError;
use sqlparser::test_utils::*;
Expand Down Expand Up @@ -220,14 +220,17 @@ fn set_statement_with_minus() {
Ident::new("java"),
Ident::new("opts")
]),
value: vec![SetVariableValue::Ident("-Xmx4g".into())],
value: vec![Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Identifier(Ident::new("Xmx4g")))
}],
}
);

assert_eq!(
hive().parse_sql_statements("SET hive.tez.java.opts = -"),
Err(ParserError::ParserError(
"Expected word, found: EOF".to_string()
"Expected variable value, found: EOF".to_string()
))
)
}
Expand Down
20 changes: 20 additions & 0 deletions tests/sqlparser_mysql.rs
Expand Up @@ -251,6 +251,26 @@ fn parse_use() {
);
}

#[test]
fn parse_set_variables() {
mysql_and_generic().verified_stmt("SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')");
assert_eq!(
mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"),
Statement::SetVariable {
local: true,
hivevar: false,
variable: ObjectName(vec!["autocommit".into()]),
value: vec![Expr::Value(Value::Number(
#[cfg(not(feature = "bigdecimal"))]
"1".to_string(),
#[cfg(feature = "bigdecimal")]
bigdecimal::BigDecimal::from(1),
false
))],
}
);
}

#[test]
fn parse_create_table_auto_increment() {
let sql = "CREATE TABLE foo (bar INT PRIMARY KEY AUTO_INCREMENT)";
Expand Down
34 changes: 23 additions & 11 deletions tests/sqlparser_postgres.rs
Expand Up @@ -18,7 +18,6 @@
mod test_utils;
use test_utils::*;

use sqlparser::ast::Value::Boolean;
use sqlparser::ast::*;
use sqlparser::dialect::{GenericDialect, PostgreSqlDialect};
use sqlparser::parser::ParserError;
Expand Down Expand Up @@ -782,7 +781,10 @@ fn parse_set() {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Ident("b".into())],
value: vec![Expr::Identifier(Ident {
value: "b".into(),
quote_style: None
})],
}
);

Expand All @@ -793,9 +795,7 @@ fn parse_set() {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Literal(Value::SingleQuotedString(
"b".into()
))],
value: vec![Expr::Value(Value::SingleQuotedString("b".into()))],
}
);

Expand All @@ -806,7 +806,13 @@ fn parse_set() {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Literal(number("0"))],
value: vec![Expr::Value(Value::Number(
#[cfg(not(feature = "bigdecimal"))]
"0".to_string(),
#[cfg(feature = "bigdecimal")]
bigdecimal::BigDecimal::from(0),
false,
))],
}
);

Expand All @@ -817,7 +823,10 @@ fn parse_set() {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Ident("DEFAULT".into())],
value: vec![Expr::Identifier(Ident {
value: "DEFAULT".into(),
quote_style: None
})],
}
);

Expand All @@ -828,7 +837,7 @@ fn parse_set() {
local: true,
hivevar: false,
variable: ObjectName(vec![Ident::new("a")]),
value: vec![SetVariableValue::Ident("b".into())],
value: vec![Expr::Identifier("b".into())],
}
);

Expand All @@ -839,7 +848,10 @@ fn parse_set() {
local: false,
hivevar: false,
variable: ObjectName(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]),
value: vec![SetVariableValue::Ident("b".into())],
value: vec![Expr::Identifier(Ident {
value: "b".into(),
quote_style: None
})],
}
);

Expand All @@ -859,7 +871,7 @@ fn parse_set() {
Ident::new("reducer"),
Ident::new("parallelism")
]),
value: vec![SetVariableValue::Literal(Boolean(false))],
value: vec![Expr::Value(Value::Boolean(false))],
}
);

Expand Down Expand Up @@ -1107,7 +1119,7 @@ fn parse_pg_unary_ops() {
];

for (str_op, op) in pg_unary_ops {
let select = pg().verified_only_select(&format!("SELECT {} a", &str_op));
let select = pg().verified_only_select(&format!("SELECT {}a", &str_op));
assert_eq!(
SelectItem::UnnamedExpr(Expr::UnaryOp {
op: op.clone(),
Expand Down