Skip to content

Commit

Permalink
[ONNX] Reland: Update training state logic to support ScriptedModule (p…
Browse files Browse the repository at this point in the history
…ytorch#86745)

In pytorch#86325, it was reported that ScriptedModule do not have a training attribute and will fail export because we don't expect them as input.

Also

- Parameterized the test_util_funs test

Thanks @borisfom for the suggestion!

Fixes pytorch#86325

Pull Request resolved: pytorch#86745
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
  • Loading branch information
justinchuby authored and atalman committed Oct 21, 2022
1 parent 0c0df0b commit 2f9cb1b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 44 deletions.
86 changes: 47 additions & 39 deletions test/onnx/test_utility_funs.py
Expand Up @@ -3,9 +3,11 @@
import copy
import functools
import io
import warnings
from typing import Callable

import onnx
import parameterized

import torch
import torch.onnx
Expand All @@ -17,7 +19,7 @@
skipIfUnsupportedMaxOpsetVersion,
skipIfUnsupportedMinOpsetVersion,
)
from torch.onnx import OperatorExportTypes, TrainingMode, utils
from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
from torch.onnx._globals import GLOBALS
from torch.onnx.symbolic_helper import _unpack_list, parse_args
from torch.testing._internal import common_utils
Expand Down Expand Up @@ -128,8 +130,15 @@ def test_it_returns_empty_list_when_all_ops_convertible(
self.assertEqual(unconvertible_ops, [])


class TestUtilityFuns_opset9(_BaseTestCase):
opset_version = 9
@parameterized.parameterized_class(
[
{"opset_version": opset}
for opset in range(_constants.ONNX_BASE_OPSET, _constants.ONNX_MAX_OPSET + 1)
],
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
)
class TestUtilityFuns(_BaseTestCase):
opset_version = None

def test_is_in_onnx_export(self):
test_self = self
Expand All @@ -148,8 +157,6 @@ def forward(self, x):
self.assertFalse(torch.onnx.is_in_onnx_export())

def test_validate_dynamic_axes_invalid_input_output_name(self):
import warnings

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
utils._validate_dynamic_axes(
Expand Down Expand Up @@ -792,6 +799,31 @@ def forward(self, x):
# verify that the model state is preserved
self.assertEqual(model.training, old_state)

def test_export_does_not_fail_on_frozen_scripted_module(self):
class Inner(torch.nn.Module):
def forward(self, x):
if x > 0:
return x
else:
return x * x

class Outer(torch.nn.Module):
def __init__(self):
super().__init__()
self.inner = torch.jit.script(Inner())

def forward(self, x):
return self.inner(x)

x = torch.zeros(1)
# Freezing is only implemented in eval mode. So we need to call eval()
outer_module = Outer().eval()
module = torch.jit.trace_module(outer_module, {"forward": (x)})
# jit.freeze removes the training attribute in the module
module = torch.jit.freeze(module)

torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version)

@skipIfUnsupportedMinOpsetVersion(15)
def test_local_function(self):
class N(torch.nn.Module):
Expand Down Expand Up @@ -1059,20 +1091,20 @@ def forward(self, x, y, z):

model = M(3)
expected_scope_names = {
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.activation.GELU::gelu1",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.activation.GELU::gelu2",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.0",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.1",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.2",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.N::relu/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.N::relu/"
"torch.nn.modules.activation.ReLU::relu",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::",
}

graph, _, _ = self._model_to_graph(
Expand Down Expand Up @@ -1130,7 +1162,7 @@ def forward(self, x):
# so we expect 3 constants with different scopes. The 3 constants are for the 3 layers.
# If CSE in exporter is improved later, this test needs to be updated.
# It should expect 1 constant, with same scope as root.
scope_prefix = "test_utility_funs.TestUtilityFuns_opset9.test_scope_of_constants_when_combined_by_cse_pass.<locals>"
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_constants_when_combined_by_cse_pass.<locals>"
expected_root_scope_name = f"{scope_prefix}.N::"
expected_layer_scope_name = f"{scope_prefix}.M::layers"
expected_constant_scope_name = [
Expand Down Expand Up @@ -1180,7 +1212,7 @@ def forward(self, x):
graph, _, _ = self._model_to_graph(
N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
)
scope_prefix = "test_utility_funs.TestUtilityFuns_opset9.test_scope_of_nodes_when_combined_by_cse_pass.<locals>"
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_nodes_when_combined_by_cse_pass.<locals>"
expected_root_scope_name = f"{scope_prefix}.N::"
expected_layer_scope_name = f"{scope_prefix}.M::layers"
expected_add_scope_names = [
Expand Down Expand Up @@ -1884,29 +1916,5 @@ def forward(self, x):
torch.onnx.unregister_custom_op_symbolic("::cat", _onnx_opset_version)


class TestUtilityFuns_opset10(TestUtilityFuns_opset9):
opset_version = 10


class TestUtilityFuns_opset11(TestUtilityFuns_opset9):
opset_version = 11


class TestUtilityFuns_opset12(TestUtilityFuns_opset9):
opset_version = 12


class TestUtilityFuns_opset13(TestUtilityFuns_opset9):
opset_version = 13


class TestUtilityFuns_opset14(TestUtilityFuns_opset9):
opset_version = 14


class TestUtilityFuns_opset15(TestUtilityFuns_opset9):
opset_version = 15


if __name__ == "__main__":
common_utils.run_tests()
7 changes: 2 additions & 5 deletions torch/onnx/utils.py
Expand Up @@ -90,7 +90,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
)
originally_training: bool = False

if not isinstance(model, torch.jit.ScriptFunction):
if hasattr(model, "training"):
originally_training = model.training

# ONNX opset 12 has better support for training amenable models, with updated
Expand Down Expand Up @@ -119,10 +119,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
try:
yield
finally:
if not (
isinstance(model, torch.jit.ScriptFunction)
or mode == _C_onnx.TrainingMode.PRESERVE
):
if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE:
model.train(originally_training)


Expand Down

0 comments on commit 2f9cb1b

Please sign in to comment.