Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid stack overflows via configurable with_recursion_limit #764

Merged
merged 1 commit into from Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/lib.rs
Expand Up @@ -12,11 +12,15 @@

//! SQL Parser for Rust
//!
//! Example code:
//!
//! This crate provides an ANSI:SQL 2011 lexer and parser that can parse SQL
//! into an Abstract Syntax Tree (AST).
//!
//! See [`Parser::parse_sql`](crate::parser::Parser::parse_sql) and
//! [`Parser::new`](crate::parser::Parser::new) for the Parsing API
//! and the [`ast`](crate::ast) crate for the AST structure.
//!
//! Example:
//!
//! ```
//! use sqlparser::dialect::GenericDialect;
//! use sqlparser::parser::Parser;
Expand Down
229 changes: 205 additions & 24 deletions src/parser.rs
Expand Up @@ -37,6 +37,7 @@ use crate::tokenizer::*;
pub enum ParserError {
TokenizerError(String),
ParserError(String),
RecursionLimitExceeded,
}

// Use `Parser::expected` instead, if possible
Expand All @@ -55,6 +56,92 @@ macro_rules! return_ok_if_some {
}};
}

#[cfg(feature = "std")]
/// Implemenation [`RecursionCounter`] if std is available
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key change for actually limiting recursion depth

mod recursion {
use core::sync::atomic::{AtomicUsize, Ordering};
use std::rc::Rc;

use super::ParserError;

/// Tracks remaining recursion depth. This value is decremented on
/// each call to `try_decrease()`, when it reaches 0 an error will
/// be returned.
///
/// Note: Uses an Rc and AtomicUsize in order to satisfy the Rust
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Rc approach is from @46bit -- if anyone has ideas about how to avoid it, I would love a PR to help.

/// borrow checker so the automatic DepthGuard decrement a
/// reference to the counter. The actual value is not modified
/// concurrently
pub(crate) struct RecursionCounter {
remaining_depth: Rc<AtomicUsize>,
}

impl RecursionCounter {
/// Creates a [`RecursionCounter`] with the specified maximum
/// depth
pub fn new(remaining_depth: usize) -> Self {
Self {
remaining_depth: Rc::new(remaining_depth.into()),
}
}

/// Decreases the remaining depth by 1.
///
/// Returns `Err` if the remaining depth falls to 0.
///
/// Returns a [`DepthGuard`] which will adds 1 to the
/// remaining depth upon drop;
pub fn try_decrease(&self) -> Result<DepthGuard, ParserError> {
let old_value = self.remaining_depth.fetch_sub(1, Ordering::SeqCst);
// ran out of space
if old_value == 0 {
Err(ParserError::RecursionLimitExceeded)
} else {
Ok(DepthGuard::new(Rc::clone(&self.remaining_depth)))
}
}
}

/// Guard that increass the remaining depth by 1 on drop
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DepthGuard is used to automatically ensure the recursion depth is restored upon function return

pub struct DepthGuard {
remaining_depth: Rc<AtomicUsize>,
}

impl DepthGuard {
fn new(remaining_depth: Rc<AtomicUsize>) -> Self {
Self { remaining_depth }
}
}
impl Drop for DepthGuard {
fn drop(&mut self) {
self.remaining_depth.fetch_add(1, Ordering::SeqCst);
}
}
}

#[cfg(not(feature = "std"))]
mod recursion {
/// Implemenation [`RecursionCounter`] if std is NOT available (and does not
/// guard against stack overflow).
///
/// Has the same API as the std RecursionCounter implementation
/// but does not actually limit stack depth.
pub(crate) struct RecursionCounter {}

impl RecursionCounter {
pub fn new(_remaining_depth: usize) -> Self {
Self {}
}
pub fn try_decrease(&self) -> Result<DepthGuard, super::ParserError> {
Ok(DepthGuard {})
}
}

pub struct DepthGuard {}
}

use recursion::RecursionCounter;

#[derive(PartialEq, Eq)]
pub enum IsOptional {
Optional,
Expand Down Expand Up @@ -96,6 +183,7 @@ impl fmt::Display for ParserError {
match self {
ParserError::TokenizerError(s) => s,
ParserError::ParserError(s) => s,
ParserError::RecursionLimitExceeded => "recursion limit exceeded",
}
)
}
Expand All @@ -104,22 +192,78 @@ impl fmt::Display for ParserError {
#[cfg(feature = "std")]
impl std::error::Error for ParserError {}

// By default, allow expressions up to this deep before erroring
const DEFAULT_REMAINING_DEPTH: usize = 50;

pub struct Parser<'a> {
tokens: Vec<TokenWithLocation>,
/// The index of the first unprocessed token in `self.tokens`
index: usize,
/// The current dialect to use
dialect: &'a dyn Dialect,
/// ensure the stack does not overflow by limiting recusion depth
recursion_counter: RecursionCounter,
}

impl<'a> Parser<'a> {
/// Parse the specified tokens
/// To avoid breaking backwards compatibility, this function accepts
/// bare tokens.
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
Parser::new_without_locations(tokens, dialect)
/// Create a parser for a [`Dialect`]
///
/// See also [`Parser::parse_sql`]
///
/// Example:
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let statements = Parser::new(&dialect)
/// .try_with_sql("SELECT * FROM foo")?
/// .parse_statements()?;
/// # Ok(())
/// # }
/// ```
pub fn new(dialect: &'a dyn Dialect) -> Self {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the signature here has changed - I tried to illustrate the intended usage with doc comments

Self {
tokens: vec![],
index: 0,
dialect,
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
}
}

/// Specify the maximum recursion limit while parsing.
///
///
/// [`Parser`] prevents stack overflows by returning
/// [`ParserError::RecursionLimitExceeded`] if the parser exceeds
/// this depth while processing the query.
///
/// Example:
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let result = Parser::new(&dialect)
/// .with_recursion_limit(1)
/// .try_with_sql("SELECT * FROM foo WHERE (a OR (b OR (c OR d)))")?
/// .parse_statements();
/// assert_eq!(result, Err(ParserError::RecursionLimitExceeded));
/// # Ok(())
/// # }
/// ```
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
self.recursion_counter = RecursionCounter::new(recursion_limit);
self
}

/// Reset this parser to parse the specified token stream
pub fn with_tokens_with_locations(mut self, tokens: Vec<TokenWithLocation>) -> Self {
self.tokens = tokens;
self.index = 0;
self
}

pub fn new_without_locations(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
/// Reset this parser state to parse the specified tokens
pub fn with_tokens(self, tokens: Vec<Token>) -> Self {
// Put in dummy locations
let tokens_with_locations: Vec<TokenWithLocation> = tokens
.into_iter()
Expand All @@ -128,49 +272,84 @@ impl<'a> Parser<'a> {
location: Location { line: 0, column: 0 },
})
.collect();
Parser::new_with_locations(tokens_with_locations, dialect)
self.with_tokens_with_locations(tokens_with_locations)
}

/// Parse the specified tokens
pub fn new_with_locations(tokens: Vec<TokenWithLocation>, dialect: &'a dyn Dialect) -> Self {
Parser {
tokens,
index: 0,
dialect,
}
/// Tokenize the sql string and sets this [`Parser`]'s state to
/// parse the resulting tokens
///
/// Returns an error if there was an error tokenizing the SQL string.
///
/// See example on [`Parser::new()`] for an example
pub fn try_with_sql(self, sql: &str) -> Result<Self, ParserError> {
debug!("Parsing sql '{}'...", sql);
let mut tokenizer = Tokenizer::new(self.dialect, sql);
let tokens = tokenizer.tokenize()?;
Ok(self.with_tokens(tokens))
}

/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

github mangles the diff -- this function still exists with the same signature. It now also has docstrings

let mut tokenizer = Tokenizer::new(dialect, sql);
let tokens = tokenizer.tokenize()?;
let mut parser = Parser::new(tokens, dialect);
/// Parse potentially multiple statements
///
/// Example
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let statements = Parser::new(&dialect)
/// // Parse a SQL string with 2 separate statements
/// .try_with_sql("SELECT * FROM foo; SELECT * FROM bar;")?
/// .parse_statements()?;
/// assert_eq!(statements.len(), 2);
/// # Ok(())
/// # }
/// ```
pub fn parse_statements(&mut self) -> Result<Vec<Statement>, ParserError> {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was factored out of parse_sql so that people directly using Parser::new() could also have access to that logic

let mut stmts = Vec::new();
let mut expecting_statement_delimiter = false;
debug!("Parsing sql '{}'...", sql);
loop {
// ignore empty statements (between successive statement delimiters)
while parser.consume_token(&Token::SemiColon) {
while self.consume_token(&Token::SemiColon) {
expecting_statement_delimiter = false;
}

if parser.peek_token() == Token::EOF {
if self.peek_token() == Token::EOF {
break;
}
if expecting_statement_delimiter {
return parser.expected("end of statement", parser.peek_token());
return self.expected("end of statement", self.peek_token());
}

let statement = parser.parse_statement()?;
let statement = self.parse_statement()?;
stmts.push(statement);
expecting_statement_delimiter = true;
}
Ok(stmts)
}

/// Convience method to parse a string with one or more SQL
/// statements into produce an Abstract Syntax Tree (AST).
///
/// Example
/// ```
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
/// # fn main() -> Result<(), ParserError> {
/// let dialect = GenericDialect{};
/// let statements = Parser::parse_sql(
/// &dialect, "SELECT * FROM foo"
/// )?;
/// assert_eq!(statements.len(), 1);
/// # Ok(())
/// # }
/// ```
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
Parser::new(dialect).try_with_sql(sql)?.parse_statements()
}

/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
/// stopping before the statement separator, if any.
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> {
let _guard = self.recursion_counter.try_decrease()?;

// allow the dialect to override statement parsing
if let Some(statement) = self.dialect.parse_statement(self) {
return statement;
Expand Down Expand Up @@ -364,6 +543,7 @@ impl<'a> Parser<'a> {

/// Parse a new expression
pub fn parse_expr(&mut self) -> Result<Expr, ParserError> {
let _guard = self.recursion_counter.try_decrease()?;
self.parse_subexpr(0)
}

Expand Down Expand Up @@ -4454,6 +4634,7 @@ impl<'a> Parser<'a> {
/// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't
/// expect the initial keyword to be already consumed
pub fn parse_query(&mut self) -> Result<Query, ParserError> {
let _guard = self.recursion_counter.try_decrease()?;
let with = if self.parse_keyword(Keyword::WITH) {
Some(With {
recursive: self.parse_keyword(Keyword::RECURSIVE),
Expand Down
6 changes: 2 additions & 4 deletions src/test_utils.rs
Expand Up @@ -29,7 +29,6 @@ use core::fmt::Debug;
use crate::ast::*;
use crate::dialect::*;
use crate::parser::{Parser, ParserError};
use crate::tokenizer::Tokenizer;

/// Tests use the methods on this struct to invoke the parser on one or
/// multiple dialects.
Expand Down Expand Up @@ -65,9 +64,8 @@ impl TestedDialects {
F: Fn(&mut Parser) -> T,
{
self.one_of_identical_results(|dialect| {
let mut tokenizer = Tokenizer::new(dialect, sql);
let tokens = tokenizer.tokenize().unwrap();
f(&mut Parser::new(tokens, dialect))
let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap();
f(&mut parser)
})
}

Expand Down