Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Deprecate various args #65962

Merged
merged 6 commits into from Oct 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 18 additions & 13 deletions docs/source/onnx.rst
Expand Up @@ -307,11 +307,11 @@ If the operator is an ATen operator (shows up in the TorchScript graph with the

* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
`torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py>`_.
Make sure the function has the same name as the ATen function, which may declared in
Make sure the function has the same name as the ATen function, which may be declared in
``torch/_C/_VariableFunctions.pyi`` or ``torch/nn/functional.pyi`` (these files are generated at
build time, so will not appear in your checkout until you build PyTorch).
* The first arg is always the ONNX graph that is being built for export.
Other arg names must EXACTLY match the names in ``_VariableFunctions.pyi``,
Other arg names must EXACTLY match the names in the ``.pyi`` file,
because dispatch is done with keyword arguments.
* In the symbolic function, if the operator is in the
`ONNX standard operator set <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
Expand Down Expand Up @@ -365,8 +365,8 @@ See the ``symbolic_opset*.py`` files for more examples.
torch.autograd.Functions
^^^^^^^^^^^^^^^^^^^^^^^^

If the operator is defined in a sub-class of :class:`torch.autograd.Function`,
there are two ways to export it.
If the operator is a sub-class of :class:`torch.autograd.Function`, there are two ways
to export it.

Static Symbolic Method
~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -388,23 +388,24 @@ PythonOp Symbolic
~~~~~~~~~~~~~~~~~

Alternatively, you can register a custom symbolic function.
This gives the symoblic function access to more info through the
This gives the symbolic function access to more info through the
TorchScript ``Node`` object for the original operation, which gets passed in as the second
argument (after the ``Graph`` object).

All autograd ``Function``s are emitted in the TorchScript graph as ``prim::PythonOp`` nodes.
All autograd ``Function``\ s appear in the TorchScript graph as ``prim::PythonOp`` nodes.
In order to differentiate between different ``Function`` subclasses, the
symbolic function should use the ``name`` kwarg which gets set to the name of the class.

:func:`register_custom_op_symbolic` does not allow registration for ops in
the ``prim`` namespace, so for this use case, there's a back door: register the
symbolic for ``"::prim_PythonOp"``.

Please also consider adding shape inference logic when you regiester a custom symbolic function
via setType API. This can help the exporter to obtain correct shape inference.
An example of setType is test_aten_embedding_2 in test_operators.py.
Although it is not required to add shape inference logic,
the exporter emits a warning message if it is not added.
Custom symbolic functions should add type and shape information by calling ``setType(...)``
on Value objects before returning them (implemented in C++ by
``torch::jit::Value::setType``). This is not required, but it can help the exporter's
shape and type inference for down-stream nodes. For a non-trivial example of ``setType``, see
``test_aten_embedding_2`` in
`test_operators.py <https://github.com/pytorch/pytorch/blob/master/test/onnx/test_operators.py>`_.

The example below shows how you can access ``requires_grad`` via the ``Node`` object::

Expand All @@ -430,13 +431,17 @@ The example below shows how you can access ``requires_grad`` via the ``Node`` ob
print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))

name = kwargs["name"]
ret = None
if name == "MyClip":
return g.op("Clip", args[0], min_f=args[1])
ret = g.op("Clip", args[0], min_f=args[1])
elif name == "MyRelu":
return g.op("Relu", args[0])
ret = g.op("Relu", args[0])
else:
# Logs a warning and returns None
return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
# Copy type and shape from original node.
ret.setType(n.type())
return ret

from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("::prim_PythonOp", symbolic_pythonop, 1)
Expand Down
8 changes: 3 additions & 5 deletions test/jit/test_export_modes.py
Expand Up @@ -64,7 +64,7 @@ def foo(a):
return (a, a)
f = io.BytesIO()
x = torch.ones(3)
torch.onnx._export(foo, (x,), f, example_outputs=(x, x))
torch.onnx._export(foo, (x,), f)

@skipIfNoLapack
def test_aten_fallback(self):
Expand All @@ -76,9 +76,8 @@ def forward(self, x, y):

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export_to_pretty_string(
ModelWithAtenNotONNXOp(), (x, y), f,
ModelWithAtenNotONNXOp(), (x, y), None,
add_node_names=False,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
Expand All @@ -91,11 +90,10 @@ class ModelWithAtenFmod(nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

f = io.BytesIO()
x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
torch.onnx.export_to_pretty_string(
ModelWithAtenFmod(), (x, y), f,
ModelWithAtenFmod(), (x, y), None,
add_node_names=False,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN)
61 changes: 17 additions & 44 deletions test/jit/test_onnx_export.py
Expand Up @@ -49,20 +49,17 @@ def forward(self, x):

tm = TraceMe()
tm = torch.jit.trace(tm, torch.rand(3, 4))
example_outputs = (tm(torch.rand(3, 4)),)
f = io.BytesIO()
torch.onnx._export(tm, (torch.rand(3, 4),), f, example_outputs=example_outputs)
torch.onnx._export(tm, (torch.rand(3, 4),), f)

def test_export_tensoroption_to(self):
def foo(x):
return x[0].clone().detach().cpu() + x

traced = torch.jit.trace(foo, (torch.rand([2])))
example_outputs = traced(torch.rand([2]))

f = io.BytesIO()
torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
example_outputs=example_outputs)
torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f)

def test_onnx_export_script_module(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -75,10 +72,8 @@ def forward(self, x):
return x + x

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3),), None, verbose=False)

@suppress_warnings
def test_onnx_export_func_with_warnings(self):
Expand All @@ -93,11 +88,9 @@ def __init__(self):
def forward(self, x):
return func_with_warning(x)

outputs = WarningTest()(torch.randn(42))
# no exception
torch.onnx.export_to_pretty_string(
WarningTest(), torch.randn(42), None, verbose=False,
example_outputs=outputs)
WarningTest(), torch.randn(42), None, verbose=False)

def test_onnx_export_script_python_fail(self):
class PythonModule(torch.jit.ScriptModule):
Expand All @@ -119,11 +112,9 @@ def forward(self, x):
return y + y

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
f = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"):
torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
example_outputs=outputs)
torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False)

def test_onnx_export_script_inline_trace(self):
class ModuleToInline(torch.nn.Module):
Expand All @@ -144,10 +135,8 @@ def forward(self, x):
return y + y

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3),), None, verbose=False)

def test_onnx_export_script_inline_script(self):
class ModuleToInline(torch.jit.ScriptModule):
Expand All @@ -169,10 +158,8 @@ def forward(self, x):
return y + y

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3),), None, verbose=False)

def test_onnx_export_script_module_loop(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -189,10 +176,8 @@ def forward(self, x):
return x

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3),), None, verbose=False)

@suppress_warnings
def test_onnx_export_script_truediv(self):
Expand All @@ -206,11 +191,9 @@ def forward(self, x):
return x + z

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3))

torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3, dtype=torch.float),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3, dtype=torch.float),), None, verbose=False)

def test_onnx_export_script_non_alpha_add_sub(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -223,10 +206,8 @@ def forward(self, x):
return bs - 1

mte = ModuleToExport()
outputs = torch.LongTensor([mte(torch.rand(3, 4))])
torch.onnx.export_to_pretty_string(
mte, (torch.rand(3, 4),), None, verbose=False,
example_outputs=outputs)
mte, (torch.rand(3, 4),), None, verbose=False)

def test_onnx_export_script_module_if(self):
class ModuleToExport(torch.jit.ScriptModule):
Expand All @@ -240,10 +221,8 @@ def forward(self, x):
return x

mte = ModuleToExport()
outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
torch.onnx.export_to_pretty_string(
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
example_outputs=outputs)
mte, (torch.zeros(1, 2, 3),), None, verbose=False)

def test_onnx_export_script_inline_params(self):
class ModuleToInline(torch.jit.ScriptModule):
Expand Down Expand Up @@ -272,8 +251,7 @@ def forward(self, x):
reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
self.assertEqual(result, reference)
torch.onnx.export_to_pretty_string(
mte, (torch.ones(2, 3),), None, verbose=False,
example_outputs=result)
mte, (torch.ones(2, 3),), None, verbose=False)

def test_onnx_export_speculate(self):

Expand Down Expand Up @@ -305,18 +283,16 @@ def transpose(x):
return x.t()

f1 = Foo(transpose)
outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
f2 = Foo(linear)
outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))

torch.onnx.export_to_pretty_string(
f1,
(torch.ones(1, 10, dtype=torch.float), ),
None, verbose=False, example_outputs=outputs_f1)
None, verbose=False)
torch.onnx.export_to_pretty_string(
f2,
(torch.ones(1, 10, dtype=torch.float), ),
None, verbose=False, example_outputs=outputs_f2)
None, verbose=False)

def test_onnx_export_shape_reshape(self):
class Foo(torch.nn.Module):
Expand All @@ -328,10 +304,8 @@ def forward(self, x):
return reshaped

foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
outputs = foo(torch.zeros(1, 2, 3))
f = io.BytesIO()
torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
example_outputs=outputs)
torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f)

def test_listconstruct_erasure(self):
class FooMod(torch.nn.Module):
Expand Down Expand Up @@ -360,11 +334,10 @@ def forward(self, x):
mod = DynamicSliceExportMod()

input = torch.rand(3, 4, 5)
example_outs = mod(input)

f = io.BytesIO()
torch.onnx.export_to_pretty_string(
DynamicSliceExportMod(), (input,), f, example_outputs=example_outs, opset_version=10)
DynamicSliceExportMod(), (input,), f, opset_version=10)

def test_export_dict(self):
class DictModule(torch.nn.Module):
Expand All @@ -380,4 +353,4 @@ def forward(self, x_in: torch.Tensor) -> typing.Dict[str, torch.Tensor]:

with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."):
torch.onnx.export_to_pretty_string(
torch.jit.script(mod), (x_in,), f, example_outputs=(mod(x_in),))
torch.jit.script(mod), (x_in,), f)
3 changes: 1 addition & 2 deletions test/jit/test_tracer.py
Expand Up @@ -1115,9 +1115,8 @@ class Mod(torch.nn.Module):
def forward(self, x, w):
return torch.matmul(x, w).detach()

f = io.BytesIO()
torch.onnx.export_to_pretty_string(
Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f)
Mod(), (torch.rand(3, 4), torch.rand(4, 5)), None)

def test_trace_slice_full_dim(self):
def foo(x):
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/test_models_onnxruntime.py
Expand Up @@ -19,7 +19,7 @@ def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):

outputs = model(inputs)
script_model = torch.jit.script(model)
run_model_test(self, script_model, False, example_outputs=outputs,
run_model_test(self, script_model, False,
input=inputs, rtol=rtol, atol=atol)


Expand Down
12 changes: 4 additions & 8 deletions test/onnx/test_onnx_opset.py
Expand Up @@ -40,14 +40,13 @@ def check_onnx_opset_operator(model, ops, opset_version=_export_onnx_opset_versi
assert attributes[j][attribute_field] == getattr(graph.node[i].attribute[j], attribute_field)


def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL, example_outputs=None,
def check_onnx_opsets_operator(module, x, ops, opset_versions, training=torch.onnx.TrainingMode.EVAL,
input_names=None, dynamic_axes=None):
for opset_version in opset_versions:
f = io.BytesIO()
torch.onnx.export(module, x, f,
opset_version=opset_version,
training=training,
example_outputs=example_outputs,
input_names=input_names,
dynamic_axes=dynamic_axes)
model = onnx.load(io.BytesIO(f.getvalue()))
Expand Down Expand Up @@ -91,10 +90,8 @@ def forward(self, input, k):
x = torch.arange(1., 6., requires_grad=True)
k = torch.tensor(3)
module = MyModuleDynamic()
example_output = module(x, k)
check_onnx_opsets_operator(module, [x, k], ops,
opset_versions=[10],
example_outputs=example_output)
opset_versions=[10])

def test_maxpool(self):
module = torch.nn.MaxPool1d(2, stride=1)
Expand Down Expand Up @@ -191,7 +188,6 @@ def forward(self, x):

module = DynamicSliceModel()
x = torch.rand(1, 2)
example_output = module(x)
ops_10 = [{"op_name" : "Shape"},
{"op_name" : "Constant"},
{"op_name" : "Gather",
Expand All @@ -202,7 +198,7 @@ def forward(self, x):
{"op_name" : "Slice",
"attributes" : []}]
ops = {10 : ops_10}
check_onnx_opsets_operator(module, x, ops, opset_versions=[10], example_outputs=example_output,
check_onnx_opsets_operator(module, x, ops, opset_versions=[10],
input_names=['x'], dynamic_axes={"x": [0, 1]})

ops_10 = [{"op_name" : "Constant"},
Expand All @@ -212,7 +208,7 @@ def forward(self, x):
{"op_name" : "Slice",
"attributes" : []}]
ops = {10 : ops_10}
check_onnx_opsets_operator(module, x, ops, opset_versions=[10], example_outputs=example_output)
check_onnx_opsets_operator(module, x, ops, opset_versions=[10])

def test_flip(self):
class MyModule(Module):
Expand Down