-
Notifications
You must be signed in to change notification settings - Fork 548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TODO]: Address issue with torch and torch_scatter installation order - replace torch_scatter functions with native torch? #580
Comments
Also note from another comment that we might need special installation instructions for |
@kevingreenman we might not need torch scatter anymore, according to this comment from the maintainer (rusty1s/pytorch_scatter#379 (comment)) all of the functionality is now in pytorch. From a cursory glance through the source code, we only use scatter sum, mean, and softmax. The first two are directly implemented in PyTorch now (https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_reduce_.html#torch.Tensor.scatter_reduce_) and the last I am sure we could find a workaround for (i.e. calling exp on the arg to scatter reduce w/ sum). Just an idea - this might end up being more work. |
We can actually just take the implementation of scatter softmax from torch_scatter (but using the pytorch's built-in scatter) and put it in chemprop: https://github.com/rusty1s/pytorch_scatter/blob/c095c62e4334fcd05e4ac3c4bb09d285960d6be6/torch_scatter/composite/softmax.py#L9 |
Thanks for pointing that out! Curious what @davidegraff thinks about this since he's the one who added |
I don't think it's a simple drop-in replacement as the APIs for import torch
from torch_scatter import scatter_sum
X = torch.arange(4) * torch.ones(4, 6)
dim_size = 2
print(X)
# tensor([[0., 0., 0., 0., 0., 0.],
# [1., 1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3., 3.]])
index_torch_scatter = torch.tensor([0, 1, 0, 1])
index_torch = index_torch_scatter[:, None].repeat(1, 6)
print(scatter_sum(X, index_torch_scatter, dim=0, dim_size=dim_size))
# tensor([[2., 2., 2., 2., 2., 2.],
# [4., 4., 4., 4., 4., 4.]])
print(torch.zeros(dim_size, X.shape[1]).scatter_add(0, index_torch, X))
# tensor([[2., 2., 2., 2., 2., 2.],
# [4., 4., 4., 4., 4., 4.]]) testing with random inputs also seems to work dim_size = 64
X = torch.randn(256, 256)
index_torch_scatter = torch.randint(dim_size, size=(X.shape[0],))
index_torch = index_torch_scatter[:, None].repeat(1, X.shape[1])
Z_torch_scatter = scatter_sum(X, index_torch_scatter, dim=0, dim_size=dim_size)
Z_torch = torch.zeros(dim_size, X.shape[1]).scatter_add(0, index_torch, X)
print(torch.isclose(Z_torch_scatter, Z_torch).all())
# tensor(True) and FWIW, native torch seems to be slightly faster too: from timeit import timeit
NUMBER = 10000
print(timeit('scatter_sum(X, index_torch_scatter, dim=0, dim_size=dim_size)', globals=globals(), number=NUMBER)
# 0.4636257500387728
print(timeit('torch.zeros(dim_size, X.shape[1]).scatter_add(0, index_torch, X)', globals=globals(), number=NUMBER)
# 0.4535887080710381 If anyone wants to tackle this and fix our environment build problems, the second code snippet should provide a path forward, i.e., for every instance of index = index[:, None].repeat(1, src.shape[1])
torch.zeros(dim_size, src.shape[1]).scatter_add(0, index, src) Note: in many places, If anyone does decide to tackle this, I would first write a unit test to ensure the torch-native reimplementation is mathematically equivalent to the |
I've tested our installation instructions on my Mac and on 5 different Linux machines. They worked without issue on my Mac (2021 M1 Pro), but we found in #695 that even updating the CI to define the correct order of installing torch and torch-scatter does not resolve the automated building step error on Mac. On Linux, I was able to install with no issues on 3/5 machines. The other two machines encountered issues at the torch-scatter step. However, they are different errors than the typical one we see that comes from not having torch installed ( One one machine (called slater, for my reference), I get:
This machine is running the following:
On the other machine (called kohn, for my reference), I get:
This machine currently has an issue with 1 of its 3 GPUs, but I'm not sure why that would lead to this error. Based on these results, I think we should definitely try to move away from the torch-scatter dependency by replacing its functions with native torch alternatives between now and the April MLPDS meeting. |
I am going to take a wack at this. I have been trying to get the CI to work (#714) but cannot get |
for MacOS I have to install
|
The issue installing
torch
andtorch-scatter
in one go is not something we're equipped to address, as it's a larger problem faced by many others. We should probably add a note in the README about installing correctly:Originally posted by @davidegraff in #567 (comment)
The text was updated successfully, but these errors were encountered: