Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for returning callback from LightningModule.configure_callbacks #11060

Merged
merged 6 commits into from Dec 18, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))


- Added support for returning a Callback from `LightningModule.configure_callbacks` ([#11060](https://github.com/PyTorchLightning/pytorch-lightning/issues/11060))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


### Changed

- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/core/lightning.py
Expand Up @@ -21,7 +21,7 @@
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union

import torch
from torch import ScriptModule, Tensor
Expand All @@ -31,6 +31,7 @@
from typing_extensions import Literal

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.progress import base as progress_base
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
Expand Down Expand Up @@ -1119,12 +1120,13 @@ def predicts_step(self, batch, batch_idx, dataloader_idx=0):
"""
return self(batch)

def configure_callbacks(self):
def configure_callbacks(self) -> Optional[Union[Sequence[Callback], Callback]]:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()``
gets called, the list returned here will be merged with the list of callbacks passed to the Trainer's
``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already
present in the Trainer's callbacks list, it will take priority and replace them. In addition, Lightning
will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks run last.
gets called, the list or a callback returned here will be merged with the list of callbacks passed to the
Trainer's ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks
already present in the Trainer's callbacks list, it will take priority and replace them. In addition,
Lightning will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
run last.

Return:
A list of callbacks which will extend the list of callbacks in the Trainer.
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from datetime import timedelta
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Sequence, Union

from pytorch_lightning.callbacks import (
Callback,
Expand Down Expand Up @@ -271,6 +271,9 @@ def _attach_model_callbacks(self) -> None:
model_callbacks = self.trainer._call_lightning_module_hook("configure_callbacks")
if not model_callbacks:
return

model_callbacks = [model_callbacks] if not isinstance(model_callbacks, Sequence) else model_callbacks
model_callbacks = list(model_callbacks)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
model_callback_types = {type(c) for c in model_callbacks}
trainer_callback_types = {type(c) for c in self.trainer.callbacks}
override_types = model_callback_types.intersection(trainer_callback_types)
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_callbacks.py
Expand Up @@ -83,7 +83,7 @@ def test_configure_callbacks_hook_multiple_calls(tmpdir):

class TestModel(BoringModel):
def configure_callbacks(self):
return [model_callback_mock]
return model_callback_mock

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, enable_checkpointing=False)
Expand Down