forked from joblib/joblib
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FIX parse pre-dispatch with AST instead of calling eval (joblib#1327)
- Loading branch information
1 parent
01b1ed4
commit dae0b93
Showing
4 changed files
with
80 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Adapted from https://stackoverflow.com/a/9558001/2536294 | ||
|
||
import ast | ||
import operator as op | ||
|
||
# supported operators | ||
operators = { | ||
ast.Add: op.add, | ||
ast.Sub: op.sub, | ||
ast.Mult: op.mul, | ||
ast.Div: op.truediv, | ||
ast.FloorDiv: op.floordiv, | ||
ast.Mod: op.mod, | ||
ast.Pow: op.pow, | ||
ast.USub: op.neg, | ||
} | ||
|
||
|
||
def eval_expr(expr): | ||
""" | ||
>>> eval_expr('2*6') | ||
12 | ||
>>> eval_expr('2**6') | ||
64 | ||
>>> eval_expr('1 + 2*3**(4) / (6 + -7)') | ||
-161.0 | ||
""" | ||
try: | ||
return eval_(ast.parse(expr, mode="eval").body) | ||
except (TypeError, SyntaxError, KeyError) as e: | ||
raise ValueError( | ||
f"{expr!r} is not a valid or supported arithmetic expression." | ||
) from e | ||
|
||
|
||
def eval_(node): | ||
if isinstance(node, ast.Num): # <number> | ||
return node.n | ||
elif isinstance(node, ast.BinOp): # <left> <operator> <right> | ||
return operators[type(node.op)](eval_(node.left), eval_(node.right)) | ||
elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1 | ||
return operators[type(node.op)](eval_(node.operand)) | ||
else: | ||
raise TypeError(node) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import pytest | ||
|
||
from joblib._utils import eval_expr | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"expr", | ||
["exec('import os')", "print(1)", "import os", "1+1; import os", "1^1"], | ||
) | ||
def test_eval_expr_invalid(expr): | ||
with pytest.raises( | ||
ValueError, match="is not a valid or supported arithmetic" | ||
): | ||
eval_expr(expr) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"expr, result", | ||
[ | ||
("2*6", 12), | ||
("2**6", 64), | ||
("1 + 2*3**(4) / (6 + -7)", -161.0), | ||
("(20 // 3) % 5", 1), | ||
], | ||
) | ||
def test_eval_expr_valid(expr, result): | ||
assert eval_expr(expr) == result |