Skip to content

Commit

Permalink
FIX parse pre-dispatch with AST instead of calling eval (#1327)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrinjalali committed Sep 12, 2022
1 parent 1f00a1c commit 54f4d21
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Expand Up @@ -30,7 +30,7 @@ Development version

- Fix a security issue where ``eval(pre_dispatch)`` could potentially run
arbitrary code. Now only basic numerics are supported.
https://github.com/joblib/joblib/pull/1321
https://github.com/joblib/joblib/pull/1327

- Vendor cloudpickle 2.2.0 which adds support for PyPy 3.8+.

Expand Down
44 changes: 44 additions & 0 deletions joblib/_utils.py
@@ -0,0 +1,44 @@
# Adapted from https://stackoverflow.com/a/9558001/2536294

import ast
import operator as op

# supported operators
operators = {
ast.Add: op.add,
ast.Sub: op.sub,
ast.Mult: op.mul,
ast.Div: op.truediv,
ast.FloorDiv: op.floordiv,
ast.Mod: op.mod,
ast.Pow: op.pow,
ast.USub: op.neg,
}


def eval_expr(expr):
"""
>>> eval_expr('2*6')
12
>>> eval_expr('2**6')
64
>>> eval_expr('1 + 2*3**(4) / (6 + -7)')
-161.0
"""
try:
return eval_(ast.parse(expr, mode="eval").body)
except (TypeError, SyntaxError, KeyError) as e:
raise ValueError(
f"{expr!r} is not a valid or supported arithmetic expression."
) from e


def eval_(node):
if isinstance(node, ast.Num): # <number>
return node.n
elif isinstance(node, ast.BinOp): # <left> <operator> <right>
return operators[type(node.op)](eval_(node.left), eval_(node.right))
elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
return operators[type(node.op)](eval_(node.operand))
else:
raise TypeError(node)
7 changes: 3 additions & 4 deletions joblib/parallel.py
Expand Up @@ -27,6 +27,7 @@
ThreadingBackend, SequentialBackend,
LokyBackend)
from .externals.cloudpickle import dumps, loads
from ._utils import eval_expr

# Make sure that those two classes are part of the public joblib.parallel API
# so that 3rd party backend implementers can import them from here.
Expand Down Expand Up @@ -1051,10 +1052,8 @@ def _batched_calls_reducer_callback():
else:
self._original_iterator = iterator
if hasattr(pre_dispatch, 'endswith'):
pre_dispatch = eval(
pre_dispatch,
{"n_jobs": n_jobs, "__builtins__": {}}, # globals
{} # locals
pre_dispatch = eval_expr(
pre_dispatch.replace("n_jobs", str(n_jobs))
)
self._pre_dispatch_amount = pre_dispatch = int(pre_dispatch)

Expand Down
27 changes: 27 additions & 0 deletions joblib/test/test_utils.py
@@ -0,0 +1,27 @@
import pytest

from joblib._utils import eval_expr


@pytest.mark.parametrize(
"expr",
["exec('import os')", "print(1)", "import os", "1+1; import os", "1^1"],
)
def test_eval_expr_invalid(expr):
with pytest.raises(
ValueError, match="is not a valid or supported arithmetic"
):
eval_expr(expr)


@pytest.mark.parametrize(
"expr, result",
[
("2*6", 12),
("2**6", 64),
("1 + 2*3**(4) / (6 + -7)", -161.0),
("(20 // 3) % 5", 1),
],
)
def test_eval_expr_valid(expr, result):
assert eval_expr(expr) == result

0 comments on commit 54f4d21

Please sign in to comment.