Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ATen Fallback for BUILD_CAFFE2=0 for ONNX-only ops #88504

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
220 changes: 134 additions & 86 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.onnx import symbolic_helper, utils
from torch.onnx import OperatorExportTypes, symbolic_helper, utils
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import registration
from torch.testing._internal import common_quantization, common_utils, jit_utils
Expand Down Expand Up @@ -935,6 +935,139 @@ def forward(self, x, w):

torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5)))

@common_utils.skipIfNoCaffe2
def test_caffe2_aten_fallback_must_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

# TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize
for operator_export_type in (
OperatorExportTypes.ONNX_ATEN,
OperatorExportTypes.ONNX_ATEN_FALLBACK,
):
x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=operator_export_type,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfNoCaffe2
def test_caffe2_onnx_aten_must_not_fallback(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

# TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize
for operator_export_type in (
OperatorExportTypes.ONNX_ATEN_FALLBACK,
OperatorExportTypes.ONNX_ATEN,
):
x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=operator_export_type,
opset_version=10, # or higher
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[0].op_type == "Mod"

@common_utils.skipIfCaffe2
def test_aten_fallback_must_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfCaffe2
def test_onnx_aten(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "fmod", "Tensor")

@common_utils.skipIfCaffe2
def test_onnx_aten_fallback_must_not_fallback(self):
# For BUILD_CAFFE2=0, aten fallback only when not exportable
class ONNXExportable(torch.nn.Module):
def __init__(self):
super(ONNXExportable, self).__init__()
self.quant = torch.quantization.QuantStub()
self.fc1 = torch.nn.Linear(12, 8)
self.fc2 = torch.nn.Linear(8, 4)
self.fc3 = torch.nn.Linear(4, 6)
self.dequant = torch.quantization.DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = x.view((-1, 12))
h = F.relu(self.fc1(x))
h = F.relu(self.fc2(h))
h = F.relu(self.fc3(h))
h = self.dequant(h)
return h

dummy_input = torch.randn(12)
f = io.BytesIO()
torch.onnx.export(
ONNXExportable(),
(dummy_input,),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
all_aten_nodes = [
p
for p in onnx_model.graph.node
if p.op_type == "ATen" and p.domain == "org.pytorch.aten"
]
self.assertEqual(len(all_aten_nodes), 0)


class TestQuantizeEagerONNXExport(common_utils.TestCase):
def _test_lower_graph_impl(self, model, data):
Expand Down Expand Up @@ -997,91 +1130,6 @@ def test_lower_graph_conv3d(self):
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)

@common_utils.skipIfNoCaffe2
def test_caffe2_aten_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfNoCaffe2
def test_caffe2_onnx_aten(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
opset_version=10, # or higher
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[0].op_type == "Mod"

@common_utils.skipIfCaffe2
def test_aten_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfCaffe2
def test_onnx_aten(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "fmod", "Tensor")


if __name__ == "__main__":
common_utils.run_tests()
19 changes: 15 additions & 4 deletions torch/onnx/utils.py
Expand Up @@ -1752,10 +1752,21 @@ def _should_aten_fallback(
)
is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK

return name.startswith("aten::") and (
((is_onnx_aten_export or is_aten_fallback_export) and not is_caffe2_build)
or (not is_exportable_aten_op and is_aten_fallback_export)
)
if not name.startswith("aten::"):
return False

if is_caffe2_build:
if (
is_onnx_aten_export or is_aten_fallback_export
) and not is_exportable_aten_op:
return True
else:
if is_onnx_aten_export or (
is_aten_fallback_export and not is_exportable_aten_op
):
return True

return False


@_beartype.beartype
Expand Down