New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improper model conversion from PyTorch to ONNX with torch.onnx.OperatorExportTypes.ONNX_ATEN flag #87313
Comments
I got the example of the code snippet from one existing unit test: pytorch/test/jit/test_export_modes.py Lines 107 to 121 in 664058f
|
Operator substitution is happening inside the Lines 726 to 743 in 664058f
The
And after the call:
despite |
This is happening inside |
Fix: #87735 |
Fixes pytorch#87313 Our ONNX pipelines do not run with BUILD_CAFFE2=0, so tests for operator_export_type ONNX_ATEN and ONNX_ATEN_FALLBACK will not be fully tested, allowing regressions to happen again. We need to run the same set of tests for both BUILD_CAFFE2=0 and 1 Pull Request resolved: pytorch#87735 Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
Follow-up for #87735 Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0 A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again Fixes #87313 Pull Request resolved: #88504 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
Fixes pytorch#87313 Our ONNX pipelines do not run with BUILD_CAFFE2=0, so tests for operator_export_type ONNX_ATEN and ONNX_ATEN_FALLBACK will not be fully tested, allowing regressions to happen again. We need to run the same set of tests for both BUILD_CAFFE2=0 and 1 Pull Request resolved: pytorch#87735 Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao (cherry picked from commit 2aed670)
Fixes #87313 Our ONNX pipelines do not run with BUILD_CAFFE2=0, so tests for operator_export_type ONNX_ATEN and ONNX_ATEN_FALLBACK will not be fully tested, allowing regressions to happen again. We need to run the same set of tests for both BUILD_CAFFE2=0 and 1 Pull Request resolved: #87735 Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao (cherry picked from commit 2aed670) Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
Follow-up for pytorch#87735 Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0 A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again Fixes pytorch#87313 Pull Request resolved: pytorch#88504 Approved by: https://github.com/justinchuby, https://github.com/BowenBao (cherry picked from commit 5f0783b)
Follow-up for pytorch#87735 Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0 A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again Fixes pytorch#87313 Pull Request resolved: pytorch#88504 Approved by: https://github.com/justinchuby, https://github.com/BowenBao (cherry picked from commit 5f0783b)
Follow-up for pytorch#87735 Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0 A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again Fixes pytorch#87313 Pull Request resolved: pytorch#88504 Approved by: https://github.com/justinchuby, https://github.com/BowenBao (cherry picked from commit 5f0783b)
* Fix ATen Fallback for BUILD_CAFFE2=0 for ONNX-only ops (#88504) Follow-up for #87735 Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0 A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again Fixes #87313 Pull Request resolved: #88504 Approved by: https://github.com/justinchuby, https://github.com/BowenBao (cherry picked from commit 5f0783b) * Update test/onnx/test_pytorch_onnx_no_runtime.py Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com> * Update test/onnx/test_pytorch_onnx_no_runtime.py Fix linter * Update test/onnx/test_pytorch_onnx_no_runtime.py Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com> * fix lint warnings Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com> Co-authored-by: Andrey Talman <atalman@fb.com>
Fixes pytorch#87313 Our ONNX pipelines do not run with BUILD_CAFFE2=0, so tests for operator_export_type ONNX_ATEN and ONNX_ATEN_FALLBACK will not be fully tested, allowing regressions to happen again. We need to run the same set of tests for both BUILD_CAFFE2=0 and 1 Pull Request resolved: pytorch#87735 Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao
Follow-up for pytorch#87735 Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0 A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again Fixes pytorch#87313 Pull Request resolved: pytorch#88504 Approved by: https://github.com/justinchuby, https://github.com/BowenBao
🐛 Describe the bug
I have spotted improper model conversion to ONNX when using
torch.onnx.OperatorExportTypes.ONNX_ATEN
flag withPyTorch >=1.12.0
My understanding from the documentation
is that when one uses
torch.onnx.OperatorExportTypes.ONNX_ATEN
flag, all ATen operators has to be exported to ONNX graph. When I run the code snippet below:with
PyTorch 1.11.0
I obtain ONNX graph with ATen operator that is correct:However, when I run the same code with
PyTorch 1.12.1
, ATen operator got substituted with ONNX native operator and that is not expected behaviour:Versions
PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 12.6 (arm64)
GCC version: Could not collect
Clang version: 14.0.6
CMake version: version 3.24.1
Libc version: N/A
Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:14) [Clang 12.0.1 ] (64-bit runtime)
Python platform: macOS-12.6-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] torch==1.12.1
[pip3] torchvision==0.13.1
[conda] numpy 1.23.4 py38h09ac2d9_0 conda-forge
[conda] pytorch 1.12.1 py3.8_0 pytorch
[conda] torchvision 0.13.1 py38_cpu pytorch
The text was updated successfully, but these errors were encountered: