diff --git a/src/lib.rs b/src/lib.rs index 880c6d6e1..4a8c3c51d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,11 +12,15 @@ //! SQL Parser for Rust //! -//! Example code: -//! //! This crate provides an ANSI:SQL 2011 lexer and parser that can parse SQL //! into an Abstract Syntax Tree (AST). //! +//! See [`Parser::parse_sql`](crate::parser::Parser::parse_sql) and +//! [`Parser::new`](crate::parser::Parser::new) for the Parsing API +//! and the [`ast`](crate::ast) crate for the AST structure. +//! +//! Example: +//! //! ``` //! use sqlparser::dialect::GenericDialect; //! use sqlparser::parser::Parser; diff --git a/src/parser.rs b/src/parser.rs index 1e3d71b34..c48895da9 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -37,6 +37,7 @@ use crate::tokenizer::*; pub enum ParserError { TokenizerError(String), ParserError(String), + RecursionLimitExceeded, } // Use `Parser::expected` instead, if possible @@ -55,6 +56,92 @@ macro_rules! return_ok_if_some { }}; } +#[cfg(feature = "std")] +/// Implemenation [`RecursionCounter`] if std is available +mod recursion { + use core::sync::atomic::{AtomicUsize, Ordering}; + use std::rc::Rc; + + use super::ParserError; + + /// Tracks remaining recursion depth. This value is decremented on + /// each call to `try_decrease()`, when it reaches 0 an error will + /// be returned. + /// + /// Note: Uses an Rc and AtomicUsize in order to satisfy the Rust + /// borrow checker so the automatic DepthGuard decrement a + /// reference to the counter. The actual value is not modified + /// concurrently + pub(crate) struct RecursionCounter { + remaining_depth: Rc, + } + + impl RecursionCounter { + /// Creates a [`RecursionCounter`] with the specified maximum + /// depth + pub fn new(remaining_depth: usize) -> Self { + Self { + remaining_depth: Rc::new(remaining_depth.into()), + } + } + + /// Decreases the remaining depth by 1. + /// + /// Returns `Err` if the remaining depth falls to 0. + /// + /// Returns a [`DepthGuard`] which will adds 1 to the + /// remaining depth upon drop; + pub fn try_decrease(&self) -> Result { + let old_value = self.remaining_depth.fetch_sub(1, Ordering::SeqCst); + // ran out of space + if old_value == 0 { + Err(ParserError::RecursionLimitExceeded) + } else { + Ok(DepthGuard::new(Rc::clone(&self.remaining_depth))) + } + } + } + + /// Guard that increass the remaining depth by 1 on drop + pub struct DepthGuard { + remaining_depth: Rc, + } + + impl DepthGuard { + fn new(remaining_depth: Rc) -> Self { + Self { remaining_depth } + } + } + impl Drop for DepthGuard { + fn drop(&mut self) { + self.remaining_depth.fetch_add(1, Ordering::SeqCst); + } + } +} + +#[cfg(not(feature = "std"))] +mod recursion { + /// Implemenation [`RecursionCounter`] if std is NOT available (and does not + /// guard against stack overflow). + /// + /// Has the same API as the std RecursionCounter implementation + /// but does not actually limit stack depth. + pub(crate) struct RecursionCounter {} + + impl RecursionCounter { + pub fn new(_remaining_depth: usize) -> Self { + Self {} + } + pub fn try_decrease(&self) -> Result { + Ok(DepthGuard {}) + } + } + + pub struct DepthGuard {} +} + +use recursion::RecursionCounter; + #[derive(PartialEq, Eq)] pub enum IsOptional { Optional, @@ -96,6 +183,7 @@ impl fmt::Display for ParserError { match self { ParserError::TokenizerError(s) => s, ParserError::ParserError(s) => s, + ParserError::RecursionLimitExceeded => "recursion limit exceeded", } ) } @@ -104,22 +192,78 @@ impl fmt::Display for ParserError { #[cfg(feature = "std")] impl std::error::Error for ParserError {} +// By default, allow expressions up to this deep before erroring +const DEFAULT_REMAINING_DEPTH: usize = 50; + pub struct Parser<'a> { tokens: Vec, /// The index of the first unprocessed token in `self.tokens` index: usize, + /// The current dialect to use dialect: &'a dyn Dialect, + /// ensure the stack does not overflow by limiting recusion depth + recursion_counter: RecursionCounter, } impl<'a> Parser<'a> { - /// Parse the specified tokens - /// To avoid breaking backwards compatibility, this function accepts - /// bare tokens. - pub fn new(tokens: Vec, dialect: &'a dyn Dialect) -> Self { - Parser::new_without_locations(tokens, dialect) + /// Create a parser for a [`Dialect`] + /// + /// See also [`Parser::parse_sql`] + /// + /// Example: + /// ``` + /// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect}; + /// # fn main() -> Result<(), ParserError> { + /// let dialect = GenericDialect{}; + /// let statements = Parser::new(&dialect) + /// .try_with_sql("SELECT * FROM foo")? + /// .parse_statements()?; + /// # Ok(()) + /// # } + /// ``` + pub fn new(dialect: &'a dyn Dialect) -> Self { + Self { + tokens: vec![], + index: 0, + dialect, + recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH), + } + } + + /// Specify the maximum recursion limit while parsing. + /// + /// + /// [`Parser`] prevents stack overflows by returning + /// [`ParserError::RecursionLimitExceeded`] if the parser exceeds + /// this depth while processing the query. + /// + /// Example: + /// ``` + /// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect}; + /// # fn main() -> Result<(), ParserError> { + /// let dialect = GenericDialect{}; + /// let result = Parser::new(&dialect) + /// .with_recursion_limit(1) + /// .try_with_sql("SELECT * FROM foo WHERE (a OR (b OR (c OR d)))")? + /// .parse_statements(); + /// assert_eq!(result, Err(ParserError::RecursionLimitExceeded)); + /// # Ok(()) + /// # } + /// ``` + pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self { + self.recursion_counter = RecursionCounter::new(recursion_limit); + self + } + + /// Reset this parser to parse the specified token stream + pub fn with_tokens_with_locations(mut self, tokens: Vec) -> Self { + self.tokens = tokens; + self.index = 0; + self } - pub fn new_without_locations(tokens: Vec, dialect: &'a dyn Dialect) -> Self { + /// Reset this parser state to parse the specified tokens + pub fn with_tokens(self, tokens: Vec) -> Self { // Put in dummy locations let tokens_with_locations: Vec = tokens .into_iter() @@ -128,49 +272,84 @@ impl<'a> Parser<'a> { location: Location { line: 0, column: 0 }, }) .collect(); - Parser::new_with_locations(tokens_with_locations, dialect) + self.with_tokens_with_locations(tokens_with_locations) } - /// Parse the specified tokens - pub fn new_with_locations(tokens: Vec, dialect: &'a dyn Dialect) -> Self { - Parser { - tokens, - index: 0, - dialect, - } + /// Tokenize the sql string and sets this [`Parser`]'s state to + /// parse the resulting tokens + /// + /// Returns an error if there was an error tokenizing the SQL string. + /// + /// See example on [`Parser::new()`] for an example + pub fn try_with_sql(self, sql: &str) -> Result { + debug!("Parsing sql '{}'...", sql); + let mut tokenizer = Tokenizer::new(self.dialect, sql); + let tokens = tokenizer.tokenize()?; + Ok(self.with_tokens(tokens)) } - /// Parse a SQL statement and produce an Abstract Syntax Tree (AST) - pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result, ParserError> { - let mut tokenizer = Tokenizer::new(dialect, sql); - let tokens = tokenizer.tokenize()?; - let mut parser = Parser::new(tokens, dialect); + /// Parse potentially multiple statements + /// + /// Example + /// ``` + /// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect}; + /// # fn main() -> Result<(), ParserError> { + /// let dialect = GenericDialect{}; + /// let statements = Parser::new(&dialect) + /// // Parse a SQL string with 2 separate statements + /// .try_with_sql("SELECT * FROM foo; SELECT * FROM bar;")? + /// .parse_statements()?; + /// assert_eq!(statements.len(), 2); + /// # Ok(()) + /// # } + /// ``` + pub fn parse_statements(&mut self) -> Result, ParserError> { let mut stmts = Vec::new(); let mut expecting_statement_delimiter = false; - debug!("Parsing sql '{}'...", sql); loop { // ignore empty statements (between successive statement delimiters) - while parser.consume_token(&Token::SemiColon) { + while self.consume_token(&Token::SemiColon) { expecting_statement_delimiter = false; } - if parser.peek_token() == Token::EOF { + if self.peek_token() == Token::EOF { break; } if expecting_statement_delimiter { - return parser.expected("end of statement", parser.peek_token()); + return self.expected("end of statement", self.peek_token()); } - let statement = parser.parse_statement()?; + let statement = self.parse_statement()?; stmts.push(statement); expecting_statement_delimiter = true; } Ok(stmts) } + /// Convience method to parse a string with one or more SQL + /// statements into produce an Abstract Syntax Tree (AST). + /// + /// Example + /// ``` + /// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect}; + /// # fn main() -> Result<(), ParserError> { + /// let dialect = GenericDialect{}; + /// let statements = Parser::parse_sql( + /// &dialect, "SELECT * FROM foo" + /// )?; + /// assert_eq!(statements.len(), 1); + /// # Ok(()) + /// # } + /// ``` + pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result, ParserError> { + Parser::new(dialect).try_with_sql(sql)?.parse_statements() + } + /// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.), /// stopping before the statement separator, if any. pub fn parse_statement(&mut self) -> Result { + let _guard = self.recursion_counter.try_decrease()?; + // allow the dialect to override statement parsing if let Some(statement) = self.dialect.parse_statement(self) { return statement; @@ -364,6 +543,7 @@ impl<'a> Parser<'a> { /// Parse a new expression pub fn parse_expr(&mut self) -> Result { + let _guard = self.recursion_counter.try_decrease()?; self.parse_subexpr(0) } @@ -4454,6 +4634,7 @@ impl<'a> Parser<'a> { /// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't /// expect the initial keyword to be already consumed pub fn parse_query(&mut self) -> Result { + let _guard = self.recursion_counter.try_decrease()?; let with = if self.parse_keyword(Keyword::WITH) { Some(With { recursive: self.parse_keyword(Keyword::RECURSIVE), diff --git a/src/test_utils.rs b/src/test_utils.rs index cbb929285..45b42f9fd 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -29,7 +29,6 @@ use core::fmt::Debug; use crate::ast::*; use crate::dialect::*; use crate::parser::{Parser, ParserError}; -use crate::tokenizer::Tokenizer; /// Tests use the methods on this struct to invoke the parser on one or /// multiple dialects. @@ -65,9 +64,8 @@ impl TestedDialects { F: Fn(&mut Parser) -> T, { self.one_of_identical_results(|dialect| { - let mut tokenizer = Tokenizer::new(dialect, sql); - let tokens = tokenizer.tokenize().unwrap(); - f(&mut Parser::new(tokens, dialect)) + let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap(); + f(&mut parser) }) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index e5ed0bb80..9602c862d 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -6457,3 +6457,91 @@ fn parse_uncache_table() { res.unwrap_err() ); } + +#[test] +fn parse_deeply_nested_parens_hits_recursion_limits() { + let sql = "(".repeat(1000); + let res = parse_sql_statements(&sql); + assert_eq!(ParserError::RecursionLimitExceeded, res.unwrap_err()); +} + +#[test] +fn parse_deeply_nested_expr_hits_recursion_limits() { + let dialect = GenericDialect {}; + + let where_clause = make_where_clause(100); + let sql = format!("SELECT id, user_id FROM test WHERE {where_clause}"); + + let res = Parser::new(&dialect) + .try_with_sql(&sql) + .expect("tokenize to work") + .parse_statements(); + + assert_eq!(res, Err(ParserError::RecursionLimitExceeded)); +} + +#[test] +fn parse_deeply_nested_subquery_expr_hits_recursion_limits() { + let dialect = GenericDialect {}; + + let where_clause = make_where_clause(100); + let sql = format!("SELECT id, user_id where id IN (select id from t WHERE {where_clause})"); + + let res = Parser::new(&dialect) + .try_with_sql(&sql) + .expect("tokenize to work") + .parse_statements(); + + assert_eq!(res, Err(ParserError::RecursionLimitExceeded)); +} + +#[test] +fn parse_with_recursion_limit() { + let dialect = GenericDialect {}; + + let where_clause = make_where_clause(20); + let sql = format!("SELECT id, user_id FROM test WHERE {where_clause}"); + + // Expect the statement to parse with default limit + let res = Parser::new(&dialect) + .try_with_sql(&sql) + .expect("tokenize to work") + .parse_statements(); + + assert!(matches!(res, Ok(_)), "{:?}", res); + + // limit recursion to something smaller, expect parsing to fail + let res = Parser::new(&dialect) + .try_with_sql(&sql) + .expect("tokenize to work") + .with_recursion_limit(20) + .parse_statements(); + + assert_eq!(res, Err(ParserError::RecursionLimitExceeded)); + + // limit recursion to 50, expect it to succeed + let res = Parser::new(&dialect) + .try_with_sql(&sql) + .expect("tokenize to work") + .with_recursion_limit(50) + .parse_statements(); + + assert!(matches!(res, Ok(_)), "{:?}", res); +} + +/// Makes a predicate that looks like ((user_id = $id) OR user_id = $2...) +fn make_where_clause(num: usize) -> String { + use std::fmt::Write; + let mut output = "(".repeat(num - 1); + + for i in 0..num { + if i > 0 { + write!(&mut output, " OR ").unwrap(); + } + write!(&mut output, "user_id = {}", i).unwrap(); + if i < num - 1 { + write!(&mut output, ")").unwrap(); + } + } + output +}