You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am implementing a GNN model and compiling it using Concrete ML, in which the underlying is an ONNX conversion using ONNX 1.13.1. However, Concrete ML has yet to implement the ScatterElements ONNX operator in their library, resulting in a ValueError: The following ONNX operators are required to convert the torch model to numpy but are not currently implemented: ScatterElements. Hence, to solve this issue, I have the following questions:
Based on below function and graph, do you think it is more likely that ScatterElements occurs in
x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu() line OR
x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1) line?
Do you have any ideas how to replace the line or function that uses the ScatterElements operator with an alternative one that uses one of Concrete ML's supported ONNX operators?
Any suggestions of alternative workarounds to this issue (e.g. a simplified implementation of ScatterElements)?
Below is the forward pass function of the GNN.
def forward(self, x, edge_index, edge_attr):
src, dst = edge_index
x = self.node_emb(x)
edge_attr = self.edge_emb(edge_attr)
for i in range(self.num_gnn_layers):
x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
if self.edge_updates:
edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2
x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
out = x
return self.mlp(out)
Below is a snippet of the ONNX graph that captures one of the ScatterElements instance.
Thank you!
The text was updated successfully, but these errors were encountered:
Thank you for the suggestion! However, I did not activate self.edge_updates when running the code so it did not pass through that line. To confirm, I just tried removing that line, and the ONNX graph remains the same.
Question
I am implementing a GNN model and compiling it using Concrete ML, in which the underlying is an ONNX conversion using ONNX 1.13.1. However, Concrete ML has yet to implement the ScatterElements ONNX operator in their library, resulting in a
ValueError: The following ONNX operators are required to convert the torch model to numpy but are not currently implemented: ScatterElements.
Hence, to solve this issue, I have the following questions:x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
line ORx = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
line?Below is the forward pass function of the GNN.
Below is a snippet of the ONNX graph that captures one of the ScatterElements instance.
Thank you!
The text was updated successfully, but these errors were encountered: