Skip to content

Commit

Permalink
Add configurable recursion limit to parser, to protect against stack …
Browse files Browse the repository at this point in the history
…overflows (#764)
  • Loading branch information
alamb committed Dec 28, 2022
1 parent 2c20ec0 commit 79d0baa
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 30 deletions.
8 changes: 6 additions & 2 deletions src/lib.rs
Expand Up @@ -12,11 +12,15 @@

//! SQL Parser for Rust
//!
//! Example code:
//!
//! This crate provides an ANSI:SQL 2011 lexer and parser that can parse SQL
//! into an Abstract Syntax Tree (AST).
//!
//! See [`Parser::parse_sql`](crate::parser::Parser::parse_sql) and
//! [`Parser::new`](crate::parser::Parser::new) for the Parsing API
//! and the [`ast`](crate::ast) crate for the AST structure.
//!
//! Example:
//!
//! ```
//! use sqlparser::dialect::GenericDialect;
//! use sqlparser::parser::Parser;
Expand Down
229 changes: 205 additions & 24 deletions src/parser.rs
Expand Up @@ -37,6 +37,7 @@ use crate::tokenizer::*;
pub enum ParserError {
TokenizerError(String),
ParserError(String),
RecursionLimitExceeded,
}

// Use `Parser::expected` instead, if possible
Expand All @@ -55,6 +56,92 @@ macro_rules! return_ok_if_some {
}};
}

#[cfg(feature = "std")]
/// Implemenation [`RecursionCounter`] if std is available
mod recursion {
use core::sync::atomic::{AtomicUsize, Ordering};
use std::rc::Rc;

use super::ParserError;

/// Tracks remaining recursion depth. This value is decremented on
/// each call to `try_decrease()`, when it reaches 0 an error will
/// be returned.
///
/// Note: Uses an Rc and AtomicUsize in order to satisfy the Rust
/// borrow checker so the automatic DepthGuard decrement a
/// reference to the counter. The actual value is not modified
/// concurrently
pub(crate) struct RecursionCounter {
remaining_depth: Rc<AtomicUsize>,
}

impl RecursionCounter {
/// Creates a [`RecursionCounter`] with the specified maximum
/// depth
pub fn new(remaining_depth: usize) -> Self {
Self {
remaining_depth: Rc::new(remaining_depth.into()),
}
}

/// Decreases the remaining depth by 1.
///
/// Returns `Err` if the remaining depth falls to 0.
///
/// Returns a [`DepthGuard`] which will adds 1 to the
/// remaining depth upon drop;
pub fn try_decrease(&self) -> Result<DepthGuard, ParserError> {
let old_value = self.remaining_depth.fetch_sub(1, Ordering::SeqCst);
// ran out of space
if old_value == 0 {
Err(ParserError::RecursionLimitExceeded)
} else {
Ok(DepthGuard::new(Rc::clone(&self.remaining_depth)))
}
}
}

/// Guard that increass the remaining depth by 1 on drop
pub struct DepthGuard {
remaining_depth: Rc<AtomicUsize>,
}

impl DepthGuard {
fn new(remaining_depth: Rc<AtomicUsize>) -> Self {
Self { remaining_depth }
}
}
impl Drop for DepthGuard {
fn drop(&mut self) {
self.remaining_depth.fetch_add(1, Ordering::SeqCst);
}
}
}

#[cfg(not(feature = "std"))]
mod recursion {
/// Implemenation [`RecursionCounter`] if std is NOT available (and does not
/// guard against stack overflow).
///
/// Has the same API as the std RecursionCounter implementation
/// but does not actually limit stack depth.
pub(crate) struct RecursionCounter {}

impl RecursionCounter {
pub fn new(_remaining_depth: usize) -> Self {
Self {}
}
pub fn try_decrease(&self) -> Result<DepthGuard, super::ParserError> {
Ok(DepthGuard {})
}
}

pub struct DepthGuard {}
}

use recursion::RecursionCounter;

#[derive(PartialEq, Eq)]
pub enum IsOptional {
Optional,
Expand Down Expand Up @@ -96,6 +183,7 @@ impl fmt::Display for ParserError {
match self {
ParserError::TokenizerError(s) => s,
ParserError::ParserError(s) => s,
ParserError::RecursionLimitExceeded => "recursion limit exceeded",
}
)
}
Expand All @@ -104,22 +192,78 @@ impl fmt::Display for ParserError {
#[cfg(feature = "std")]
impl std::error::Error for ParserError {}

// By default, allow expressions up to this deep before erroring
const DEFAULT_REMAINING_DEPTH: usize = 50;

pub struct Parser<'a> {
tokens: Vec<TokenWithLocation>,
/// The index of the first unprocessed token in `self.tokens`
index: usize,
/// The current dialect to use
dialect: &'a dyn Dialect,
/// ensure the stack does not overflow by limiting recusion depth
recursion_counter: RecursionCounter,
}

impl<'a> Parser<'a> {
/// Parse the specified tokens
/// To avoid breaking backwards compatibility, this function accepts
/// bare tokens.
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
Parser::new_without_locations(tokens, dialect)
/// Create a parser for a [`Dialect`]
///
/// See also [`Parser::parse_sql`]
///
/// Example:
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let statements = Parser::new(&dialect)
/// .try_with_sql("SELECT * FROM foo")?
/// .parse_statements()?;
/// # Ok(())
/// # }
/// ```
pub fn new(dialect: &'a dyn Dialect) -> Self {
Self {
tokens: vec![],
index: 0,
dialect,
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
}
}

/// Specify the maximum recursion limit while parsing.
///
///
/// [`Parser`] prevents stack overflows by returning
/// [`ParserError::RecursionLimitExceeded`] if the parser exceeds
/// this depth while processing the query.
///
/// Example:
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let result = Parser::new(&dialect)
/// .with_recursion_limit(1)
/// .try_with_sql("SELECT * FROM foo WHERE (a OR (b OR (c OR d)))")?
/// .parse_statements();
/// assert_eq!(result, Err(ParserError::RecursionLimitExceeded));
/// # Ok(())
/// # }
/// ```
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
self.recursion_counter = RecursionCounter::new(recursion_limit);
self
}

