Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stats logging in "on_train_epoch_end" ends up on wrong progress bar #19322

Open
jojje opened this issue Jan 21, 2024 · 5 comments · May be fixed by #19578
Open

stats logging in "on_train_epoch_end" ends up on wrong progress bar #19322

jojje opened this issue Jan 21, 2024 · 5 comments · May be fixed by #19578
Labels
bug Something isn't working logging Related to the `LoggerConnector` and `log()` ver: 2.1.x

Comments

@jojje
Copy link

jojje commented Jan 21, 2024

Bug description

When logging statistics at the end of an epoch from within on_train_epoch_end, the statistics end up on the wrong progress bar.

Since there doesn't seem to be a configuration to tell lightning nor the TQDMProgressBar to retain the bar for each epoch, I've been forced to inject a new line after each epoch ends, in order to not lose any of the valuable statistics in the console output.

The following is the output from a 3 epoch run:

Epoch 0: 100%|█████████████████████████| 938/938 [00:04<00:00, 206.06it/s, v_num=207]
Epoch 1: 100%|█████████████| 938/938 [00:04<00:00, 233.29it/s, v_num=207, loss=0.553]
Epoch 2: 100%|█████████████| 938/938 [00:04<00:00, 233.39it/s, v_num=207, loss=0.329]
`Trainer.fit` stopped: `max_epochs=3` reached.
Epoch 2: 100%|█████████████| 938/938 [00:04<00:00, 232.93it/s, v_num=207, loss=0.329]
  • Loss for epoch 0 is incorrectly shown for epoch 1.
  • Loss for epoch 1 is incorrectly shown for epoch 2.
  • No logged loss at all is reported for epoch 0 nor epoch 2,

If there is a proper way to retain the progress bar for each epoch that is different from what I'm doing, then please let me know and this ticket can then be closed. If not, hopefully a fix can be found.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import torch
import torchvision
import pytorch_lightning as pl

class DemoNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(784, 10)
        self.batch_losses = []

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def training_step(self, batch:torch.Tensor, _):
        x, y = batch
        x = x.reshape(x.size(0), -1)
        yh = self.fc(x)
        loss = torch.nn.functional.cross_entropy(yh, y)
        self.batch_losses.append(loss)
        return loss

    def on_train_epoch_end(self):
        loss = torch.stack(self.batch_losses).mean()
        self.log('loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.batch_losses.clear()
        print("")

ds = torchvision.datasets.MNIST(root="dataset/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=ds, batch_size=64, shuffle=False)
trainer = pl.Trainer(max_epochs=3)
trainer.fit(DemoNet(), train_loader)

Error messages and logs

N/A

Environment

Current environment
  • Lightning:

    • lightning-utilities: 0.9.0
    • pytorch-lightning: 2.1.3
    • torch: 2.1.2+cu118
    • torchaudio: 2.1.2+cu118
    • torchmetrics: 1.2.1
    • torchvision: 0.16.2+cu118
    • tqdm: 4.66.1
  • System:

    • OS: Windows
    • architecture:
      • 64bit
      • WindowsPE
    • processor: AMD64 Family 25 Model 97 Stepping 2, AuthenticAMD
    • python: 3.10.0
    • release: 10
    • version: 10.0.19045
  • CUDA:

    • GPU:
      • NVIDIA GeForce RTX 4090
    • available: True
    • version: 11.8
  • How you installed Lightning(conda, pip, source): pip

  • Running environment of LightningApp (e.g. local, cloud): local

More info

No response

cc @carmocca

@jojje jojje added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jan 21, 2024
@ManoBharathi93
Copy link

ManoBharathi93 commented Feb 29, 2024

Injecting new line :

def on_train_epoch_end(self):
     loss = torch.stack(self.batch_losses).mean()
     self.log('loss', loss, on_step=False, on_epoch=True, prog_bar=True)
     print(f"Epoch {self.current_epoch} --> loss{loss.item()}")
     self.batch_losses.clear()
     print("")

Output will look like this:
image

@jojje Do you meant this approach ?
Also Epoch 2 is printed twice..
cc @Borda Do you think it as issue or not? If issue can you please share your insights I could work on it.

@awaelchli awaelchli added logging Related to the `LoggerConnector` and `log()` and removed needs triage Waiting to be triaged by maintainers labels Feb 29, 2024
@awaelchli
Copy link
Member

Hey @jojje It could be seen as an issue or not, it depends.
Fixing this might be very hard. It has to do with the fact that the callback hooks run before the LightningModule hooks if I interpret this correctly.

If you self.log in your training step with on_epoch=True, it will work correctly.

Regarding "why does Epoch 2 show twice" it is because you have print statements and the TQDM bar will continue to write updates to the progress bar after your prints. If you want to avoid that, use self.print(...) instead.

@jojje
Copy link
Author

jojje commented Mar 5, 2024

@awaelchli I tried the two changes you proposed, It solved the "off by one" problem, but at the cost of a performance hit. It also doesn't solve the problem of the individual epoch progress bars vanishing, causing data loss in the console output.

Change:

@@ -7,5 +7,4 @@ class DemoNet(pl.LightningModule):
         super().__init__()
         self.fc = torch.nn.Linear(784, 10)
-        self.batch_losses = []

     def configure_optimizers(self):
@@ -17,12 +16,9 @@ class DemoNet(pl.LightningModule):
         yh = self.fc(x)
         loss = torch.nn.functional.cross_entropy(yh, y)
-        self.batch_losses.append(loss)
+        self.log('loss', loss, on_step=False, on_epoch=True, prog_bar=True)
         return loss

     def on_train_epoch_end(self):
-        loss = torch.stack(self.batch_losses).mean()
-        self.log('loss', loss, on_step=False, on_epoch=True, prog_bar=True)
-        self.batch_losses.clear()
-        print("")
+        self.print("")

 ds = torchvision.datasets.MNIST(root="dataset/", train=True, transform=torchvision.transforms.ToTensor(), download=True)

Resulting output:



Epoch 2: 100%|███████| 938/938 [00:04<00:00, 217.09it/s, v_num=6, loss=0.300]`Trainer.fit` stopped: `max_epochs=3` reached.
Epoch 2: 100%|███████| 938/938 [00:04<00:00, 216.87it/s, v_num=6, loss=0.300]

As you can see,

  1. the previous progress bars are closed and thus not retained. That's why there are two leading blank lines.
  2. there is still a duplicate of the final progress bar for the third epoch (Epoch 2).

The reason why I didn't let Lightning calculate the stats automatically via the on_epoch (end) flag is because it's expensive. On my test run above, the training takes a 25% performance (throughput) hit by logging on each training step. with on_step=True, on_epoch=True and about 7% with on_step=False, on_epoch=True. I've researched the issues and discussion forums, and the consensus seem to be "Log as little and as seldom as possible, and calculate statistics only when you need to in order to not slow down training". So that's why I'm performing the cheapest operation possible in the training step; just storing the losses, and then at the end of the epoch, doing the expensive tensor creation, mean calculation and logging, since it's only at the end of the epoch it's relevant to log the loss for the epoch. I'm simply trying to find a near "zero cost" stats logging solution here that keeps the training observability ergonomics from our pure pytorch training loops.

Right now I'm just in an evaluation phase seeing if Lightning might be something we can use going forward, but these initial 101 training ergonomics have put such notions on ice. I like the idea of bringing more structure to training, but can unfortunately not sell the idea of a new framework without even the basics being handled correctly, so that's why I opened this issue. I look forward to hearing further suggestions on how to leverage lightning correctly, so as to pass the initial sniff test ;)

To reiterate the composite objective:

  1. Log loss or any other statistic at the end of each epoch.
  2. Retain the progress bar and statistics for each epoch.
  3. Avoid incurring significant training slowdown due to logging.

@jojje
Copy link
Author

jojje commented Mar 5, 2024

Update, workaround that makes lightning log as expected:

import torch
import torchvision
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar

class LitProgressBar(TQDMProgressBar):
    def on_train_end(self, *_):
        # self.train_progress_bar.close()
        pass

    def on_validation_end(self, trainer, pl_module):
        # self.val_progress_bar.close()
        self.reset_dataloader_idx_tracker()
        if self._train_progress_bar is not None and trainer.state.fn == "fit":
            self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

    def on_test_end(self, trainer, pl_module):
        # self.test_progress_bar.close()
        self.reset_dataloader_idx_tracker()

    def on_predict_end(self, trainer, pl_module):
        # self.predict_progress_bar.close()
        self.reset_dataloader_idx_tracker()


class DemoNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(784, 10)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def training_step(self, batch:torch.Tensor, _):
        x, y = batch
        x = x.reshape(x.size(0), -1)
        yh = self.fc(x)
        loss = torch.nn.functional.cross_entropy(yh, y)
        self.log('loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        # import ipdb; ipdb.set_trace(context=15)
        print("")
        pass

ds = torchvision.datasets.MNIST(root="dataset/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=ds, batch_size=64, shuffle=False)
trainer = pl.Trainer(max_epochs=3, callbacks=[LitProgressBar()])
trainer.fit(DemoNet(), train_loader)

The key bit of information here is the need to subclass the TQDMProgressBar, just to be able to disable all the hard-coded *bar.close() calls you make in the default progress bar.

It would be great if every user didn't have to deal with all that boiler plate for every project, and instead the TQDMProgressBar constructor taking an optional argument such as "leave:bool" (same as tqdm) that you'd then check in the code to decide whether to close the progress bars or not.

E.g.

class TQDMProgressBar(ProgressBar):
    def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool = False):

@jojje
Copy link
Author

jojje commented Mar 5, 2024

A PR for discussion and review has been submitted to address this issue.
If anyone has time to look at it and provide feedback, that'd be great.

Reviewer note: There was a failed test, but it seems entirely unrelated. In fact, the change was made such that there is zero change in behavior by default, and explicitly setting a new flag (which no existing tests could possibly be aware of) is required to enable the new behavior , so I don't how this change could possibly be related to the failure of core/test_metric_result_integration.py::test_result_reduce_ddp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working logging Related to the `LoggerConnector` and `log()` ver: 2.1.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants