Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.clamp issue due to update of pytorch 1.12.0 #237

Open
CryptoSalamander opened this issue Jul 14, 2022 · 2 comments
Open

torch.clamp issue due to update of pytorch 1.12.0 #237

CryptoSalamander opened this issue Jul 14, 2022 · 2 comments

Comments

@CryptoSalamander
Copy link

CryptoSalamander commented Jul 14, 2022

Since torch.clamp was updated in 1.12.0, the latest version of Pytorch, torch.clamp's min, max argument should be loaded on same device with input tensor. pytorch/pytorch#77035

I got an error with PyTorch 1.12.0 in this line,

logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()

Error :

backbones/SwinV2.py:153, in WindowAttention.forward(self, x, mask)
    151 # cosine attention
    152 attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
--> 153 logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
    154 attn = attn * logit_scale
    156 relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument max in method wrapper_clamp_Tensor)

In 1.11.0 this line works without problems because there was no argument-type promotion before 1.12.0!
but now, guess it should be fixed.

CryptoSalamander added a commit to CryptoSalamander/Swin-Transformer that referenced this issue Jul 14, 2022
@jaehyunnn
Copy link

jaehyunnn commented Aug 3, 2022

I have simply solved it as follows:

153 logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.get_device())).exp()

@haraldger
Copy link

haraldger commented Dec 17, 2022

I solved it as follows:
image
image

A fix to this problem would be very useful.

CryptoSalamander added a commit to CryptoSalamander/Swin-Transformer that referenced this issue Apr 11, 2023
CryptoSalamander added a commit to CryptoSalamander/Swin-Transformer that referenced this issue Apr 11, 2023
CryptoSalamander added a commit to CryptoSalamander/Swin-Transformer that referenced this issue Apr 11, 2023
CryptoSalamander added a commit to CryptoSalamander/Swin-Transformer that referenced this issue Apr 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants