Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a CW-SSIM support for torchmetrics #2428

Open
michael080808 opened this issue Mar 5, 2024 · 8 comments
Open

Add a CW-SSIM support for torchmetrics #2428

michael080808 opened this issue Mar 5, 2024 · 8 comments
Labels
enhancement New feature or request New metric
Milestone

Comments

@michael080808
Copy link

michael080808 commented Mar 5, 2024

🚀 Feature

A Complex-Wavelets Structure Similarity (also know as CW-SSIM) support with Steerable Pyramid (SP) or Dual-Tree Complex Wavelet Transform Method (DTCWT). Maybe support all possible Q-shift and first level filters as well?

Motivation

I noticed that someone just mentioned in #799

Would you accept a PR for one not on the list? I have an implementation of complex wavelet structural similarity (CW-SSIM) that I could contribute. (https://ieeexplore.ieee.org/abstract/document/5109651)

For some research purpose, I found that there is a few project with CW-SSIM code but they are not updated to the latest pytorch version. Will torchmetrics add a CW-SSIM support in torchmetrics.image? Here are some collections with old codes.

I just tried to use latest pytorch version (which has supported complex convolution and it's important to my usage) to achieve this function. It's very difficult for me to understand the math formula for all the complex wavelet things. If there are some suggestions on math I would appreciate for that. Please do not make confuse with Complex Wavelet Transform and Continuous Wavelet Transform because of the CWT abbreviation.

Pitch

Alternatives

I tried the scipy and pywavelets but they do not support SP or DTCWT. Only Discrete Wavelet Transform (DWT) with multi-dimensional support and Continuous Wavelet Transform are included. Listed projects are too old to run on latest PyTorch.

Additional context

If there are some further math discussion about SP or DTCWT, I'll try to achieve myself and pull request to the torchmetrics. I'm very confused with the relationship of scaling function and wavelet function and whether it should be considered in SP or DTCWT. How does the SP or DTCWT's filters be calucated? Sorry for my pool math about the discrete and continous domain conversion.

@michael080808 michael080808 added the enhancement New feature or request label Mar 5, 2024
@SkafteNicki SkafteNicki added this to the future milestone Mar 5, 2024
@SkafteNicki
Copy link
Member

Hi @michael080808, thanks for opening this issue. We would be more than welcome to receive a pull request with this metric (either a partial implementation or full), but I do not think anyone at the core team has bandwidth or the experience to implement such a complex metric at the moment. If you can point me to a specific reference implementation, maybe I can give it a stab.

@michael080808
Copy link
Author

michael080808 commented Mar 5, 2024

Hi @michael080808, thanks for opening this issue. We would be more than welcome to receive a pull request with this metric (either a partial implementation or full), but I do not think anyone at the core team has bandwidth or the experience to implement such a complex metric at the moment. If you can point me to a specific reference implementation, maybe I can give it a stab.

There are some implementations as mentioned.

@SkafteNicki
Copy link
Member

@michael080808 could you point me to which one of these have the most promising reference implementation? What specific file are the metric implemented in?

@michael080808
Copy link
Author

michael080808 commented Mar 5, 2024

@michael080808 could you point me to which one of these have the most promising reference implementation? What specific file are the metric implemented in?

Others are the references for DTCWT implementation. According to my understanding, once Steerable Pyramid or DTCWT has been implemented, CW-SSIM will be very easy to code because CW-SSIM just uses the transform result from SP or DTCWT to calculate the new SSIM definition. Maybe it's very hard to implement DTCWT without some reference codes. The following two are specific for DTCWT implementation.

I hope that this information is helpful. If there are any other questions, I would be pleased to answer them. 🙂❤

@SkafteNicki
Copy link
Member

@michael080808 thanks for providing this overview, it really helps me.
So it seems to me that the best cause of action is to take the implementation from https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/CW_SSIM.py, make sure it works with newer versions of pytorch and then replace the backend to use https://github.com/LabForComputationalVision/plenoptic/blob/main/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py instead. I think that is definitely doable, but I have to look closer at the code and understand the metric a bit better.

@michael080808
Copy link
Author

michael080808 commented Mar 7, 2024

I got a quick learn with

and tried a simple version of CW-SSIM. Here are two parts of the codes running with PyTorch 2.2. I did some coordinate changes for better calculation when input width or height is with odd number. I think it should pay more attention with complex convolution support. It's a very new feature and CW-SSIM heavily depends on it. I hope it could be helpful for understand.

#pyramid.py

"""
Put [0, Length - 1] into [-1, 1]
I prefer use pixel center as coordinate position
                  +-----+-----+-----+
                  |     |     |     |
                  |  A  |  B  |  C  |
                  |     |     |     |
+-----+-----+     +-----+-----+-----+
|     |     |     |     |     |     |
|  A  |  B  |     |  D  |  O  |  E  |
|     |     |     |     |     |     |
+-----O-----+     +-----+-----+-----+
|     |     |     |     |     |     |
|  C  |  D  |     |  F  |  G  |  H  |
|     |     |     |     |     |     |
+-----+-----+     +-----+-----+-----+
Here, O is the coordinate origin.
In even amount of pixels situation, A, B, C, D's coordinates are with half values.
In odd  amount of pixels situation, A, B, C, D, E, F, G, H's coordinates are without half values.
"""
import functools
import itertools
import math
import operator
from abc import ABCMeta, abstractmethod
from typing import List, Tuple, Union

import torch.fft
from torch import Tensor
from torch.types import Device


class SteerablePyramid:
    class _Filter(metaclass=ABCMeta):
        @staticmethod
        def bound_convert_2_tuple(boundary: Union[float, Tuple[float], Tuple[float, float]]) -> Tuple[float, float]:
            if isinstance(boundary, float):
                boundary = (boundary,)
            if isinstance(boundary, tuple) and len(boundary) == 1:
                boundary = (boundary[0], boundary[0])
            return boundary[0], boundary[1]

        @staticmethod
        def normalized_lin_spaces(length: int, device: Device = None) -> Tensor:
            number = torch.arange(length, device=device)
            return (number - length / 2 + 0.5) / (length // 2)

        @staticmethod
        def normalized_coordinate(shapes: Tuple[int, int], device: Device = None) -> Tuple[Tensor, ...]:
            coords = [SteerablePyramid._Filter.normalized_lin_spaces(length, device) for length in reversed(shapes)]
            return torch.meshgrid(*list(reversed(coords)), indexing='ij')

        @staticmethod
        def polars(shapes: Tuple[int, int], device: Device = None) -> Tuple[Tensor, Tensor]:
            x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
            return torch.sqrt(x ** 2 + y ** 2), torch.arctan2(y, x)

        @staticmethod
        def angles(shapes: Tuple[int, int], device: Device = None) -> Tensor:
            x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
            return torch.arctan2(y, x)

        @staticmethod
        def radius(shapes: Tuple[int, int], device: Device = None) -> Tensor:
            x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
            return torch.sqrt(x ** 2 + y ** 2)

        @staticmethod
        def bounds(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]], device: Device = None) -> Tensor:
            boundary = SteerablePyramid._Filter.bound_convert_2_tuple(boundary)
            return (boundary[0] * boundary[1]) / torch.sqrt((boundary[0] * torch.cos(SteerablePyramid._Filter.angles(shapes, device))) ** 2 + (boundary[1] * torch.sin(SteerablePyramid._Filter.angles(shapes, device))) ** 2)

        @staticmethod
        def high_band_pass_filter(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, device: Device = None):
            diff = torch.log2(SteerablePyramid._Filter.radius(shapes, device)) - torch.log2(SteerablePyramid._Filter.bounds(shapes, boundary, device))
            return torch.abs(torch.cos((torch.clamp(diff, min=-transition_width, max=0) / transition_width) * (math.pi / 2)))

        @staticmethod
        def bass_band_pass_filter(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, device: Device = None):
            high = SteerablePyramid._Filter.high_band_pass_filter(shapes, boundary=boundary, transition_width=transition_width, device=device)
            return torch.sqrt(1 - high ** 2)

        @abstractmethod
        def __init__(self):
            super().__init__()

        @abstractmethod
        def __call__(self, shapes: Tuple[int, int], device: Device = None):
            raise NotImplementedError

    class BassPassFilter(_Filter):
        def __init__(self, boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
            super().__init__()
            self.boundary = boundary
            self.transition_width = transition_width

        def __call__(self, shapes: Tuple[int, int], device: Device = None):
            return self.bass_band_pass_filter(shapes=shapes, boundary=self.boundary, transition_width=self.transition_width, device=device)

    class HighPassFilter(_Filter):
        def __init__(self, boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
            super().__init__()
            self.boundary = boundary
            self.transition_width = transition_width

        def __call__(self, shapes: Tuple[int, int], device: Device = None):
            return self.high_band_pass_filter(shapes=shapes, boundary=self.boundary, transition_width=self.transition_width, device=device)

    class BandPassFilter(_Filter):
        def __init__(self, boundary_high: Union[float, Tuple[float], Tuple[float, float]] = 1.0, boundary_bass: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
            super().__init__()
            assert functools.reduce(operator.__and__, itertools.starmap(operator.ge, zip((boundary_high,) if isinstance(boundary_high, float) else boundary_high, (boundary_bass,) if isinstance(boundary_bass, float) else boundary_bass))), 'All elements from "boundary_high" must be greater than or equal to the corresponding elements in "boundary_bass".'
            self.boundary_bass, self.boundary_high, self.transition_width = boundary_bass, boundary_high, transition_width

        def __call__(self, shapes: Tuple[int, int], device: Device = None):
            return self.bass_band_pass_filter(shapes=shapes, boundary=self.boundary_high, transition_width=self.transition_width, device=device) * self.high_band_pass_filter(shapes=shapes, boundary=self.boundary_bass, transition_width=self.transition_width, device=device)

    class SteeringFilter(BandPassFilter):
        def __init__(self, boundary_bass: Union[float, Tuple[float], Tuple[float, float]] = 1.0, boundary_high: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, index: int = 0, orientations: int = 2, support_cplx: bool = False):
            super().__init__(boundary_bass, boundary_high, transition_width)
            assert index < orientations, '"index" must be less than or equal to "orientations".'
            self.index, self.orientations, self.support_cplx = index, orientations, support_cplx

        def __call__(self, shapes: Tuple[int, int], device: Device = None):
            return super().__call__(shapes, device) * self.orientation_filter(shapes, self.support_cplx, device)

        @property
        def constant(self):
            order = self.orientations - 1
            return math.pow(2, (2 * order)) * math.pow(math.factorial(order), 2) / (self.orientations * math.factorial(2 * order))

        def orientation_filter(self, shapes: Tuple[int, int], u4cplx: bool = False, device: Device = None):
            angles = torch.remainder(math.pi + self.angles(shapes, device) - math.pi * self.index / self.orientations, 2 * math.pi) - math.pi
            return (torch.abs(math.sqrt(self.constant) * torch.pow(torch.cos(angles), self.orientations - 1))) * (torch.lt(torch.abs(angles), math.pi / 2) if u4cplx else 1)

    @staticmethod
    def to_freq_domain(x: Tensor) -> Tensor:
        assert x.dim() >= 2, 'Not enough dimensions to run "to_freq_domain" procedure.'
        return torch.fft.fftshift(torch.fft.fft2(x, dim=[-2, -1]), dim=[-2, -1])

    @staticmethod
    def to_time_domain(x: Tensor) -> Tensor:
        assert x.dim() >= 2, 'Not enough dimensions to run "to_time_domain" procedure.'
        return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-2, -1]), dim=[-2, -1])

    @staticmethod
    def to_crop_region(entire: Tuple[int, int], region: Tuple[int, int]) -> Tuple[List[int], ...]:
        assert functools.reduce(operator.__and__, itertools.starmap(operator.ge, zip(entire, region))) and functools.reduce(operator.__and__, itertools.starmap(lambda x, y: (x - y) % 2 == 0, zip(entire, region))), 'All elements from "shapes" must be greater than or equal to the corresponding elements in "region".'
        return tuple([(shape - focal) // 2, focal, (shape - focal) // 2] for shape, focal in zip(entire, region))

    @staticmethod
    def to_crop_tensor(inputs: Tensor, region: Tuple[int, int]) -> Tensor:
        splits = SteerablePyramid.to_crop_region(entire=(inputs.shape[-2], inputs.shape[-1]), region=region)
        return torch.split(torch.split(inputs, splits[-1], dim=-1)[1], splits[-2], dim=-2)[1]

    @staticmethod
    def to_join_tensor(fronts: Tensor, backed: Tensor) -> Tensor:
        assert fronts.dim() == backed.dim() >= 2 and fronts.shape[-1] < backed.shape[-1] and fronts.shape[-2] < backed.shape[-2] and fronts.shape[:-2] == backed.shape[:-2], 'Unable to join two tensors into one due to the shape mismatch.'
        return torch.nn.functional.pad(fronts, [(backed.shape[-1] - fronts.shape[-1]) // 2] * 2 + [(backed.shape[-2] - fronts.shape[-2]) // 2] * 2, mode='constant', value=0) + backed * torch.nn.functional.pad(torch.zeros_like(fronts), [(backed.shape[-1] - fronts.shape[-1]) // 2] * 2 + [(backed.shape[-2] - fronts.shape[-2]) // 2] * 2, mode='constant', value=1)

    def __init__(self, group_levels: int = 6, orientations: int = 16, support_cplx: bool = True, transition_w: float = 1.0):
        super().__init__()
        self.group_levels = group_levels
        self.orientations = orientations
        self.support_cplx = support_cplx
        self.transition_w = transition_w

    def region_iteration(self, shapes: Tuple[int, int]):
        last = shapes
        yield last
        for i in range(self.group_levels):
            last = (last[0] - (last[0] // 4) * 2, last[1] - (last[1] // 4) * 2)
            yield last

    def factor_iteration(self, shapes: Tuple[int, int]):
        last = shapes
        yield tuple(itertools.starmap(operator.truediv, zip(last, shapes)))
        for _ in range(self.group_levels):
            last = (last[0] - (last[0] // 4) * 2, last[1] - (last[1] // 4) * 2)
            yield tuple(itertools.starmap(operator.truediv, zip(last, shapes)))

    def filter_iteration(self, shapes: Tuple[int, int], device: Device = None):
        iteration = zip(itertools.pairwise(self.factor_iteration(shapes)), self.region_iteration(shapes))
        for level, ((prev_f, curr_f), region) in enumerate(iteration):
            if level == 0:
                yield self.to_crop_tensor(self.HighPassFilter(boundary=prev_f, transition_width=self.transition_w)(shapes, device), region), f'H{level}'
                yield self.to_crop_tensor(self.BassPassFilter(boundary=prev_f, transition_width=self.transition_w)(shapes, device), region), f'L{level}'
            for orientation in range(self.orientations):
                yield self.to_crop_tensor(self.SteeringFilter(boundary_bass=prev_f, boundary_high=curr_f, transition_width=self.transition_w, index=orientation, orientations=self.orientations, support_cplx=self.support_cplx)(shapes, device), region), f'B{level + 1}o{orientation}'
            yield self.to_crop_tensor(self.BassPassFilter(boundary=curr_f, transition_width=self.transition_w)(shapes, device), region), f'L{level + 1}'

    def encode_iteration(self, tensor: Tensor):
        shapes = (tensor.shape[-2], tensor.shape[-1])
        target, window = self.to_freq_domain(tensor), None
        it_filter = self.filter_iteration(shapes, tensor.device)

        # L0 HighPass Output
        window = next(it_filter)
        time_domain = self.to_time_domain(target * window[0])
        yield time_domain if self.support_cplx else torch.real(time_domain), window[1]

        # L0 BassPass Remove
        window = next(it_filter)
        target = target * window[0]
        # yield time_domain if self.support_cplx else torch.real(time_domain), window[1] <- Removed due to definition.

        # Each Level Steering BandPass
        for level, (curr_r, next_r) in enumerate(itertools.pairwise(self.region_iteration(shapes))):
            for orientation in range(self.orientations):
                window = next(it_filter)
                time_domain = self.to_time_domain(target * window[0])
                yield time_domain if self.support_cplx else torch.real(time_domain), window[1]
            window = next(it_filter)
            target = self.to_crop_tensor(target * window[0], next_r)

        # Final BassPass
        time_domain = self.to_time_domain(target)
        yield time_domain if self.support_cplx else torch.real(time_domain), window[1]
# main.py

from typing import Tuple

import skimage
import torch
from skimage.data import astronaut
from torch import Tensor
from torch.nn import Module

from pyramid import SteerablePyramid


class CwSSIM(Module):
    result_pyramid = SteerablePyramid()
    ground_pyramid = SteerablePyramid()

    def __init__(self, kernel: int = 7, k: float = 0, levels: int = 6, orientations: int = 16, transition_w: float = 1.0):
        super().__init__()
        self.k = k
        self.kernel = torch.ones([kernel] * 2)[None, None, ...]
        self.result_pyramid = SteerablePyramid(group_levels=levels, orientations=orientations, transition_w=transition_w, support_cplx=True)
        self.ground_pyramid = SteerablePyramid(group_levels=levels, orientations=orientations, transition_w=transition_w, support_cplx=True)

    def multidim_conv2d(self, inputs: Tensor, *args, **kwargs) -> Tensor:
        if inputs.dim() <= 1:
            raise ValueError('One Dimensional Input is not supported.')
        channels = inputs.shape[-3] if inputs.dim() >= 3 else 1
        paddings = [self.kernel.size(dim) // 2 for _ in range(2) for dim in [-1, -2]]
        groups = args[4] if len(args) >= 5 else kwargs.get('groups', channels)
        kwargs['groups'] = groups
        shapes = inputs.shape
        kernel = self.kernel.repeat(1, 1, 1, 1) if inputs.dim() == 2 else self.kernel.repeat(channels, channels // groups, 1, 1)

        if inputs.dim() >= 5:
            return torch.nn.functional.conv2d(torch.nn.functional.pad(inputs.flatten(0x0, -0x4), paddings, mode='reflect'), kernel.to(device=inputs.device, dtype=inputs.dtype), *args, **kwargs).unflatten(0, shapes[:-3])
        if 2 <= inputs.dim() <= 4:
            return torch.nn.functional.conv2d(torch.nn.functional.pad(inputs.repeat(1, 1, 1, 1), paddings, mode='reflect'), kernel.to(device=inputs.device, dtype=inputs.dtype), *args, **kwargs).squeeze(tuple(range(0, 4 - inputs.dim())))

    def statistics(self, result: Tensor, ground: Tensor, *args, **kwargs) -> Tuple[Tensor, ...]:
        conj_prods = self.multidim_conv2d(result * torch.conj(ground), *args, **kwargs)
        sum_mod_sq = self.multidim_conv2d(torch.abs(result) ** 2 + torch.abs(ground) ** 2, *args, **kwargs)
        return conj_prods, sum_mod_sq

    def forward(self, result: Tensor, ground: Tensor, *args, **kwargs) -> Tensor:
        assert result.shape == ground.shape
        result_encode_iter = self.__class__.result_pyramid.encode_iteration(result)
        ground_encode_iter = self.__class__.ground_pyramid.encode_iteration(ground)

        count, summarized = 0, torch.zeros(1)
        for (result_encode, _), (ground_encode, _) in zip(result_encode_iter, ground_encode_iter):
            conj_prods, sum_mod_sq = self.statistics(result_encode, ground_encode, *args, **kwargs)
            _ssim = (2 * torch.abs(conj_prods) + self.k) / (sum_mod_sq + self.k)
            count, summarized = count + 1, summarized + torch.mean(_ssim, dim=[-2, -1], keepdim=True)
        return summarized / count


ssim = CwSSIM()
if __name__ == '__main__':
    image = skimage.util.img_as_float32(astronaut())
    noise = skimage.util.random_noise(image, mode='speckle')
    print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))

    image = skimage.util.img_as_float32(astronaut())
    noise = 0.8 * image
    print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))

    image = skimage.util.img_as_float32(astronaut())
    noise = skimage.transform.rotate(image, 1)
    print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))
# Result of CW-SSIM
tensor([[[0.8367]], [[0.8667]], [[0.8853]]], dtype=torch.float64)
tensor([[[0.9756]], [[0.9756]], [[0.9756]]])
tensor([[[0.8905]], [[0.8921]], [[0.8893]]])

@SkafteNicki
Copy link
Member

@michael080808 is what you posted a full implementation? It does not seem to rely on any third party package?

@michael080808
Copy link
Author

michael080808 commented Mar 8, 2024

@michael080808 is what you posted a full implementation? It does not seem to rely on any third party package?

It's a relatively full implementation. I did not write update and compute procedure in torchmetrics. I just rewrite SteerablePyramid Method to meet my requirement. So it does not rely on any other packages.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request New metric
Projects
None yet
Development

No branches or pull requests

2 participants