Skip to content

Commit

Permalink
[JIT][Security] Do not blindly eval input string (pytorch#89189)
Browse files Browse the repository at this point in the history
Introduce `_eval_no_call` method, that evaluates statement only if it
does not contain any calls(done by examining the bytecode), thus preventing command injection exploit

Added simple unit test to check for that
`torch.jit.annotations.get_signature` would not result in calling random
code.

Although, this code path exists for Python-2 compatibility, and perhaps
should be simply removed.

Fixes pytorch#88868

Pull Request resolved: pytorch#89189
Approved by: https://github.com/suo
  • Loading branch information
malfet authored and atalman committed Nov 30, 2022
1 parent 7c98e70 commit 78cad99
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
8 changes: 8 additions & 0 deletions test/test_jit.py
Expand Up @@ -3912,6 +3912,14 @@ def invalid4(a):
return a + 2
torch.jit.script(invalid4)

def test_calls_in_type_annotations(self):
with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"):
def spooky(a):
# type: print("Hello") -> Tensor # noqa: F723
return a + 2
print(torch.__file__)
torch.jit.annotations.get_signature(spooky, None, 1, True)

def test_is_optional(self):
ann = Union[List[int], List[float]]
torch._jit_internal.is_optional(ann)
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/frontend/script_type_parser.cpp
Expand Up @@ -316,7 +316,7 @@ std::vector<IValue> ScriptTypeParser::evaluateDefaults(
// We then run constant prop on this graph and check the results are
// constant. This approach avoids having to have separate handling of
// default arguments from standard expressions by piecing together existing
// machinery for graph generation, constant propgation, and constant
// machinery for graph generation, constant propagation, and constant
// extraction.
auto tuple_type = Subscript::create(
r,
Expand Down
14 changes: 12 additions & 2 deletions torch/jit/annotations.py
@@ -1,4 +1,5 @@
import ast
import dis
import enum
import inspect
import re
Expand Down Expand Up @@ -144,6 +145,15 @@ def check_fn(fn, loc):
raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")


def _eval_no_call(stmt, glob, loc):
"""Evaluate statement as long as it does not contain any method/function calls"""
bytecode = compile(stmt, "", mode="eval")
for insn in dis.get_instructions(bytecode):
if "CALL" in insn.opname:
raise RuntimeError(f"Type annotation should not contain calls, but '{stmt}' does")
return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204


def parse_type_line(type_line, rcb, loc):
"""Parses a type annotation specified as a comment.
Expand All @@ -154,15 +164,15 @@ def parse_type_line(type_line, rcb, loc):
arg_ann_str, ret_ann_str = split_type_line(type_line)

try:
arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204
arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb))
except (NameError, SyntaxError) as e:
raise RuntimeError("Failed to parse the argument list of a type annotation") from e

if not isinstance(arg_ann, tuple):
arg_ann = (arg_ann,)

try:
ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204
ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb))
except (NameError, SyntaxError) as e:
raise RuntimeError("Failed to parse the return type of a type annotation") from e

Expand Down

0 comments on commit 78cad99

Please sign in to comment.