From 56a744bf47edd1adb423593955b786a2ede8bd4f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 14 Oct 2022 19:44:44 +0000 Subject: [PATCH] [ONNX] Reland: Update training state logic to support ScriptedModule (#86745) In https://github.com/pytorch/pytorch/issues/86325, it was reported that ScriptedModule do not have a training attribute and will fail export because we don't expect them as input. Also - Parameterized the test_util_funs test Thanks @borisfom for the suggestion! Fixes #86325 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86745 Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao --- test/onnx/test_utility_funs.py | 86 +++++++++++++++++++--------------- torch/onnx/utils.py | 7 +-- 2 files changed, 49 insertions(+), 44 deletions(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index bb8cf8f9c5df..26467d54c1c6 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -3,9 +3,11 @@ import copy import functools import io +import warnings from typing import Callable import onnx +import parameterized import torch import torch.onnx @@ -17,7 +19,7 @@ skipIfUnsupportedMaxOpsetVersion, skipIfUnsupportedMinOpsetVersion, ) -from torch.onnx import OperatorExportTypes, TrainingMode, utils +from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils from torch.onnx._globals import GLOBALS from torch.onnx.symbolic_helper import _unpack_list, parse_args from torch.testing._internal import common_utils @@ -128,8 +130,15 @@ def test_it_returns_empty_list_when_all_ops_convertible( self.assertEqual(unconvertible_ops, []) -class TestUtilityFuns_opset9(_BaseTestCase): - opset_version = 9 +@parameterized.parameterized_class( + [ + {"opset_version": opset} + for opset in range(_constants.ONNX_BASE_OPSET, _constants.ONNX_MAX_OPSET + 1) + ], + class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}", +) +class TestUtilityFuns(_BaseTestCase): + opset_version = None def test_is_in_onnx_export(self): test_self = self @@ -148,8 +157,6 @@ def forward(self, x): self.assertFalse(torch.onnx.is_in_onnx_export()) def test_validate_dynamic_axes_invalid_input_output_name(self): - import warnings - with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") utils._validate_dynamic_axes( @@ -792,6 +799,31 @@ def forward(self, x): # verify that the model state is preserved self.assertEqual(model.training, old_state) + def test_export_does_not_fail_on_frozen_scripted_module(self): + class Inner(torch.nn.Module): + def forward(self, x): + if x > 0: + return x + else: + return x * x + + class Outer(torch.nn.Module): + def __init__(self): + super().__init__() + self.inner = torch.jit.script(Inner()) + + def forward(self, x): + return self.inner(x) + + x = torch.zeros(1) + # Freezing is only implemented in eval mode. So we need to call eval() + outer_module = Outer().eval() + module = torch.jit.trace_module(outer_module, {"forward": (x)}) + # jit.freeze removes the training attribute in the module + module = torch.jit.freeze(module) + + torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version) + @skipIfUnsupportedMinOpsetVersion(15) def test_local_function(self): class N(torch.nn.Module): @@ -1059,20 +1091,20 @@ def forward(self, x, y, z): model = M(3) expected_scope_names = { - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" "torch.nn.modules.activation.GELU::gelu1", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" "torch.nn.modules.activation.GELU::gelu2", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" "torch.nn.modules.normalization.LayerNorm::lns.0", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" "torch.nn.modules.normalization.LayerNorm::lns.1", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" "torch.nn.modules.normalization.LayerNorm::lns.2", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::/" - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..N::relu/" + "test_utility_funs.TestUtilityFuns.test_node_scope..M::/" + "test_utility_funs.TestUtilityFuns.test_node_scope..N::relu/" "torch.nn.modules.activation.ReLU::relu", - "test_utility_funs.TestUtilityFuns_opset9.test_node_scope..M::", + "test_utility_funs.TestUtilityFuns.test_node_scope..M::", } graph, _, _ = self._model_to_graph( @@ -1130,7 +1162,7 @@ def forward(self, x): # so we expect 3 constants with different scopes. The 3 constants are for the 3 layers. # If CSE in exporter is improved later, this test needs to be updated. # It should expect 1 constant, with same scope as root. - scope_prefix = "test_utility_funs.TestUtilityFuns_opset9.test_scope_of_constants_when_combined_by_cse_pass." + scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_constants_when_combined_by_cse_pass." expected_root_scope_name = f"{scope_prefix}.N::" expected_layer_scope_name = f"{scope_prefix}.M::layers" expected_constant_scope_name = [ @@ -1180,7 +1212,7 @@ def forward(self, x): graph, _, _ = self._model_to_graph( N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={} ) - scope_prefix = "test_utility_funs.TestUtilityFuns_opset9.test_scope_of_nodes_when_combined_by_cse_pass." + scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_nodes_when_combined_by_cse_pass." expected_root_scope_name = f"{scope_prefix}.N::" expected_layer_scope_name = f"{scope_prefix}.M::layers" expected_add_scope_names = [ @@ -1884,29 +1916,5 @@ def forward(self, x): torch.onnx.unregister_custom_op_symbolic("::cat", _onnx_opset_version) -class TestUtilityFuns_opset10(TestUtilityFuns_opset9): - opset_version = 10 - - -class TestUtilityFuns_opset11(TestUtilityFuns_opset9): - opset_version = 11 - - -class TestUtilityFuns_opset12(TestUtilityFuns_opset9): - opset_version = 12 - - -class TestUtilityFuns_opset13(TestUtilityFuns_opset9): - opset_version = 13 - - -class TestUtilityFuns_opset14(TestUtilityFuns_opset9): - opset_version = 14 - - -class TestUtilityFuns_opset15(TestUtilityFuns_opset9): - opset_version = 15 - - if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 1e05b73f9d65..04fc984ded2b 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -90,7 +90,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): ) originally_training: bool = False - if not isinstance(model, torch.jit.ScriptFunction): + if hasattr(model, "training"): originally_training = model.training # ONNX opset 12 has better support for training amenable models, with updated @@ -119,10 +119,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode): try: yield finally: - if not ( - isinstance(model, torch.jit.ScriptFunction) - or mode == _C_onnx.TrainingMode.PRESERVE - ): + if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE: model.train(originally_training)