Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX parse pre-dispatch with AST instead of calling eval #1327

Merged
merged 6 commits into from Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.rst
Expand Up @@ -30,7 +30,7 @@ Development version

- Fix a security issue where ``eval(pre_dispatch)`` could potentially run
arbitrary code. Now only basic numerics are supported.
https://github.com/joblib/joblib/pull/1321
https://github.com/joblib/joblib/pull/1327

- Vendor cloudpickle 2.2.0 which adds support for PyPy 3.8+.

Expand Down
44 changes: 44 additions & 0 deletions joblib/_utils.py
@@ -0,0 +1,44 @@
# Adapted from https://stackoverflow.com/a/9558001/2536294

import ast
import operator as op

# supported operators
operators = {
ast.Add: op.add,
ast.Sub: op.sub,
ast.Mult: op.mul,
ast.Div: op.truediv,
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
ast.FloorDiv: op.floordiv,
ast.Mod: op.mod,
ast.Pow: op.pow,
ast.USub: op.neg,
}


def eval_expr(expr):
"""
>>> eval_expr('2*6')
12
>>> eval_expr('2**6')
64
>>> eval_expr('1 + 2*3**(4) / (6 + -7)')
-161.0
"""
try:
return eval_(ast.parse(expr, mode="eval").body)
except (TypeError, SyntaxError, KeyError) as e:
raise ValueError(
f"{expr!r} is not a valid or supported arithmetic expression."
) from e


def eval_(node):
if isinstance(node, ast.Num): # <number>
return node.n
elif isinstance(node, ast.BinOp): # <left> <operator> <right>
return operators[type(node.op)](eval_(node.left), eval_(node.right))
elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
return operators[type(node.op)](eval_(node.operand))
else:
raise TypeError(node)
7 changes: 3 additions & 4 deletions joblib/parallel.py
Expand Up @@ -27,6 +27,7 @@
ThreadingBackend, SequentialBackend,
LokyBackend)
from .externals.cloudpickle import dumps, loads
from ._utils import eval_expr

# Make sure that those two classes are part of the public joblib.parallel API
# so that 3rd party backend implementers can import them from here.
Expand Down Expand Up @@ -1051,10 +1052,8 @@ def _batched_calls_reducer_callback():
else:
self._original_iterator = iterator
if hasattr(pre_dispatch, 'endswith'):
pre_dispatch = eval(
pre_dispatch,
{"n_jobs": n_jobs, "__builtins__": {}}, # globals
{} # locals
pre_dispatch = eval_expr(
pre_dispatch.replace("n_jobs", str(n_jobs))
)
self._pre_dispatch_amount = pre_dispatch = int(pre_dispatch)

Expand Down
25 changes: 25 additions & 0 deletions joblib/test/test_utils.py
@@ -0,0 +1,25 @@
import pytest

from joblib._utils import eval_expr


@pytest.mark.parametrize(
"expr",
["exec('import os')", "print(1)", "import os", "1+1; import os", "1^1"],
)
def test_eval_expr_invalid(expr):
with pytest.raises(ValueError, match="is not a valid or supported arithmetic"):
eval_expr(expr)


@pytest.mark.parametrize(
"expr, result",
[
("2*6", 12),
("2**6", 64),
("1 + 2*3**(4) / (6 + -7)", -161.0),
("(20 // 3) % 5", 1),
],
)
def test_eval_expr_valid(expr, result):
assert eval_expr(expr) == result