diff --git a/test/test_jit.py b/test/test_jit.py index a4f535921e55..055c990f8598 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3912,6 +3912,14 @@ def invalid4(a): return a + 2 torch.jit.script(invalid4) + def test_calls_in_type_annotations(self): + with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"): + def spooky(a): + # type: print("Hello") -> Tensor # noqa: F723 + return a + 2 + print(torch.__file__) + torch.jit.annotations.get_signature(spooky, None, 1, True) + def test_is_optional(self): ann = Union[List[int], List[float]] torch._jit_internal.is_optional(ann) diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index f5d6f640d413..d05ec95fb9fa 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -316,7 +316,7 @@ std::vector ScriptTypeParser::evaluateDefaults( // We then run constant prop on this graph and check the results are // constant. This approach avoids having to have separate handling of // default arguments from standard expressions by piecing together existing - // machinery for graph generation, constant propgation, and constant + // machinery for graph generation, constant propagation, and constant // extraction. auto tuple_type = Subscript::create( r, diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index a4a36ce36a5e..a6ff2d04d207 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -1,4 +1,5 @@ import ast +import dis import enum import inspect import re @@ -144,6 +145,15 @@ def check_fn(fn, loc): raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function") +def _eval_no_call(stmt, glob, loc): + """Evaluate statement as long as it does not contain any method/function calls""" + bytecode = compile(stmt, "", mode="eval") + for insn in dis.get_instructions(bytecode): + if "CALL" in insn.opname: + raise RuntimeError(f"Type annotation should not contain calls, but '{stmt}' does") + return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204 + + def parse_type_line(type_line, rcb, loc): """Parses a type annotation specified as a comment. @@ -154,7 +164,7 @@ def parse_type_line(type_line, rcb, loc): arg_ann_str, ret_ann_str = split_type_line(type_line) try: - arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204 + arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb)) except (NameError, SyntaxError) as e: raise RuntimeError("Failed to parse the argument list of a type annotation") from e @@ -162,7 +172,7 @@ def parse_type_line(type_line, rcb, loc): arg_ann = (arg_ann,) try: - ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204 + ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb)) except (NameError, SyntaxError) as e: raise RuntimeError("Failed to parse the return type of a type annotation") from e