Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LightningCLI(auto_registry) #12108

Merged
merged 6 commits into from Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -57,6 +57,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/pull/10860))


- Added `LightningCLI(auto_registry)` flag to register all subclasses of the registerable components automatically ([#12108](https://github.com/PyTorchLightning/pytorch-lightning/pull/12108))


- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/pull/10700))


Expand Down
23 changes: 23 additions & 0 deletions docs/source/common/lightning_cli.rst
Expand Up @@ -345,6 +345,29 @@ This can be useful to implement custom logic without having to subclass the CLI,
and argument parsing capabilities.


Subclass registration
^^^^^^^^^^^^^^^^^^^^^

To use shorthand notation, the options need to be registered beforehand. This can be easily done with:

.. code-block::

LightningCLI(auto_registry=True) # False by default

which will register all subclasses of :class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler._LRScheduler`,
:class:`~pytorch_lightning.core.lightning.LightningModule`,
:class:`~pytorch_lightning.core.datamodule.LightningDataModule`, :class:`~pytorch_lightning.callbacks.Callback`, and
:class:`~pytorch_lightning.loggers.LightningLoggerBase` across all imported modules. This includes those in your own
code.

Alternatively, if this is left unset, only the subclasses defined in PyTorch's :class:`torch.optim.Optimizer`,
:class:`torch.optim.lr_scheduler._LRScheduler` and Lightning's :class:`~pytorch_lightning.callbacks.Callback` and
:class:`~pytorch_lightning.loggers.LightningLoggerBase` subclassess will be registered.

In subsequent sections, we will go over adding specific classes to specific registries as well as how to use
shorthand notation.


Trainer Callbacks and arguments with class type
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
51 changes: 34 additions & 17 deletions pytorch_lightning/utilities/cli.py
Expand Up @@ -30,6 +30,7 @@
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE
from pytorch_lightning.utilities.meta import get_all_subclasses
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple, LRSchedulerTypeUnion
Expand Down Expand Up @@ -58,9 +59,8 @@ def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False)
elif not isinstance(key, str):
raise TypeError(f"`key` must be a str, found {key}")

if key in self and not override:
raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.")
self[key] = cls
if key not in self or override:
self[key] = cls
return cls

def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None:
Expand Down Expand Up @@ -91,10 +91,11 @@ def __str__(self) -> str:


OPTIMIZER_REGISTRY = _Registry()
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)

LR_SCHEDULER_REGISTRY = _Registry()
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
CALLBACK_REGISTRY = _Registry()
MODEL_REGISTRY = _Registry()
DATAMODULE_REGISTRY = _Registry()
LOGGER_REGISTRY = _Registry()


class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
Expand All @@ -103,17 +104,29 @@ def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any
self.monitor = monitor


LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau)

CALLBACK_REGISTRY = _Registry()
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback)

MODEL_REGISTRY = _Registry()

DATAMODULE_REGISTRY = _Registry()

LOGGER_REGISTRY = _Registry()
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.LightningLoggerBase)
def _populate_registries(subclasses: bool) -> None:
if subclasses:
# this will register any subclasses from all loaded modules including userland
for cls in get_all_subclasses(torch.optim.Optimizer):
OPTIMIZER_REGISTRY(cls)
for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
LR_SCHEDULER_REGISTRY(cls)
for cls in get_all_subclasses(pl.Callback):
CALLBACK_REGISTRY(cls)
for cls in get_all_subclasses(pl.LightningModule):
MODEL_REGISTRY(cls)
for cls in get_all_subclasses(pl.LightningDataModule):
DATAMODULE_REGISTRY(cls)
for cls in get_all_subclasses(pl.loggers.LightningLoggerBase):
LOGGER_REGISTRY(cls)
else:
# manually register torch's subclasses and our subclasses
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback)
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.LightningLoggerBase)
# `ReduceLROnPlateau` does not subclass `_LRScheduler`
LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau)


class LightningArgumentParser(ArgumentParser):
Expand Down Expand Up @@ -471,6 +484,7 @@ def __init__(
subclass_mode_model: bool = False,
subclass_mode_data: bool = False,
run: bool = True,
auto_registry: bool = False,
) -> None:
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which
are called / instantiated using a parsed configuration file and / or command line args.
Expand Down Expand Up @@ -514,6 +528,7 @@ def __init__(
of the given class.
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
auto_registry: Whether to automatically fill up the registries with all defined subclasses.
"""
self.save_config_callback = save_config_callback
self.save_config_filename = save_config_filename
Expand All @@ -533,6 +548,8 @@ def __init__(
self._datamodule_class = datamodule_class or LightningDataModule
self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data

_populate_registries(auto_registry)

main_kwargs, subparser_kwargs = self._setup_parser_kwargs(
parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463
{"description": description, "env_prefix": env_prefix, "default_env": env_parse},
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/meta.py
Expand Up @@ -147,7 +147,7 @@ def init_meta(*_, **__):


# https://stackoverflow.com/a/63851681/9201239
def get_all_subclasses(cls: Type[nn.Module]) -> Set[nn.Module]:
def get_all_subclasses(cls: Type) -> Set[Type]:
subclass_list = []

def recurse(cl):
Expand Down
80 changes: 54 additions & 26 deletions tests/utilities/test_cli.py
Expand Up @@ -38,6 +38,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import (
_populate_registries,
CALLBACK_REGISTRY,
DATAMODULE_REGISTRY,
instantiate_class,
Expand Down Expand Up @@ -883,27 +884,38 @@ def test_lightning_cli_run():
assert isinstance(cli.model, LightningModule)


@OPTIMIZER_REGISTRY
class CustomAdam(torch.optim.Adam):
pass
@pytest.fixture(autouse=True)
def clear_registries():
# since the registries are global, it's good to clear them after each test to avoid unwanted interactions
yield
OPTIMIZER_REGISTRY.clear()
LR_SCHEDULER_REGISTRY.clear()
CALLBACK_REGISTRY.clear()
MODEL_REGISTRY.clear()
DATAMODULE_REGISTRY.clear()
LOGGER_REGISTRY.clear()


@LR_SCHEDULER_REGISTRY
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
pass

def test_registries():
# the registries are global so this is only necessary when this test is run standalone
_populate_registries(False)

@CALLBACK_REGISTRY
class CustomCallback(Callback):
pass
@OPTIMIZER_REGISTRY
class CustomAdam(torch.optim.Adam):
pass

@LR_SCHEDULER_REGISTRY
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
pass

@LOGGER_REGISTRY
class CustomLogger(LightningLoggerBase):
pass
@CALLBACK_REGISTRY
class CustomCallback(Callback):
pass

@LOGGER_REGISTRY
class CustomLogger(LightningLoggerBase):
pass

def test_registries():
assert "SGD" in OPTIMIZER_REGISTRY.names
assert "RMSprop" in OPTIMIZER_REGISTRY.names
assert "CustomAdam" in OPTIMIZER_REGISTRY.names
Expand All @@ -916,9 +928,13 @@ def test_registries():
assert "EarlyStopping" in CALLBACK_REGISTRY.names
assert "CustomCallback" in CALLBACK_REGISTRY.names

with pytest.raises(MisconfigurationException, match="is already present in the registry"):
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer)
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True)
class Foo:
...

OPTIMIZER_REGISTRY(Foo, key="SGD") # not overridden by default
assert OPTIMIZER_REGISTRY["SGD"] is torch.optim.SGD
OPTIMIZER_REGISTRY(Foo, key="SGD", override=True)
assert OPTIMIZER_REGISTRY["SGD"] is Foo

# test `_Registry.__call__` returns the class
assert isinstance(CustomCallback(), CustomCallback)
Expand All @@ -927,18 +943,24 @@ def test_registries():
assert "CustomLogger" in LOGGER_REGISTRY


@MODEL_REGISTRY
def test_registries_register_automatically():
assert "SaveConfigCallback" not in CALLBACK_REGISTRY
with mock.patch("sys.argv", ["any.py"]):
LightningCLI(BoringModel, run=False, auto_registry=True)
assert "SaveConfigCallback" in CALLBACK_REGISTRY


class TestModel(BoringModel):
def __init__(self, foo, bar=5):
super().__init__()
self.foo = foo
self.bar = bar


MODEL_REGISTRY(cls=BoringModel)


def test_lightning_cli_model_choices():
MODEL_REGISTRY(cls=TestModel)
MODEL_REGISTRY(cls=BoringModel)

with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl"
) as run:
Expand All @@ -953,18 +975,18 @@ def test_lightning_cli_model_choices():
assert cli.model.bar == 5


@DATAMODULE_REGISTRY
class MyDataModule(BoringDataModule):
def __init__(self, foo, bar=5):
super().__init__()
self.foo = foo
self.bar = bar


DATAMODULE_REGISTRY(cls=BoringDataModule)


def test_lightning_cli_datamodule_choices():
MODEL_REGISTRY(cls=BoringModel)
DATAMODULE_REGISTRY(cls=MyDataModule)
DATAMODULE_REGISTRY(cls=BoringDataModule)

# with set model
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl"
Expand Down Expand Up @@ -1001,7 +1023,7 @@ def test_lightning_cli_datamodule_choices():
assert not hasattr(cli.parser.groups["data"], "group_class")

with mock.patch("sys.argv", ["any.py"]), mock.patch.dict(DATAMODULE_REGISTRY, clear=True):
cli = LightningCLI(BoringModel, run=False)
cli = LightningCLI(BoringModel, run=False, auto_registry=False)
# no registered classes so not added automatically
assert "data" not in cli.parser.groups
assert len(DATAMODULE_REGISTRY) # check state was not modified
Expand All @@ -1014,6 +1036,8 @@ def test_lightning_cli_datamodule_choices():

@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_registries_resolution(use_class_path_callbacks):
MODEL_REGISTRY(cls=BoringModel)

"""This test validates registries are used when simplified command line are being used."""
cli_args = [
"--optimizer",
Expand Down Expand Up @@ -1070,6 +1094,7 @@ def test_argv_transformation_single_callback():
}
]
expected = base + ["--trainer.callbacks", str(callbacks)]
_populate_registries(False)
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected

Expand All @@ -1093,6 +1118,7 @@ def test_argv_transformation_multiple_callbacks():
},
]
expected = base + ["--trainer.callbacks", str(callbacks)]
_populate_registries(False)
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected

Expand Down Expand Up @@ -1120,6 +1146,7 @@ def test_argv_transformation_multiple_callbacks_with_config():
]
expected = base + ["--trainer.callbacks", str(callbacks)]
nested_key = "trainer.callbacks"
_populate_registries(False)
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input)
assert argv == expected

Expand Down Expand Up @@ -1156,6 +1183,7 @@ def test_argv_transformation_multiple_callbacks_with_config():
def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry):
base = ["any.py", "--trainer.max_epochs=1"]
argv = base + args
_populate_registries(False)
new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv)
assert new_argv == base + [f"--{nested_key}", str(expected)]

Expand Down