Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Reland: Update training state logic to support ScriptedModule #86745

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 47 additions & 39 deletions test/onnx/test_utility_funs.py
Expand Up @@ -3,9 +3,11 @@
import copy
import functools
import io
import warnings
from typing import Callable

import onnx
import parameterized

import torch
import torch.onnx
Expand All @@ -17,7 +19,7 @@
skipIfUnsupportedMaxOpsetVersion,
skipIfUnsupportedMinOpsetVersion,
)
from torch.onnx import OperatorExportTypes, TrainingMode, utils
from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
from torch.onnx._globals import GLOBALS
from torch.onnx.symbolic_helper import _unpack_list, parse_args
from torch.testing._internal import common_utils
Expand Down Expand Up @@ -128,8 +130,15 @@ def test_it_returns_empty_list_when_all_ops_convertible(
self.assertEqual(unconvertible_ops, [])


class TestUtilityFuns_opset9(_BaseTestCase):
opset_version = 9
@parameterized.parameterized_class(
[
{"opset_version": opset}
for opset in range(_constants.ONNX_BASE_OPSET, _constants.ONNX_MAX_OPSET + 1)
],
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_opset_{params_dict['opset_version']}",
)
class TestUtilityFuns(_BaseTestCase):
opset_version = None

def test_is_in_onnx_export(self):
test_self = self
Expand All @@ -148,8 +157,6 @@ def forward(self, x):
self.assertFalse(torch.onnx.is_in_onnx_export())

def test_validate_dynamic_axes_invalid_input_output_name(self):
import warnings

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
utils._validate_dynamic_axes(
Expand Down Expand Up @@ -792,6 +799,31 @@ def forward(self, x):
# verify that the model state is preserved
self.assertEqual(model.training, old_state)

def test_export_does_not_fail_on_frozen_scripted_module(self):
class Inner(torch.nn.Module):
def forward(self, x):
if x > 0:
return x
else:
return x * x

class Outer(torch.nn.Module):
def __init__(self):
super().__init__()
self.inner = torch.jit.script(Inner())

def forward(self, x):
return self.inner(x)

x = torch.zeros(1)
# Freezing is only implemented in eval mode. So we need to call eval()
outer_module = Outer().eval()
module = torch.jit.trace_module(outer_module, {"forward": (x)})
# jit.freeze removes the training attribute in the module
module = torch.jit.freeze(module)

torch.onnx.export(module, (x,), io.BytesIO(), opset_version=self.opset_version)

@skipIfUnsupportedMinOpsetVersion(15)
def test_local_function(self):
class N(torch.nn.Module):
Expand Down Expand Up @@ -1059,20 +1091,20 @@ def forward(self, x, y, z):

model = M(3)
expected_scope_names = {
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.activation.GELU::gelu1",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.activation.GELU::gelu2",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.0",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.1",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"torch.nn.modules.normalization.LayerNorm::lns.2",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.N::relu/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::/"
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.N::relu/"
"torch.nn.modules.activation.ReLU::relu",
"test_utility_funs.TestUtilityFuns_opset9.test_node_scope.<locals>.M::",
"test_utility_funs.TestUtilityFuns.test_node_scope.<locals>.M::",
}

graph, _, _ = self._model_to_graph(
Expand Down Expand Up @@ -1130,7 +1162,7 @@ def forward(self, x):
# so we expect 3 constants with different scopes. The 3 constants are for the 3 layers.
# If CSE in exporter is improved later, this test needs to be updated.
# It should expect 1 constant, with same scope as root.
scope_prefix = "test_utility_funs.TestUtilityFuns_opset9.test_scope_of_constants_when_combined_by_cse_pass.<locals>"
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_constants_when_combined_by_cse_pass.<locals>"
expected_root_scope_name = f"{scope_prefix}.N::"
expected_layer_scope_name = f"{scope_prefix}.M::layers"
expected_constant_scope_name = [
Expand Down Expand Up @@ -1180,7 +1212,7 @@ def forward(self, x):
graph, _, _ = self._model_to_graph(
N(), (torch.randn(2, 3)), input_names=[], dynamic_axes={}
)
scope_prefix = "test_utility_funs.TestUtilityFuns_opset9.test_scope_of_nodes_when_combined_by_cse_pass.<locals>"
scope_prefix = "test_utility_funs.TestUtilityFuns.test_scope_of_nodes_when_combined_by_cse_pass.<locals>"
expected_root_scope_name = f"{scope_prefix}.N::"
expected_layer_scope_name = f"{scope_prefix}.M::layers"
expected_add_scope_names = [
Expand Down Expand Up @@ -1884,29 +1916,5 @@ def forward(self, x):
torch.onnx.unregister_custom_op_symbolic("::cat", _onnx_opset_version)


class TestUtilityFuns_opset10(TestUtilityFuns_opset9):
opset_version = 10


class TestUtilityFuns_opset11(TestUtilityFuns_opset9):
opset_version = 11


class TestUtilityFuns_opset12(TestUtilityFuns_opset9):
opset_version = 12


class TestUtilityFuns_opset13(TestUtilityFuns_opset9):
opset_version = 13


class TestUtilityFuns_opset14(TestUtilityFuns_opset9):
opset_version = 14


class TestUtilityFuns_opset15(TestUtilityFuns_opset9):
opset_version = 15


if __name__ == "__main__":
common_utils.run_tests()
7 changes: 2 additions & 5 deletions torch/onnx/utils.py
Expand Up @@ -90,7 +90,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
)
originally_training: bool = False

if not isinstance(model, torch.jit.ScriptFunction):
if hasattr(model, "training"):
originally_training = model.training

# ONNX opset 12 has better support for training amenable models, with updated
Expand Down Expand Up @@ -119,10 +119,7 @@ def select_model_mode_for_export(model, mode: _C_onnx.TrainingMode):
try:
yield
finally:
if not (
isinstance(model, torch.jit.ScriptFunction)
or mode == _C_onnx.TrainingMode.PRESERVE
):
if hasattr(model, "training") and not mode == _C_onnx.TrainingMode.PRESERVE:
model.train(originally_training)


Expand Down