Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Operator Design for torch.compile: Must Output Tensors Always Be Returned? #124918

Open
hk3911 opened this issue Apr 25, 2024 · 5 comments
Labels
actionable module: custom-operators module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hk3911
Copy link

hk3911 commented Apr 25, 2024

馃悰 Describe the bug

Hello PyTorch Community,

I am currently developing a custom operator and considering its compatibility with torch.compile. I have a specific question regarding the function signature of the operator: Does the output tensor need to be explicitly returned by the function, or can it be passed as an input argument?

If I pass the output as an input parameter when creating a custom operator, with the operator interface having no return value, I encounter incorrect results when running with torch.compile, but correct results in eager mode. Based on my analysis, torch.compile seems to optimize away the custom operators that do not return any value, leading to incorrect outcomes.

Error logs

Traceback (most recent call last):
File "/tmp/test.py", line 40, in
assert torch.allclose(out, x * y)
AssertionError

Minified repro

import torch
import numpy as np                                                                                                                                                           
                                                                                                                                                                             
def custom_func(x, y, out):                                                                                                            
    torch._check(x.shape == y.shape)
    torch._check(x.device == y.device)
    x_np = x.numpy()
    y_np = y.numpy()
    z_np = np.multiply(x_np, y_np)
    out.copy_(torch.from_numpy(z_np))
    return

torch.library.define("mylib::custom_func", "(Tensor x, Tensor y, Tensor out) -> None")
 # Add the implementation of the custom op                                                                                                                                 
torch.library.impl("mylib::custom_func", "default", custom_func)

# Add an abstract impl that describes what the properties of the output  tensor are, given the properties of the input Tensors.                                                                                                                     
@torch.library.impl_abstract("mylib::custom_func")
def custom_func_abstract(x, y, out):
    torch._check(x.shape == y.shape)
    torch._check(x.device == y.device)
    torch._check(out.shape == y.shape)
    torch._check(out.device == y.device)
    return
                                                                                                      
@torch.compile(backend="inductor", fullgraph=True)
def f(x, y, out):
    return torch.ops.mylib.custom_func.default(x, y, out)

x = torch.randn(3)
y = torch.randn(3)
out = torch.empty_like(x)
z = f(x, y, out)
assert torch.allclose(out, x * y)

Versions

PyTorch version: 2.3.0a0+40ec155e58.nv24.03
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-150-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe
GPU 2: NVIDIA A100 80GB PCIe
GPU 3: NVIDIA A100 80GB PCIe
GPU 4: NVIDIA A100 80GB PCIe
GPU 5: NVIDIA A100 80GB PCIe
GPU 6: NVIDIA A100 80GB PCIe
GPU 7: NVIDIA A100 80GB PCIe

Nvidia driver version: 525.89.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] onnx==1.15.0rc2
[pip3] optree==0.10.0
[pip3] pytorch-quantization==2.1.2
[pip3] pytorch-triton==2.2.0+e28a256d7
[pip3] torch==2.3.0a0+40ec155e58.nv24.3
[pip3] torch-tensorrt==2.3.0a0
[pip3] torchdata==0.7.1a0
[pip3] torchtext==0.17.0a0
[pip3] torchvision==0.18.0a0
[conda] Could not collect

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @zou3519

@ezyang
Copy link
Contributor

ezyang commented Apr 25, 2024

You must annotate the schema to specify which arguments mutate. Once you do so the right behavior will occur.

cc @zou3519

@zou3519
Copy link
Contributor

zou3519 commented Apr 25, 2024

Use "mylib::custom_func", "(Tensor x, Tensor y, Tensor(a!) out) -> None" to mark the out Tensor as mutating. Also if you're using PyTorch nightlies, we have a new custom ops API that makes this more explicit: https://pytorch.org/docs/main/library.html#creating-new-custom-ops-in-python

@jbschlosser jbschlosser added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: custom-operators labels Apr 25, 2024
@hk3911
Copy link
Author

hk3911 commented Apr 26, 2024

I鈥檓 very grateful for your answer. @zou3519 @ezyang
According to your suggestion, using the new custom API(https://pytorch.org/docs/main/library.html#creating-new-custom-ops-in-python) and PyTorch nightly wheel package, this issue can be resolved.

If using the old API, in addition to using 'mylib::custom_func', '(Tensor x, Tensor y, Tensor(a!) out) -> None' to mark the out Tensor as mutating, it is also necessary to add the boilerplate functionalization logic according to the documentation(https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa). Otherwise, the following error will occur when executing:

RuntimeError: Found a custom (non-ATen) operator that either mutates or its inputs: mylib::custom_func.. Getting these operators to work with functionalization requires some extra work. For mutable ops you need to register a corresponding out-of-place variant of the op, and you also need to register a Functionalization kernel that performs some boilerplate, telling functionalization to map from the mutable op to the out-of-place op. See a more complete example of how to do this at https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa.

The doc(https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa) seem to be targeted at C++ custom operators. For Python custom operators, if an older version of PyTorch is being used, is it true that the only way to avoid such issues is by not using mutative arguments when defining the interface?

@zou3519
Copy link
Contributor

zou3519 commented Apr 26, 2024

Try () return instead of the None return. There's no need to add the boilerplate logic (as of PyTorch 2.2 or 2.3)

import torch

lib = torch.library.Library("mylib", "FRAGMENT")
lib.define("foo(Tensor x, Tensor(a!) out) -> ()")

def foo_impl(x, out):
    out.copy_(x)

lib.impl("foo", foo_impl, "CompositeExplicitAutograd")

x = torch.randn(3)
out = torch.zeros(3)

@torch.compile
def f(x, out):
    torch.ops.mylib.foo(x, out)

f(x, out)

I filed an issue for supporting schemas with None returns (#125044)

If using the old API, in addition to using 'mylib::custom_func', '(Tensor x, Tensor y, Tensor(a!) out) -> None' to mark the out Tensor as mutating, it is also necessary to add the boilerplate functionalization logic according to the documentation(https://gist.github.com/bdhirsh/7dadbf6296f8f7d1abcf4c482f438aaa). Otherwise, the following error will occur when executing:

The error message is outdated (and misleading), we've already fixed it on main

@zou3519 zou3519 added the module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, label Apr 26, 2024
@hk3911
Copy link
Author

hk3911 commented Apr 28, 2024

Changing the return value from None to () resolved the issue. Thanks again @zou3519

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable module: custom-operators module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants