Skip to content

Commit

Permalink
[ONNX] Deprecate setter functions for global variables (#85165)
Browse files Browse the repository at this point in the history
`_set_opset_version` and `_set_operator_export_type` are previously deprecated. This PR decorates them with the deprecation decorator, so warnings are emitted.

- Remove usage of `_set_opset_version` and `_set_operator_export_type` in favor of setting the globals vars directly in torch.onnx internal
- Update `GLOBALS.operator_export_type`'s default to not be None to tighten types
- Remove usage of `_set_onnx_shape_inference`
Pull Request resolved: #85165
Approved by: https://github.com/BowenBao, https://github.com/AllenTiTaiWang
  • Loading branch information
justinchuby authored and pytorchmergebot committed Sep 28, 2022
1 parent 5deeb09 commit 85d8441
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 93 deletions.
110 changes: 53 additions & 57 deletions test/onnx/test_utility_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,8 @@
unregister_custom_op_symbolic,
utils,
)
from torch.onnx.symbolic_helper import (
_set_operator_export_type,
_set_opset_version,
_unpack_list,
parse_args,
)
from torch.onnx._globals import GLOBALS
from torch.onnx.symbolic_helper import _unpack_list, parse_args
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack
from verify import verify
Expand Down Expand Up @@ -136,8 +132,8 @@ def forward(self, x, y, t):
out, out2 = torch.split(t, splits, dim=1)
return out, out2

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.randn(2, 3)
y = torch.randn(2, 4)
t = torch.randn(2, 7)
Expand All @@ -157,8 +153,8 @@ def forward(self, x):
b = torch.transpose(a, 1, 0)
return b + x

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(3, 2)
graph, _, __ = self._model_to_graph(
TransposeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
Expand All @@ -176,8 +172,8 @@ def forward(self, x):
b = torch.norm(a, p=2, dim=-2, keepdim=False)
return b + x

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(
ReduceModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
Expand All @@ -194,8 +190,8 @@ def forward(self, x):
b = torch.norm(a, p=1, dim=-2)
return b + x

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(
NormModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
Expand All @@ -212,8 +208,8 @@ def forward(self, x):
b = torch.narrow(a, 0, 0, 1)
return b + x

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(1, 3)
graph, _, __ = self._model_to_graph(
NarrowModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
Expand All @@ -231,8 +227,8 @@ def forward(self, x):
b = a[1:10] # index exceeds dimension
return b + x

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(1, 3)
graph, _, __ = self._model_to_graph(
SliceIndexExceedsDimModule(),
Expand All @@ -255,8 +251,8 @@ def forward(self, x):
d = torch.select(a, dim=1, index=0)
return b + x, c + d

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(1, 3)
graph, _, __ = self._model_to_graph(
SliceNegativeIndexModule(),
Expand All @@ -277,8 +273,8 @@ def forward(self, x):
c = torch.index_select(a, dim=-2, index=torch.tensor([0, 1]))
return b + 1, c + x

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(1, 3)
model = GatherModule()
model(x)
Expand All @@ -296,8 +292,8 @@ def forward(self, x):
b = torch.unsqueeze(a, -2)
return b + x

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(1, 2, 3)
graph, _, __ = self._model_to_graph(
UnsqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1, 2]}
Expand All @@ -318,8 +314,8 @@ def forward(self, x):
a = torch.randn(2, 3, 4, 5, 8, 7)
return self.prelu(x) + a

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.randn(2, 3, 4, 5, 8, 7)
graph, _, __ = self._model_to_graph(
PReluModel(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3, 4, 5]}
Expand All @@ -336,8 +332,8 @@ def forward(self, x):
a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
return torch.squeeze(a) + x + torch.squeeze(a)

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(
SqueezeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
Expand All @@ -353,8 +349,8 @@ def forward(self, x):
a = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
return torch.squeeze(a, dim=-3) + x

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(
SqueezeAxesModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
Expand Down Expand Up @@ -389,8 +385,8 @@ def forward(self, x):
d = b + c
return x + d

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.ones(2, 3)
graph, _, __ = self._model_to_graph(
ConcatModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
Expand All @@ -410,8 +406,8 @@ def __init__(self):
def forward(self, input, initial_state):
return self.mygru(input, initial_state)

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
input = torch.randn(5, 3, 7)
h0 = torch.randn(1, 3, 3)
graph, _, __ = self._model_to_graph(
Expand Down Expand Up @@ -441,8 +437,8 @@ def __init__(self):
def forward(self, A):
return torch.matmul(A, torch.transpose(self.B, -1, -2))

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
A = torch.randn(2, 3)
graph, _, __ = self._model_to_graph(
MatMulNet(), (A,), input_names=["A"], dynamic_axes={"A": [0, 1]}
Expand All @@ -464,8 +460,8 @@ def forward(self, x):
b = self.weight.reshape(1, -1, 1, 1)
return x * b

_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
x = torch.randn(4, 5)
graph, _, __ = self._model_to_graph(
ReshapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
Expand All @@ -488,8 +484,8 @@ def forward(self, x):
return div * x

x = torch.randn(2, 5)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
Expand All @@ -511,8 +507,8 @@ def forward(self, x):
return mul / x

x = torch.randn(2, 5)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
Expand All @@ -534,8 +530,8 @@ def forward(self, x):
return add - x

x = torch.randn(2, 5)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, params_dict, __ = self._model_to_graph(
Module(),
(x,),
Expand Down Expand Up @@ -566,8 +562,8 @@ def forward(self, x):
return sub + x

x = torch.randn(2, 5)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, params_dict, __ = self._model_to_graph(
Module(),
(x,),
Expand Down Expand Up @@ -598,8 +594,8 @@ def forward(self, x):
return sqrt / x

x = torch.randn(2, 5)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
Module(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
Expand All @@ -618,8 +614,8 @@ def forward(self, x):
return x + shape

x = torch.randn(2, 5)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
ShapeModule(), (x,), input_names=["x"], dynamic_axes={"x": [0, 1]}
)
Expand Down Expand Up @@ -1063,7 +1059,7 @@ def forward(self, x):
return torch.erfc(x)

x = torch.randn(2, 3, 4)
_set_opset_version(self.opset_version)
GLOBALS.export_onnx_opset_version = self.opset_version
graph, _, __ = self._model_to_graph(
Module(),
(x,),
Expand Down Expand Up @@ -1325,8 +1321,8 @@ def forward(self, x):
return x

x = torch.randn(20, 16, 50, 100)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
_, params_dict, __ = self._model_to_graph(
Model(),
(x,),
Expand Down Expand Up @@ -1354,8 +1350,8 @@ def forward(self, x):

model = torch.jit.script(MyModule())
x = torch.randn(10, 3, 128, 128)
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
model,
(x,),
Expand Down Expand Up @@ -1449,8 +1445,8 @@ def forward(self, x, y):

input_1 = torch.tensor([11])
input_2 = torch.tensor([12])
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
GLOBALS.export_onnx_opset_version = self.opset_version
GLOBALS.operator_export_type = OperatorExportTypes.ONNX
graph, _, __ = self._model_to_graph(
MyModule(),
(input_1, input_2),
Expand Down
5 changes: 3 additions & 2 deletions torch/onnx/_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
variables unless they are absolutely necessary.
"""
import os
from typing import Optional

import torch._C._onnx as _C_onnx

Expand All @@ -28,7 +27,9 @@ def __init__(self):
self._in_onnx_export: bool = False
# Whether the user's model is training during export
self.export_training: bool = False
self.operator_export_type: Optional[_C_onnx.OperatorExportTypes] = None
self.operator_export_type: _C_onnx.OperatorExportTypes = (
_C_onnx.OperatorExportTypes.ONNX
)
self.onnx_shape_inference: bool = True

# Internal feature flags
Expand Down
23 changes: 22 additions & 1 deletion torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from torch import _C

# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx import _constants, _patch_torch, _type_utils, errors # noqa: F401
from torch.onnx import ( # noqa: F401
_constants,
_deprecation,
_patch_torch,
_type_utils,
errors,
)
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, jit_utils
from torch.types import Number
Expand Down Expand Up @@ -1632,16 +1638,31 @@ def args_have_same_dtype(args):


# TODO(justinchuby): Delete these setters, users should set the vars directly.
@_deprecation.deprecated(
"1.13",
"1.14",
"remove its usage and avoid setting internal variables directly",
)
def _set_opset_version(opset_version: int):
GLOBALS.export_onnx_opset_version = opset_version


@_deprecation.deprecated(
"1.13",
"1.14",
"remove its usage and avoid setting internal variables directly",
)
def _set_operator_export_type(operator_export_type):
GLOBALS.operator_export_type = operator_export_type


# This function is for debug use only.
# onnx_shape_inference = True by default.
@_deprecation.deprecated(
"1.13",
"1.14",
"remove its usage and avoid setting internal variables directly",
)
def _set_onnx_shape_inference(onnx_shape_inference: bool):
GLOBALS.onnx_shape_inference = onnx_shape_inference

Expand Down
8 changes: 4 additions & 4 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -6502,7 +6502,7 @@ def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]:
torch._C._jit_pass_onnx_block(
old_block,
new_block_context.block,
operator_export_type, # type:ignore[arg-type]
operator_export_type,
env,
False,
)
Expand Down Expand Up @@ -6564,8 +6564,8 @@ def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]:
env = torch._C._jit_pass_onnx_block(
current_b,
block,
operator_export_type, # type:ignore[arg-type]
env, # type:ignore[arg-type]
operator_export_type,
env,
True,
)
if_output_list = list(n.outputs())
Expand All @@ -6591,7 +6591,7 @@ def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]:
torch._C._jit_pass_onnx_block(
old_block,
new_block_context.block,
operator_export_type, # type:ignore[arg-type]
operator_export_type,
env,
False,
)
Expand Down

0 comments on commit 85d8441

Please sign in to comment.