Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability for dialects to override prefix, infix, and statement parsing #581

Merged
merged 16 commits into from Aug 19, 2022
27 changes: 27 additions & 0 deletions src/dialect/mod.rs
Expand Up @@ -22,6 +22,7 @@ mod redshift;
mod snowflake;
mod sqlite;

use crate::ast::{Expr, Statement};
use core::any::{Any, TypeId};
use core::fmt::Debug;
use core::iter::Peekable;
Expand All @@ -39,6 +40,7 @@ pub use self::redshift::RedshiftSqlDialect;
pub use self::snowflake::SnowflakeDialect;
pub use self::sqlite::SQLiteDialect;
pub use crate::keywords;
use crate::parser::{Parser, ParserError};

/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
/// to `true` if `parser.dialect` is one of the `Dialect`s specified.
Expand All @@ -65,6 +67,31 @@ pub trait Dialect: Debug + Any {
fn is_identifier_start(&self, ch: char) -> bool;
/// Determine if a character is a valid unquoted identifier character
fn is_identifier_part(&self, ch: char) -> bool;
/// Dialect-specific prefix parser override
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior
None
}
/// Dialect-specific infix parser override
fn parse_infix(
&self,
_parser: &mut Parser,
_expr: &Expr,
_precendence: u8,
) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior
None
}
/// Dialect-specific precedence override
fn get_next_precedence(&self, _parser: &Parser) -> Option<Result<u8, ParserError>> {
// return None to fall back to the default behavior
None
}
/// Dialect-specific statement parser override
fn parse_statement(&self, _parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
// return None to fall back to the default behavior
None
}
}

impl dyn Dialect {
Expand Down
41 changes: 41 additions & 0 deletions src/dialect/postgresql.rs
Expand Up @@ -10,7 +10,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::ast::{CommentObject, Statement};
use crate::dialect::Dialect;
use crate::keywords::Keyword;
use crate::parser::{Parser, ParserError};
use crate::tokenizer::Token;

#[derive(Debug)]
pub struct PostgreSqlDialect {}
Expand All @@ -30,4 +34,41 @@ impl Dialect for PostgreSqlDialect {
|| ch == '$'
|| ch == '_'
}

fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::COMMENT) {
Some(parse_comment(parser))
} else {
None
}
}
}

pub fn parse_comment(parser: &mut Parser) -> Result<Statement, ParserError> {
parser.expect_keyword(Keyword::ON)?;
let token = parser.next_token();

let (object_type, object_name) = match token {
Token::Word(w) if w.keyword == Keyword::COLUMN => {
let object_name = parser.parse_object_name()?;
(CommentObject::Column, object_name)
}
Token::Word(w) if w.keyword == Keyword::TABLE => {
let object_name = parser.parse_object_name()?;
(CommentObject::Table, object_name)
}
_ => parser.expected("comment object_type", token)?,
};

parser.expect_keyword(Keyword::IS)?;
let comment = if parser.parse_keyword(Keyword::NULL) {
None
} else {
Some(parser.parse_literal_string()?)
};
Ok(Statement::Comment {
object_type,
object_name,
comment,
})
}
12 changes: 12 additions & 0 deletions src/dialect/sqlite.rs
Expand Up @@ -10,7 +10,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::ast::Statement;
use crate::dialect::Dialect;
use crate::keywords::Keyword;
use crate::parser::{Parser, ParserError};

