Skip to content

Commit

Permalink
Add ability for dialects to override prefix, infix, and statement par…
Browse files Browse the repository at this point in the history
…sing (#581)
  • Loading branch information
andygrove committed Aug 19, 2022
1 parent 7c02477 commit 72559e9
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 37 deletions.
27 changes: 27 additions & 0 deletions src/dialect/mod.rs
Expand Up @@ -22,6 +22,7 @@ mod redshift;
mod snowflake;
mod sqlite;

use crate::ast::{Expr, Statement};
use core::any::{Any, TypeId};
use core::fmt::Debug;
use core::iter::Peekable;
Expand All @@ -39,6 +40,7 @@ pub use self::redshift::RedshiftSqlDialect;
pub use self::snowflake::SnowflakeDialect;
pub use self::sqlite::SQLiteDialect;
pub use crate::keywords;
use crate::parser::{Parser, ParserError};

/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
/// to `true` if `parser.dialect` is one of the `Dialect`s specified.
Expand All @@ -65,6 +67,31 @@ pub trait Dialect: Debug + Any {
fn is_identifier_start(&self, ch: char) -> bool;
/// Determine if a character is a valid unquoted identifier character
fn is_identifier_part(&self, ch: char) -> bool;
/// Dialect-specific prefix parser override
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior
None
}
/// Dialect-specific infix parser override
fn parse_infix(
&self,
_parser: &mut Parser,
_expr: &Expr,
_precendence: u8,
) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior
None
}
/// Dialect-specific precedence override
fn get_next_precedence(&self, _parser: &Parser) -> Option<Result<u8, ParserError>> {
// return None to fall back to the default behavior
None
}
/// Dialect-specific statement parser override
fn parse_statement(&self, _parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
// return None to fall back to the default behavior
None
}
}

impl dyn Dialect {
Expand Down
41 changes: 41 additions & 0 deletions src/dialect/postgresql.rs
Expand Up @@ -10,7 +10,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::ast::{CommentObject, Statement};
use crate::dialect::Dialect;
use crate::keywords::Keyword;
use crate::parser::{Parser, ParserError};
use crate::tokenizer::Token;

#[derive(Debug)]
pub struct PostgreSqlDialect {}
Expand All @@ -30,4 +34,41 @@ impl Dialect for PostgreSqlDialect {
|| ch == '$'
|| ch == '_'
}

fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::COMMENT) {
Some(parse_comment(parser))
} else {
None
}
}
}

pub fn parse_comment(parser: &mut Parser) -> Result<Statement, ParserError> {
parser.expect_keyword(Keyword::ON)?;
let token = parser.next_token();

let (object_type, object_name) = match token {
Token::Word(w) if w.keyword == Keyword::COLUMN => {
let object_name = parser.parse_object_name()?;
(CommentObject::Column, object_name)
}
Token::Word(w) if w.keyword == Keyword::TABLE => {
let object_name = parser.parse_object_name()?;
(CommentObject::Table, object_name)
}
_ => parser.expected("comment object_type", token)?,
};

parser.expect_keyword(Keyword::IS)?;
let comment = if parser.parse_keyword(Keyword::NULL) {
None
} else {
Some(parser.parse_literal_string()?)
};
Ok(Statement::Comment {
object_type,
object_name,
comment,
})
}
12 changes: 12 additions & 0 deletions src/dialect/sqlite.rs
Expand Up @@ -10,7 +10,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::ast::Statement;
use crate::dialect::Dialect;
use crate::keywords::Keyword;
use crate::parser::{Parser, ParserError};

#[derive(Debug)]
pub struct SQLiteDialect {}
Expand All @@ -35,4 +38,13 @@ impl Dialect for SQLiteDialect {
fn is_identifier_part(&self, ch: char) -> bool {
self.is_identifier_start(ch) || ('0'..='9').contains(&ch)
}

fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::REPLACE) {
parser.prev_token();
Some(parser.parse_insert())
} else {
None
}
}
}
58 changes: 21 additions & 37 deletions src/parser.rs
Expand Up @@ -152,6 +152,11 @@ impl<'a> Parser<'a> {
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
/// stopping before the statement separator, if any.
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> {
// allow the dialect to override statement parsing
if let Some(statement) = self.dialect.parse_statement(self) {
return statement;
}

match self.next_token() {
Token::Word(w) => match w.keyword {
Keyword::KILL => Ok(self.parse_kill()?),
Expand Down Expand Up @@ -195,13 +200,6 @@ impl<'a> Parser<'a> {
Keyword::EXECUTE => Ok(self.parse_execute()?),
Keyword::PREPARE => Ok(self.parse_prepare()?),
Keyword::MERGE => Ok(self.parse_merge()?),
Keyword::REPLACE if dialect_of!(self is SQLiteDialect ) => {
self.prev_token();
Ok(self.parse_insert()?)
}
Keyword::COMMENT if dialect_of!(self is PostgreSqlDialect) => {
Ok(self.parse_comment()?)
}
_ => self.expected("an SQL statement", Token::Word(w)),
},
Token::LParen => {
Expand Down Expand Up @@ -381,6 +379,11 @@ impl<'a> Parser<'a> {

/// Parse an expression prefix
pub fn parse_prefix(&mut self) -> Result<Expr, ParserError> {
// allow the dialect to override prefix parsing
if let Some(prefix) = self.dialect.parse_prefix(self) {
return prefix;
}

// PostgreSQL allows any string literal to be preceded by a type name, indicating that the
// string literal represents a literal of that type. Some examples:
//
Expand Down Expand Up @@ -1164,6 +1167,11 @@ impl<'a> Parser<'a> {

/// Parse an operator following an expression
pub fn parse_infix(&mut self, expr: Expr, precedence: u8) -> Result<Expr, ParserError> {
// allow the dialect to override infix parsing
if let Some(infix) = self.dialect.parse_infix(self, &expr, precedence) {
return infix;
}

let tok = self.next_token();

let regular_binary_operator = match &tok {
Expand Down Expand Up @@ -1491,6 +1499,11 @@ impl<'a> Parser<'a> {

/// Get the precedence of the next token
pub fn get_next_precedence(&self) -> Result<u8, ParserError> {
// allow the dialect to override precedence logic
if let Some(precedence) = self.dialect.get_next_precedence(self) {
return precedence;
}

let token = self.peek_token();
debug!("get_next_precedence() {:?}", token);
let token_0 = self.peek_nth_token(0);
Expand Down Expand Up @@ -1618,7 +1631,7 @@ impl<'a> Parser<'a> {
}

/// Report unexpected token
fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> {
pub fn expected<T>(&self, expected: &str, found: Token) -> Result<T, ParserError> {
parser_err!(format!("Expected {}, found: {}", expected, found))
}

Expand Down Expand Up @@ -4735,35 +4748,6 @@ impl<'a> Parser<'a> {
})
}

pub fn parse_comment(&mut self) -> Result<Statement, ParserError> {
self.expect_keyword(Keyword::ON)?;
let token = self.next_token();

let (object_type, object_name) = match token {
Token::Word(w) if w.keyword == Keyword::COLUMN => {
let object_name = self.parse_object_name()?;
(CommentObject::Column, object_name)
}
Token::Word(w) if w.keyword == Keyword::TABLE => {
let object_name = self.parse_object_name()?;
(CommentObject::Table, object_name)
}
_ => self.expected("comment object_type", token)?,
};

self.expect_keyword(Keyword::IS)?;
let comment = if self.parse_keyword(Keyword::NULL) {
None
} else {
Some(self.parse_literal_string()?)
};
Ok(Statement::Comment {
object_type,
object_name,
comment,
})
}

pub fn parse_merge_clauses(&mut self) -> Result<Vec<MergeClause>, ParserError> {
let mut clauses: Vec<MergeClause> = vec![];
loop {
Expand Down
138 changes: 138 additions & 0 deletions tests/sqlparser_custom_dialect.rs
@@ -0,0 +1,138 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Test the ability for dialects to override parsing

use sqlparser::{
ast::{BinaryOperator, Expr, Statement, Value},
dialect::Dialect,
keywords::Keyword,
parser::{Parser, ParserError},
tokenizer::Token,
};

#[test]
fn custom_prefix_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}

impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
is_identifier_start(ch)
}

fn is_identifier_part(&self, ch: char) -> bool {
is_identifier_part(ch)
}

fn parse_prefix(&self, parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
if parser.consume_token(&Token::Number("1".to_string(), false)) {
Some(Ok(Expr::Value(Value::Null)))
} else {
None
}
}
}

let dialect = MyDialect {};
let sql = "SELECT 1 + 2";
let ast = Parser::parse_sql(&dialect, sql)?;
let query = &ast[0];
assert_eq!("SELECT NULL + 2", &format!("{}", query));
Ok(())
}

#[test]
fn custom_infix_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}

impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
is_identifier_start(ch)
}

fn is_identifier_part(&self, ch: char) -> bool {
is_identifier_part(ch)
}

fn parse_infix(
&self,
parser: &mut Parser,
expr: &Expr,
_precendence: u8,
) -> Option<Result<Expr, ParserError>> {
if parser.consume_token(&Token::Plus) {
Some(Ok(Expr::BinaryOp {
left: Box::new(expr.clone()),
op: BinaryOperator::Multiply, // translate Plus to Multiply
right: Box::new(parser.parse_expr().unwrap()),
}))
} else {
None
}
}
}

let dialect = MyDialect {};
let sql = "SELECT 1 + 2";
let ast = Parser::parse_sql(&dialect, sql)?;
let query = &ast[0];
assert_eq!("SELECT 1 * 2", &format!("{}", query));
Ok(())
}

#[test]
fn custom_statement_parser() -> Result<(), ParserError> {
#[derive(Debug)]
struct MyDialect {}

impl Dialect for MyDialect {
fn is_identifier_start(&self, ch: char) -> bool {
is_identifier_start(ch)
}

fn is_identifier_part(&self, ch: char) -> bool {
is_identifier_part(ch)
}

fn parse_statement(&self, parser: &mut Parser) -> Option<Result<Statement, ParserError>> {
if parser.parse_keyword(Keyword::SELECT) {
for _ in 0..3 {
let _ = parser.next_token();
}
Some(Ok(Statement::Commit { chain: false }))
} else {
None
}
}
}

let dialect = MyDialect {};
let sql = "SELECT 1 + 2";
let ast = Parser::parse_sql(&dialect, sql)?;
let query = &ast[0];
assert_eq!("COMMIT", &format!("{}", query));
Ok(())
}

fn is_identifier_start(ch: char) -> bool {
('a'..='z').contains(&ch) || ('A'..='Z').contains(&ch) || ch == '_'
}

fn is_identifier_part(ch: char) -> bool {
('a'..='z').contains(&ch)
|| ('A'..='Z').contains(&ch)
|| ('0'..='9').contains(&ch)
|| ch == '$'
|| ch == '_'
}

0 comments on commit 72559e9

Please sign in to comment.