Skip to content

Commit

Permalink
SQUASH ME see extended
Browse files Browse the repository at this point in the history
once this acutally works, squash it with subsequent commits

attempting to get rid of pytorch_scatter in favor of native pytorch scatter
  • Loading branch information
JacksonBurns committed Mar 15, 2024
1 parent 73da67d commit d9bab46
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 20 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ jobs:
shell: bash -l {0}
run: |
python -m pip install flake8 pytest parameterized nbmake
python -m pip install torch==2.2.0
python -m pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.0+cpu.html
python -m pip install -e ".[dev,docs,test]"
python -m pip install ".[dev,docs,test]"
- name: Test with pytest
shell: bash -l {0}
run: |
Expand Down
5 changes: 0 additions & 5 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ build:
tools:
python: "3.11"
jobs:
# We need to install torch and torch-scatter separately as a pre-install step
# due to an issue with torch-scatter
pre_install:
- python -m pip install --upgrade --no-cache-dir torch
- python -m pip install --upgrade --no-cache-dir torch-scatter
post_install:
- python -m pip install --upgrade --upgrade-strategy only-if-needed --no-cache-dir ".[docs]"

Expand Down
3 changes: 1 addition & 2 deletions chemprop/nn/agg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from abc import abstractmethod
from torch import Tensor, nn
from torch_scatter import scatter, scatter_softmax

from chemprop.utils import ClassRegistry
from chemprop.utils import ClassRegistry, scatter, scatter_softmax
from chemprop.nn.hparams import HasHParams


Expand Down
2 changes: 1 addition & 1 deletion chemprop/nn/message_passing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from lightning.pytorch.core.mixins import HyperparametersMixin
import torch
from torch import Tensor, nn
from torch_scatter import scatter_sum

from chemprop.conf import DEFAULT_ATOM_FDIM, DEFAULT_BOND_FDIM, DEFAULT_HIDDEN_DIM
from chemprop.exceptions import InvalidShapeError
from chemprop.data import BatchMolGraph
from chemprop.utils import scatter_sum
from chemprop.nn.utils import Activation, get_activation_function
from chemprop.nn.message_passing.proto import MessagePassing

Expand Down
13 changes: 12 additions & 1 deletion chemprop/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
from .mixins import ReprMixin
from .registry import ClassRegistry, Factory
from .utils import EnumMapping, make_mol, pretty_shape
from .scatter import scatter_softmax, scatter, scatter_sum

__all__ = ["ReprMixin", "ClassRegistry", "Factory", "EnumMapping", "make_mol", "pretty_shape"]
__all__ = [
"ReprMixin",
"ClassRegistry",
"Factory",
"EnumMapping",
"make_mol",
"pretty_shape",
"scatter_softmax",
"scatter",
"scatter_sum",
]
73 changes: 73 additions & 0 deletions chemprop/utils/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# scatter.py
#
# This file wraps the native pytorch scatter API to act more like
# the third-party torch_scatter package, which we used previously
# but has since fallen into disrepair. The API for torch is also in
# beta, so having all of the implementation details here will
# simplify future updates.
#
# Most of the functionality here was been adapted from
# github.com/rusty1s/pytorch_scatter which is licensed under the MIT
# license.
import torch


# copied verbatim from:
# https://github.com/rusty1s/pytorch_scatter/blob/c095c62e4334fcd05e4ac3c4bb09d285960d6be6/torch_scatter/utils.py#L4
def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand(other.size())
return src


def _get_dimsize(src: torch.Tensor, index: torch.Tensor, dim: int):
# deduce the size of the output tensor from calling scatter
size = list(src.size())
size[dim] = int(index.max()) + 1
return size


def _get_dimsize_zeros_like(src: torch.Tensor, index: torch.Tensor, dim: int):
size = _get_dimsize(src, index, dim)
return torch.zeros(size, dtype=src.dtype, device=src.device)


def _batch_to_index(src: torch.Tensor, batch: torch.Tensor):
# convert a torch_scatter style index (batch) to a native pytorch scatter style index
return batch[:, None].repeat(1, src.shape[1])


# adapted from:
# https://github.com/rusty1s/pytorch_scatter/blob/c095c62e4334fcd05e4ac3c4bb09d285960d6be6/torch_scatter/composite/softmax.py#L9
def scatter_softmax(src: torch.Tensor, batch: torch.Tensor, dim: int = -1) -> torch.Tensor:
batch = _broadcast(batch, src, dim)
scatter_out = _get_dimsize_zeros_like(src, batch, dim)
index = _batch_to_index(src, batch)
max_value_per_index = scatter_out.scatter_reduce(0, index, src, "amax")[0]
max_per_src_element = max_value_per_index.gather(dim, batch)
recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp_()
sum_per_index = scatter_out.scatter_add(dim, index, recentered_scores_exp)
normalizing_constants = sum_per_index.gather(dim, batch)
return recentered_scores_exp.div(normalizing_constants)


def scatter(src: torch.Tensor, batch: torch.Tensor, dim: int, reduce: str):
scatter_out = _get_dimsize_zeros_like(src, batch, dim)
index = _batch_to_index(src, batch)
return scatter_out.scatter_reduce_(dim, index, src, reduce)


def scatter_sum(src: torch.Tensor, batch: torch.Tensor, dim: int, dim_size=None):
if dim_size is None:
dim_size = _get_dimsize(src, batch, dim)
index = _batch_to_index(src, batch)
return torch.zeros(dim_size, dtype=src.dtype, device=src.device).scatter_reduce_(
dim, index, src, "add"
)
7 changes: 0 additions & 7 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,11 @@ Start by setting up your virtual environment. We assume you are using ``conda``
conda install pytorch cpuonly -c pytorch
.. note::
We are aware that some users may experience issues during installation while trying to install `torch-scatter`. This is an issue with the `torch-scatter` package and not with Chemprop. We will resolve this issue before the release of v2.0.0, most likely by replacing our `torch-scatter` functions with native PyTorch functions and removing the `torch-scatter` dependency. You can follow along with this issue here: https://github.com/chemprop/chemprop/issues/580.

Option 1: Installing from PyPI
------------------------------

.. code-block::
pip install torch
pip install torch-scatter
pip install chemprop --pre
Expand All @@ -49,8 +44,6 @@ Option 2: Installing from source
git clone https://github.com/chemprop/chemprop.git
cd chemprop
git checkout v2/dev
pip install torch
pip install torch-scatter
pip install .
Option 3: Installing via Docker
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ dependencies = [
"scikit-learn",
"scipy",
"torch >= 2.1",
"torch_scatter",
"astartes[molecules]",
]

Expand Down

0 comments on commit d9bab46

Please sign in to comment.