From ee2e4763cd30638d2d5b41bef2970e6f0bede512 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 11 Oct 2022 15:46:47 -0700 Subject: [PATCH 1/8] [ONNX] Update training state logic to support ScriptedModule In https://github.com/pytorch/pytorch/issues/86325, it was reported that ScriptedModule do not have a training attribute and will fail export because we don't expect them as input. We should also relax the type constraints in a following PR. Fixes #86325 --- torch/onnx/utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 1e05b73f9d65..04fc984ded2b 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -90,7 +90,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): ) originally_training: bool = False - if not isinstance(model, torch.jit.ScriptFunction): + if hasattr(model, "training"): originally_training = model.training # ONNX opset 12 has better support for training amenable models, with updated @@ -119,10 +119,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): try: yield finally: - if not ( - isinstance(model, torch.jit.ScriptFunction) - or mode == _C_onnx.TrainingMode.PRESERVE - ): + if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE: model.train(originally_training) From aa3172fcb36e0e5632004bc87a9b37d2496dbbaf Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Oct 2022 21:46:28 +0000 Subject: [PATCH 2/8] Add test --- test/onnx/test_utility_funs.py | 77 +++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index bb8cf8f9c5df..07042802770c 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -6,6 +6,7 @@ from typing import Callable import onnx +import parameterized import torch import torch.onnx @@ -18,6 +19,7 @@ skipIfUnsupportedMinOpsetVersion, ) from torch.onnx import OperatorExportTypes, TrainingMode, utils +from torch.onnx import _constants from torch.onnx._globals import GLOBALS from torch.onnx.symbolic_helper import _unpack_list, parse_args from torch.testing._internal import common_utils @@ -128,8 +130,15 @@ def test_it_returns_empty_list_when_all_ops_convertible( self.assertEqual(unconvertible_ops, []) -class TestUtilityFuns_opset9(_BaseTestCase): - opset_version = 9 +@parameterized.parameterized_class( + ("opset_version",), + [ + (opset,) + for opset in range(_constants.ONNX_BASE_OPSET, _constants.ONNX_MAX_OPSET + 1) + ], +) +class TestUtilityFuns(_BaseTestCase): + opset_version = None def test_is_in_onnx_export(self): test_self = self @@ -792,6 +801,30 @@ def forward(self, x): # verify that the model state is preserved self.assertEqual(model.training, old_state) + def test_export_frozen_scripted_module(self): + class Inner(torch.nn.Module): + def forward(self, x): + if x > 0: + return x + else: + return x * x + + class Outer(torch.nn.Module): + def __init__(self): + super().__init__() + self.inner = torch.jit.script(Inner()) + + def forward(self, x): + return self.inner(x) + + x = torch.zeros(1) + outer_module = Outer() + module = torch.jit.trace_module(outer_module, {"forward": (x)}) + # borisf: passes if you comment this line out + module = torch.jit.optimize_for_inference(torch.jit.freeze(module)) + + torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version) + @skipIfUnsupportedMinOpsetVersion(15) def test_local_function(self): class N(torch.nn.Module): @@ -1059,20 +1092,20 @@ def forward(self, x, y, z): model = M(3) expected_scope_names = { - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" "torch.nn.modules.activation.GELU::gelu1", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" "torch.nn.modules.activation.GELU::gelu2", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" "torch.nn.modules.normalization.LayerNorm::lns.0", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" "torch.nn.modules.normalization.LayerNorm::lns.1", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" "torch.nn.modules.normalization.LayerNorm::lns.2", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..N::relu/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..N::relu/" "torch.nn.modules.activation.ReLU::relu", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::", + "test_utility_funs.TestUtilityFuns.test_node_scope..M::", } graph, _, _ = self._model_to_graph( @@ -1884,29 +1917,5 @@ def forward(self, x): torch.onnx.unregister_custom_op_symbolic("::cat", _onnx_opset_version) -class TestUtilityFuns_opset10(TestUtilityFuns_opset9): - opset_version = 10 - - -class TestUtilityFuns_opset11(TestUtilityFuns_opset9): - opset_version = 11 - - -class TestUtilityFuns_opset12(TestUtilityFuns_opset9): - opset_version = 12 - - -class TestUtilityFuns_opset13(TestUtilityFuns_opset9): - opset_version = 13 - - -class TestUtilityFuns_opset14(TestUtilityFuns_opset9): - opset_version = 14 - - -class TestUtilityFuns_opset15(TestUtilityFuns_opset9): - opset_version = 15 - - if __name__ == "__main__": common_utils.run_tests() From adcec982e270081444ce26da73617c2f3e6b845f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Oct 2022 22:57:18 +0000 Subject: [PATCH 3/8] Add test --- test/onnx/test_utility_funs.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 07042802770c..afe72f95c927 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -131,11 +131,11 @@ def test_it_returns_empty_list_when_all_ops_convertible( @parameterized.parameterized_class( - ("opset_version",), [ - (opset,) + {"opset_version": opset} for opset in range(_constants.ONNX_BASE_OPSET, _constants.ONNX_MAX_OPSET + 1) ], + class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}", ) class TestUtilityFuns(_BaseTestCase): opset_version = None @@ -801,7 +801,7 @@ def forward(self, x): # verify that the model state is preserved self.assertEqual(model.training, old_state) - def test_export_frozen_scripted_module(self): + def test_export_does_not_fail_on_frozen_scripted_module(self): class Inner(torch.nn.Module): def forward(self, x): if x > 0: @@ -818,10 +818,11 @@ def forward(self, x): return self.inner(x) x = torch.zeros(1) - outer_module = Outer() + # Freezing is only implemented in eval mode. So we need to call eval() + outer_module = Outer().eval() module = torch.jit.trace_module(outer_module, {"forward": (x)}) - # borisf: passes if you comment this line out - module = torch.jit.optimize_for_inference(torch.jit.freeze(module)) + # jit.freeze removes the training attribute in the module + module = torch.jit.freeze(module) torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version) From 44d2eaed62e15b520a44a08a342087026563d6d3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Oct 2022 23:26:00 +0000 Subject: [PATCH 4/8] Imports --- test/onnx/test_utility_funs.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index afe72f95c927..eda71ac48699 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -4,6 +4,7 @@ import functools import io from typing import Callable +import warnings import onnx import parameterized @@ -157,8 +158,6 @@ def forward(self, x): self.assertFalse(torch.onnx.is_in_onnx_export()) def test_validate_dynamic_axes_invalid_input_output_name(self): - import warnings - with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") utils._validate_dynamic_axes( From e5a0c2cdd2f41da4c1977abead7cea454808f45e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 13 Oct 2022 23:37:07 +0000 Subject: [PATCH 5/8] Format --- test/onnx/test_utility_funs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index eda71ac48699..97d11b3c0a45 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -3,8 +3,8 @@ import copy import functools import io -from typing import Callable import warnings +from typing import Callable import onnx import parameterized @@ -19,8 +19,7 @@ skipIfUnsupportedMaxOpsetVersion, skipIfUnsupportedMinOpsetVersion, ) -from torch.onnx import OperatorExportTypes, TrainingMode, utils -from torch.onnx import _constants +from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils from torch.onnx._globals import GLOBALS from torch.onnx.symbolic_helper import _unpack_list, parse_args from torch.testing._internal import common_utils From 399fe2ee5cd0ffa1ab363ded2156865f7802e6b6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Oct 2022 16:04:34 +0000 Subject: [PATCH 6/8] Update from master --- test/onnx/test_utility_funs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 97d11b3c0a45..e95a6b3bc79c 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -1162,7 +1162,7 @@ def forward(self, x): # so we expect 3 constants with different scopes. The 3 constants are for the 3 layers. # If CSE in exporter is improved later, this test needs to be updated. # It should expect 1 constant, with same scope as root. - scope_prefix = "test_utility_funs.TestUtilityFuns_opset9.test_scope_of_constants_when_combined_by_cse_pass." + scope_prefix = f"test_utility_funs.TestUtilityFuns_opset{self.opset_version}.test_scope_of_constants_when_combined_by_cse_pass." expected_root_scope_name = f"{scope_prefix}.N::" expected_layer_scope_name = f"{scope_prefix}.M::layers" expected_constant_scope_name = [ @@ -1212,7 +1212,7 @@ def forward(self, x): graph, _, _ = self._model_to_graph( N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={} ) - scope_prefix = "test_utility_funs.TestUtilityFuns_opset9.test_scope_of_nodes_when_combined_by_cse_pass." + scope_prefix = f"test_utility_funs.TestUtilityFuns_opset{self.opset_version}.test_scope_of_nodes_when_combined_by_cse_pass." expected_root_scope_name = f"{scope_prefix}.N::" expected_layer_scope_name = f"{scope_prefix}.M::layers" expected_add_scope_names = [ From 5ca31cda0350cddc7a9d701a7f599797b5da271a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Oct 2022 16:29:16 +0000 Subject: [PATCH 7/8] Format --- test/onnx/test_utility_funs.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index e95a6b3bc79c..560ecac9fb12 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -1162,7 +1162,10 @@ def forward(self, x): # so we expect 3 constants with different scopes. The 3 constants are for the 3 layers. # If CSE in exporter is improved later, this test needs to be updated. # It should expect 1 constant, with same scope as root. - scope_prefix = f"test_utility_funs.TestUtilityFuns_opset{self.opset_version}.test_scope_of_constants_when_combined_by_cse_pass." + scope_prefix = ( + f"test_utility_funs.TestUtilityFuns_opset{self.opset_version}." + "test_scope_of_constants_when_combined_by_cse_pass." + ) expected_root_scope_name = f"{scope_prefix}.N::" expected_layer_scope_name = f"{scope_prefix}.M::layers" expected_constant_scope_name = [ @@ -1212,7 +1215,10 @@ def forward(self, x): graph, _, _ = self._model_to_graph( N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={} ) - scope_prefix = f"test_utility_funs.TestUtilityFuns_opset{self.opset_version}.test_scope_of_nodes_when_combined_by_cse_pass." + scope_prefix = ( + f"test_utility_funs.TestUtilityFuns_opset{self.opset_version}." + "test_scope_of_nodes_when_combined_by_cse_pass." + ) expected_root_scope_name = f"{scope_prefix}.N::" expected_layer_scope_name = f"{scope_prefix}.M::layers" expected_add_scope_names = [ From 9b440547f9892c29ba3b84f37dff59baae81d781 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Oct 2022 16:38:06 +0000 Subject: [PATCH 8/8] Fix string --- test/onnx/test_utility_funs.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 560ecac9fb12..26467d54c1c6 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -1162,10 +1162,7 @@ def forward(self, x): # so we expect 3 constants with different scopes. The 3 constants are for the 3 layers. # If CSE in exporter is improved later, this test needs to be updated. # It should expect 1 constant, with same scope as root. - scope_prefix = ( - f"test_utility_funs.TestUtilityFuns_opset{self.opset_version}." - "test_scope_of_constants_when_combined_by_cse_pass." - ) + scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_constants_when_combined_by_cse_pass." expected_root_scope_name = f"{scope_prefix}.N::" expected_layer_scope_name = f"{scope_prefix}.M::layers" expected_constant_scope_name = [ @@ -1215,10 +1212,7 @@ def forward(self, x): graph, _, _ = self._model_to_graph( N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={} ) - scope_prefix = ( - f"test_utility_funs.TestUtilityFuns_opset{self.opset_version}." - "test_scope_of_nodes_when_combined_by_cse_pass." - ) + scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_nodes_when_combined_by_cse_pass." expected_root_scope_name = f"{scope_prefix}.N::" expected_layer_scope_name = f"{scope_prefix}.M::layers" expected_add_scope_names = [