Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

return_attention_weights when set to False returns attention weights in GATv2Conv #9319

Open
kar655 opened this issue May 13, 2024 · 1 comment
Labels

Comments

@kar655
Copy link

kar655 commented May 13, 2024

馃悰 Describe the bug

When setting optional argument return_attention_weights in GATv2Conv's forward to False it returns the attention weights. I believe it doesn't check parameter's value only if it not None.

        if isinstance(return_attention_weights, bool):
            if isinstance(edge_index, Tensor):
                if is_torch_sparse_tensor(edge_index):
                    # TODO TorchScript requires to return a tuple
                    adj = set_sparse_value(edge_index, alpha)
                    return out, (adj, alpha)
                else:
                    return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

My guess is to change
return_attention_weights: Optional[bool] = None to just return_attention_weights: bool = False and appropriate if.

By running rg "isinstance(.*, bool)" -C 5 in repo other files might have similar issue:

  • torch_geometric/nn/conv/gat_conv.py
  • torch_geometric/nn/conv/fa_conv.py
  • torch_geometric/nn/conv/gatv2_conv.py
  • torch_geometric/nn/models/mlp.py
  • torch_geometric/nn/conv/rgat_conv.py

Let me know if it's a bug, or I don't get something.
Thanks

Versions

I'm using torch-geometric 2.4.0 but in the newest repo it still occurs.

@kar655 kar655 added the bug label May 13, 2024
@rusty1s
Copy link
Member

rusty1s commented May 22, 2024

Yes, this is an expected hack due to TorchScript. We cannot yet change the return type based on values of boolean types, so we opted to condition the return type based on return_attention_weights is None vs return_attention_weights is bool.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants