Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT][Security] Do not blindly eval input string #89189

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/test_jit.py
Expand Up @@ -3951,6 +3951,14 @@ def invalid4(a):
return a + 2
torch.jit.script(invalid4)

def test_calls_in_type_annotations(self):
with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"):
def spooky(a):
# type: print("Hello") -> Tensor # noqa: F723
return a + 2
print(torch.__file__)
torch.jit.annotations.get_signature(spooky, None, 1, True)

def test_is_optional(self):
ann = Union[List[int], List[float]]
torch._jit_internal.is_optional(ann)
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/frontend/script_type_parser.cpp
Expand Up @@ -316,7 +316,7 @@ std::vector<IValue> ScriptTypeParser::evaluateDefaults(
// We then run constant prop on this graph and check the results are
// constant. This approach avoids having to have separate handling of
// default arguments from standard expressions by piecing together existing
// machinery for graph generation, constant propgation, and constant
// machinery for graph generation, constant propagation, and constant
// extraction.
auto tuple_type = Subscript::create(
r,
Expand Down
14 changes: 12 additions & 2 deletions torch/jit/annotations.py
@@ -1,4 +1,5 @@
import ast
import dis
import enum
import inspect
import re
Expand Down Expand Up @@ -144,6 +145,15 @@ def check_fn(fn, loc):
raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")


def _eval_no_call(stmt, glob, loc):
"""Evaluate statement as long as it does not contain any method/function calls"""
bytecode = compile(stmt, "", mode="eval")
for insn in dis.get_instructions(bytecode):
if "CALL" in insn.opname:
raise RuntimeError(f"Type annotation should not contain calls, but '{stmt}' does")
return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204


def parse_type_line(type_line, rcb, loc):
"""Parses a type annotation specified as a comment.

Expand All @@ -154,15 +164,15 @@ def parse_type_line(type_line, rcb, loc):
arg_ann_str, ret_ann_str = split_type_line(type_line)

try:
arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204
arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb))
except (NameError, SyntaxError) as e:
raise RuntimeError("Failed to parse the argument list of a type annotation") from e

if not isinstance(arg_ann, tuple):
arg_ann = (arg_ann,)

try:
ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204
ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb))
except (NameError, SyntaxError) as e:
raise RuntimeError("Failed to parse the return type of a type annotation") from e

Expand Down