Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DimeNet Model #178

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ There are many options for HydraGNN; the dataset and model type are particularly
important:
- `["Verbosity"]["level"]`: `0`, `1`, `2`, `3`, `4`
- `["Dataset"]["name"]`: `CuAu_32atoms`, `FePt_32atoms`, `FeSi_1024atoms`
- `["NeuralNetwork"]["Architecture"]["model_type"]`: `PNA`, `MFC`, `GIN`, `GAT`, `CGCNN`, `SchNet`
- `["NeuralNetwork"]["Architecture"]["model_type"]`: `PNA`, `MFC`, `GIN`, `GAT`, `CGCNN`, `SchNet`, `DimeNet`

### Citations
"HydraGNN: Distributed PyTorch implementation of multi-headed graph convolutional neural networks", Copyright ID#: 81929619
Expand Down
202 changes: 202 additions & 0 deletions hydragnn/models/DIMEStack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
##############################################################################
# Copyright (c) 2021, Oak Ridge National Laboratory #
# All rights reserved. #
# #
# This file is part of HydraGNN and is distributed under a BSD 3-clause #
# license. For the licensing terms see the LICENSE file in the top-level #
# directory. #
# #
# SPDX-License-Identifier: BSD-3-Clause #
##############################################################################

from typing import Callable, Optional, Tuple
from torch_geometric.typing import SparseTensor

import torch
from torch import Tensor
from torch.nn import SiLU

from torch_geometric.nn import Linear, Sequential
from torch_geometric.nn.models.dimenet import (
BesselBasisLayer,
EmbeddingBlock,
InteractionBlock,
OutputBlock,
SphericalBasisLayer,
)
from torch_geometric.utils import scatter

from .Base import Base


class DIMEStack(Base):
"""
Generates angles, distances, to/from indices, radial basis
functions and spherical basis functions for learning.
"""

def __init__(
self,
envelope_exponent,
num_after_skip,
num_before_skip,
num_bilinear,
num_radial,
num_spherical,
radius,
*args,
max_neighbours: Optional[int] = None,
**kwargs
):
self.num_bilinear = num_bilinear
self.num_radial = num_radial
self.num_spherical = num_spherical
self.num_before_skip = num_before_skip
self.num_after_skip = num_after_skip
self.radius = radius

super().__init__(*args, **kwargs)

self.rbf = BesselBasisLayer(num_radial, radius, envelope_exponent)
self.sbf = SphericalBasisLayer(
num_spherical, num_radial, radius, envelope_exponent
)

pass

def get_conv(self, input_dim, output_dim):
hidden_dim = output_dim if input_dim == 1 else input_dim
assert (
hidden_dim > 1
), "DimeNet requires more than one hidden dimension between input_dim and output_dim."
lin = Linear(input_dim, hidden_dim)
emb = HydraEmbeddingBlock(self.num_radial, hidden_dim, act=SiLU())
inter = HydraInteractionBlock(
hidden_channels=hidden_dim,
num_bilinear=self.num_bilinear,
num_spherical=self.num_spherical,
num_radial=self.num_radial,
num_before_skip=self.num_before_skip,
num_after_skip=self.num_after_skip,
act=SiLU(),
)
dec = OutputBlock(self.num_radial, hidden_dim, output_dim, 1, SiLU())
return Sequential(
"x, rbf, sbf, i, j, idx_kj, idx_ji",
[
(lin, "x -> x"),
(emb, "x, rbf, i, j -> x1"),
(inter, "x1, rbf, sbf, idx_kj, idx_ji -> x2"),
(dec, "x2, rbf, i -> c"),
],
)

def _conv_args(self, data):
assert (
data.pos is not None
), "DimeNet requires node positions (data.pos) to be set."
i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(
data.edge_index, num_nodes=data.x.size(0)
)
dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt()

# Calculate angles.
pos_i = data.pos[idx_i]
pos_ji, pos_ki = data.pos[idx_j] - pos_i, data.pos[idx_k] - pos_i
a = (pos_ji * pos_ki).sum(dim=-1)
b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
angle = torch.atan2(b, a)

rbf = self.rbf(dist)
sbf = self.sbf(dist, angle, idx_kj)

conv_args = {
"rbf": rbf,
"sbf": sbf,
"i": i,
"j": j,
"idx_kj": idx_kj,
"idx_ji": idx_ji,
}

return conv_args


"""
PyG Adapted Codes
------------------
The following code is adapted from
https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet.py

"""


def triplets(
edge_index: Tensor,
num_nodes: int,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
row, col = edge_index # j->i

value = torch.arange(row.size(0), device=row.device)
adj_t = SparseTensor(
row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes)
)
adj_t_row = adj_t[row]
num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

# Node indices (k->j->i) for triplets.
idx_i = col.repeat_interleave(num_triplets)
idx_j = row.repeat_interleave(num_triplets)
idx_k = adj_t_row.storage.col()
mask = idx_i != idx_k # Remove i == k triplets.
idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]

# Edge indices (k-j, j->i) for triplets.
idx_kj = adj_t_row.storage.value()[mask]
idx_ji = adj_t_row.storage.row()[mask]

return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji


class HydraEmbeddingBlock(EmbeddingBlock):
def __init__(self, num_radial: int, hidden_channels: int, act: Callable):
super().__init__(
num_radial=num_radial, hidden_channels=hidden_channels, act=act
)
del self.emb # Atomic embeddings are handled by Hydra.
self.reset_parameters()

def reset_parameters(self):
# self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
self.lin_rbf.reset_parameters()
self.lin.reset_parameters()

def forward(self, x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) -> Tensor:
# x = self.emb(x)
rbf = self.act(self.lin_rbf(rbf))
return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))


class HydraInteractionBlock(InteractionBlock):
def forward(
self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor, idx_ji: Tensor
) -> Tensor:
rbf = self.lin_rbf(rbf)
sbf = self.lin_sbf(sbf)

x_ji = self.act(self.lin_ji(x))
x_kj = self.act(self.lin_kj(x))
x_kj = x_kj * rbf
tmp = torch.einsum("wj,wl->wlj", sbf, x_kj[idx_kj])
x_kj = torch.einsum("wlj,ijl->wi", tmp, self.W)
# x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W) # Optimal Path cannot handle triple in this way.
x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0), reduce="sum")

h = x_ji + x_kj
for layer in self.layers_before_skip:
h = layer(h)
h = self.act(self.lin(h)) + x
for layer in self.layers_after_skip:
h = layer(h)

return h
45 changes: 45 additions & 0 deletions hydragnn/models/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from hydragnn.models.CGCNNStack import CGCNNStack
from hydragnn.models.SAGEStack import SAGEStack
from hydragnn.models.SCFStack import SCFStack
from hydragnn.models.DIMEStack import DIMEStack

from hydragnn.utils.distributed import get_device
from hydragnn.utils.print_utils import print_distributed
Expand Down Expand Up @@ -47,6 +48,12 @@ def create_model_config(
config["Architecture"]["max_neighbours"],
config["Architecture"]["edge_dim"],
config["Architecture"]["pna_deg"],
config["Architecture"]["num_before_skip"],
config["Architecture"]["num_after_skip"],
config["Architecture"]["num_bilinear"],
config["Architecture"]["num_radial"],
config["Architecture"]["envelope_exponent"],
config["Architecture"]["num_spherical"],
config["Architecture"]["num_gaussians"],
config["Architecture"]["num_filters"],
config["Architecture"]["radius"],
Expand All @@ -72,6 +79,12 @@ def create_model(
max_neighbours: int = None,
edge_dim: int = None,
pna_deg: torch.tensor = None,
num_before_skip: int = None,
num_after_skip: int = None,
num_bilinear: int = None,
num_radial: int = None,
envelope_exponent: int = None,
num_spherical: int = None,
num_gaussians: int = None,
num_filters: int = None,
radius: float = None,
Expand Down Expand Up @@ -206,6 +219,38 @@ def create_model(
num_nodes=num_nodes,
)

elif model_type == "DimeNet":
assert (
envelope_exponent is not None
), "DimeNet requires envelope_exponent input."
assert num_after_skip is not None, "DimeNet requires num_after_skip input."
assert num_before_skip is not None, "DimeNet requires num_before_skip input."
assert num_bilinear is not None, "DimeNet requires num_bilinear input."
assert num_radial is not None, "DimeNet requires num_radial input."
assert num_spherical is not None, "DimeNet requires num_spherical input."
assert radius is not None, "DimeNet requires radius input."
model = DIMEStack(
envelope_exponent,
num_after_skip,
num_before_skip,
num_bilinear,
num_radial,
num_spherical,
radius,
input_dim,
hidden_dim,
output_dim,
output_type,
output_heads,
loss_function_type,
max_neighbours=max_neighbours,
loss_weights=task_weights,
freeze_conv=freeze_conv,
initial_bias=initial_bias,
num_conv_layers=num_conv_layers,
num_nodes=num_nodes,
)

else:
raise ValueError("Unknown model_type: {0}".format(model_type))

Expand Down
14 changes: 13 additions & 1 deletion hydragnn/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ def update_config(config, train_loader, val_loader, test_loader):
config["NeuralNetwork"]["Architecture"]["num_gaussians"] = None
if "num_filters" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_filters"] = None
if "envelope_exponent" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["envelope_exponent"] = None
if "num_after_skip" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_after_skip"] = None
if "num_before_skip" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_before_skip"] = None
if "num_bilinear" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_bilinear"] = None
if "num_radial" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_radial"] = None
if "num_spherical" not in config["NeuralNetwork"]["Architecture"]:
config["NeuralNetwork"]["Architecture"]["num_spherical"] = None

config["NeuralNetwork"]["Architecture"] = update_config_edge_dim(
config["NeuralNetwork"]["Architecture"]
Expand Down Expand Up @@ -84,7 +96,7 @@ def update_config_edge_dim(config):
if "edge_features" in config and config["edge_features"]:
assert (
config["model_type"] in edge_models
), "Edge features can only be used with PNA and CGCNN."
), "Edge features can only be used with SchNet, PNA and CGCNN."
config["edge_dim"] = len(config["edge_features"])
elif config["model_type"] == "CGCNN":
# CG always needs an integer edge_dim
Expand Down
6 changes: 6 additions & 0 deletions tests/inputs/ci.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
"radius": 2.0,
"max_neighbours": 100,
"num_gaussians": 50,
"envelope_exponent": 5,
"num_after_skip": 2,
"num_before_skip": 1,
"num_bilinear": 8,
"num_radial": 6,
"num_spherical": 7,
"num_filters": 126,
"periodic_boundary_conditions": false,
"hidden_dim": 8,
Expand Down
6 changes: 6 additions & 0 deletions tests/inputs/ci_multihead.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
"radius": 2.0,
"max_neighbours": 100,
"num_gaussians": 50,
"envelope_exponent": 5,
"num_after_skip": 2,
"num_before_skip": 1,
"num_bilinear": 8,
"num_radial": 6,
"num_spherical": 7,
"num_filters": 126,
"periodic_boundary_conditions": false,
"hidden_dim": 8,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False
"GAT": [0.60, 0.70],
"CGCNN": [0.50, 0.40],
"SchNet": [0.20, 0.20],
"DimeNet": [0.50, 0.50],
}
if use_lengths and ("vector" not in ci_input):
thresholds["CGCNN"] = [0.175, 0.175]
Expand Down Expand Up @@ -173,7 +174,7 @@ def unittest_train_model(model_type, ci_input, use_lengths, overwrite_data=False

# Test across all models with both single/multihead
@pytest.mark.parametrize(
"model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN", "SchNet"]
"model_type", ["SAGE", "GIN", "GAT", "MFC", "PNA", "CGCNN", "SchNet", "DimeNet"]
)
@pytest.mark.parametrize("ci_input", ["ci.json", "ci_multihead.json"])
def pytest_train_model(model_type, ci_input, overwrite_data=False):
Expand Down