Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation runs only for one iteration when restarting from checkpoint mid-epoch, wrongly reporting validation loss #19549

Open
pimdh opened this issue Feb 29, 2024 · 2 comments · May be fixed by #19583
Labels
bug Something isn't working help wanted Open to be worked on loops Related to the Loop API

Comments

@pimdh
Copy link

pimdh commented Feb 29, 2024

Bug description

When resuming from a mid-epoch checkpoint (which I have to use as my dataset is large), the training loop runs a validation loop for only one iteration, which leads to wrong validation loss logged.

It appears like the batch_progress of lighting.pytorch.loops._EvaluationLoop wrongly gets filled from the checkpoint as if the validation loop was already done, and not properly reset after the checkpoint is loaded.

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import os

import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        print("Training", batch.shape, loss.item(), batch_idx)
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        print("Validation", batch.shape, loss.item())
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def on_train_epoch_end(self) -> None:
        return super().on_train_epoch_end()


def run(ckpt_path):
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=4,
        limit_val_batches=3,
        val_check_interval=2,
        max_epochs=1,
        enable_model_summary=False,
        enable_progress_bar=False,
        num_sanity_val_steps=0,
        logger=False,
        callbacks=[
            ModelCheckpoint(
                save_last=False,
                save_top_k=10,
                monitor="valid_loss",
                every_n_train_steps=1,
                dirpath="./checkpoints",
                enable_version_counter=False,
            )
        ],
    )
    trainer.fit(
        model,
        train_dataloaders=train_data,
        val_dataloaders=val_data,
        ckpt_path=ckpt_path,
    )


run(ckpt_path=None)
run(ckpt_path="checkpoints/epoch=0-step=3.ckpt")

Error messages and logs

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/.../lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory /local/mnt/workspace/pim/projects/equi-scaling/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/.../lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/.../lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
Training torch.Size([2, 32]) -0.8351932168006897 0
/.../lightning/pytorch/callbacks/model_checkpoint.py:382: `ModelCheckpoint(monitor='valid_loss')` could not find the monitored key in the returned metrics: ['train_loss', 'epoch', 'step']. HINT: Did you call `log('valid_loss', value)` in the `LightningModule`?
Training torch.Size([2, 32]) -6.302826881408691 1
Validation torch.Size([2, 32]) -3.5122809410095215
Validation torch.Size([2, 32]) 2.169618844985962
Validation torch.Size([2, 32]) -6.107339859008789
Training torch.Size([2, 32]) -7.338858604431152 2
Training torch.Size([2, 32]) 9.845891952514648 3
Validation torch.Size([2, 32]) -4.606845378875732
Validation torch.Size([2, 32]) 8.005867004394531
Validation torch.Size([2, 32]) -4.361298561096191
`Trainer.fit` stopped: `max_epochs=1` reached.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at checkpoints/epoch=0-step=3.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Restored all states from the checkpoint at checkpoints/epoch=0-step=3.ckpt
/.../lightning/pytorch/loops/training_epoch_loop.py:156: You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable results if further training is done. Consider using an end-of-epoch checkpoint
Validation torch.Size([2, 32]) 8.310140609741211   # <--- only one iteration
Training torch.Size([2, 32]) -4.653291702270508 2
Training torch.Size([2, 32]) -8.187255859375 3
Validation torch.Size([2, 32]) 9.539518356323242
Validation torch.Size([2, 32]) -8.617881774902344
Validation torch.Size([2, 32]) 0.5192334651947021
`Trainer.fit` stopped: `max_epochs=1` reached.

Environment

Current environment
* CUDA:
        - GPU:
                - NVIDIA GeForce RTX 2080 Ti
        - available:         True
        - version:           11.7
* Lightning:
        - lightning:         2.2.0.post0
        - lightning-utilities: 0.10.1
        - pytorch-lightning: 2.2.0.post0
        - torch:             2.0.1
        - torch-ema:         0.3
        - torch-geometric:   2.5.0
        - torch-scatter:     2.1.2+pt20cu117
        - torchmetrics:      1.3.1
        - torchvision:       0.15.2
* System:
        - OS:                Linux
        - architecture:
                - 64bit
        - processor:         x86_64
        - python:            3.10.13
        - release:           5.4.0-152-generic

More info

A fix/workaround for this issue, is to add self.batch_progress.reset_on_run() at the end of _EvaluationLoop.run.

cc @carmocca @justusschock

@pimdh pimdh added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Feb 29, 2024
@awaelchli
Copy link
Member

@pimdh Thank you for already investigating this.

Since the training loop is quite complex, I can't say for sure this is the right solution but it sounds reasonable. Would you be interested to send a PR with this change? We can then let the full test suite run on your PR and see if there are any edge cases with this. If it works I can help add a test case.

@awaelchli awaelchli added help wanted Open to be worked on loops Related to the Loop API and removed needs triage Waiting to be triaged by maintainers labels Mar 3, 2024
@pimdh pimdh linked a pull request Mar 6, 2024 that will close this issue
7 tasks
@pimdh
Copy link
Author

pimdh commented Mar 6, 2024

Hi @awaelchli , I've filed the PR at #19583. While this suffices in my usecase, unfortunately, I won't have time to add unit tests to validate this.
Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on loops Related to the Loop API
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants