Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LearningRateMonitor docs and tests for log_weight_decay #19805

Merged
merged 3 commits into from May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/lightning/pytorch/callbacks/lr_monitor.py
Expand Up @@ -44,6 +44,8 @@ class LearningRateMonitor(Callback):
according to the ``interval`` key of each scheduler. Defaults to ``None``.
log_momentum: option to also log the momentum values of the optimizer, if the optimizer
has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.
log_weight_decay: option to also log the weight decay values of the optimizer. Defaults to
``False``.

Raises:
MisconfigurationException:
Expand All @@ -58,7 +60,7 @@ class LearningRateMonitor(Callback):

Logging names are automatically determined based on optimizer class name.
In case of multiple optimizers of same type, they will be named ``Adam``,
``Adam-1`` etc. If a optimizer has multiple parameter groups they will
``Adam-1`` etc. If an optimizer has multiple parameter groups they will
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
``name`` keyword in the construction of the learning rate schedulers.
A ``name`` keyword can also be used for parameter groups in the
Expand Down
3 changes: 3 additions & 0 deletions tests/tests_pytorch/callbacks/test_lr_monitor.py
Expand Up @@ -44,6 +44,9 @@ def test_lr_monitor_single_lr(tmp_path):

assert lr_monitor.lrs, "No learning rates logged"
assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default"
assert all(
v is None for v in lr_monitor.last_weight_decay_values.values()
), "Weight decay should not be logged by default"
assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs)
assert list(lr_monitor.lrs) == ["lr-SGD"]

Expand Down