Skip to content

Commit

Permalink
FIX parse pre-dispatch with AST instead of calling eval (joblib#1327)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrinjalali authored and jeremiedbb committed Oct 7, 2022
1 parent 01b1ed4 commit dae0b93
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 4 deletions.
6 changes: 6 additions & 0 deletions CHANGES.rst
Expand Up @@ -4,6 +4,12 @@ Latest changes
Development version
-------------------

Release 1.1.1

- Fix a security issue where ``eval(pre_dispatch)`` could potentially run
arbitrary code. Now only basic numerics are supported.
https://github.com/joblib/joblib/pull/1327

Release 1.1.0
--------------

Expand Down
44 changes: 44 additions & 0 deletions joblib/_utils.py
@@ -0,0 +1,44 @@
# Adapted from https://stackoverflow.com/a/9558001/2536294

import ast
import operator as op

# supported operators
operators = {
ast.Add: op.add,
ast.Sub: op.sub,
ast.Mult: op.mul,
ast.Div: op.truediv,
ast.FloorDiv: op.floordiv,
ast.Mod: op.mod,
ast.Pow: op.pow,
ast.USub: op.neg,
}


def eval_expr(expr):
"""
>>> eval_expr('2*6')
12
>>> eval_expr('2**6')
64
>>> eval_expr('1 + 2*3**(4) / (6 + -7)')
-161.0
"""
try:
return eval_(ast.parse(expr, mode="eval").body)
except (TypeError, SyntaxError, KeyError) as e:
raise ValueError(
f"{expr!r} is not a valid or supported arithmetic expression."
) from e


def eval_(node):
if isinstance(node, ast.Num): # <number>
return node.n
elif isinstance(node, ast.BinOp): # <left> <operator> <right>
return operators[type(node.op)](eval_(node.left), eval_(node.right))
elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
return operators[type(node.op)](eval_(node.operand))
else:
raise TypeError(node)
7 changes: 3 additions & 4 deletions joblib/parallel.py
Expand Up @@ -28,6 +28,7 @@
LokyBackend)
from .externals.cloudpickle import dumps, loads
from .externals import loky
from ._utils import eval_expr

# Make sure that those two classes are part of the public joblib.parallel API
# so that 3rd party backend implementers can import them from here.
Expand Down Expand Up @@ -1014,10 +1015,8 @@ def _batched_calls_reducer_callback():
else:
self._original_iterator = iterator
if hasattr(pre_dispatch, 'endswith'):
pre_dispatch = eval(
pre_dispatch,
{"n_jobs": n_jobs, "__builtins__": {}}, # globals
{} # locals
pre_dispatch = eval_expr(
pre_dispatch.replace("n_jobs", str(n_jobs))
)
self._pre_dispatch_amount = pre_dispatch = int(pre_dispatch)

Expand Down
27 changes: 27 additions & 0 deletions joblib/test/test_utils.py
@@ -0,0 +1,27 @@
import pytest

from joblib._utils import eval_expr


@pytest.mark.parametrize(
"expr",
["exec('import os')", "print(1)", "import os", "1+1; import os", "1^1"],
)
def test_eval_expr_invalid(expr):
with pytest.raises(
ValueError, match="is not a valid or supported arithmetic"
):
eval_expr(expr)


@pytest.mark.parametrize(
"expr, result",
[
("2*6", 12),
("2**6", 64),
("1 + 2*3**(4) / (6 + -7)", -161.0),
("(20 // 3) % 5", 1),
],
)
def test_eval_expr_valid(expr, result):
assert eval_expr(expr) == result

0 comments on commit dae0b93

Please sign in to comment.