Skip to content

Commit

Permalink
Add unit test for torch_geometric library (#85937)
Browse files Browse the repository at this point in the history
Fixes #65138

Pull Request resolved: #85937
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
  • Loading branch information
Thiago Crepaldi authored and pytorchmergebot committed Nov 1, 2022
1 parent 5c3666c commit fdc4197
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -12575,6 +12575,59 @@ def forward(self, x):
x,
)

@skipScriptTest()
@skipIfUnsupportedMinOpsetVersion(16)
@unittest.skipIf(
not torch.hub._check_module_exists("torch_geometric"),
"torch_geometric not installed.",
)
def test_sage_conv(self):
from torch_geometric import nn as torch_geometric_nn

# Input
coords0 = torch.randn(1, 6)
coords1 = torch.randn(1, 6)
coords = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1)
adj = torch_geometric_nn.knn_graph(coords, k=2, batch=None, loop=True)
edge_from = adj[0:1, :]
edge_to = adj[1:, :]
inputs = (coords0, coords1, edge_from, edge_to)

class MySAGEConv(torch.nn.Module):
def __init__(self):
super().__init__()
self.SAGEConvBlock1 = torch_geometric_nn.SAGEConv(
2, 512, normalize=True
)
self.bano1 = torch_geometric_nn.BatchNorm(512)
self.relu = torch.nn.ReLU()
self.dense1 = torch.nn.Seq(Lin(512, 1))
self.sigmoid = torch.nn.Sigmoid()

def forward(self, coords0, coords1, edge_from, edge_to):
adj = torch.cat((edge_from, edge_to), dim=0)
gra = torch.transpose(torch.cat((coords0, coords1), dim=0), 0, 1)
x1 = self.SAGEConvBlock1(gra, edge_index=adj)
x = torch.unsqueeze(torch.sum(x1), dim=0)
return x

input_names = ["coords0", "coords1", "edge_from", "edge_to"]
output_names = ["outputs"]
dynamic_axes = {
"coords0": {0: "batch_size", 1: "features"},
"coords1": {0: "batch_size", 1: "features"},
"edge_from": {0: "batch_size", 1: "features"},
"edge_to": {0: "batch_size", 1: "features"},
"outputs": {0: "batch_size"},
}
self.run_test(
MySAGEConv(),
inputs,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)

# Cannot export with older opsets because of "ConstantFill" op
# ConstantFill was a temp op removed at opset 8. This is no longer supported by onnxruntime
# There are still some issues prevent us from enabling script test for these scenarios:
Expand Down

0 comments on commit fdc4197

Please sign in to comment.