/// Reset this parser to parse the specified token stream
pub fn with_tokens_with_locations(mut self, tokens: Vec<TokenWithLocation>) -> Self {
self.tokens = tokens;
self.index = 0;
self
}

pub fn new_without_locations(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
/// Reset this parser state to parse the specified tokens
pub fn with_tokens(self, tokens: Vec<Token>) -> Self {
// Put in dummy locations
let tokens_with_locations: Vec<TokenWithLocation> = tokens
.into_iter()
Expand All @@ -128,49 +272,84 @@ impl<'a> Parser<'a> {
location: Location { line: 0, column: 0 },
})
.collect();
Parser::new_with_locations(tokens_with_locations, dialect)
self.with_tokens_with_locations(tokens_with_locations)
}

/// Parse the specified tokens
pub fn new_with_locations(tokens: Vec<TokenWithLocation>, dialect: &'a dyn Dialect) -> Self {
Parser {
tokens,
index: 0,
dialect,
}
/// Tokenize the sql string and sets this [`Parser`]'s state to
/// parse the resulting tokens
///
/// Returns an error if there was an error tokenizing the SQL string.
///
/// See example on [`Parser::new()`] for an example
pub fn try_with_sql(self, sql: &str) -> Result<Self, ParserError> {
debug!("Parsing sql '{}'...", sql);
let mut tokenizer = Tokenizer::new(self.dialect, sql);
let tokens = tokenizer.tokenize()?;
Ok(self.with_tokens(tokens))
}

/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
let mut tokenizer = Tokenizer::new(dialect, sql);
let tokens = tokenizer.tokenize()?;
let mut parser = Parser::new(tokens, dialect);
/// Parse potentially multiple statements
///
/// Example
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let statements = Parser::new(&dialect)
/// // Parse a SQL string with 2 separate statements
/// .try_with_sql("SELECT * FROM foo; SELECT * FROM bar;")?
/// .parse_statements()?;
/// assert_eq!(statements.len(), 2);
/// # Ok(())
/// # }
/// ```
pub fn parse_statements(&mut self) -> Result<Vec<Statement>, ParserError> {
let mut stmts = Vec::new();
let mut expecting_statement_delimiter = false;
debug!("Parsing sql '{}'...", sql);
loop {
// ignore empty statements (between successive statement delimiters)
while parser.consume_token(&Token::SemiColon) {
while self.consume_token(&Token::SemiColon) {
expecting_statement_delimiter = false;
}

if parser.peek_token() == Token::EOF {
if self.peek_token() == Token::EOF {
break;
}
if expecting_statement_delimiter {
return parser.expected("end of statement", parser.peek_token());
return self.expected("end of statement", self.peek_token());
}

let statement = parser.parse_statement()?;
let statement = self.parse_statement()?;
stmts.push(statement);
expecting_statement_delimiter = true;
}
Ok(stmts)
}

/// Convience method to parse a string with one or more SQL
/// statements into produce an Abstract Syntax Tree (AST).
///
/// Example
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let statements = Parser::parse_sql(
/// &dialect, "SELECT * FROM foo"
/// )?;
/// assert_eq!(statements.len(), 1);
/// # Ok(())
/// # }
/// ```
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
Parser::new(dialect).try_with_sql(sql)?.parse_statements()
}

/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
/// stopping before the statement separator, if any.
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> {
let _guard = self.recursion_counter.try_decrease()?;

// allow the dialect to override statement parsing
if let Some(statement) = self.dialect.parse_statement(self) {
return statement;
Expand Down Expand Up @@ -364,6 +543,7 @@ impl<'a> Parser<'a> {

/// Parse a new expression
pub fn parse_expr(&mut self) -> Result<Expr, ParserError> {
let _guard = self.recursion_counter.try_decrease()?;
self.parse_subexpr(0)
}

Expand Down Expand Up @@ -4512,6 +4692,7 @@ impl<'a> Parser<'a> {
/// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't
/// expect the initial keyword to be already consumed
pub fn parse_query(&mut self) -> Result<Query, ParserError> {
let _guard = self.recursion_counter.try_decrease()?;
let with = if self.parse_keyword(Keyword::WITH) {
Some(With {
recursive: self.parse_keyword(Keyword::RECURSIVE),
Expand Down
6 changes: 2 additions & 4 deletions src/test_utils.rs
Expand Up @@ -29,7 +29,6 @@ use core::fmt::Debug;
use crate::ast::*;
use crate::dialect::*;
use crate::parser::{Parser, ParserError};
use crate::tokenizer::Tokenizer;

/// Tests use the methods on this struct to invoke the parser on one or
/// multiple dialects.
Expand Down Expand Up @@ -65,9 +64,8 @@ impl TestedDialects {
F: Fn(&mut Parser) -> T,
{
self.one_of_identical_results(|dialect| {
let mut tokenizer = Tokenizer::new(dialect, sql);
let tokens = tokenizer.tokenize().unwrap();
f(&mut Parser::new(tokens, dialect))
let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap();
f(&mut parser)
})
}

Expand Down

0 comments on commit 79d0baa

Please sign in to comment.