Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP strategy doesn't work for on_validation_epoch_end, always hang #19783

Open
jzhanghzau opened this issue Apr 16, 2024 · 4 comments
Open

DDP strategy doesn't work for on_validation_epoch_end, always hang #19783

jzhanghzau opened this issue Apr 16, 2024 · 4 comments
Labels
logging Related to the `LoggerConnector` and `log()` question Further information is requested ver: 2.1.x

Comments

@jzhanghzau
Copy link

jzhanghzau commented Apr 16, 2024

Bug description

My code looks like below, i want compute a validation metric based on entire validation dataset. so i will append every results in every batch into a list, and then in on_validation_epoch_end function to compute the metric.

It works fine with single GPU, but when i use ddp strategy the do so, i always meet error, seems validation dataset hangs.

What version are you seeing the problem on?

v2.1, v2.2

How to reproduce the bug

self.validation_step_outputs = []
  self.validation_step_clusters = []

def validation_step(self, batch, batch_idx):

  batch_tokens, clusters = batch
  projection= self._common_step(batch_tokens)

  self.validation_step_outputs.append(projection)
  self.validation_step_clusters.append(clusters)
  
def on_validation_epoch_end(self):
  
  if self.trainer.is_global_zero:  # Check if this is the rank 0 process
  
    all_preds = torch.cat(self.validation_step_outputs, dim=0)

    all_clusters = LabelEncoder().fit_transform(list(itertools.chain.from_iterable(self.validation_step_clusters)))
    all_clusters = torch.tensor(all_clusters)
    
    self.validation_step_outputs.clear()
    self.validation_step_clusters.clear()

    loss = loss_func(all_preds, all_clusters)
    
    accuracy = self._cal_accuracy(all_preds, all_clusters)
                  
    self.log('validation_loss', loss, on_epoch=True, prog_bar=True)
    self.log('accuracy', accuracy, on_epoch=True, prog_bar=True)
    
  self.trainer.strategy.barrier()

Error messages and logs

# Error messages and logs here please

image

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @carmocca

@jzhanghzau jzhanghzau added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 16, 2024
@awaelchli
Copy link
Member

@jzhanghzau self.log() issues collective calls, so you can't just call it on "rank zero" only. If you want to do that, pass self.log(rank_zero_only=True).

Here are the relevant docs for this:
https://lightning.ai/docs/pytorch/stable/visualize/logging_advanced.html#rank-zero-only

@awaelchli awaelchli added question Further information is requested logging Related to the `LoggerConnector` and `log()` and removed bug Something isn't working labels Apr 16, 2024
@jzhanghzau
Copy link
Author

@jzhanghzau self.log() issues collective calls, so you can't just call it on "rank zero" only. If you want to do that, pass self.log(rank_zero_only=True).

Here are the relevant docs for this: https://lightning.ai/docs/pytorch/stable/visualize/logging_advanced.html#rank-zero-only

Thanks for your quick reply! if i setself.log(rank_zero_only=True), it seems like that i am not allowed to use callback, what if i still want to use callback(earylstopping), how can i organize my code? remove if self.trainer.is_global_zero: and let every process go through the logic what i present above ?
Thanks again.

@jzhanghzau
Copy link
Author

@jzhanghzau self.log() issues collective calls, so you can't just call it on "rank zero" only. If you want to do that, pass self.log(rank_zero_only=True).
Here are the relevant docs for this: https://lightning.ai/docs/pytorch/stable/visualize/logging_advanced.html#rank-zero-only

Thanks for your quick reply! if i setself.log(rank_zero_only=True), it seems like that i am not allowed to use callback, what if i still want to use callback(earylstopping), how can i organize my code? remove if self.trainer.is_global_zero: and let every process go through the logic what i present above ? Thanks again.

@awaelchli

@awaelchli
Copy link
Member

Yes, that would be another option.

@awaelchli awaelchli removed the needs triage Waiting to be triaged by maintainers label Apr 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
logging Related to the `LoggerConnector` and `log()` question Further information is requested ver: 2.1.x
Projects
None yet
Development

No branches or pull requests

2 participants