Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ATen Fallback for BUILD_CAFFE2=0 for ONNX-only ops (#88504) #90104

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
175 changes: 143 additions & 32 deletions test/onnx/test_pytorch_onnx_no_runtime.py
Expand Up @@ -15,10 +15,11 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.onnx import symbolic_helper, utils
import numpy as np
atalman marked this conversation as resolved.
Show resolved Hide resolved
from torch.onnx import OperatorExportTypes, symbolic_helper, utils
atalman marked this conversation as resolved.
Show resolved Hide resolved
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import registration
from torch.testing._internal import common_utils
from torch.testing._internal import common_quantization, common_utils


def export_to_onnx(
Expand Down Expand Up @@ -781,50 +782,60 @@ def forward(self, x):
)

@common_utils.skipIfNoCaffe2
def test_caffe2_aten_fallback(self):
def test_caffe2_aten_fallback_must_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg

x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")
# TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize
for operator_export_type in (
OperatorExportTypes.ONNX_ATEN,
OperatorExportTypes.ONNX_ATEN_FALLBACK,
):
x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=operator_export_type,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")

@common_utils.skipIfNoCaffe2
def test_caffe2_onnx_aten(self):
def test_caffe2_onnx_aten_must_not_fallback(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)

x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
opset_version=10, # or higher
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[0].op_type == "Mod"
# TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize
for operator_export_type in (
OperatorExportTypes.ONNX_ATEN_FALLBACK,
OperatorExportTypes.ONNX_ATEN,
):
x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=operator_export_type,
opset_version=10, # or higher
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[0].op_type == "Mod"

@common_utils.skipIfCaffe2
def test_aten_fallback(self):
def test_aten_fallback_must_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
Expand Down Expand Up @@ -865,6 +876,106 @@ def forward(self, x, y):
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "fmod", "Tensor")

@common_utils.skipIfCaffe2
def test_onnx_aten_fallback_must_not_fallback(self):
# For BUILD_CAFFE2=0, aten fallback only when not exportable
class ONNXExportable(torch.nn.Module):
def __init__(self):
super(ONNXExportable, self).__init__()
self.quant = torch.quantization.QuantStub()
self.fc1 = torch.nn.Linear(12, 8)
self.fc2 = torch.nn.Linear(8, 4)
self.fc3 = torch.nn.Linear(4, 6)
self.dequant = torch.quantization.DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = x.view((-1, 12))
h = F.relu(self.fc1(x))
h = F.relu(self.fc2(h))
h = F.relu(self.fc3(h))
h = self.dequant(h)
return h

dummy_input = torch.randn(12)
f = io.BytesIO()
torch.onnx.export(
ONNXExportable(),
(dummy_input,),
f,
do_constant_folding=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
all_aten_nodes = [
p
for p in onnx_model.graph.node
if p.op_type == "ATen" and p.domain == "org.pytorch.aten"
]
self.assertEqual(len(all_aten_nodes), 0)


class TestQuantizeEagerONNXExport(common_utils.TestCase):
def _test_lower_graph_impl(self, model, data):
model.qconfig = torch.ao.quantization.default_qconfig
model = torch.ao.quantization.prepare(model)
model = torch.ao.quantization.convert(model)

_ = model(data)
input_names = ["x"]

def _export_to_onnx(model, input, input_names):
traced = torch.jit.trace(model, input)
buf = io.BytesIO()
torch.jit.save(traced, buf)
buf.seek(0)

model = torch.jit.load(buf)
f = io.BytesIO()
torch.onnx.export(
model,
input,
f,
input_names=input_names,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
opset_version=9,
)

_export_to_onnx(model, data, input_names)

@common_quantization.skipIfNoFBGEMM
atalman marked this conversation as resolved.
Show resolved Hide resolved
@common_utils.skipIfNoCaffe2
def test_lower_graph_linear(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Linear(5, 10, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 2, 5).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)

@common_quantization.skipIfNoFBGEMM
@common_utils.skipIfNoCaffe2
def test_lower_graph_conv2d(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Conv2d(3, 5, 2, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)

@common_quantization.skipIfNoFBGEMM
@unittest.skip(
"onnx opset9 does not support quantize_per_tensor and caffe2 \
does not support conv3d"
)
def test_lower_graph_conv3d(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Conv3d(3, 5, 2, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 3, 6, 6, 6).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)


if __name__ == "__main__":
common_utils.run_tests()
19 changes: 15 additions & 4 deletions torch/onnx/utils.py
Expand Up @@ -1752,10 +1752,21 @@ def _should_aten_fallback(
)
is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK

return name.startswith("aten::") and (
((is_onnx_aten_export or is_aten_fallback_export) and not is_caffe2_build)
or (not is_exportable_aten_op and is_aten_fallback_export)
)
if not name.startswith("aten::"):
return False

if is_caffe2_build:
if (
is_onnx_aten_export or is_aten_fallback_export
) and not is_exportable_aten_op:
return True
else:
if is_onnx_aten_export or (
is_aten_fallback_export and not is_exportable_aten_op
):
return True

return False


@_beartype.beartype
Expand Down