Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EvaluationDistributedSampler and examples on distributed evaluation #1886

Draft
wants to merge 44 commits into
base: master
Choose a base branch
from

Conversation

SkafteNicki
Copy link
Member

@SkafteNicki SkafteNicki commented Jul 6, 2023

What does this PR do?

Fixes #1338

The original issue is about if we should implement a join context such that metrics could be evaluated on uneven number of samples in distributed settings. Just to remind, we normally discourage users from evaluating in distributed because the default distributed sampler from Pytorch will add additional samples to make all processes do even work, which messes with results.

After investigating this issue, it seems that we do not need a join context at all due to the custom synchronization we have for metrics. To understand this we need to look at the two different states we can have: tensor state and list of tensor states.

  1. For tensor states the logic is fairly simple: even if rank 0 is evaluated on more samples or more batches than rank 1, we still only need to do one all-gather operation regardless of how many samples/batches each rank has seen.
  2. For list states we need are saved by the custom logic we have. Imaging that rank 0 state is a list of two tensors [t_01, t_02] and rank 1 state is a list of one tensor [t_11] (rank 0 have seen one more batch than rank 1). We list states are encountered internally we make sure to concatenate the states into one tensor to not need to call allgather for each tensor in the list
    if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1:
    input_dict[attr] = [dim_zero_cat(input_dict[attr])]

    such after this each state is a single tensor t_0 and t_1 but clearly t_0.shape != t_1.shape. Again, internally we deal with this by padding to same size and then doing a all gather:
    # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
    pad_dims = []
    pad_by = (max_size - local_size).detach().cpu()
    for val in reversed(pad_by):
    pad_dims.append(0)
    pad_dims.append(val.item())
    result_padded = F.pad(result, pad_dims)
    gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
    torch.distributed.all_gather(gathered_result, result_padded, group)
    for idx, item_size in enumerate(local_sizes):
    slice_param = [slice(dim_size) for dim_size in item_size]
    gathered_result[idx] = gathered_result[idx][slice_param]
    return gathered_result

Thus in both cases, even if one rank sees more samples/batches, we still do the same number of distributed operations per rank, which should mean that everything works.

To highlight this feature of TM this PR does a couple of things:

  • Introduce a new EvaluationDistributedSampler that does not add extra samplers. Thus, users can use this as a drop in replacement for any DistributedSampler if they want to do proper distributed evaluation (else they just need to secure that number of samples are even divisible by the number of processes).
  • Add unittests that supports the above
  • Add example on how to do this distributed evaluation in both lightning + standard torch
Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃


📚 Documentation preview 📚: https://torchmetrics--1886.org.readthedocs.build/en/1886/

@SkafteNicki SkafteNicki added the enhancement New feature or request label Jul 26, 2023
@SkafteNicki SkafteNicki added this to the v1.1.0 milestone Jul 26, 2023
@SkafteNicki SkafteNicki marked this pull request as ready for review July 26, 2023 14:21
@SkafteNicki SkafteNicki changed the title Investigating distributed com Add EvaluationDistributedSampler and examples on distributed evaluation Jul 26, 2023
@codecov
Copy link

codecov bot commented Jul 26, 2023

Codecov Report

Merging #1886 (311bce3) into master (29f3289) will decrease coverage by 0%.
The diff coverage is 50%.

Additional details and impacted files
@@          Coverage Diff           @@
##           master   #1886   +/-   ##
======================================
- Coverage      87%     87%   -0%     
======================================
  Files         270     270           
  Lines       15581   15592   +11     
======================================
+ Hits        13483   13488    +5     
- Misses       2098    2104    +6     

@mergify mergify bot added the ready label Aug 21, 2023
Copy link
Member

@justusschock justusschock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please add tests for validation and training as well? And maybe an fsdp test? Also some notes on caveats might be good to add to the sampler docs

super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed)

len_dataset = len(self.dataset) # type: ignore[arg-type]
if not self.drop_last and len_dataset % self.num_replicas != 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the issue with this that it wouldn't necessarily work with validation, since not all ranks would reach the same distributed function calls and therefore time out which would kill the entire process. Also this would never work with FSDP, since some ranks have a batch more and for fsdp, not all processes would reach the forward syncing points also resulting in timeouts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that in the context of Lightning this wouldn't work well, as it does not support Join (Lightning-AI/pytorch-lightning#3325)
FSDP also doesn't support join afaik (pytorch/pytorch#64683)

But outside Lightning, and taking FSDP out of the equation, I agree this can work and is a good utility to have IMO. It also suits the metric design well, since synchronization is only necessary when all processes have finished collecting their statistics and .compute() can be called.

@justusschock
Copy link
Member

calling @awaelchli for distributed review :)

@SkafteNicki
Copy link
Member Author

can we please add tests for validation and training as well? And maybe an fsdp test? Also some notes on caveats might be good to add to the sampler docs

You are right that we need to test this feature better to clearly state the limitations.
I am going to remove it from the 1.1 milestone to future because it is not important to get done right now.

@SkafteNicki SkafteNicki modified the milestones: v1.1.0, future Aug 21, 2023
@SkafteNicki SkafteNicki marked this pull request as draft August 21, 2023 13:08
@SkafteNicki
Copy link
Member Author

Converted to draft until better tested.

@mergify mergify bot removed the ready label Aug 21, 2023
docs/source/references/utilities.rst Outdated Show resolved Hide resolved
src/torchmetrics/utilities/distributed.py Outdated Show resolved Hide resolved
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed)

len_dataset = len(self.dataset) # type: ignore[arg-type]
if not self.drop_last and len_dataset % self.num_replicas != 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that in the context of Lightning this wouldn't work well, as it does not support Join (Lightning-AI/pytorch-lightning#3325)
FSDP also doesn't support join afaik (pytorch/pytorch#64683)

But outside Lightning, and taking FSDP out of the equation, I agree this can work and is a good utility to have IMO. It also suits the metric design well, since synchronization is only necessary when all processes have finished collecting their statistics and .compute() can be called.


"""

def __init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Lightning we have a very similar class: https://github.com/Lightning-AI/lightning/blob/fbdbe632c67b05158804b52f4345944781ca4f07/src/lightning/pytorch/overrides/distributed.py#L194

I think the main difference is that yours respects the setting drop_last. I'm not sure why we have the __iter__ overridden there but if you are interested you can compare the two.

metric_class=metric_class,
),
range(NUM_PROCESSES),
)
Copy link
Member

@awaelchli awaelchli Aug 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, a unit test for just the sampler alone could be useful, one that doesn't launch processes (not needed) but rather just assert the indices returned on each rank match the expectation, e.g.:

sampler = EvaluationDistributedSampler(dataset, rank=0, num_replicas=3, drop_last=...)
assert list(iter(sampler)) == ....

sampler = EvaluationDistributedSampler(dataset, rank=2, num_replicas=3, drop_last=...)
assert list(iter(sampler)) == ....

and so on to test all edge cases.

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
@Borda
Copy link
Member

Borda commented May 21, 2024

@SkafteNicki, what is missing here to make it land? 🐿️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Investigate use of join context for distributed sync
4 participants