From 72559e9b6298f39c9ab2a084b5e1cd11d3fb6c6d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 19 Aug 2022 05:44:14 -0600 Subject: [PATCH] Add ability for dialects to override prefix, infix, and statement parsing (#581) --- src/dialect/mod.rs | 27 ++++++ src/dialect/postgresql.rs | 41 +++++++++ src/dialect/sqlite.rs | 12 +++ src/parser.rs | 58 +++++-------- tests/sqlparser_custom_dialect.rs | 138 ++++++++++++++++++++++++++++++ 5 files changed, 239 insertions(+), 37 deletions(-) create mode 100644 tests/sqlparser_custom_dialect.rs diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 63821dd74..46e8dda2c 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -22,6 +22,7 @@ mod redshift; mod snowflake; mod sqlite; +use crate::ast::{Expr, Statement}; use core::any::{Any, TypeId}; use core::fmt::Debug; use core::iter::Peekable; @@ -39,6 +40,7 @@ pub use self::redshift::RedshiftSqlDialect; pub use self::snowflake::SnowflakeDialect; pub use self::sqlite::SQLiteDialect; pub use crate::keywords; +use crate::parser::{Parser, ParserError}; /// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates /// to `true` if `parser.dialect` is one of the `Dialect`s specified. @@ -65,6 +67,31 @@ pub trait Dialect: Debug + Any { fn is_identifier_start(&self, ch: char) -> bool; /// Determine if a character is a valid unquoted identifier character fn is_identifier_part(&self, ch: char) -> bool; + /// Dialect-specific prefix parser override + fn parse_prefix(&self, _parser: &mut Parser) -> Option> { + // return None to fall back to the default behavior + None + } + /// Dialect-specific infix parser override + fn parse_infix( + &self, + _parser: &mut Parser, + _expr: &Expr, + _precendence: u8, + ) -> Option> { + // return None to fall back to the default behavior + None + } + /// Dialect-specific precedence override + fn get_next_precedence(&self, _parser: &Parser) -> Option> { + // return None to fall back to the default behavior + None + } + /// Dialect-specific statement parser override + fn parse_statement(&self, _parser: &mut Parser) -> Option> { + // return None to fall back to the default behavior + None + } } impl dyn Dialect { diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 0c2eb99f0..04d64b9bf 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -10,7 +10,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::ast::{CommentObject, Statement}; use crate::dialect::Dialect; +use crate::keywords::Keyword; +use crate::parser::{Parser, ParserError}; +use crate::tokenizer::Token; #[derive(Debug)] pub struct PostgreSqlDialect {} @@ -30,4 +34,41 @@ impl Dialect for PostgreSqlDialect { || ch == '$' || ch == '_' } + + fn parse_statement(&self, parser: &mut Parser) -> Option> { + if parser.parse_keyword(Keyword::COMMENT) { + Some(parse_comment(parser)) + } else { + None + } + } +} + +pub fn parse_comment(parser: &mut Parser) -> Result { + parser.expect_keyword(Keyword::ON)?; + let token = parser.next_token(); + + let (object_type, object_name) = match token { + Token::Word(w) if w.keyword == Keyword::COLUMN => { + let object_name = parser.parse_object_name()?; + (CommentObject::Column, object_name) + } + Token::Word(w) if w.keyword == Keyword::TABLE => { + let object_name = parser.parse_object_name()?; + (CommentObject::Table, object_name) + } + _ => parser.expected("comment object_type", token)?, + }; + + parser.expect_keyword(Keyword::IS)?; + let comment = if parser.parse_keyword(Keyword::NULL) { + None + } else { + Some(parser.parse_literal_string()?) + }; + Ok(Statement::Comment { + object_type, + object_name, + comment, + }) } diff --git a/src/dialect/sqlite.rs b/src/dialect/sqlite.rs index 4ce2f834b..64d7f62fd 100644 --- a/src/dialect/sqlite.rs +++ b/src/dialect/sqlite.rs @@ -10,7 +10,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::ast::Statement; use crate::dialect::Dialect; +use crate::keywords::Keyword; +use crate::parser::{Parser, ParserError}; #[derive(Debug)] pub struct SQLiteDialect {} @@ -35,4 +38,13 @@ impl Dialect for SQLiteDialect { fn is_identifier_part(&self, ch: char) -> bool { self.is_identifier_start(ch) || ('0'..='9').contains(&ch) } + + fn parse_statement(&self, parser: &mut Parser) -> Option> { + if parser.parse_keyword(Keyword::REPLACE) { + parser.prev_token(); + Some(parser.parse_insert()) + } else { + None + } + } } diff --git a/src/parser.rs b/src/parser.rs index 715843e9d..4cfbdd23b 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -152,6 +152,11 @@ impl<'a> Parser<'a> { /// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.), /// stopping before the statement separator, if any. pub fn parse_statement(&mut self) -> Result { + // allow the dialect to override statement parsing + if let Some(statement) = self.dialect.parse_statement(self) { + return statement; + } + match self.next_token() { Token::Word(w) => match w.keyword { Keyword::KILL => Ok(self.parse_kill()?), @@ -195,13 +200,6 @@ impl<'a> Parser<'a> { Keyword::EXECUTE => Ok(self.parse_execute()?), Keyword::PREPARE => Ok(self.parse_prepare()?), Keyword::MERGE => Ok(self.parse_merge()?), - Keyword::REPLACE if dialect_of!(self is SQLiteDialect ) => { - self.prev_token(); - Ok(self.parse_insert()?) - } - Keyword::COMMENT if dialect_of!(self is PostgreSqlDialect) => { - Ok(self.parse_comment()?) - } _ => self.expected("an SQL statement", Token::Word(w)), }, Token::LParen => { @@ -381,6 +379,11 @@ impl<'a> Parser<'a> { /// Parse an expression prefix pub fn parse_prefix(&mut self) -> Result { + // allow the dialect to override prefix parsing + if let Some(prefix) = self.dialect.parse_prefix(self) { + return prefix; + } + // PostgreSQL allows any string literal to be preceded by a type name, indicating that the // string literal represents a literal of that type. Some examples: // @@ -1164,6 +1167,11 @@ impl<'a> Parser<'a> { /// Parse an operator following an expression pub fn parse_infix(&mut self, expr: Expr, precedence: u8) -> Result { + // allow the dialect to override infix parsing + if let Some(infix) = self.dialect.parse_infix(self, &expr, precedence) { + return infix; + } + let tok = self.next_token(); let regular_binary_operator = match &tok { @@ -1491,6 +1499,11 @@ impl<'a> Parser<'a> { /// Get the precedence of the next token pub fn get_next_precedence(&self) -> Result { + // allow the dialect to override precedence logic + if let Some(precedence) = self.dialect.get_next_precedence(self) { + return precedence; + } + let token = self.peek_token(); debug!("get_next_precedence() {:?}", token); let token_0 = self.peek_nth_token(0); @@ -1618,7 +1631,7 @@ impl<'a> Parser<'a> { } /// Report unexpected token - fn expected(&self, expected: &str, found: Token) -> Result { + pub fn expected(&self, expected: &str, found: Token) -> Result { parser_err!(format!("Expected {}, found: {}", expected, found)) } @@ -4735,35 +4748,6 @@ impl<'a> Parser<'a> { }) } - pub fn parse_comment(&mut self) -> Result { - self.expect_keyword(Keyword::ON)?; - let token = self.next_token(); - - let (object_type, object_name) = match token { - Token::Word(w) if w.keyword == Keyword::COLUMN => { - let object_name = self.parse_object_name()?; - (CommentObject::Column, object_name) - } - Token::Word(w) if w.keyword == Keyword::TABLE => { - let object_name = self.parse_object_name()?; - (CommentObject::Table, object_name) - } - _ => self.expected("comment object_type", token)?, - }; - - self.expect_keyword(Keyword::IS)?; - let comment = if self.parse_keyword(Keyword::NULL) { - None - } else { - Some(self.parse_literal_string()?) - }; - Ok(Statement::Comment { - object_type, - object_name, - comment, - }) - } - pub fn parse_merge_clauses(&mut self) -> Result, ParserError> { let mut clauses: Vec = vec![]; loop { diff --git a/tests/sqlparser_custom_dialect.rs b/tests/sqlparser_custom_dialect.rs new file mode 100644 index 000000000..c0fe4c1dd --- /dev/null +++ b/tests/sqlparser_custom_dialect.rs @@ -0,0 +1,138 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Test the ability for dialects to override parsing + +use sqlparser::{ + ast::{BinaryOperator, Expr, Statement, Value}, + dialect::Dialect, + keywords::Keyword, + parser::{Parser, ParserError}, + tokenizer::Token, +}; + +#[test] +fn custom_prefix_parser() -> Result<(), ParserError> { + #[derive(Debug)] + struct MyDialect {} + + impl Dialect for MyDialect { + fn is_identifier_start(&self, ch: char) -> bool { + is_identifier_start(ch) + } + + fn is_identifier_part(&self, ch: char) -> bool { + is_identifier_part(ch) + } + + fn parse_prefix(&self, parser: &mut Parser) -> Option> { + if parser.consume_token(&Token::Number("1".to_string(), false)) { + Some(Ok(Expr::Value(Value::Null))) + } else { + None + } + } + } + + let dialect = MyDialect {}; + let sql = "SELECT 1 + 2"; + let ast = Parser::parse_sql(&dialect, sql)?; + let query = &ast[0]; + assert_eq!("SELECT NULL + 2", &format!("{}", query)); + Ok(()) +} + +#[test] +fn custom_infix_parser() -> Result<(), ParserError> { + #[derive(Debug)] + struct MyDialect {} + + impl Dialect for MyDialect { + fn is_identifier_start(&self, ch: char) -> bool { + is_identifier_start(ch) + } + + fn is_identifier_part(&self, ch: char) -> bool { + is_identifier_part(ch) + } + + fn parse_infix( + &self, + parser: &mut Parser, + expr: &Expr, + _precendence: u8, + ) -> Option> { + if parser.consume_token(&Token::Plus) { + Some(Ok(Expr::BinaryOp { + left: Box::new(expr.clone()), + op: BinaryOperator::Multiply, // translate Plus to Multiply + right: Box::new(parser.parse_expr().unwrap()), + })) + } else { + None + } + } + } + + let dialect = MyDialect {}; + let sql = "SELECT 1 + 2"; + let ast = Parser::parse_sql(&dialect, sql)?; + let query = &ast[0]; + assert_eq!("SELECT 1 * 2", &format!("{}", query)); + Ok(()) +} + +#[test] +fn custom_statement_parser() -> Result<(), ParserError> { + #[derive(Debug)] + struct MyDialect {} + + impl Dialect for MyDialect { + fn is_identifier_start(&self, ch: char) -> bool { + is_identifier_start(ch) + } + + fn is_identifier_part(&self, ch: char) -> bool { + is_identifier_part(ch) + } + + fn parse_statement(&self, parser: &mut Parser) -> Option> { + if parser.parse_keyword(Keyword::SELECT) { + for _ in 0..3 { + let _ = parser.next_token(); + } + Some(Ok(Statement::Commit { chain: false })) + } else { + None + } + } + } + + let dialect = MyDialect {}; + let sql = "SELECT 1 + 2"; + let ast = Parser::parse_sql(&dialect, sql)?; + let query = &ast[0]; + assert_eq!("COMMIT", &format!("{}", query)); + Ok(()) +} + +fn is_identifier_start(ch: char) -> bool { + ('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_' +} + +fn is_identifier_part(ch: char) -> bool { + ('a'..='z').contains(&ch) + || ('A'..='Z').contains(&ch) + || ('0'..='9').contains(&ch) + || ch == '$' + || ch == '_' +}