New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[v2]: CI Overhaul (and make v2 actually _pass_ the CI) #714
Conversation
ab36dde
to
c83bfa6
Compare
535fca5
to
6a1849d
Compare
Have to table this for time being, but here are some notes from the future:
|
e5bce06
to
b32d523
Compare
I have rolled back the commits attempting to fix the install problem. |
@davidegraff can you look at caf728d ? It's a (bad) initial attempt at removing |
9087fee
to
caf728d
Compare
The good news is that all the tests are failing (meaning that they are all finally running!) 🎉 Apologies for the double ping, @davidegraff I know the The pytorch scatter API took in what we internally labeled as a 'batch' as its 'index', so I have retained that difference here. |
7a9fd1b
to
6856fe8
Compare
TODO: remove the torch_scatter installation workaround from the Dockerfile |
I don't feel that copying the code from |
This reply suggests that installing from source works for MacOS, but I have no problem when I install: |
Agreed that we don't want to keep it exactly the same as the reference implementation. I've copied it here as (1) a starting point and (2) to establish the structure I think we should shoot for when re-implementing this. I think we should have the scatter calls wrapped in a single util like I have done here - the actual implementation could be totally different, but this setup saves us having to update this in a lot of places if pytorch changes the API (since this feature is only in beta).
I tried the source install on MacOS and it didn't work on GitHub actions (failed during build). The former might work, but the larger issue in my head is that this installation procedure ( |
But that still doesn't solve the problem of slow python-level code though. I think it's much more preferable we move to a torch-native implementation based on a beta API than a slow one using yanked code for a few reasons:
Also FWIW, I don't think this install process is too cumbersome. It's the exact same ask as And I'm not sure where the MacOS problems are coming from--I have no problems installing on my Mac: $ conda create -n test_env python=3.11
$ conda activate test_env
$ pip install torch
$ pip install torch_scatter
$ python -c "import torch, torch_scatter; torch.manual_seed(42); print(torch_scatter.scatter_sum(torch.rand(10), torch.randint(4, size=(10,)), dim_size=4))"
tensor([0.6828, 1.9506, 0.5650, 1.6426])
$ pip uninstall torch torch_scatter -y
$ pip cache purge
$ pip install 'torch==2.2.0'
$ pip install torch_scatter
$ python -c "import torch, torch_scatter; torch.manual_seed(42); print(torch_scatter.scatter_sum(torch.rand(10), torch.randint(4, size=(10,)), dim_size=4))"
$ pip uninstall torch torch_scatter -y
$ pip cache purge
$ pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.0+cpu.html
$ python -c "import torch, torch_scatter; torch.manual_seed(42); print(torch_scatter.scatter_sum(torch.rand(10), torch.randint(4, size=(10,)), dim_size=4))" |
We can change the actual implementation to be pure pytorch, I don't possess the skills to do so but that's a great idea. The only thing I am seeking to demonstrate here is the idea of keeping our existing calls to scatter in the codebase and keeping the 'nitty gritty' in one place. That 'nitty gritty' detail can look like whatever
It's not that bad, but it could be so much better 😉
It's not an issue on all Macs, but it doesn't even work all the time on Linux either - see @kevingreenman's comment here. |
I'm fine with including black magic code and hiding it somewhere in the code-base if it's fast and something that we wrote. I just think it's really bad practice to include it when it's not actually our code. Then we don't understand it but it's also technically our problem if it's bugged somewhere (as opposed to abstractly relying on its correctness via a 3rd-party package). As for a torch-native implementation, I included an experiment here that should provide some guidelines for switching the code over. |
6856fe8
to
d9bab46
Compare
I have pulled your suggestion for setting up the index and put it here: and pulled from your suggested zero tensor initialization and put it here: but as I have mentioned here: #714 (comment) This new code does not work because I don't know what I'm doing. The only reason I have even started to edit the scatter stuff is because chempropv2 is currently spuriously un-installable on half the world's computers because of our usage of |
my bad! I thought you had actually just copied most of the source from |
FWIW I don't think it's worth it to keep our I think these are all the recipes we would need: Starting from: >>> import torch_scatter
>>> X = torch.tensor([
[0.3891, 0.0895, 0.3264, 0.5580],
[0.4385, 0.3270, 0.6044, 0.6146],
[0.7692, 0.8121, 0.4114, 0.0574],
[0.8758, 0.6162, 0.7566, 0.2899],
[0.7039, 0.5815, 0.0581, 0.2298],
[0.0574, 0.8424, 0.0066, 0.7251],
[0.5908, 0.7087, 0.5777, 0.2527],
[0.8578, 0.8803, 0.8242, 0.3267],
[0.2310, 0.4105, 0.9994, 0.6507],
[0.8728, 0.1607, 0.9016, 0.1866]]
)
>>> index = torch.tensor([2, 3, 0, 1, 2, 0, 3, 3, 1, 0])
>>> index_torch = index.unsqueeze(1).repeat(1, X.shape[1])
>>> dim = 0
>>> dim_size = 4 Sum
>>> torch_scatter.scatter(X, index, dim, dim_size=dim_size, reduce='sum')
tensor([[1.6994, 1.8152, 1.3196, 0.9690],
[1.1068, 1.0266, 1.7560, 0.9405],
[1.0930, 0.6710, 0.3845, 0.7878],
[1.8871, 1.9160, 2.0063, 1.1941]])
>>> torch.zeros(dim_size, X.shape[1]).scatter_reduce_(dim, index_torch, X, reduce='sum', include_self=False)
tensor([[1.6994, 1.8152, 1.3196, 0.9690],
[1.1068, 1.0266, 1.7560, 0.9405],
[1.0930, 0.6710, 0.3845, 0.7878],
[1.8871, 1.9160, 2.0063, 1.1941]]) Mean
>>> torch_scatter.scatter(X, index, dim, dim_size=dim_size, reduce='mean')
tensor([[0.5665, 0.6051, 0.4399, 0.3230],
[0.5534, 0.5133, 0.8780, 0.4703],
[0.5465, 0.3355, 0.1922, 0.3939],
[0.6290, 0.6387, 0.6688, 0.3980]])
>>> torch.zeros(dim_size, X.shape[1]).scatter_reduce_(dim, index_torch, X, reduce='mean', include_self=False)
tensor([[0.5665, 0.6051, 0.4399, 0.3230],
[0.5534, 0.5133, 0.8780, 0.4703],
[0.5465, 0.3355, 0.1922, 0.3939],
[0.6290, 0.6387, 0.6688, 0.3980]]) Softmax
>>> torch_scatter.scatter_softmax(X, index, dim=0, dim_size=dim_size)
tensor([[0.4219, 0.3794, 0.5667, 0.5813],
[0.2713, 0.2379, 0.3106, 0.4088],
[0.3846, 0.3918, 0.3031, 0.2446],
[0.6558, 0.5513, 0.4396, 0.4108],
[0.5781, 0.6206, 0.4333, 0.4187],
[0.1888, 0.4039, 0.2022, 0.4770],
[0.3160, 0.3485, 0.3024, 0.2847],
[0.4127, 0.4137, 0.3870, 0.3065],
[0.3442, 0.4487, 0.5604, 0.5892],
[0.4266, 0.2043, 0.4948, 0.2784]])
Note This recipe uses both >>> X_exp = X.exp()
>>> Z = torch.zeros(dim_size, X.shape[1]).scatter_reduce_(dim, index_torch, X_exp, reduce='sum', include_self=False)
>>> X_exp / Z[index]
tensor([[0.4219, 0.3794, 0.5667, 0.5813],
[0.2713, 0.2379, 0.3106, 0.4088],
[0.3846, 0.3918, 0.3031, 0.2446],
[0.6558, 0.5513, 0.4396, 0.4108],
[0.5781, 0.6206, 0.4333, 0.4187],
[0.1888, 0.4039, 0.2022, 0.4770],
[0.3160, 0.3485, 0.3024, 0.2847],
[0.4127, 0.4137, 0.3870, 0.3065],
[0.3442, 0.4487, 0.5604, 0.5892],
[0.4266, 0.2043, 0.4948, 0.2784]]) EDIT: updated to use the identity as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't looked at the torch scatter bits yet, so will cherry-pick that and run some tests. Some small comments on the rest of it. I am grateful you put so much into overhauling our CI for Chemprop.
# clone the repo, attempt to build | ||
- uses: actions/checkout@v4 | ||
- run: python -m pip install build | ||
- run: python -m build . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I appreciate all the inline documentation because I'm not familiar with most of this. What is the difference between building the repo and running tests? I would think that you need to build the repo before you can run tests, but I don't see a build step in "Execute Tests".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 'build' step refers to the literal python build
command. If it fails we couldn't build the Python package at the end or install Chemprop to run the tests.
The test running doesn't use build
but just installs the package. We could just go straight to running the tests, but it's kinda nice to have one expected point of failure for each step in the CI (like, the build only fails in one place, the tests only fail in one place, etc).
"astartes[molecules]", | ||
] | ||
|
||
[project.optional-dependencies] | ||
dev = ["black", "bumpversion", "flake8", "pytest", "pytest-cov"] | ||
docs = ["nbsphinx", "sphinx", "sphinx-argparse", "sphinx-autobuild", "sphinx-autoapi", "sphinxcontrib-bibtex", "sphinx-book-theme","nbsphinx-link","ipykernel"] | ||
test = ["parameterized > 0.8", "pytest >= 6.2", "pytest-cov"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it was used in v1 tests. Our testing frame doesn't seem to use it.
|
||
[tool.autopep8] | ||
select = ["E262", "W293"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@davidegraff Do you remember why you included this in #504? Is it okay to remove?
chemprop/nn/agg.py
Outdated
index_torch = batch.unsqueeze(1).repeat(1, H.shape[1]) | ||
dim_size = batch.max().int() + 1 | ||
H_exp = H.exp() | ||
Z = torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_( | ||
self.dim, index_torch, H_exp, reduce="sum", include_self=False | ||
) | ||
return H_exp / Z[batch] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't quite correct. Note how it doesn't make use of self.W
while the current v2/dev version does.
You can see a dimension mismatch error by replacing agg = nn.MeanAggregation()
in examples/training.ipynb
with agg = nn.AttentiveAggregation(output_size=mp.output_dim)
.
Here is a version of forward that works by my tests.
def forward(self, H: Tensor, batch: Tensor) -> Tensor:
dim_size = batch.max().int() + 1
attention_logits = (self.W(H) - self.W(H).max()).exp()
Z = torch.zeros(dim_size, 1, dtype=H.dtype, device=H.device).scatter_reduce_(
self.dim, batch.unsqueeze(1), attention_logits, reduce="sum", include_self=False
)
alphas = attention_logits / Z[batch]
index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
self.dim, index_torch, alphas * H, reduce="sum", include_self=False
)
I tested it by replacing agg = nn.AttentiveAggregation(output_size=mp.output_dim)
in examples/training.ipynb
with:
import torch
from torch import Tensor
class NewAgg(nn.AttentiveAggregation):
def forward(self, H: Tensor, batch: Tensor) -> Tensor:
dim_size = batch.max().int() + 1
attention_logits = (self.W(H) - self.W(H).max()).exp()
Z = torch.zeros(dim_size, 1, dtype=H.dtype, device=H.device).scatter_reduce_(
self.dim, batch.unsqueeze(1), attention_logits, reduce="sum", include_self=False
)
alphas = attention_logits / Z[batch]
index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
self.dim, index_torch, alphas * H, reduce="sum", include_self=False
)
agg = NewAgg(output_size=mp.output_dim)
I'm note sure attention_logits
is the best name for that variable. I'm not even sure that is a correct name, so open to suggestions.
Also note that I subtract self.W(H).max()
before exponentiating. This doesn't change the output, but I think it helps with numerical stability? At least this is what scatter_softmax does. IMPORTANT: torch_scatter
actually uses scatter_max
to get the max from each index grouping (each molecule) in a batch while my version only get the single largest value across the whole self.W(H)
vector. I couldn't find a way to get the max value for each grouping using only native pytorch. Using the max value of the whole vector may cause problems if one molecule's message passing output is much larger than the other molecules, so it would be good to find a more correct way to do this.
Second IMPORTANT note, I found numerical instability when testing a tensor of all ones (H = torch.ones_like(mp(bmg, V_d))
). The output should be a tensor of ones. Sometimes however some of the values are 1.0000001192092896
or 0.9999999403953552
. These are one bit above and below 1.0
(determined using this method). I also sometimes saw 0.9999992847442627
, 1.0000003576278687
, 1.0000004768371582
, 0.9999992251396179
and potentially others that are all off of 1.0
by a couple of bits.
In the test below I simulate the output of message passing with a variable number of atoms, a batch size of 1 (so all molecule indices are 0, batch = torch.zeros(n_atoms, dtype=torch.int64)
and an mp output of all ones (H = torch.ones(n_atoms, mp_output_dim, dtype=torch.float32
). The dtypes are the same as what I get out of a MolGraphDataLoader
. I compare three methods for the forward method of AttentiveAggregation
. forward_v2dev
is the current torch_scatter
version. forward_notscaled
is my version but without subtracting the max value before exponentiating. forward_scaled
is the version I presented. The off-by-one-bit error doesn't happen for every random initialization of W
(the attention layer that is just a torch.nn.Linear(mp_output_dim, 1)
) so I loop through the test 1000 times to get a sense of how often it fails. My results are summarized below followed by the code.
n_atoms | v2/dev errors | not scaled errors | scaled errors |
---|---|---|---|
3 | 0 | 326 | 0 |
4 | 0 | 0 | 0 |
5 | 0 | 193 | 0 |
6 | 0 | 442 | 0 |
7 | 0 | 466 | 0 |
8 | 0 | 447 | 0 |
9 | 0 | 375 | 0 |
10 | 1000 | 777 | 1000 |
11 | 1000 | 1000 | 1000 |
12 | 1000 | 1000 | 1000 |
13 | 0 | 432 | 0 |
14 | 1000 | 1000 | 1000 |
15 | 0 | 474 | 0 |
16 | 0 | 697 | 0 |
17 | 0 | 431 | 0 |
18 | 1000 | 1000 | 1000 |
19 | 1000 | 1000 | 1000 |
20 | 1000 | 832 | 1000 |
21 | 0 | 551 | 0 |
22 | 1000 | 585 | 1000 |
23 | 1000 | 1000 | 1000 |
24 | 1000 | 1000 | 1000 |
25 | 1000 | 1000 | 1000 |
26 | 1000 | 1000 | 1000 |
27 | 1000 | 1000 | 1000 |
28 | 1000 | 920 | 1000 |
29 | 1000 | 1000 | 1000 |
30 | 1000 | 1000 | 1000 |
31 | 1000 | 1000 | 1000 |
32 | 0 | 844 | 0 |
I'll stop there, but I continued to see ups and downs in how often a couple bits are off. I'll note that most of the time torch_scatter
and my scaled version have no failures while the unscaled version does, but there are also n_atoms values for which those two version fail everytime and the unscaled version only fails most of the time.
Ultimately the fact that the torch_scatter
version also has failures makes me think that we don't need to worry about this. And my proposed scaled version fails the same times that torch_scatter
does so it could be a replacement.
dim = 0
def forward_v2dev(H, batch, W):
alphas = scatter_softmax(W(H), batch, dim)
return scatter(alphas * H, batch, dim, reduce="sum")
def forward_notscaled(H, batch, W):
dim_size = batch.max().int() + 1
attention_logits = W(H).exp()
Z = torch.zeros(dim_size, 1, dtype=H.dtype, device=H.device).scatter_reduce_(
dim, batch.unsqueeze(1), attention_logits, reduce="sum", include_self=False
)
alphas = attention_logits / Z[batch]
index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
0, index_torch, alphas * H, reduce="sum", include_self=False
)
def forward_scaled(H, batch, W):
dim_size = batch.max().int() + 1
attention_logits = (W(H) - W(H).max()).exp()
Z = torch.zeros(dim_size, 1, dtype=H.dtype, device=H.device).scatter_reduce_(
dim, batch.unsqueeze(1), attention_logits, reduce="sum", include_self=False
)
alphas = attention_logits / Z[batch]
index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
0, index_torch, alphas * H, reduce="sum", include_self=False
)
n_atoms = 29
mp_output_dim = 1
fail_v2dev = 0
fail_notscaled = 0
fail_scaled = 0
n = 1000
for i in range(n):
W = torch.nn.Linear(mp_output_dim, 1)
batch = torch.zeros(n_atoms, dtype=torch.int64)
H = torch.ones(n_atoms, mp_output_dim, dtype=torch.float32)
v2dev = forward_v2dev(H, batch, W)
notscaled = forward_notscaled(H, batch, W)
scaled = forward_scaled(H, batch, W)
if v2dev.item() != 1.0:
fail_v2dev += 1
if notscaled.item() != 1.0:
fail_notscaled += 1
if scaled.item() != 1.0:
fail_scaled += 1
print(f"v2dev fails: {fail_v2dev}/{n}")
print(f"new fails: {fail_notscaled}/{n}")
print(f"fix fails: {fail_scaled}/{n}")
print(notscaled.item())
I've reviewed the other torch_scatter changes and believe they are correct. |
5cee9cb
to
cf1ccbb
Compare
see inline comments and PR description for more details
- removed unused imports - change tests with unused variables to instead use said variables - update the `pyproject.toml` specification for pep8 checking to match what we were doing in the CI - replaced star imports with explicit imports - add explicit exports in `__init__.py`'s - generic formatting changes i.e. trailing spaces, newlines, etc.
otherwise it can pass locally and fail on GHA
- simplifies installation process throughout as well
this resolves the hanging issue on windows and for some reason mac
cf1ccbb
to
05a0802
Compare
With the merging of #741 I have dropped the commits related to disallowing loading models from checkpoint in Python 3.12 and rebased. I will now look at the review comment from Nathan and push another commit. |
linter was not run on chemprop#741
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work on this. I'll add a note for the future that we decided to use self.W(H).exp()
instead of (self.W(H) - self.W(H).max()).exp()
because it is unlikely that self.W(H).exp()
will overflow a 32 bit floating point number.
This PR accomplishes a few things which are connected, and though the file diff is kinda big the changes are actually pretty simple:
pytorch_scatter
dependency with calls topytorch
's nativescatter
functions (thanks @davidegraff!), which involves some changes to the codebase, Dockerfile, CI, and documentation for installationnum_workers
>0 hangs on Windows and sometimes MacOS. #740)