From 767f6aa49fe20a2766b9843d01e3b7f7793df6a3 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 17 Nov 2022 22:05:27 +0000 Subject: [PATCH] [JIT][Security] Do not blindly eval input string (#89189) Introduce `_eval_no_call` method, that evaluates statement only if it does not contain any calls(done by examining the bytecode), thus preventing command injection exploit Added simple unit test to check for that `torch.jit.annotations.get_signature` would not result in calling random code. Although, this code path exists for Python-2 compatibility, and perhaps should be simply removed. Fixes https://github.com/pytorch/pytorch/issues/88868 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89189 Approved by: https://github.com/suo --- test/test_jit.py | 8 ++++++++ torch/csrc/jit/frontend/script_type_parser.cpp | 2 +- torch/jit/annotations.py | 14 ++++++++++++-- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 13c27b0efa55565..6cbc091d506b586 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3951,6 +3951,14 @@ def invalid4(a): return a + 2 torch.jit.script(invalid4) + def test_calls_in_type_annotations(self): + with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"): + def spooky(a): + # type: print("Hello") -> Tensor # noqa: F723 + return a + 2 + print(torch.__file__) + torch.jit.annotations.get_signature(spooky, None, 1, True) + def test_is_optional(self): ann = Union[List[int], List[float]] torch._jit_internal.is_optional(ann) diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index f5d6f640d413d47..d05ec95fb9fa24e 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -316,7 +316,7 @@ std::vector ScriptTypeParser::evaluateDefaults( // We then run constant prop on this graph and check the results are // constant. This approach avoids having to have separate handling of // default arguments from standard expressions by piecing together existing - // machinery for graph generation, constant propgation, and constant + // machinery for graph generation, constant propagation, and constant // extraction. auto tuple_type = Subscript::create( r, diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index a4a36ce36a5e89d..a6ff2d04d207670 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -1,4 +1,5 @@ import ast +import dis import enum import inspect import re @@ -144,6 +145,15 @@ def check_fn(fn, loc): raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function") +def _eval_no_call(stmt, glob, loc): + """Evaluate statement as long as it does not contain any method/function calls""" + bytecode = compile(stmt, "", mode="eval") + for insn in dis.get_instructions(bytecode): + if "CALL" in insn.opname: + raise RuntimeError(f"Type annotation should not contain calls, but '{stmt}' does") + return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204 + + def parse_type_line(type_line, rcb, loc): """Parses a type annotation specified as a comment. @@ -154,7 +164,7 @@ def parse_type_line(type_line, rcb, loc): arg_ann_str, ret_ann_str = split_type_line(type_line) try: - arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204 + arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb)) except (NameError, SyntaxError) as e: raise RuntimeError("Failed to parse the argument list of a type annotation") from e @@ -162,7 +172,7 @@ def parse_type_line(type_line, rcb, loc): arg_ann = (arg_ann,) try: - ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204 + ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb)) except (NameError, SyntaxError) as e: raise RuntimeError("Failed to parse the return type of a type annotation") from e