#[derive(Debug)]
pub struct SQLiteDialect {}
Expand All @@ -35,4 +38,13 @@ impl Dialect for SQLiteDialect {
fn is_identifier_part(&self, ch: char) -> bool {
self.is_identifier_start(ch) || ('0'..='9').contains(&ch)
}

fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::REPLACE) {
parser.prev_token();
Some(parser.parse_insert())
} else {
None
}
}
}
58 changes: 21 additions & 37 deletions src/parser.rs
Expand Up @@ -152,6 +152,11 @@ impl<'a> Parser<'a> {
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
/// stopping before the statement separator, if any.
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> {
// allow the dialect to override statement parsing
if let Some(statement) = self.dialect.parse_statement(self) {
return statement;
}

match self.next_token() {
Token::Word(w) => match w.keyword {
Keyword::KILL => Ok(self.parse_kill()?),
Expand Down Expand Up @@ -195,13 +200,6 @@ impl<'a> Parser<'a> {
Keyword::EXECUTE => Ok(self.parse_execute()?),
Keyword::PREPARE => Ok(self.parse_prepare()?),
Keyword::MERGE => Ok(self.parse_merge()?),
Keyword::REPLACE if dialect_of!(self is SQLiteDialect ) => {
self.prev_token();
Ok(self.parse_insert()?)
}
Keyword::COMMENT if dialect_of!(self is PostgreSqlDialect) => {
Ok(self.parse_comment()?)
}
Comment on lines -198 to -204
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dialect-specific code has now moved to the appropriate dialects.

_ => self.expected("an SQL statement", Token::Word(w)),
},
Token::LParen => {
Expand Down Expand Up @@ -381,6 +379,11 @@ impl<'a> Parser<'a> {

/// Parse an expression prefix
pub fn parse_prefix(&mut self) -> Result<Expr, ParserError> {
// allow the dialect to override prefix parsing
if let Some(prefix) = self.dialect.parse_prefix(self) {
return prefix;
}

// PostgreSQL allows any string literal to be preceded by a type name, indicating that the
// string literal represents a literal of that type. Some examples:
//
Expand Down Expand Up @@ -1164,6 +1167,11 @@ impl<'a> Parser<'a> {

/// Parse an operator following an expression
pub fn parse_infix(&mut self, expr: Expr, precedence: u8) -> Result<Expr, ParserError> {
// allow the dialect to override infix parsing
if let Some(infix) = self.dialect.parse_infix(self, &expr, precedence) {
return infix;
}

let tok = self.next_token();

let regular_binary_operator = match &tok {
Expand Down Expand Up @@ -1477,6 +1485,11 @@ impl<'a> Parser<'a> {

/// Get the precedence of the next token
pub fn get_next_precedence(&self) -> Result<u8, ParserError> {
// allow the dialect to override precedence logic
if let Some(precedence) = self.dialect.get_next_precedence(self) {
return precedence;
}

let token = self.peek_token();
debug!("get_next_precedence() {:?}", token);
let token_0 = self.peek_nth_token(0);
Expand Down Expand Up @@ -1604,7 +1617,7 @@ impl<'a> Parser<'a> {
}

/// Report unexpected token
fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> {
pub fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> {
parser_err!(format!("Expected {}, found: {}", expected, found))
}

Expand Down Expand Up @@ -4731,35 +4744,6 @@ impl<'a> Parser<'a> {
})
}

pub fn parse_comment(&mut self) -> Result<Statement, ParserError> {
self.expect_keyword(Keyword::ON)?;
let token = self.next_token();

let (object_type, object_name) = match token {
Token::Word(w) if w.keyword == Keyword::COLUMN => {
let object_name = self.parse_object_name()?;
(CommentObject::Column, object_name)
}
Token::Word(w) if w.keyword == Keyword::TABLE => {
let object_name = self.parse_object_name()?;
(CommentObject::Table, object_name)
}
_ => self.expected("comment object_type", token)?,
};

self.expect_keyword(Keyword::IS)?;
let comment = if self.parse_keyword(Keyword::NULL) {
None
} else {
Some(self.parse_literal_string()?)
};
Ok(Statement::Comment {
object_type,
object_name,
comment,
})
}

pub fn parse_merge_clauses(&mut self) -> Result<Vec<MergeClause>, ParserError> {
let mut clauses: Vec<MergeClause> = vec![];
loop {
Expand Down
138 changes: 138 additions & 0 deletions tests/sqlparser_custom_dialect.rs
@@ -0,0 +1,138 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Test the ability for dialects to override parsing

use sqlparser::{
ast::{BinaryOperator, Expr, Statement, Value},
dialect::Dialect,
keywords::Keyword,
parser::{Parser, ParserError},
tokenizer::Token,
};

#[test]
fn custom_prefix_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}

impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
is_identifier_start(ch)
}

fn is_identifier_part(&self, ch: char) -> bool {
is_identifier_part(ch)
}

fn parse_prefix(&self, parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
if parser.consume_token(&Token::Number("1".to_string(), false)) {
Some(Ok(Expr::Value(Value::Null)))
} else {
None
}
}
}

let dialect = MyDialect {};
let sql = "SELECT 1 + 2";
let ast = Parser::parse_sql(&dialect, sql)?;
let query = &ast[0];
assert_eq!("SELECT NULL + 2", &format!("{}", query));
Ok(())
}

#[test]
fn custom_infix_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}

impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
is_identifier_start(ch)
}

fn is_identifier_part(&self, ch: char) -> bool {
is_identifier_part(ch)
}

fn parse_infix(
&self,
parser: &mut Parser,
expr: &Expr,
_precendence: u8,
) -> Option<Result<Expr, ParserError>> {
if parser.consume_token(&Token::Plus) {
Some(Ok(Expr::BinaryOp {
left: Box::new(expr.clone()),
op: BinaryOperator::Multiply, // translate Plus to Multiply
right: Box::new(parser.parse_expr().unwrap()),
}))
} else {
None
}
}
}

let dialect = MyDialect {};
let sql = "SELECT 1 + 2";
let ast = Parser::parse_sql(&dialect, sql)?;
let query = &ast[0];
assert_eq!("SELECT 1 * 2", &format!("{}", query));
Ok(())
}

#[test]
fn custom_statement_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}

impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
is_identifier_start(ch)
}

fn is_identifier_part(&self, ch: char) -> bool {
is_identifier_part(ch)
}

fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::SELECT) {
for _ in 0..3 {
let _ = parser.next_token();
}
Some(Ok(Statement::Commit { chain: false }))
} else {
None
}
}
}

let dialect = MyDialect {};
let sql = "SELECT 1 + 2";
let ast = Parser::parse_sql(&dialect, sql)?;
let query = &ast[0];
assert_eq!("COMMIT", &format!("{}", query));
Ok(())
}

fn is_identifier_start(ch: char) -> bool {
('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_'
}

fn is_identifier_part(ch: char) -> bool {
('a'..='z').contains(&ch)
|| ('A'..='Z').contains(&ch)
|| ('0'..='9').contains(&ch)
|| ch == '$'
|| ch == '_'
}