New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ability for dialects to override prefix, infix, and statement parsing #581
Changes from 15 commits
41c9969
af39f1f
cec3108
eadd846
b299687
1445886
3c65284
28568de
ab5e683
f70ff47
f7b57fa
9dd2ea4
c2e096e
f518c30
8640d1b
ed14940
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,6 +152,11 @@ impl<'a> Parser<'a> { | |
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.), | ||
/// stopping before the statement separator, if any. | ||
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> { | ||
// allow the dialect to override statement parsing | ||
if let Some(statement) = self.dialect.parse_statement(self) { | ||
return statement; | ||
} | ||
|
||
match self.next_token() { | ||
Token::Word(w) => match w.keyword { | ||
Keyword::KILL => Ok(self.parse_kill()?), | ||
|
@@ -195,13 +200,6 @@ impl<'a> Parser<'a> { | |
Keyword::EXECUTE => Ok(self.parse_execute()?), | ||
Keyword::PREPARE => Ok(self.parse_prepare()?), | ||
Keyword::MERGE => Ok(self.parse_merge()?), | ||
Keyword::REPLACE if dialect_of!(self is SQLiteDialect ) => { | ||
self.prev_token(); | ||
Ok(self.parse_insert()?) | ||
} | ||
Keyword::COMMENT if dialect_of!(self is PostgreSqlDialect) => { | ||
Ok(self.parse_comment()?) | ||
} | ||
_ => self.expected("an SQL statement", Token::Word(w)), | ||
}, | ||
Token::LParen => { | ||
|
@@ -381,6 +379,11 @@ impl<'a> Parser<'a> { | |
|
||
/// Parse an expression prefix | ||
pub fn parse_prefix(&mut self) -> Result<Expr, ParserError> { | ||
// allow the dialect to override prefix parsing | ||
if let Some(prefix) = self.dialect.parse_prefix(self) { | ||
return prefix; | ||
} | ||
|
||
// PostgreSQL allows any string literal to be preceded by a type name, indicating that the | ||
// string literal represents a literal of that type. Some examples: | ||
// | ||
|
@@ -1164,6 +1167,11 @@ impl<'a> Parser<'a> { | |
|
||
/// Parse an operator following an expression | ||
pub fn parse_infix(&mut self, expr: Expr, precedence: u8) -> Result<Expr, ParserError> { | ||
// allow the dialect to override infix parsing | ||
if let Some(infix) = self.dialect.parse_infix(self, &expr, precedence) { | ||
return infix; | ||
} | ||
|
||
let tok = self.next_token(); | ||
|
||
let regular_binary_operator = match &tok { | ||
|
@@ -1477,6 +1485,11 @@ impl<'a> Parser<'a> { | |
|
||
/// Get the precedence of the next token | ||
pub fn get_next_precedence(&self) -> Result<u8, ParserError> { | ||
// allow the dialect to override precedence logic | ||
if let Some(precedence) = self.dialect.get_next_precedence(self) { | ||
return precedence; | ||
} | ||
|
||
let token = self.peek_token(); | ||
debug!("get_next_precedence() {:?}", token); | ||
let token_0 = self.peek_nth_token(0); | ||
|
@@ -1604,7 +1617,7 @@ impl<'a> Parser<'a> { | |
} | ||
|
||
/// Report unexpected token | ||
fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> { | ||
pub fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> { | ||
parser_err!(format!("Expected {}, found: {}", expected, found)) | ||
} | ||
|
||
|
@@ -4731,35 +4744,6 @@ impl<'a> Parser<'a> { | |
}) | ||
} | ||
|
||
pub fn parse_comment(&mut self) -> Result<Statement, ParserError> { | ||
self.expect_keyword(Keyword::ON)?; | ||
let token = self.next_token(); | ||
|
||
let (object_type, object_name) = match token { | ||
Token::Word(w) if w.keyword == Keyword::COLUMN => { | ||
let object_name = self.parse_object_name()?; | ||
(CommentObject::Column, object_name) | ||
} | ||
Token::Word(w) if w.keyword == Keyword::TABLE => { | ||
let object_name = self.parse_object_name()?; | ||
(CommentObject::Table, object_name) | ||
} | ||
_ => self.expected("comment object_type", token)?, | ||
}; | ||
|
||
self.expect_keyword(Keyword::IS)?; | ||
let comment = if self.parse_keyword(Keyword::NULL) { | ||
None | ||
} else { | ||
Some(self.parse_literal_string()?) | ||
}; | ||
Ok(Statement::Comment { | ||
object_type, | ||
object_name, | ||
comment, | ||
}) | ||
} | ||
|
||
pub fn parse_merge_clauses(&mut self) -> Result<Vec<MergeClause>, ParserError> { | ||
let mut clauses: Vec<MergeClause> = vec![]; | ||
loop { | ||
|
@@ -4926,4 +4910,124 @@ mod tests { | |
assert_eq!(ast.to_string(), sql.to_string()); | ||
}); | ||
} | ||
|
||
#[test] | ||
fn custom_prefix_parser() -> Result<(), ParserError> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recommend putting these tests into Perhaps There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. I have moved them. |
||
#[derive(Debug)] | ||
struct MyDialect {} | ||
|
||
impl Dialect for MyDialect { | ||
fn is_identifier_start(&self, ch: char) -> bool { | ||
is_identifier_start(ch) | ||
} | ||
|
||
fn is_identifier_part(&self, ch: char) -> bool { | ||
is_identifier_part(ch) | ||
} | ||
|
||
fn parse_prefix(&self, parser: &mut Parser) -> Option<Result<Expr, ParserError>> { | ||
if parser.consume_token(&Token::Number("1".to_string(), false)) { | ||
Some(Ok(Expr::Value(Value::Null))) | ||
} else { | ||
None | ||
} | ||
} | ||
} | ||
|
||
let dialect = MyDialect {}; | ||
let sql = "SELECT 1 + 2"; | ||
let ast = Parser::parse_sql(&dialect, sql)?; | ||
let query = &ast[0]; | ||
assert_eq!("SELECT NULL + 2", &format!("{}", query)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😆 |
||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn custom_infix_parser() -> Result<(), ParserError> { | ||
#[derive(Debug)] | ||
struct MyDialect {} | ||
|
||
impl Dialect for MyDialect { | ||
fn is_identifier_start(&self, ch: char) -> bool { | ||
is_identifier_start(ch) | ||
} | ||
|
||
fn is_identifier_part(&self, ch: char) -> bool { | ||
is_identifier_part(ch) | ||
} | ||
|
||
fn parse_infix( | ||
&self, | ||
parser: &mut Parser, | ||
expr: &Expr, | ||
_precendence: u8, | ||
) -> Option<Result<Expr, ParserError>> { | ||
if parser.consume_token(&Token::Plus) { | ||
Some(Ok(Expr::BinaryOp { | ||
left: Box::new(expr.clone()), | ||
op: BinaryOperator::Multiply, // translate Plus to Multiply | ||
right: Box::new(parser.parse_expr().unwrap()), | ||
})) | ||
} else { | ||
None | ||
} | ||
} | ||
} | ||
|
||
let dialect = MyDialect {}; | ||
let sql = "SELECT 1 + 2"; | ||
let ast = Parser::parse_sql(&dialect, sql)?; | ||
let query = &ast[0]; | ||
assert_eq!("SELECT 1 * 2", &format!("{}", query)); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn custom_statement_parser() -> Result<(), ParserError> { | ||
#[derive(Debug)] | ||
struct MyDialect {} | ||
|
||
impl Dialect for MyDialect { | ||
fn is_identifier_start(&self, ch: char) -> bool { | ||
is_identifier_start(ch) | ||
} | ||
|
||
fn is_identifier_part(&self, ch: char) -> bool { | ||
is_identifier_part(ch) | ||
} | ||
|
||
fn parse_statement( | ||
&self, | ||
parser: &mut Parser, | ||
) -> Option<Result<Statement, ParserError>> { | ||
if parser.parse_keyword(Keyword::SELECT) { | ||
for _ in 0..3 { | ||
let _ = parser.next_token(); | ||
} | ||
Some(Ok(Statement::Commit { chain: false })) | ||
} else { | ||
None | ||
} | ||
} | ||
} | ||
|
||
let dialect = MyDialect {}; | ||
let sql = "SELECT 1 + 2"; | ||
let ast = Parser::parse_sql(&dialect, sql)?; | ||
let query = &ast[0]; | ||
assert_eq!("COMMIT", &format!("{}", query)); | ||
Ok(()) | ||
} | ||
|
||
fn is_identifier_start(ch: char) -> bool { | ||
('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_' | ||
} | ||
|
||
fn is_identifier_part(ch: char) -> bool { | ||
('a'..='z').contains(&ch) | ||
|| ('A'..='Z').contains(&ch) | ||
|| ('0'..='9').contains(&ch) | ||
|| ch == '$' | ||
|| ch == '_' | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This dialect-specific code has now moved to the appropriate dialects.