New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom Operator Design for torch.compile: Must Output Tensors Always Be Returned? #124918
Comments
You must annotate the schema to specify which arguments mutate. Once you do so the right behavior will occur. cc @zou3519 |
Use "mylib::custom_func", "(Tensor x, Tensor y, Tensor(a!) out) -> None" to mark the out Tensor as mutating. Also if you're using PyTorch nightlies, we have a new custom ops API that makes this more explicit: https://pytorch.org/docs/main/library.html#creating-new-custom-ops-in-python |
I鈥檓 very grateful for your answer. @zou3519 @ezyang If using the old API, in addition to using 'mylib::custom_func', '(Tensor x, Tensor y, Tensor(a!) out) -> None' to mark the out Tensor as mutating, it is also necessary to add the boilerplate functionalization logic according to the documentation(https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa). Otherwise, the following error will occur when executing:
The doc(https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa) seem to be targeted at C++ custom operators. For Python custom operators, if an older version of PyTorch is being used, is it true that the only way to avoid such issues is by not using mutative arguments when defining the interface? |
Try
I filed an issue for supporting schemas with None returns (#125044)
The error message is outdated (and misleading), we've already fixed it on main |
Changing the return value from None to () resolved the issue. Thanks again @zou3519 |
馃悰 Describe the bug
Hello PyTorch Community,
I am currently developing a custom operator and considering its compatibility with torch.compile. I have a specific question regarding the function signature of the operator: Does the output tensor need to be explicitly returned by the function, or can it be passed as an input argument?
If I pass the output as an input parameter when creating a custom operator, with the operator interface having no return value, I encounter incorrect results when running with torch.compile, but correct results in eager mode. Based on my analysis, torch.compile seems to optimize away the custom operators that do not return any value, leading to incorrect outcomes.
Error logs
Traceback (most recent call last):
File "/tmp/test.py", line 40, in
assert torch.allclose(out, x * y)
AssertionError
Minified repro
Versions
PyTorch version: 2.3.0a0+40ec155e58.nv24.03
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-150-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe
GPU 2: NVIDIA A100 80GB PCIe
GPU 3: NVIDIA A100 80GB PCIe
GPU 4: NVIDIA A100 80GB PCIe
GPU 5: NVIDIA A100 80GB PCIe
GPU 6: NVIDIA A100 80GB PCIe
GPU 7: NVIDIA A100 80GB PCIe
Nvidia driver version: 525.89.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] onnx==1.15.0rc2
[pip3] optree==0.10.0
[pip3] pytorch-quantization==2.1.2
[pip3] pytorch-triton==2.2.0+e28a256d7
[pip3] torch==2.3.0a0+40ec155e58.nv24.3
[pip3] torch-tensorrt==2.3.0a0
[pip3] torchdata==0.7.1a0
[pip3] torchtext==0.17.0a0
[pip3] torchvision==0.18.0a0
[conda] Could not collect
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @zou3519
The text was updated successfully, but these errors were encountered: