-
Notifications
You must be signed in to change notification settings - Fork 548
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
once this acutally works, squash it with subsequent commits attempting to get rid of pytorch_scatter in favor of native pytorch scatter
- Loading branch information
1 parent
b32d523
commit caf728d
Showing
8 changed files
with
81 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,16 @@ | ||
from .mixins import ReprMixin | ||
from .registry import ClassRegistry, Factory | ||
from .utils import EnumMapping, make_mol, pretty_shape | ||
from .scatter import scatter_softmax, scatter, scatter_sum | ||
|
||
__all__ = ["ReprMixin", "ClassRegistry", "Factory", "EnumMapping", "make_mol", "pretty_shape"] | ||
__all__ = [ | ||
"ReprMixin", | ||
"ClassRegistry", | ||
"Factory", | ||
"EnumMapping", | ||
"make_mol", | ||
"pretty_shape", | ||
"scatter_softmax", | ||
"scatter", | ||
"scatter_sum", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# scatter.py | ||
# | ||
# This file wraps the native pytorch scatter API to act more like | ||
# the third-party torch_scatter package, which we used previously | ||
# but has since fallen into disrepair. The API for torch is also in | ||
# beta, so having all of the implementation details here will | ||
# simplify future updates. | ||
# | ||
# Most of the functionality here has been adapted from | ||
# github.com/rusty1s/pytorch_scatter which is license under the MIT | ||
# license. | ||
import torch | ||
|
||
|
||
# copied verbatim from: | ||
# https://github.com/rusty1s/pytorch_scatter/blob/c095c62e4334fcd05e4ac3c4bb09d285960d6be6/torch_scatter/utils.py#L4 | ||
def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): | ||
if dim < 0: | ||
dim = other.dim() + dim | ||
if src.dim() == 1: | ||
for _ in range(0, dim): | ||
src = src.unsqueeze(0) | ||
for _ in range(src.dim(), other.dim()): | ||
src = src.unsqueeze(-1) | ||
src = src.expand(other.size()) | ||
return src | ||
|
||
|
||
def _get_dimsize(src: torch.Tensor, index: torch.Tensor, dim: int): | ||
# deduce the size of the output tensor from calling scatter | ||
size = list(src.size()) | ||
size[dim] = int(index.max()) + 1 | ||
return size | ||
|
||
|
||
def _get_dimsize_zeros(src: torch.Tensor, index: torch.Tensor, dim: int): | ||
size = _get_dimsize(src, index, dim) | ||
# return a zero tensor that looks like the src tensor | ||
return torch.zeros(size, dtype=src.dtype, device=src.device) | ||
|
||
|
||
# adapted from: | ||
# https://github.com/rusty1s/pytorch_scatter/blob/c095c62e4334fcd05e4ac3c4bb09d285960d6be6/torch_scatter/composite/softmax.py#L9 | ||
def scatter_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1) -> torch.Tensor: | ||
index = _broadcast(index, src, dim) | ||
scatter_out = _get_dimsize_zeros(src, index, dim) | ||
max_value_per_index = scatter_out.scatter_reduce(0, index, src, "amax")[0] | ||
max_per_src_element = max_value_per_index.gather(dim, index) | ||
recentered_scores = src - max_per_src_element | ||
recentered_scores_exp = recentered_scores.exp_() | ||
sum_per_index = scatter_out.scatter_add(dim, index, recentered_scores_exp) | ||
normalizing_constants = sum_per_index.gather(dim, index) | ||
return recentered_scores_exp.div(normalizing_constants) | ||
|
||
|
||
def scatter(H: torch.Tensor, batch: torch.Tensor, dim: int, reduce: str): | ||
scatter_out = _get_dimsize_zeros(H, batch, dim) | ||
return scatter_out.scatter_reduce_(dim, batch, H, reduce) | ||
|
||
|
||
def scatter_sum(H: torch.Tensor, index: torch.Tensor, dim: int, dim_size=None): | ||
if dim_size is None: | ||
dim_size = _get_dimsize(H, index, dim) | ||
return torch.zeros(dim_size, dtype=H.dtype, device=H.device).scatter_reduce_( | ||
dim, index, H, "add" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,6 @@ dependencies = [ | |
"scikit-learn", | ||
"scipy", | ||
"torch >= 2.1", | ||
"torch_scatter", | ||
"astartes[molecules]", | ||
] | ||
|
||
|