Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v2]: CI Overhaul (and make v2 actually _pass_ the CI) #714

Merged
merged 18 commits into from Apr 1, 2024

Conversation

JacksonBurns
Copy link
Member

@JacksonBurns JacksonBurns commented Mar 7, 2024

This PR accomplishes a few things which are connected, and though the file diff is kinda big the changes are actually pretty simple:

  • replaced the pytorch_scatter dependency with calls to pytorch's native scatter functions (thanks @davidegraff!), which involves some changes to the codebase, Dockerfile, CI, and documentation for installation
  • reworked the CI to run the linting, build checking, tests, etc. as we had always intended to, but had put off during beta development for the sake of speed
  • extensive formatting changes to satisfy said CI
  • runs the CI on all platforms and supported Python versions, which involved:

@JacksonBurns
Copy link
Member Author

image

Has a nicely formatted action summary now.

@JacksonBurns
Copy link
Member Author

Have to table this for time being, but here are some notes from the future:

  • MacOS never gets to run the tests because the dependency install fails. It seems that conda doesn't have the binaries, pypi doesn't either, and the wheel building fails. My last effort is to change to compiler, but that probably won't work
  • On Python 3.12 all platforms fail the tests. When attempting to load a model from a checkpoint file, lightning somehow incorrectly tries to pass a self argument, which should have been filtered out by their own internal logic (they use inspect to get all of the model class's attributes and then filter out self) - I have no idea why this is broken, but lightning is famously finicky with loading from checkpoints so this might just require changing things seemingly randomly until it works.

@JacksonBurns
Copy link
Member Author

I have rolled back the commits attempting to fix the install problem.

@JacksonBurns
Copy link
Member Author

@davidegraff can you look at caf728d ? It's a (bad) initial attempt at removing pytorch_lightning in favor of native torch (but it's wrong because I don't know what I am doing).

@JacksonBurns JacksonBurns changed the title [v2]: CI Overhaul [v2]: CI Overhaul (and make v2 actually _pass_ the CI) Mar 13, 2024
@JacksonBurns
Copy link
Member Author

JacksonBurns commented Mar 13, 2024

The good news is that all the tests are failing (meaning that they are all finally running!) 🎉

Apologies for the double ping, @davidegraff I know the index errors come from not properly passing the index tensor in the way that you showed in the issue, but I wasn't sure where to include it here. have made a first pass at it in 6856fe8

The pytorch scatter API took in what we internally labeled as a 'batch' as its 'index', so I have retained that difference here.

@JacksonBurns
Copy link
Member Author

JacksonBurns commented Mar 13, 2024

TODO: remove the torch_scatter installation workaround from the Dockerfile

@davidegraff
Copy link
Contributor

I don't feel that copying the code from torch_scatter is the right solution. We're not using the package for its python-level code, but rather because these operations are implemented efficiently at the C++/CUDA level. Copying the python source into the repo is the worst of all three options IMO because it's (1) slow and (2) not maintained (in that we're now responsible for code that we didn't actually write).

@davidegraff
Copy link
Contributor

davidegraff commented Mar 13, 2024

This reply suggests that installing from source works for MacOS, but I have no problem when I install: pip install torch_scatter -f URL --no-index

@JacksonBurns
Copy link
Member Author

I don't feel that copying the code from torch_scatter is the right solution. We're not using the package for its python-level code, but rather because these operations are implemented efficiently at the C++/CUDA level. Copying the python source into the repo is the worst of all three options IMO because it's (1) slow and (2) not maintained (in that we're now responsible for code that we didn't actually write).

Agreed that we don't want to keep it exactly the same as the reference implementation. I've copied it here as (1) a starting point and (2) to establish the structure I think we should shoot for when re-implementing this. I think we should have the scatter calls wrapped in a single util like I have done here - the actual implementation could be totally different, but this setup saves us having to update this in a lot of places if pytorch changes the API (since this feature is only in beta).

This reply suggests that installing from source works for MacOS, but I have no problem when I install: pip install torch_scatter -f URL --no-index

I tried the source install on MacOS and it didn't work on GitHub actions (failed during build). The former might work, but the larger issue in my head is that this installation procedure (install torch -> tomfoolery to install torch_scatter -> install chemprop) has to go.

@kevingreenman kevingreenman linked an issue Mar 13, 2024 that may be closed by this pull request
@davidegraff
Copy link
Contributor

But that still doesn't solve the problem of slow python-level code though. I think it's much more preferable we move to a torch-native implementation based on a beta API than a slow one using yanked code for a few reasons:

  1. The API is unlikely to change that much considering that both gather and scatter (from which these APIs derive) are established APIs.
  2. it's fast and "maintained" (I think it's a very bad idea to copy source code we don't understand for the core computation in our package)
  3. we can limit the versions until the API is stable.

Also FWIW, I don't think this install process is too cumbersome. It's the exact same ask as pyg, a significantly larger package than ours. Consider that only until recently both RDKit and PyTorch required some finnicky install procedure and the community just dealt with it. I agree that we should strive for a package that is pip-installable, but the current procedure is really not that bad: it doesn't require specifying magic URLs, an involved setup script, etc.

And I'm not sure where the MacOS problems are coming from--I have no problems installing on my Mac:

$ conda create -n test_env python=3.11
$ conda activate test_env
$ pip install torch
$ pip install torch_scatter
$ python -c "import torch, torch_scatter; torch.manual_seed(42); print(torch_scatter.scatter_sum(torch.rand(10), torch.randint(4, size=(10,)), dim_size=4))"
tensor([0.6828, 1.9506, 0.5650, 1.6426])
$ pip uninstall torch torch_scatter -y
$ pip cache purge
$ pip install 'torch==2.2.0'
$ pip install torch_scatter
$ python -c "import torch, torch_scatter; torch.manual_seed(42); print(torch_scatter.scatter_sum(torch.rand(10), torch.randint(4, size=(10,)), dim_size=4))"
$ pip uninstall torch torch_scatter -y
$ pip cache purge
$ pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.0+cpu.html
$ python -c "import torch, torch_scatter; torch.manual_seed(42); print(torch_scatter.scatter_sum(torch.rand(10), torch.randint(4, size=(10,)), dim_size=4))"

@JacksonBurns
Copy link
Member Author

But that still doesn't solve the problem of slow python-level code though. I think it's much more preferable we move to a torch-native implementation based on a beta API than a slow one using yanked code for a few reasons:

  1. The API is unlikely to change that much considering that both gather and scatter (from which these APIs derive) are established APIs.
  2. it's fast and "maintained" (I think it's a very bad idea to copy source code we don't understand for the core computation in our package)
  3. we can limit the versions until the API is stable.

We can change the actual implementation to be pure pytorch, I don't possess the skills to do so but that's a great idea. The only thing I am seeking to demonstrate here is the idea of keeping our existing calls to scatter in the codebase and keeping the 'nitty gritty' in one place. That 'nitty gritty' detail can look like whatever

Also FWIW, I don't think this install process is too cumbersome. It's the exact same ask as pyg, a significantly larger package than ours. Consider that only until recently both RDKit and PyTorch required some finnicky install procedure and the community just dealt with it. I agree that we should strive for a package that is pip-installable, but the current procedure is really not that bad: it doesn't require specifying magic URLs, an involved setup script, etc.

It's not that bad, but it could be so much better 😉

And I'm not sure where the MacOS problems are coming from--I have no problems installing on my Mac:It

It's not an issue on all Macs, but it doesn't even work all the time on Linux either - see @kevingreenman's comment here.

@davidegraff
Copy link
Contributor

I'm fine with including black magic code and hiding it somewhere in the code-base if it's fast and something that we wrote. I just think it's really bad practice to include it when it's not actually our code. Then we don't understand it but it's also technically our problem if it's bugged somewhere (as opposed to abstractly relying on its correctness via a 3rd-party package). As for a torch-native implementation, I included an experiment here that should provide some guidelines for switching the code over.

@JacksonBurns
Copy link
Member Author

I have pulled your suggestion for setting up the index and put it here:

https://github.com/chemprop/chemprop/pull/714/files#diff-81dbd7f93e533247fbfddddbf857f3974b4324778d27cd562fe072ebbb9d8f6dR43

and pulled from your suggested zero tensor initialization and put it here:

https://github.com/chemprop/chemprop/pull/714/files#diff-81dbd7f93e533247fbfddddbf857f3974b4324778d27cd562fe072ebbb9d8f6dR71

but as I have mentioned here: #714 (comment)

This new code does not work because I don't know what I'm doing.

The only reason I have even started to edit the scatter stuff is because chempropv2 is currently spuriously un-installable on half the world's computers because of our usage of torch_scatter. I need more help getting this implementation to work whether that's with pure pytorch throughout the codebase, custom python-based warppers in one place, or code shamelessly stolen from torch_scatter.

@davidegraff
Copy link
Contributor

my bad! I thought you had actually just copied most of the source from torch_scatter and placed it in inside our codebase. I see now what you're trying to do and that's on me, so I apologize for that. I can try and take a look at this later to see what's going on

@davidegraff
Copy link
Contributor

davidegraff commented Mar 15, 2024

FWIW I don't think it's worth it to keep our torch_scatter client code the same and abstract away the "guts" to utils/scatter.py module. I think we should actually just change the client code to work with torch. The only spots we would need to change are seven uses across nn/message_passing/base.py and nn/agg.py.

I think these are all the recipes we would need:

Starting from:

>>> import torch_scatter
>>> X = torch.tensor([
    [0.3891, 0.0895, 0.3264, 0.5580],
    [0.4385, 0.3270, 0.6044, 0.6146],
    [0.7692, 0.8121, 0.4114, 0.0574],
    [0.8758, 0.6162, 0.7566, 0.2899],
    [0.7039, 0.5815, 0.0581, 0.2298],
    [0.0574, 0.8424, 0.0066, 0.7251],
    [0.5908, 0.7087, 0.5777, 0.2527],
    [0.8578, 0.8803, 0.8242, 0.3267],
    [0.2310, 0.4105, 0.9994, 0.6507],
    [0.8728, 0.1607, 0.9016, 0.1866]]
)
>>> index = torch.tensor([2, 3, 0, 1, 2, 0, 3, 3, 1, 0])
>>> index_torch = index.unsqueeze(1).repeat(1, X.shape[1])
>>> dim = 0
>>> dim_size = 4

Sum

torch_scatter

>>> torch_scatter.scatter(X, index, dim, dim_size=dim_size, reduce='sum')
tensor([[1.6994, 1.8152, 1.3196, 0.9690],
        [1.1068, 1.0266, 1.7560, 0.9405],
        [1.0930, 0.6710, 0.3845, 0.7878],
        [1.8871, 1.9160, 2.0063, 1.1941]])

torch

>>> torch.zeros(dim_size, X.shape[1]).scatter_reduce_(dim, index_torch, X, reduce='sum', include_self=False)
tensor([[1.6994, 1.8152, 1.3196, 0.9690],
        [1.1068, 1.0266, 1.7560, 0.9405],
        [1.0930, 0.6710, 0.3845, 0.7878],
        [1.8871, 1.9160, 2.0063, 1.1941]])

Mean

torch_scatter

>>> torch_scatter.scatter(X, index, dim, dim_size=dim_size, reduce='mean')
tensor([[0.5665, 0.6051, 0.4399, 0.3230],
        [0.5534, 0.5133, 0.8780, 0.4703],
        [0.5465, 0.3355, 0.1922, 0.3939],
        [0.6290, 0.6387, 0.6688, 0.3980]])

torch

>>> torch.zeros(dim_size, X.shape[1]).scatter_reduce_(dim, index_torch, X, reduce='mean', include_self=False)
tensor([[0.5665, 0.6051, 0.4399, 0.3230],
        [0.5534, 0.5133, 0.8780, 0.4703],
        [0.5465, 0.3355, 0.1922, 0.3939],
        [0.6290, 0.6387, 0.6688, 0.3980]])

Softmax

torch_scatter

>>> torch_scatter.scatter_softmax(X, index, dim=0, dim_size=dim_size)
tensor([[0.4219, 0.3794, 0.5667, 0.5813],
        [0.2713, 0.2379, 0.3106, 0.4088],
        [0.3846, 0.3918, 0.3031, 0.2446],
        [0.6558, 0.5513, 0.4396, 0.4108],
        [0.5781, 0.6206, 0.4333, 0.4187],
        [0.1888, 0.4039, 0.2022, 0.4770],
        [0.3160, 0.3485, 0.3024, 0.2847],
        [0.4127, 0.4137, 0.3870, 0.3065],
        [0.3442, 0.4487, 0.5604, 0.5892],
        [0.4266, 0.2043, 0.4948, 0.2784]])

torch

Note

This recipe uses both index and index_torch. You can use index_torch[:, 0] in place of index

>>> X_exp = X.exp()
>>> Z = torch.zeros(dim_size, X.shape[1]).scatter_reduce_(dim, index_torch, X_exp, reduce='sum', include_self=False)
>>> X_exp / Z[index]
tensor([[0.4219, 0.3794, 0.5667, 0.5813],
        [0.2713, 0.2379, 0.3106, 0.4088],
        [0.3846, 0.3918, 0.3031, 0.2446],
        [0.6558, 0.5513, 0.4396, 0.4108],
        [0.5781, 0.6206, 0.4333, 0.4187],
        [0.1888, 0.4039, 0.2022, 0.4770],
        [0.3160, 0.3485, 0.3024, 0.2847],
        [0.4127, 0.4137, 0.3870, 0.3065],
        [0.3442, 0.4487, 0.5604, 0.5892],
        [0.4266, 0.2043, 0.4948, 0.2784]])

EDIT: updated to use the identity as src per my comment below

Copy link
Contributor

@KnathanM KnathanM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked at the torch scatter bits yet, so will cherry-pick that and run some tests. Some small comments on the rest of it. I am grateful you put so much into overhauling our CI for Chemprop.

.github/workflows/ci.yml Outdated Show resolved Hide resolved
.github/workflows/ci.yml Show resolved Hide resolved
# clone the repo, attempt to build
- uses: actions/checkout@v4
- run: python -m pip install build
- run: python -m build .
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate all the inline documentation because I'm not familiar with most of this. What is the difference between building the repo and running tests? I would think that you need to build the repo before you can run tests, but I don't see a build step in "Execute Tests".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'build' step refers to the literal python build command. If it fails we couldn't build the Python package at the end or install Chemprop to run the tests.

The test running doesn't use build but just installs the package. We could just go straight to running the tests, but it's kinda nice to have one expected point of failure for each step in the CI (like, the build only fails in one place, the tests only fail in one place, etc).

chemprop/data/datasets.py Outdated Show resolved Hide resolved
chemprop/nn/predictors.py Show resolved Hide resolved
"astartes[molecules]",
]

[project.optional-dependencies]
dev = ["black", "bumpversion", "flake8", "pytest", "pytest-cov"]
docs = ["nbsphinx", "sphinx", "sphinx-argparse", "sphinx-autobuild", "sphinx-autoapi", "sphinxcontrib-bibtex", "sphinx-book-theme","nbsphinx-link","ipykernel"]
test = ["parameterized > 0.8", "pytest >= 6.2", "pytest-cov"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it was used in v1 tests. Our testing frame doesn't seem to use it.

.github/workflows/ci.yml Outdated Show resolved Hide resolved

[tool.autopep8]
select = ["E262", "W293"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davidegraff Do you remember why you included this in #504? Is it okay to remove?

tests/cli/test_cli_classification_mol.py Show resolved Hide resolved
Comment on lines 123 to 129
index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
dim_size = batch.max().int() + 1
H_exp = H.exp()
Z = torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
self.dim, index_torch, H_exp, reduce="sum", include_self=False
)
return H_exp / Z[batch]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite correct. Note how it doesn't make use of self.W while the current v2/dev version does.
You can see a dimension mismatch error by replacing agg = nn.MeanAggregation() in examples/training.ipynb with agg = nn.AttentiveAggregation(output_size=mp.output_dim).

Here is a version of forward that works by my tests.

def forward(self, H: Tensor, batch: Tensor) -> Tensor:
    dim_size = batch.max().int() + 1
    attention_logits = (self.W(H) - self.W(H).max()).exp()
    Z = torch.zeros(dim_size, 1, dtype=H.dtype, device=H.device).scatter_reduce_(
        self.dim, batch.unsqueeze(1), attention_logits, reduce="sum", include_self=False
    )
    alphas = attention_logits / Z[batch]
    index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
    return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
        self.dim, index_torch, alphas * H, reduce="sum", include_self=False
    )

I tested it by replacing agg = nn.AttentiveAggregation(output_size=mp.output_dim) in examples/training.ipynb with:

import torch
from torch import Tensor
class NewAgg(nn.AttentiveAggregation):
    def forward(self, H: Tensor, batch: Tensor) -> Tensor:
        dim_size = batch.max().int() + 1
        attention_logits = (self.W(H) - self.W(H).max()).exp()
        Z = torch.zeros(dim_size, 1, dtype=H.dtype, device=H.device).scatter_reduce_(
            self.dim, batch.unsqueeze(1), attention_logits, reduce="sum", include_self=False
        )
        alphas = attention_logits / Z[batch]
        index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
        return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
            self.dim, index_torch, alphas * H, reduce="sum", include_self=False
        )
agg = NewAgg(output_size=mp.output_dim)		

I'm note sure attention_logits is the best name for that variable. I'm not even sure that is a correct name, so open to suggestions.

Also note that I subtract self.W(H).max() before exponentiating. This doesn't change the output, but I think it helps with numerical stability? At least this is what scatter_softmax does. IMPORTANT: torch_scatter actually uses scatter_max to get the max from each index grouping (each molecule) in a batch while my version only get the single largest value across the whole self.W(H) vector. I couldn't find a way to get the max value for each grouping using only native pytorch. Using the max value of the whole vector may cause problems if one molecule's message passing output is much larger than the other molecules, so it would be good to find a more correct way to do this.

Second IMPORTANT note, I found numerical instability when testing a tensor of all ones (H = torch.ones_like(mp(bmg, V_d))). The output should be a tensor of ones. Sometimes however some of the values are 1.0000001192092896 or 0.9999999403953552. These are one bit above and below 1.0 (determined using this method). I also sometimes saw 0.9999992847442627, 1.0000003576278687, 1.0000004768371582, 0.9999992251396179 and potentially others that are all off of 1.0 by a couple of bits.

In the test below I simulate the output of message passing with a variable number of atoms, a batch size of 1 (so all molecule indices are 0, batch = torch.zeros(n_atoms, dtype=torch.int64) and an mp output of all ones (H = torch.ones(n_atoms, mp_output_dim, dtype=torch.float32). The dtypes are the same as what I get out of a MolGraphDataLoader. I compare three methods for the forward method of AttentiveAggregation. forward_v2dev is the current torch_scatter version. forward_notscaled is my version but without subtracting the max value before exponentiating. forward_scaled is the version I presented. The off-by-one-bit error doesn't happen for every random initialization of W (the attention layer that is just a torch.nn.Linear(mp_output_dim, 1)) so I loop through the test 1000 times to get a sense of how often it fails. My results are summarized below followed by the code.

n_atoms v2/dev errors not scaled errors scaled errors
3 0 326 0
4 0 0 0
5 0 193 0
6 0 442 0
7 0 466 0
8 0 447 0
9 0 375 0
10 1000 777 1000
11 1000 1000 1000
12 1000 1000 1000
13 0 432 0
14 1000 1000 1000
15 0 474 0
16 0 697 0
17 0 431 0
18 1000 1000 1000
19 1000 1000 1000
20 1000 832 1000
21 0 551 0
22 1000 585 1000
23 1000 1000 1000
24 1000 1000 1000
25 1000 1000 1000
26 1000 1000 1000
27 1000 1000 1000
28 1000 920 1000
29 1000 1000 1000
30 1000 1000 1000
31 1000 1000 1000
32 0 844 0

I'll stop there, but I continued to see ups and downs in how often a couple bits are off. I'll note that most of the time torch_scatter and my scaled version have no failures while the unscaled version does, but there are also n_atoms values for which those two version fail everytime and the unscaled version only fails most of the time.

Ultimately the fact that the torch_scatter version also has failures makes me think that we don't need to worry about this. And my proposed scaled version fails the same times that torch_scatter does so it could be a replacement.

dim = 0
def forward_v2dev(H, batch, W):
    alphas = scatter_softmax(W(H), batch, dim)
    return scatter(alphas * H, batch, dim, reduce="sum")
def forward_notscaled(H, batch, W):
    dim_size = batch.max().int() + 1
    attention_logits = W(H).exp()
    Z = torch.zeros(dim_size, 1, dtype=H.dtype, device=H.device).scatter_reduce_(
        dim, batch.unsqueeze(1), attention_logits, reduce="sum", include_self=False
    )
    alphas = attention_logits / Z[batch]
    index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
    return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
        0, index_torch, alphas * H, reduce="sum", include_self=False
    )
def forward_scaled(H, batch, W):
    dim_size = batch.max().int() + 1
    attention_logits = (W(H) - W(H).max()).exp()
    Z = torch.zeros(dim_size, 1, dtype=H.dtype, device=H.device).scatter_reduce_(
        dim, batch.unsqueeze(1), attention_logits, reduce="sum", include_self=False
    )
    alphas = attention_logits / Z[batch]
    index_torch = batch.unsqueeze(1).repeat(1, H.shape[1])
    return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(
        0, index_torch, alphas * H, reduce="sum", include_self=False
    )
n_atoms = 29

mp_output_dim = 1
fail_v2dev = 0
fail_notscaled = 0
fail_scaled = 0
n = 1000
for i in range(n):
    W = torch.nn.Linear(mp_output_dim, 1)
    batch = torch.zeros(n_atoms, dtype=torch.int64)
    H = torch.ones(n_atoms, mp_output_dim, dtype=torch.float32)
    v2dev = forward_v2dev(H, batch, W)
    notscaled = forward_notscaled(H, batch, W)
    scaled = forward_scaled(H, batch, W)
    if v2dev.item() != 1.0:
        fail_v2dev += 1
    if notscaled.item() != 1.0:
        fail_notscaled += 1
    if scaled.item() != 1.0:
        fail_scaled += 1
print(f"v2dev fails: {fail_v2dev}/{n}")
print(f"new fails: {fail_notscaled}/{n}")
print(f"fix fails: {fail_scaled}/{n}")
print(notscaled.item())

@KnathanM
Copy link
Contributor

I've reviewed the other torch_scatter changes and believe they are correct.

JacksonBurns and others added 14 commits April 1, 2024 12:18
see inline comments and PR description for more details
 - removed unused imports
 - change tests with unused variables to instead use said variables
 - update the `pyproject.toml` specification for pep8 checking to match what we were doing in the CI
 - replaced star imports with explicit imports
 - add explicit exports in `__init__.py`'s
 - generic formatting changes i.e. trailing spaces, newlines, etc.
otherwise it can pass locally and fail on GHA
 - simplifies installation process throughout as well
this resolves the hanging issue on windows and for some reason mac
@JacksonBurns
Copy link
Member Author

With the merging of #741 I have dropped the commits related to disallowing loading models from checkpoint in Python 3.12 and rebased. I will now look at the review comment from Nathan and push another commit.

Copy link
Contributor

@KnathanM KnathanM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work on this. I'll add a note for the future that we decided to use self.W(H).exp() instead of (self.W(H) - self.W(H).max()).exp() because it is unlikely that self.W(H).exp() will overflow a 32 bit floating point number.

@JacksonBurns JacksonBurns merged commit fb01916 into chemprop:v2/dev Apr 1, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants