Skip to content

Commit

Permalink
Fix ONNX operator_export_type on the new registry
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagocrepaldi committed Oct 25, 2022
1 parent 44d7ba7 commit 41e5234
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 23 deletions.
33 changes: 24 additions & 9 deletions test/jit/test_export_modes.py
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: jit"]

import io
import onnx
import os
import shutil
import sys
Expand Down Expand Up @@ -79,11 +80,16 @@ def forward(self, x, y):

x = torch.rand(3, 4)
y = torch.rand(3, 4)
torch.onnx.export_to_pretty_string(
ModelWithAtenNotONNXOp(), (x, y),
add_node_names=False,
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[2].op_type == "ATen"


@skipIfCaffe2
@skipIfNoLapack
Expand All @@ -96,13 +102,17 @@ def forward(self, x, y):

x = torch.rand(3, 4)
y = torch.rand(3, 4)
torch.onnx.export_to_pretty_string(
ModelWithAtenNotONNXOp(), (x, y),
add_node_names=False,
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[2].op_type == "ATen"

# torch.fmod is using to test ONNX_ATEN.
# If you plan to remove fmod from aten, or found this test failed.
Expand All @@ -114,8 +124,13 @@ def forward(self, x, y):

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
torch.onnx.export_to_pretty_string(
ModelWithAtenFmod(), (x, y),
add_node_names=False,
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN)

onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[0].op_type == "ATen"
36 changes: 22 additions & 14 deletions torch/onnx/utils.py
Expand Up @@ -1743,13 +1743,14 @@ def _should_aten_fallback(
opset_version: int,
operator_export_type: _C_onnx.OperatorExportTypes,
):
namespace, _ = name.split("::")
is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version)
is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
is_aten_fallback_export = (
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
)
return is_onnx_aten_export or (
not is_exportable_aten_op and is_aten_fallback_export
return namespace == "aten" and (
is_onnx_aten_export or (not is_exportable_aten_op and is_aten_fallback_export)
)


Expand Down Expand Up @@ -1844,6 +1845,24 @@ def _run_symbolic_function(
env=env,
)

# when domain is "aten" and operator_export_type = ONNX_ATEN or ONNX_ATEN_FALLBACK
# create org.pytorch.aten::ATen oeprator regardless of symbolics exist or not
if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
# Direct ATen export requested
attrs = {
k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
for k in node.attributeNames()
}
outputs = node.outputsSize()
attrs["outputs"] = outputs
# `overload_name` is set for non-Caffe2 builds only
return graph_context.at(
op_name,
*inputs,
overload_name=_get_aten_op_overload_name(node),
**attrs,
)

try:
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
Expand All @@ -1861,6 +1880,7 @@ def _run_symbolic_function(
if symbolic_function_group is not None:
symbolic_fn = symbolic_function_group.get(opset_version)
if symbolic_fn is not None:
# TODO(justinchuby): Wrap almost identical attrs assignment or comment the difference.
attrs = {
k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
}
Expand All @@ -1874,18 +1894,6 @@ def _run_symbolic_function(
# Clone node to trigger ONNX shape inference
return graph_context.op(op_name, *inputs, **attrs, outputs=node.outputsSize()) # type: ignore[attr-defined]

if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
# Direct ATen export requested
outputs = node.outputsSize()
attrs["outputs"] = outputs
# `overload_name` is set for non-Caffe2 builds only
return graph_context.at(
op_name,
*inputs,
overload_name=_get_aten_op_overload_name(node),
**attrs,
)

raise errors.UnsupportedOperatorError(
domain,
op_name,
Expand Down

0 comments on commit 41e5234

Please sign in to comment.