Skip to content

Commit

Permalink
Fix ONNX operator_export_type on the new registry (#87735) (#90044)
Browse files Browse the repository at this point in the history
Fixes #87313

Our ONNX pipelines do not run with BUILD_CAFFE2=0, so tests for operator_export_type ONNX_ATEN and ONNX_ATEN_FALLBACK will not be fully tested, allowing regressions to happen again.

We need to run the same set of tests for both BUILD_CAFFE2=0 and 1
Pull Request resolved: #87735
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao

(cherry picked from commit 2aed670)

Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
  • Loading branch information
izaitsevfb and thiagocrepaldi committed Dec 3, 2022
1 parent ae2fe40 commit 642edcd
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 17 deletions.
88 changes: 88 additions & 0 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Expand Up @@ -32,6 +32,7 @@ def export_to_onnx(
mocks: Optional[Iterable] = None,
operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX,
opset_version: int = GLOBALS.export_onnx_opset_version,
**torch_onnx_export_kwargs,
) -> onnx.ModelProto:
"""Exports `model(input)` to ONNX and returns it.
Expand All @@ -44,6 +45,7 @@ def export_to_onnx(
mocks: list of mocks to use during export
operator_export_type: export type as described by `torch.onnx.export(...operator_export_type,...)`
opset_version: ONNX opset version as described by `torch.onnx.export(...opset_version,...)`
torch_onnx_export_kwargs: extra torch.onnx.export kwargs arguments
Returns:
A valid ONNX model (`onnx.ModelProto`)
"""
Expand All @@ -60,6 +62,7 @@ def export_to_onnx(
f,
operator_export_type=operator_export_type,
opset_version=opset_version,
**torch_onnx_export_kwargs,
)

# Validate ONNX graph before returning it
Expand Down Expand Up @@ -777,6 +780,91 @@ def forward(self, x):
model, inputs, f, dynamic_axes={"x": [0, 1]}, input_names=["x"]
)

@common_utils.skipIfNoCaffe2
def test_caffe2_aten_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfNoCaffe2
def test_caffe2_onnx_aten(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
opset_version=10, # or higher
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[0].op_type == "Mod"

@common_utils.skipIfCaffe2
def test_aten_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfCaffe2
def test_onnx_aten(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "fmod", "Tensor")


if __name__ == "__main__":
common_utils.run_tests()
43 changes: 26 additions & 17 deletions torch/onnx/utils.py
Expand Up @@ -1739,17 +1739,22 @@ def _add_output_to_block(block: _C.Block, value: _C.Value) -> int:

@_beartype.beartype
def _should_aten_fallback(
name: str,
opset_version: int,
operator_export_type: _C_onnx.OperatorExportTypes,
name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
):
# For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN,
# an aten::ATen operator is created regardless of symbolics existence
# For BUILD_CAFFE2=1, the same applies only if there is no symbolic available

is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version)
is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
is_aten_fallback_export = (
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
)
return is_onnx_aten_export or (
not is_exportable_aten_op and is_aten_fallback_export
is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK

return name.startswith("aten::") and (
((is_onnx_aten_export or is_aten_fallback_export) and not is_caffe2_build)
or (not is_exportable_aten_op and is_aten_fallback_export)
)


Expand Down Expand Up @@ -1844,6 +1849,21 @@ def _run_symbolic_function(
env=env,
)

# Direct ATen export requested
if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
attrs = {
k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
for k in node.attributeNames()
}
outputs = node.outputsSize()
attrs["outputs"] = outputs
return graph_context.at(
op_name,
*inputs,
overload_name=_get_aten_op_overload_name(node),
**attrs,
)

try:
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
Expand All @@ -1861,6 +1881,7 @@ def _run_symbolic_function(
if symbolic_function_group is not None:
symbolic_fn = symbolic_function_group.get(opset_version)
if symbolic_fn is not None:
# TODO Wrap almost identical attrs assignment or comment the difference.
attrs = {
k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
}
Expand All @@ -1874,18 +1895,6 @@ def _run_symbolic_function(
# Clone node to trigger ONNX shape inference
return graph_context.op(op_name, *inputs, **attrs, outputs=node.outputsSize()) # type: ignore[attr-defined]

if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
# Direct ATen export requested
outputs = node.outputsSize()
attrs["outputs"] = outputs
# `overload_name` is set for non-Caffe2 builds only
return graph_context.at(
op_name,
*inputs,
overload_name=_get_aten_op_overload_name(node),
**attrs,
)

raise errors.UnsupportedOperatorError(
domain,
op_name,
Expand Down

0 comments on commit 642edcd

Please sign in to comment.