Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint every_n_steps reruns epoch on restore #19815

Open
heth27 opened this issue Apr 25, 2024 · 3 comments
Open

Checkpoint every_n_steps reruns epoch on restore #19815

heth27 opened this issue Apr 25, 2024 · 3 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.2.x

Comments

@heth27
Copy link

heth27 commented Apr 25, 2024

Bug description

The checkpoint callback is run before batch_progress.increment_completed() in training_epoch_loop's advance method. Thus in the checkpoint
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] e.g. 9
is one smaller than for example
checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed'] e.g. 10 or global step.
same for checkpoint['loops']['fit_loop']['epoch_loop.state_dict']['_batches_that_stepped']

Thus when restoring from the checkpoint the batch with batch_idx 9 is run again, even though optimizer step was already done for this batch.

This behavior is unexpected enough to at least warrant a hint in the documentation if not regarded as a bug.

What version are you seeing the problem on?

master

How to reproduce the bug

import os
import math
import time
from typing import Any

import torch
from lightning.fabric.accelerators import find_usable_cuda_devices
from lightning.pytorch.callbacks import ModelCheckpoint, OnExceptionCheckpoint, TQDMProgressBar
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler

import lightning.pytorch as pl
from lightning.pytorch import loggers as pl_loggers


class TestModule(nn.Module):
    def __init__(self, in_dim=512, out_dim=16):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.simple_layer = nn.Linear(self.in_dim, self.out_dim, bias=True)

    def forward(self, input):
        return self.simple_layer(input)


class TestBatchSampler(Sampler):
    def __init__(self):
        super().__init__()

    def __len__(self) -> int:
        return 1e100
        # return len(self.train_allfiles)

    def __iter__(self):  # -> Iterator[int]:
        return self

    def __next__(self):  # -> Iterator[int]:
        return torch.tensor([1])


class TestDataset(Dataset):
    def __init__(self, in_dim):
        super().__init__()
        self.in_dim = in_dim
        self.total_len = 512

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return torch.randn(self.in_dim)


class TestDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.in_dim = 512
        self.val_batch_size = 1

    def train_dataloader(self):
        train_ds = TestDataset(self.in_dim)
        train_dl = DataLoader(train_ds, batch_sampler=TestBatchSampler(), num_workers=4, shuffle=False)
        return train_dl

    def val_dataloader(self):
        val_ds = TestDataset(self.in_dim)
        val_dl = DataLoader(val_ds, batch_size=self.val_batch_size, num_workers=4, shuffle=False)
        return val_dl


class TestLitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.test_module_obj = TestModule(in_dim=512, out_dim=16)
        self.automatic_optimization = False

    def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
        print(f"train_batch ended:{batch_idx}")

    def on_save_checkpoint(self, checkpoint):
        # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['completed'] = \
        #     checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['total']['processed']
        # checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] = \
        #     checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['processed']
        print(f"creating checkpoint")

    def validation_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        print(f"validation step called")
        return torch.tensor(1.0)

    def training_step(self, batch, batch_idx):
        print(f"batch_idx: {batch_idx}")
        optimizer = self.optimizers()

        output = self.test_module_obj(batch)

        loss = output.sum()

        self.manual_backward(loss)

        optimizer.step()

        if batch_idx > 25:
            raise Exception("This is to stop the program :)")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.test_module_obj.parameters()
        )
        return optimizer


if __name__ == '__main__':
    test_data_loader = TestDataModule()
    test_lit_model = TestLitModel()

    checkpoint_dir = 'a_test_logs/'

    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoint_dir,
        every_n_train_steps=10,
        save_top_k=-1, )
    exception_checkpoint_callback = OnExceptionCheckpoint(
        dirpath=checkpoint_dir,
        filename="error"
    )
    trainer = pl.Trainer(
        callbacks=[checkpoint_callback, exception_checkpoint_callback],
        max_epochs=-1,
        max_steps=400000,
        val_check_interval=5

    )
    trainer.fit(test_lit_model, test_data_loader)

    # trainer.fit(test_lit_model,
    #             datamodule=test_data_loader,
    #             ckpt_path='a_test_logs/epoch=0-step=10.ckpt')

Error messages and logs

None

Environment

Current environment
#- Lightning Component (Trainer):
#- PyTorch Lightning Version (2.2.3):

More info

No response

@heth27 heth27 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 25, 2024
@johnzielke
Copy link

I think this is also related to #18595. The fact that the modelcheckpoint is saved before properly incrementing all parts of the counters seems to lead to a host of unforeseen and hard to debug issues.

@ordabayevy
Copy link

I think it is also related to this issue #18060

@heth27
Copy link
Author

heth27 commented Apr 26, 2024

I think it is also related to this issue #18060

Yes, its the same issue, I didn't check enough if it already existed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.2.x
Projects
None yet
Development

No branches or pull requests

3 participants