Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

MetricWrapper for Target Binarization #2371

Closed
lgienapp opened this issue Feb 12, 2024 · 3 comments 路 Fixed by #2392
Closed

MetricWrapper for Target Binarization #2371

lgienapp opened this issue Feb 12, 2024 · 3 comments 路 Fixed by #2392
Labels
enhancement New feature or request

Comments

@lgienapp
Copy link
Contributor

lgienapp commented Feb 12, 2024

馃殌 Feature

Add a TargetBinarizationWrapper that cast continuous labels to binary labels given a threshold.

Motivation

Evaluating two metrics that require different label formats (e.g., one binary, the other continuous) is cumbersome since it requires setting up two different evaluation stacks where one is fed with binarized label data and the other is fed the original continuous data. This leads to code duplication. Also, persisting binarized labels into the dataset in scenarios where a metric requires different input than what is given in the ground-truth data diminishes code clarity w.r.t. the evaluation process.

Pitch

A metric wrapper that casts target data to binary targets during the .update() and .forward() methods. Can be applied to either a single Metric, or a whole MetricCollection.

Alternatives

  • using a MultiTaskWrapper is possible, but has two caveats: (1) metrics with a different signature than update(pred, target) are not supported, and (2) it requires the user to implement the thresholding logic by themselves before feeding it into the MultiTaskWrapper
  • use a more generic target processing wrapper that would allow supplying, e.g., a custom lambda that is applied to targets; more flexible, but also requires the user to implement their own logic. I think binarization is a common enough problem in torchmetrics (since its metrics make a binary vs. non-binary distinction) to warrant its own wrapper.

Additional Information

Consider the following example of the desired behaviour:

import torch
from torchmetrics.wrappers import BinarizedTargetWrapper # <-- This does not exist
from torchmetrics.collections import MetricCollection
from torchmetrics.retrieval import RetrievalNormalizedDCG, RetrievalMRR 

preds = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.6, 0.7, 0.8, 0.5, 0.4])
targets = torch.tensor([1,0,0,0,0,2,1,0,0,0])
topics = torch.tensor([0,0,0,0,0,1,1,1,1,1])

metrics = MetricCollection({
    "RetrievalNormalizedDCG": RetrievalNormalizedDCG(),
    "RetrievalMRR": BinarizedTargetWrapper(RetrievalMRR(), threshold=0)), # <-- Enable this kind of metric composition which is not possible currently
})

metrics.update(preds, targets, indexes=topics)
metrics.compute()

If simple binarization as in the example is a desired solution, I have all the code needed for a pull request ready and can take on this issue.

@lgienapp lgienapp added the enhancement New feature or request label Feb 12, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

Hi @lgienapp, thanks for opening this issue.
I am not against adding this feature, but maybe we should consider adding:

  • an general MetricInputTransformer class where user can provide custom functions for transforming the input
  • then BinarizedTargetWrapper can be included as just a subclass MetricInputTransformer with pre-selected transforms

how does that sound?

@lgienapp
Copy link
Contributor Author

lgienapp commented Feb 13, 2024

Establishing a general class sounds good. Just for clarification: the general MetricInputTransformer would still subclass the WrapperMetric to inherit all the "reset-sync" code from there (e.g., live in torchmetrics.wrappers.transformations)? Or be its own thing (e.g., live in torchmetrics.transformations.abstract), possibly duplicating the sync code from the wrapper base class?

In either case, I would propose an implementation like this (subclassing wrappers here), assuming that only positional params like preds and targets would be interesting to modify (and thus ignoring kwargs e.g. indices):

class MetricInputTransformer(WrapperMetric):
  
    def __init__(self, wrapped_metric: Union[Metric, MetricCollection], **kwargs: Any):
        super().__init__(**kwargs)
        self.wrapped_metric = wrapped_metric

    def transform(self, *args) -> Tuple[torch.Tensor]:
        raise NotImplementedError

    def update(self, *args, **kwargs: Any) -> None:
        self.wrapped_metric.update(*self.transform(*args), **kwargs)

    def compute(self) -> Any:
        return self.wrapped_metric.compute()

    def forward(self, *args, **kwargs: Any) -> Any:
        self.wrapped_metric.forward(*self.transform(*args), **kwargs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants