Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improper model conversion from PyTorch to ONNX with torch.onnx.OperatorExportTypes.ONNX_ATEN flag #87313

Closed
oviazlo opened this issue Oct 19, 2022 · 5 comments
Assignees
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@oviazlo
Copy link

oviazlo commented Oct 19, 2022

🐛 Describe the bug

I have spotted improper model conversion to ONNX when using torch.onnx.OperatorExportTypes.ONNX_ATEN flag with PyTorch >=1.12.0

My understanding from the documentation

OperatorExportTypes.ONNX_ATEN: All ATen ops (in the TorchScript namespace “aten”) are exported as ATen ops (in opset domain “org.pytorch.aten”). ATen is PyTorch’s built-in tensor library, so this instructs the runtime to use PyTorch’s implementation of these ops.

is that when one uses torch.onnx.OperatorExportTypes.ONNX_ATEN flag, all ATen operators has to be exported to ONNX graph. When I run the code snippet below:

import torch


class ModelWithAtenFmod(torch.nn.Module):
    def forward(self, x, y):
        return torch.fmod(x, y)


x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
OUT = torch.onnx.export_to_pretty_string(
    ModelWithAtenFmod(), (x, y),
    add_node_names=False,
    do_constant_folding=False,
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN)
print(OUT)

with PyTorch 1.11.0 I obtain ONNX graph with ATen operator that is correct:

### PyTorch 1.11.0

ModelProto {
  producer_name: "pytorch"
  domain: ""
  doc_string: ""
  graph:
    GraphProto {
      name: "torch-jit-export"
      inputs: [{name: "aten::ATen_0", type:Tensor dtype: 1, Tensor dims: 3 4},{name: "aten::ATen_1", type:Tensor dtype: 1, Tensor dims: 3 4}]
      outputs: [{name: "2", type:Tensor dtype: 1, Tensor dims: ? ?}]
      value_infos: []
      initializers: []
      nodes: [
        Node {type: "ATen", inputs: [aten::ATen_0,aten::ATen_1], outputs: [2], attributes: [{ name: 'operator', type: string, value: 'fmod'}]}
      ]
    }
  opset_import: [OperatorSetIdProto { domain: , version: 9}OperatorSetIdProto { domain: org.pytorch.aten, version: 1}],
}

However, when I run the same code with PyTorch 1.12.1, ATen operator got substituted with ONNX native operator and that is not expected behaviour:

### PyTorch 1.12.1

ModelProto {
  producer_name: "pytorch"
  domain: ""
  doc_string: ""
  graph:
    GraphProto {
      name: "torch_jit"
      inputs: [{name: "onnx::Mod_0", type:Tensor dtype: 1, Tensor dims: 3 4},{name: "onnx::Mod_1", type:Tensor dtype: 1, Tensor dims: 3 4}]
      outputs: [{name: "2", type:Tensor dtype: 1, Tensor dims: 3 4}]
      value_infos: []
      initializers: []
      nodes: [
        Node {type: "Mod", inputs: [onnx::Mod_0,onnx::Mod_1], outputs: [2], attributes: [{ name: 'fmod', type: int, value: 1}]}
      ]
    }
  opset_import: [OperatorSetIdProto { domain: , version: 13}],
}

Versions

PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.6 (arm64)
GCC version: Could not collect
Clang version: 14.0.6
CMake version: version 3.24.1
Libc version: N/A

Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:14) [Clang 12.0.1 ] (64-bit runtime)
Python platform: macOS-12.6-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] torch==1.12.1
[pip3] torchvision==0.13.1
[conda] numpy 1.23.4 py38h09ac2d9_0 conda-forge
[conda] pytorch 1.12.1 py3.8_0 pytorch
[conda] torchvision 0.13.1 py38_cpu pytorch

@oviazlo
Copy link
Author

oviazlo commented Oct 19, 2022

I got the example of the code snippet from one existing unit test:

# torch.fmod is using to test ONNX_ATEN.
# If you plan to remove fmod from aten, or found this test failed.
# please contact @Rui.
def test_onnx_aten(self):
class ModelWithAtenFmod(nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)
x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
torch.onnx.export_to_pretty_string(
ModelWithAtenFmod(), (x, y),
add_node_names=False,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN)

@oviazlo
Copy link
Author

oviazlo commented Oct 19, 2022

Operator substitution is happening inside the _optimize_graph function in the code below:

pytorch/torch/onnx/utils.py

Lines 726 to 743 in 664058f

model = _pre_trace_quant_model(model, args)
graph, params, torch_out, module = _create_jit_graph(model, args)
params_dict = _get_named_param_dict(graph, params)
try:
graph = _optimize_graph(
graph,
operator_export_type,
_disable_torch_constant_prop=_disable_torch_constant_prop,
fixed_batch_size=fixed_batch_size,
params_dict=params_dict,
dynamic_axes=dynamic_axes,
input_names=input_names,
module=module,
)
except Exception as e:
torch.onnx.log("Torch IR graph at exception: ", graph)
raise

The graph before _optimize_graph function call looks like:

graph(%0 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %6 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::fmod(%0, %1)
  return (%6)

And after the call:

graph(%0 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %1 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %2 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = onnx::Mod[fmod=1](%0, %1)
  return (%2)

despite operator_export_type being equal to OperatorExportTypes.ONNX_ATEN

@thiagocrepaldi
Copy link
Collaborator

This is happening inside graph = _C._jit_pass_onnx(graph, operator_export_type)

@thiagocrepaldi thiagocrepaldi self-assigned this Oct 20, 2022
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 21, 2022
@thiagocrepaldi
Copy link
Collaborator

ca37477 - [ONNX] update default opset_version to 13 (#73898) introduced this regression. Looking into it further

@thiagocrepaldi
Copy link
Collaborator

Fix: #87735

kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Nov 5, 2022
Fixes pytorch#87313

Our ONNX pipelines do not run with BUILD_CAFFE2=0, so tests for operator_export_type ONNX_ATEN and ONNX_ATEN_FALLBACK will not be fully tested, allowing regressions to happen again.

We need to run the same set of tests for both BUILD_CAFFE2=0 and 1
Pull Request resolved: pytorch#87735
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
pytorchmergebot pushed a commit that referenced this issue Nov 11, 2022
Follow-up for #87735

Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0

A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again

Fixes #87313

Pull Request resolved: #88504
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
izaitsevfb pushed a commit to izaitsevfb/pytorch that referenced this issue Dec 2, 2022
Fixes pytorch#87313

Our ONNX pipelines do not run with BUILD_CAFFE2=0, so tests for operator_export_type ONNX_ATEN and ONNX_ATEN_FALLBACK will not be fully tested, allowing regressions to happen again.

We need to run the same set of tests for both BUILD_CAFFE2=0 and 1
Pull Request resolved: pytorch#87735
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao

(cherry picked from commit 2aed670)
izaitsevfb added a commit that referenced this issue Dec 3, 2022
Fixes #87313

Our ONNX pipelines do not run with BUILD_CAFFE2=0, so tests for operator_export_type ONNX_ATEN and ONNX_ATEN_FALLBACK will not be fully tested, allowing regressions to happen again.

We need to run the same set of tests for both BUILD_CAFFE2=0 and 1
Pull Request resolved: #87735
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao

(cherry picked from commit 2aed670)

Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
izaitsevfb pushed a commit to izaitsevfb/pytorch that referenced this issue Dec 3, 2022
Follow-up for pytorch#87735

Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0

A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again

Fixes pytorch#87313

Pull Request resolved: pytorch#88504
Approved by: https://github.com/justinchuby, https://github.com/BowenBao

(cherry picked from commit 5f0783b)
izaitsevfb pushed a commit to izaitsevfb/pytorch that referenced this issue Dec 3, 2022
Follow-up for pytorch#87735

Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0

A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again

Fixes pytorch#87313

Pull Request resolved: pytorch#88504
Approved by: https://github.com/justinchuby, https://github.com/BowenBao

(cherry picked from commit 5f0783b)
izaitsevfb pushed a commit to izaitsevfb/pytorch that referenced this issue Dec 6, 2022
Follow-up for pytorch#87735

Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0

A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again

Fixes pytorch#87313

Pull Request resolved: pytorch#88504
Approved by: https://github.com/justinchuby, https://github.com/BowenBao

(cherry picked from commit 5f0783b)
atalman added a commit that referenced this issue Dec 7, 2022
* Fix ATen Fallback for BUILD_CAFFE2=0 for ONNX-only ops (#88504)

Follow-up for #87735

Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0

A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again

Fixes #87313

Pull Request resolved: #88504
Approved by: https://github.com/justinchuby, https://github.com/BowenBao

(cherry picked from commit 5f0783b)

* Update test/onnx/test_pytorch_onnx_no_runtime.py

Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>

* Update test/onnx/test_pytorch_onnx_no_runtime.py

Fix linter

* Update test/onnx/test_pytorch_onnx_no_runtime.py

Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>

* fix lint warnings

Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
Co-authored-by: Andrey Talman <atalman@fb.com>
kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Dec 10, 2022
Fixes pytorch#87313

Our ONNX pipelines do not run with BUILD_CAFFE2=0, so tests for operator_export_type ONNX_ATEN and ONNX_ATEN_FALLBACK will not be fully tested, allowing regressions to happen again.

We need to run the same set of tests for both BUILD_CAFFE2=0 and 1
Pull Request resolved: pytorch#87735
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
kulinseth pushed a commit to kulinseth/pytorch that referenced this issue Dec 10, 2022
Follow-up for pytorch#87735

Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0

A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again

Fixes pytorch#87313

Pull Request resolved: pytorch#88504
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
6 participants