Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ONNX operator_export_type on the new registry #87735

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 88 additions & 0 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Expand Up @@ -35,6 +35,7 @@ def export_to_onnx(
mocks: Optional[Iterable] = None,
operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX,
opset_version: int = GLOBALS.export_onnx_opset_version,
**torch_onnx_export_kwargs,
) -> onnx.ModelProto:
"""Exports `model(input)` to ONNX and returns it.

Expand All @@ -47,6 +48,7 @@ def export_to_onnx(
mocks: list of mocks to use during export
operator_export_type: export type as described by `torch.onnx.export(...operator_export_type,...)`
opset_version: ONNX opset version as described by `torch.onnx.export(...opset_version,...)`
torch_onnx_export_kwargs: extra torch.onnx.export kwargs arguments
Returns:
A valid ONNX model (`onnx.ModelProto`)
"""
Expand All @@ -63,6 +65,7 @@ def export_to_onnx(
f,
operator_export_type=operator_export_type,
opset_version=opset_version,
**torch_onnx_export_kwargs,
)

# Validate ONNX graph before returning it
Expand Down Expand Up @@ -994,6 +997,91 @@ def test_lower_graph_conv3d(self):
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)

@common_utils.skipIfNoCaffe2
def test_caffe2_aten_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfNoCaffe2
def test_caffe2_onnx_aten(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
opset_version=10, # or higher
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[0].op_type == "Mod"

@common_utils.skipIfCaffe2
def test_aten_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfCaffe2
def test_onnx_aten(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "fmod", "Tensor")


if __name__ == "__main__":
common_utils.run_tests()
43 changes: 26 additions & 17 deletions torch/onnx/utils.py
Expand Up @@ -1739,17 +1739,22 @@ def _add_output_to_block(block: _C.Block, value: _C.Value) -> int:

@_beartype.beartype
def _should_aten_fallback(
name: str,
opset_version: int,
operator_export_type: _C_onnx.OperatorExportTypes,
name: str, opset_version: int, operator_export_type: _C_onnx.OperatorExportTypes
):
# For BUILD_CAFFE2=0 builds, if domain=="aten" and operator_export_type==ONNX_ATEN,
# an aten::ATen operator is created regardless of symbolics existence
# For BUILD_CAFFE2=1, the same applies only if there is no symbolic available

is_exportable_aten_op = registration.registry.is_registered_op(name, opset_version)
is_onnx_aten_export = operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN
is_aten_fallback_export = (
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
)
return is_onnx_aten_export or (
not is_exportable_aten_op and is_aten_fallback_export
is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK

return name.startswith("aten::") and (
((is_onnx_aten_export or is_aten_fallback_export) and not is_caffe2_build)
or (not is_exportable_aten_op and is_aten_fallback_export)
)


Expand Down Expand Up @@ -1844,6 +1849,21 @@ def _run_symbolic_function(
env=env,
)

# Direct ATen export requested
if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
attrs = {
k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)
for k in node.attributeNames()
}
outputs = node.outputsSize()
attrs["outputs"] = outputs
return graph_context.at(
op_name,
*inputs,
overload_name=_get_aten_op_overload_name(node),
**attrs,
)

try:
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
Expand All @@ -1861,6 +1881,7 @@ def _run_symbolic_function(
if symbolic_function_group is not None:
symbolic_fn = symbolic_function_group.get(opset_version)
if symbolic_fn is not None:
# TODO Wrap almost identical attrs assignment or comment the difference.
attrs = {
k: symbolic_helper._node_get(node, k) for k in node.attributeNames()
}
Expand All @@ -1874,18 +1895,6 @@ def _run_symbolic_function(
# Clone node to trigger ONNX shape inference
return graph_context.op(op_name, *inputs, **attrs, outputs=node.outputsSize()) # type: ignore[attr-defined]

if _should_aten_fallback(ns_op_name, opset_version, operator_export_type):
# Direct ATen export requested
outputs = node.outputsSize()
attrs["outputs"] = outputs
# `overload_name` is set for non-Caffe2 builds only
return graph_context.at(
op_name,
*inputs,
overload_name=_get_aten_op_overload_name(node),
**attrs,
)

raise errors.UnsupportedOperatorError(
symbolic_function_name,
opset_version,
Expand Down