Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSIM has values larger than 1 #2327

Open
TuanDTr opened this issue Jan 29, 2024 · 8 comments
Open

SSIM has values larger than 1 #2327

TuanDTr opened this issue Jan 29, 2024 · 8 comments
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.3.x

Comments

@TuanDTr
Copy link

TuanDTr commented Jan 29, 2024

馃悰 Bug

It seems like SSIM can have values larger than 1 when computing over an epoch. I cannot reproduce this error but only observe it with tensorboard after training.
image

Environment

  • Python 3.9.13
  • Pytorch 2.0.1
  • TorchMetrics 1.3.0

Additional context

I am comparing between a synthetic and an target image. The target image has value between 0 and 1 while the synthetic image has values between -1 and 1. The data_range is set to None. I know the range between the synthetic and target image is different but still should SSIM produce a maximum score of 1?

@TuanDTr TuanDTr added bug / fix Something isn't working help wanted Extra attention is needed labels Jan 29, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

Hi @TuanDTr, thanks for reporting this issue.
I just tried different input tensors based on your description and I cannot reproduce the behavior either. I am pretty sure that the metric should not be able to output values larger than 1.
Would it be possible to share how you initialize the metric? Alternatively, how your training step where you log the metric looks like?

@TuanDTr
Copy link
Author

TuanDTr commented Jan 29, 2024

@SkafteNicki Thank you for your quick response. Please find bellow the methods for forward, training steps as well as for validation steps where I initialize the metrics. Basically the metric is initialized in on_validation_model_eval, updated per step and aggregated and reset in on_validation_epoch_end. I have used this setup for a while without any problem but noticed this when evaluating 3D tensors (since I moved from 2D to 3D diffusion models).

    def training_step(self, batch: Union[Tuple, torch.Tensor], batch_idx: int) -> torch.Tensor:
        if self.use_profiler:
            self.profiler.step()
            
        x = batch["t1c"]
        z = self.get_latent_code(x)

        z_cond = []

        for m in self.hparams.cond_modality:
                x_cond = batch[m]
                z_cond.append(self.get_latent_code(x_cond))

        z_cond = torch.cat(z_cond, dim=1)

        noise = torch.randn_like(z).to(self.device)
        timesteps = torch.randint(0, self.hparams.num_train_steps, (z.shape[0], ), device=self.device).long()
        noisy_z = self._scheduler.add_noise(original_samples=z, noise=noise, timesteps=timesteps)
        noise_pred = self._unet(torch.cat((noisy_z, z_cond), dim=1), timesteps=timesteps)
        loss = self.criterion(noise_pred, noise)

        self.log("train/noise_recons_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    @torch.inference_mode()
    def forward(self, z_cond):
        self._ema.ema_model.eval()
        z_dim = [z_cond.shape[0], self._unet.out_channels, *z_cond.shape[2:]]
        z = torch.randn(z_dim, device=self.device)

        self._scheduler.set_timesteps(num_inference_steps=self.hparams.num_inference_steps)
        for t in range(self.hparams.num_inference_steps):
            model_output = self._ema.ema_model(
                torch.cat((z, z_cond), dim=1),
                timesteps=torch.Tensor((t,)).to(self.device).long()
            )
            z, _ = self._scheduler.step(model_output, t, z)

        x = self.decode_from_latent_code(z)

        return x, z

    def on_validation_model_eval(self) -> None:
        """Prepare before validation."""

        self.metrics = {
            "PSNR": PeakSignalNoiseRatio(data_range=None).to(self.device),
            "SSIM": StructuralSimilarityIndexMeasure(data_range=None).to(self.device),
            "MAE_image": MeanAbsoluteError().to(self.device),
            "MAE_latent": MeanAbsoluteError().to(self.device)
        }
        super().on_validation_model_eval()

    def validation_step(self, batch: Union[Tuple, torch.Tensor], batch_idx: int) -> None:
        x = batch["t1c"]
        if not self.hparams.preloaded_latent:
            z = self.get_latent_code(x)
        else:
            z = batch["latent_t1c"]

        z_cond = []

        for m in self.hparams.cond_modality:
            if not self.hparams.preloaded_latent:
                x_cond = batch[m]
                z_cond.append(self.get_latent_code(x_cond))
            else:
                z_cond.append(batch[f"latent_{m}"])

        z_cond = torch.cat(z_cond, dim=1)

        preds, latents = self.forward(z_cond)

        # Compute score
        self.metrics["PSNR"].update(preds, x)
        self.metrics["SSIM"].update(preds, x)
        self.metrics["MAE_image"].update(preds, x)
        self.metrics["MAE_latent"].update(latents, z)

        # Inverse transform
        inverse_transform = BatchInverseTransform(self.val_dataloader().dataset.transforms, self.val_dataloader())
        with allow_missing_keys_mode(self.val_dataloader().dataset.transforms):
            preds = inverse_transform({"latent_t1c": preds})
        
        self.save_to_h5_dataset(preds)

    def on_validation_epoch_end(self) -> None:
        psnr = self.metrics["PSNR"].compute()
        ssim = self.metrics["SSIM"].compute()
        mae_image = self.metrics["MAE_image"].compute()
        mae_latent = self.metrics["MAE_latent"].compute()

        self.log("val/psnr", psnr, on_epoch=True, logger=True, prog_bar=True)
        self.log("val/ssim", ssim, on_epoch=True, logger=True, prog_bar=True)
        self.log("val/mae_image", mae_image, on_epoch=True, logger=True, prog_bar=True)
        self.log("val/mae_latent", mae_latent, on_epoch=True, logger=True, prog_bar=True)

        self.metrics["PSNR"].reset()
        self.metrics["SSIM"].reset()
        self.metrics["MAE_image"].reset()
        self.metrics["MAE_latent"].reset()

@Borda Borda added the v1.3.x label Jan 29, 2024
@SkafteNicki
Copy link
Member

Hi @TuanDTr,
I tried to reproduce the error again without success. Also the code you send me looks fine, nothing there.
I am going to assume that the reason is that you have different scaling on your input. The underlying assumption in SSIM is that the input is scaled in a similar way.
Else you would have to send me the full metric state when the error happens:

def on_validation_epoch_end(self) -> None:
    ssim = self.metrics["SSIM"].compute()
    if ssim > 1:
        torch.save(self.metrics["SSIM"].metric_state, "ssim_state.pt")

@TuanDTr
Copy link
Author

TuanDTr commented Feb 6, 2024

Hi @SkafteNicki, here is the state when the error happens:

{'similarity': metatensor(105.5943, device='cuda:0'),
 'total': tensor(101., device='cuda:0')}

I'll further scale all input tensors to the same range and see if this still occurs. I will follow up with you.

@SkafteNicki
Copy link
Member

@TuanDTr I been trying to further debug the issue on my end and I am still unable to reproduce the problem. From the output you send it is very clear to me that 105/101 > 1 but not how the similarity gets to be higher than 101.
Have you tried the rescaling of the range to see if it helped?

@TuanDTr
Copy link
Author

TuanDTr commented Feb 14, 2024

Hi @SkafteNicki, I have tried rescaling the range of inputs to [0, 1] (see below) but I still encountered the SSIM > 1. I am setting the data_range to None, which I believe will eventually set the data_range to 1, right? I will assess the range of saved outputs, hopefully it can shed light on something else. I am sorry for the late response as I am working on other stuffs at the moment but I still follow up on this.

def forward(self, z_cond):
   ....
  return x.clamp(0, 1), z

@TuanDTr
Copy link
Author

TuanDTr commented Apr 23, 2024

@SkafteNicki Hello and I am sorry for the late update. I might have an idea why SSIM is larger than 1. I inspected my evaluation script and found that the tensors were in float16. If I change them to float32, I will get the correct results. However, I cannot reproduce this issue outside my training script. I think my setting that uses mixed precision training could have something to do with this. Do you have any idea how to inspect this further. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.3.x
Projects
None yet
Development

No branches or pull requests

3 participants