Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/remove double forward #984

Merged
merged 50 commits into from May 10, 2022
Merged

Refactor/remove double forward #984

merged 50 commits into from May 10, 2022

Conversation

SkafteNicki
Copy link
Member

@SkafteNicki SkafteNicki commented Apr 25, 2022

What does this PR do?

Redo of #612
Fixes part of #344 (needs review after if we can close the issue)

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@SkafteNicki SkafteNicki added Priority Critical task/issue refactoring refactoring and code health labels Apr 25, 2022
@SkafteNicki SkafteNicki added this to the v0.9 milestone Apr 25, 2022
@Borda Borda added Important milestonish and removed Priority Critical task/issue labels Apr 25, 2022
@codecov
Copy link

codecov bot commented Apr 26, 2022

Codecov Report

Merging #984 (fc1e2c4) into master (bf0fa97) will increase coverage by 0%.
The diff coverage is 99%.

@@          Coverage Diff           @@
##           master   #984    +/-   ##
======================================
  Coverage      95%    95%            
======================================
  Files         180    180            
  Lines        7666   7823   +157     
======================================
+ Hits         7276   7430   +154     
- Misses        390    393     +3     

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also include in the PR the speed-up chat? 🐰

Copy link
Member

@justusschock justusschock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a substantial change. While it does indeed prevent two update calls, I feel from the user experience it might be complicating the process of implementing custom metrics.

Not saying that I am in general against merging this one, but I don't want to rush it here... I feel like we should discuss the implications a bit more

torchmetrics/metric.py Show resolved Hide resolved
torchmetrics/metric.py Show resolved Hide resolved
@SkafteNicki
Copy link
Member Author

@justusschock I completely agree that this is not a trival change and should only be done with care.
However, our current base tester class actually calls forward:
https://github.com/PyTorchLightning/metrics/blob/c494f3575f542e7c264555a71fb97e8676f64700/tests/helpers/testers.py#L180
therefore, if I can get tests to pass (which they currently are not doing) I think the change is at least safe to do to our own codebase.
Does this mean that it will not break any user implementation? Not sure to be fair. If you have a particular metric that you think is a problem then consider comment it to this PR so we can check it.

We can also consider if this should be an opt-in feature. With the changes we did to the additional metric arguments, collapsing them all into the **kwargs arg, we can easily do something like:

class Metric(nn.Module):
    def __init__(self, **kwargs):
        ...
        self.use_fast_forward = kwargs.pop('use_fast_forward', False)  # find better name, what should default be?
        ...

    def forward(self, *args, **kwargs):
        if self.use_fast_forward:
            return self.forward_method_that_only_calls_update_once(*args, **kwargs)
        else:
            return self.forward_as_it_already_is(*args, **kwargs)

@Borda any opinions?

@Borda Borda self-requested a review April 27, 2022 13:56
@mergify mergify bot added the has conflicts label May 5, 2022
@mergify mergify bot removed the has conflicts label May 5, 2022
@mergify mergify bot added the ready label May 5, 2022
@mergify mergify bot added the has conflicts label May 6, 2022
@mergify mergify bot removed the has conflicts label May 6, 2022
@Borda
Copy link
Member

Borda commented May 8, 2022

@SkafteNicki ready to go? 🎉

@SkafteNicki
Copy link
Member Author

@SkafteNicki ready to go? 🎉

Code should be ready. Let me run some speed tests before we merge to make sure that we actually get the speedup that we expect :]

@Borda
Copy link
Member

Borda commented May 8, 2022

Code should be ready. Let me run some speed tests before we merge to make sure that we actually get the speedup that we expect :]

Sweet, could you pls include it also here in this PR? :)

@SkafteNicki
Copy link
Member Author

Code should be ready. Let me run some speed tests before we merge to make sure that we actually get the speedup that we expect :]

Sweet, could you pls include it also here in this PR? :)

Yes will do :]
Then we also have something to show for next release blog post.

@SkafteNicki
Copy link
Member Author

@justusschock and @Borda, finally created an updated figure:

results

The TLDR is that metrics where update is expensive to evaluate we can more or less cut computational time in half and for metrics where update is cheap it does not really matter if the feature is enabled or disabled.

Code to create figure
from time import perf_counter

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm

import torchmetrics

NUM_REPS = 5
NUM_CALLS = [1, 10, 25, 50, 75, 100, 250, 500, 750, 1000, 2500]

metrics = [
    torchmetrics.MeanSquaredError,
    torchmetrics.CosineSimilarity,
    torchmetrics.Accuracy,
    torchmetrics.ConfusionMatrix,
    torchmetrics.StructuralSimilarityIndexMeasure,
    torchmetrics.audio.sdr.SignalDistortionRatio,
    torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity,
    torchmetrics.SQuAD,
    torchmetrics.WordErrorRate,
    torchmetrics.AUROC,
]

metric_args = [{}, {}, {"num_classes": 10}, {"num_classes": 10}, {}, {}, {"net_type": "alex"}, {}, {}, {"num_classes": 3}]

inputs = [
    (torch.randn(100,), torch.randn(100,)),
    (torch.randn(100,), torch.randn(100,)),
    (torch.randn(100, 10).softmax(dim=-1), torch.randint(10, (100,))),
    (torch.randn(100, 10).softmax(dim=-1), torch.randint(10, (100,))),
    (torch.rand(5, 3, 25, 25), torch.rand(5, 3, 25, 25)),
    (torch.randn(1, 8000), torch.randn(1, 8000)),
    (torch.rand(1, 3, 32, 32), torch.rand(1, 3, 32, 32)),
    ([{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}], [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}]),
    (["this is the prediction", "there is an other sample"], ["this is the reference", "there is another one"]),
    (torch.tensor([[0.90, 0.05, 0.05], [0.05, 0.90, 0.05], [0.05, 0.05, 0.90], [0.85, 0.05, 0.10], [0.10, 0.10, 0.80]]), torch.tensor([0, 1, 1, 2, 2]))
]


def get_metric_classes(base_class):
    class Old(base_class):
        full_state_update = True

    class New(base_class):
        full_state_update = False

    return [Old, New]


if __name__ == "__main__":
    res = {True: {}, False: {}}
    for base_metric_class, metric_args, args in zip(metrics, metric_args, inputs):
        print(f"Testing {base_metric_class}")
        name = base_metric_class.__name__
        OldClass, NewClass = get_metric_classes(base_metric_class)
        for metric, enabled in zip([OldClass(**metric_args), NewClass(**metric_args)], [False, True]):
            res[enabled][name] = np.zeros((len(NUM_CALLS), NUM_REPS))
            for i, s in tqdm(enumerate(NUM_CALLS), total=len(NUM_CALLS)):
                for r in range(NUM_REPS):
                    start = perf_counter()
                    for _ in range(s):
                        val = metric(*args)
                    end = perf_counter()
                    metric.reset()
                    res[enabled][name][i, r] = end - start

    fig, ax = plt.subplots(nrows=2, ncols=5)
    for count, metric in enumerate(metrics):
        i = count % 2
        j = count % 5
        name = metric.__name__
        mean_old = res[False][name].mean(axis=-1)
        std_old = res[False][name].std(axis=-1)
        mean_new = res[True][name].mean(axis=-1)
        std_new = res[True][name].std(axis=-1)

        ax[i, j].plot(NUM_CALLS, mean_old, label="Old standard")
        ax[i, j].fill_between(NUM_CALLS, mean_old + std_old, mean_old - std_old, alpha=0.1)

        ax[i, j].plot(NUM_CALLS, mean_new, label="New standard")
        ax[i, j].fill_between(NUM_CALLS, mean_new + std_new, mean_new - std_new, alpha=0.1)

        if i == 1:
            ax[i, j].set_xlabel("Number of forward calls", fontsize=10)
        if j == 0:
            ax[i, j].set_ylabel("Time (sec)", fontsize=10)
        ax[i, j].set_title(name, fontsize=10, fontweight='bold')
        plt.setp(ax[i, j].get_yticklabels(), fontsize=5)
        ax[i, j].legend(loc="upper left", fontsize=10)

    plt.show()

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sweet!

@Borda Borda enabled auto-merge (squash) May 10, 2022 10:18
@Borda Borda merged commit a971c6b into master May 10, 2022
@Borda Borda deleted the refactor/remove_double_forward branch May 10, 2022 10:18
@SkafteNicki SkafteNicki mentioned this pull request May 16, 2022
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Important milestonish ready refactoring refactoring and code health
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants