Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalent Implementations to ScatterElements operator #6022

Open
fabecode opened this issue Mar 14, 2024 · 3 comments
Open

Equivalent Implementations to ScatterElements operator #6022

fabecode opened this issue Mar 14, 2024 · 3 comments
Labels
question Questions about ONNX

Comments

@fabecode
Copy link

Question

I am implementing a GNN model and compiling it using Concrete ML, in which the underlying is an ONNX conversion using ONNX 1.13.1. However, Concrete ML has yet to implement the ScatterElements ONNX operator in their library, resulting in a ValueError: The following ONNX operators are required to convert the torch model to numpy but are not currently implemented: ScatterElements. Hence, to solve this issue, I have the following questions:

  1. Based on below function and graph, do you think it is more likely that ScatterElements occurs in
  • x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu() line OR
  • x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1) line?
  1. Do you have any ideas how to replace the line or function that uses the ScatterElements operator with an alternative one that uses one of Concrete ML's supported ONNX operators?
  2. Any suggestions of alternative workarounds to this issue (e.g. a simplified implementation of ScatterElements)?

Below is the forward pass function of the GNN.

def forward(self, x, edge_index, edge_attr):
        src, dst = edge_index

        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)

        for i in range(self.num_gnn_layers):
            x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
            if self.edge_updates: 
                edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2

        x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
        x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
        out = x
        
        return self.mlp(out)

Below is a snippet of the ONNX graph that captures one of the ScatterElements instance.
Screenshot 2024-03-14 114122

Thank you!

@fabecode fabecode added the question Questions about ONNX label Mar 14, 2024
@justinchuby
Copy link
Contributor

justinchuby commented Mar 14, 2024

Looks like within self.emlps[i] because that's after Concat and before Add

@fabecode
Copy link
Author

Thank you for the suggestion! However, I did not activate self.edge_updates when running the code so it did not pass through that line. To confirm, I just tried removing that line, and the ONNX graph remains the same.

Below are screenshots of the full onnx graph:
onnx1
onnx2
onnx3

Appreciate if there are any further suggestions!

@justinchuby
Copy link
Contributor

If you click on the nodes, there should be annotations showing the source code location.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions about ONNX
Projects
None yet
Development

No branches or pull requests

2 participants