Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor codebase to use the loggers property instead of logger #11731

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
2883fb2
Implement logger property, replace self.logger
akashkw Feb 1, 2022
3ae0a6e
Fix small bugs
akashkw Feb 1, 2022
6d917a4
Fixed initalization bug
akashkw Feb 1, 2022
30d3a57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2022
62c9335
Fix for case where creating a LoggerCollection of size 1
akashkw Feb 1, 2022
99a4d98
Better fix for the LoggerCollection of size 1 issue
akashkw Feb 1, 2022
01287ba
Change trainer.loggers from a property to an instance variable
akashkw Feb 1, 2022
41cbab0
Revert all instances of trainer.loggers being used
akashkw Feb 1, 2022
d78d98f
Use logger param to initialize trainer.loggers
akashkw Feb 1, 2022
7f809ff
Remove unneeded newlines
akashkw Feb 1, 2022
179f3d4
Implement unit test for loggers property
akashkw Feb 1, 2022
1ad542a
make trainer.loggers by default an empty list
akashkw Feb 1, 2022
c5df0ad
Update changelog
akashkw Feb 1, 2022
9d564d8
fix unit test according to suggestions
akashkw Feb 1, 2022
9d7c1bf
Update CHANGELOG.md
akashkw Feb 1, 2022
d263659
Remove unnecessary Trainer params
akashkw Feb 2, 2022
befad11
Remove tmpdir parameter for unit test
akashkw Feb 2, 2022
4773c26
Write setters for logger and loggers
akashkw Feb 2, 2022
8871c36
Unit test for setters
akashkw Feb 2, 2022
65ae649
Fix bug where logger setter is called twice
akashkw Feb 2, 2022
df992de
Fix initialization bug with trainer test
akashkw Feb 2, 2022
5e47ef4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2022
d55f626
Get rid of extra DummyLogger assignment
akashkw Feb 2, 2022
3debee7
Merge branch 'refactor/create-loggers-property' of github.com:akashkw…
akashkw Feb 2, 2022
29778b1
flake and mypy fixes
akashkw Feb 2, 2022
506e5fd
Flake fix did not commit properly
akashkw Feb 2, 2022
144169b
Small changes based on suggestions
akashkw Feb 2, 2022
bdcbcfb
Shorten setters and update unit test
akashkw Feb 2, 2022
d42ce90
Move unit test to a new file
akashkw Feb 2, 2022
4001efc
flake and mypy fixes
akashkw Feb 2, 2022
9401fb7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2022
e2be787
Refactor setter to handle special case of size 1 LoggerCollection
akashkw Feb 3, 2022
1e71c40
Remove DummyLogger changes
akashkw Feb 3, 2022
4418f10
Merge branch 'master' into refactor/create-loggers-property
akashkw Feb 3, 2022
c96534b
Commit suggestion to change None to []
akashkw Feb 3, 2022
6a6e2ef
Merge branch 'master' into refactor/create-loggers-property
akashkw Feb 3, 2022
7fef6fd
Decouple both setters for readability
akashkw Feb 3, 2022
f2a598f
Fix merge conflicts
akashkw Feb 3, 2022
a669a98
Fix tiny bug in trainer.loggers setter
akashkw Feb 3, 2022
b6d6fba
Merge branch 'master' into refactor/create-loggers-property
akashkw Feb 3, 2022
8714857
Add loggers property to lightning.py
akashkw Feb 3, 2022
a7a47f1
update changelog
akashkw Feb 3, 2022
15fa585
Add logger property to docs
akashkw Feb 3, 2022
3f195be
Fix typo
akashkw Feb 3, 2022
c944f9a
update trainer.rst
akashkw Feb 3, 2022
e3312d5
correct spacing
akashkw Feb 3, 2022
88b3c24
remove typing for now
akashkw Feb 3, 2022
89f5037
Merge branch 'master' into refactor/create-loggers-property
akashkw Feb 3, 2022
85256e7
Fix jit unused issue
akashkw Feb 3, 2022
7ab3683
Fix underlines in docs
akashkw Feb 3, 2022
e390c0d
convert some unit tests
akashkw Feb 4, 2022
aeaa448
Modify more tests
akashkw Feb 4, 2022
c51f336
Updates based on suggestions
akashkw Feb 4, 2022
820ce32
Merge branch 'refactor/create-loggers-property' into refactor/switch-…
akashkw Feb 4, 2022
b2bd5a1
refactor two more tests
akashkw Feb 4, 2022
ec99acd
More updates to docs based on suggestions
akashkw Feb 4, 2022
ba09e27
Create unit test for lightningmodule loggers property
akashkw Feb 4, 2022
bc6fd72
Replace Mock with Trainer
akashkw Feb 4, 2022
17895f3
Merge branch 'refactor/create-loggers-property' into refactor/switch-…
akashkw Feb 4, 2022
8278519
update all the logger tests with loggers property
akashkw Feb 4, 2022
40c3c73
update the last of the unit tests to use loggers
akashkw Feb 4, 2022
d23cac8
Bugfixes for unit tests
akashkw Feb 4, 2022
f5b492d
Update types
akashkw Feb 4, 2022
8800ec8
Remove list cast
akashkw Feb 4, 2022
15b2644
Merge branch 'refactor/create-loggers-property' into refactor/switch-…
akashkw Feb 4, 2022
5852ee2
Remove unit test for unsupported behavior
akashkw Feb 4, 2022
9a29e3f
Merge branch 'refactor/create-loggers-property' into refactor/switch-…
akashkw Feb 4, 2022
701d4fa
Add TODO for tests that can't be easily adapted
akashkw Feb 4, 2022
ac84709
Change TODO messages
akashkw Feb 4, 2022
0752fb5
Handle special case for setter
akashkw Feb 4, 2022
46b5ae1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2022
39a85da
Merge branch 'refactor/create-loggers-property' into refactor/switch-…
akashkw Feb 4, 2022
15ae1f6
Merge branch 'refactor/create-loggers-property' into refactor/switch-…
akashkw Feb 4, 2022
023802f
refactor logger_connector, batch_size_scaling, and lr_finder
akashkw Feb 4, 2022
7b8a79f
update signal_connector and trainer.py
akashkw Feb 4, 2022
b97724c
small bugfix
akashkw Feb 4, 2022
d186a7f
Fix handling of special case
akashkw Feb 4, 2022
68eaf65
refactor all logger files
akashkw Feb 5, 2022
be12180
Modify more files
akashkw Feb 5, 2022
83339b2
Update docs
akashkw Feb 5, 2022
ffbf225
Fix flake8 error
akashkw Feb 5, 2022
07707a6
Restore broken file
akashkw Feb 5, 2022
70d2461
Add TODO to broken file
akashkw Feb 5, 2022
5c5d87c
Review and resolve a few TODO's
akashkw Feb 7, 2022
d30bb02
Sanity check on a few more files
akashkw Feb 7, 2022
dd32d6a
rewrite log_dir
akashkw Feb 7, 2022
b5fa1d8
Revise unit test
akashkw Feb 7, 2022
87a8eb0
flake fixes
akashkw Feb 7, 2022
9b3fa48
resolve merge conflict
akashkw Feb 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a `MisconfigurationException` if user provided `opt_idx` in scheduler config doesn't match with actual optimizer index of its respective optimizer ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247))


- Added a `loggers` property to `Trainer` which returns a list of loggers provided by the user ([#11683](https://github.com/PyTorchLightning/pytorch-lightning/pull/11683))


- Added a `loggers` property to `LightningModule` which retrieves the `loggers` property from `Trainer` ([#11683](https://github.com/PyTorchLightning/pytorch-lightning/pull/11683))


- Added support for DDP when using a `CombinedLoader` for the training data ([#11648](https://github.com/PyTorchLightning/pytorch-lightning/pull/11648))


Expand Down
13 changes: 12 additions & 1 deletion docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ The current step (does not reset each epoch)
.. code-block:: python

def training_step(self, batch, batch_idx):
self.logger.experiment.log_image(..., step=self.global_step)
self.loggers[logger_index].experiment.log_image(..., step=self.global_step)

hparams
~~~~~~~
Expand Down Expand Up @@ -983,6 +983,17 @@ The current logger being used (tensorboard or other supported logger)
# the particular logger
tensorboard_logger = self.logger.experiment

loggers
~~~~~~~

The list of loggers currently being used.

.. code-block:: python

def training_step(self, batch, batch_idx):
# List of LightningLoggerBase objects
self.loggers

local_rank
~~~~~~~~~~~

Expand Down
17 changes: 7 additions & 10 deletions docs/source/common/loggers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ The :class:`~pytorch_lightning.loggers.CometLogger` is available anywhere except
class MyModule(LightningModule):
def any_lightning_module_function_or_hook(self):
some_img = fake_image()
self.logger.experiment.add_image("generated_images", some_img, 0)
self.loggers[logger_index].experiment.add_image("generated_images", some_img, 0)

.. seealso::
:class:`~pytorch_lightning.loggers.CometLogger` docs.
Expand Down Expand Up @@ -135,9 +135,9 @@ The :class:`~pytorch_lightning.loggers.NeptuneLogger` is available anywhere exce
def any_lightning_module_function_or_hook(self):
# generic recipe for logging custom metadata (neptune specific)
metadata = ...
self.logger.experiment["your/metadata/structure"].log(metadata)
self.loggers[logger_index].experiment["your/metadata/structure"].log(metadata)

Note that syntax: ``self.logger.experiment["your/metadata/structure"].log(metadata)``
Note that syntax: ``self.loggers[logger_index].experiment["your/metadata/structure"].log(metadata)``
is specific to Neptune and it extends logger capabilities.
Specifically, it allows you to log various types of metadata like scores, files,
images, interactive visuals, CSVs, etc. Refer to the
Expand Down Expand Up @@ -173,7 +173,7 @@ The :class:`~pytorch_lightning.loggers.TensorBoardLogger` is available anywhere
class MyModule(LightningModule):
def any_lightning_module_function_or_hook(self):
some_img = fake_image()
self.logger.experiment.add_image("generated_images", some_img, 0)
self.loggers[logger_index].experiment.add_image("generated_images", some_img, 0)

.. seealso::
:class:`~pytorch_lightning.loggers.TensorBoardLogger` docs.
Expand Down Expand Up @@ -213,9 +213,9 @@ The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except
def any_lightning_module_function_or_hook(self):
some_img = fake_image()
# Option 1
self.logger.experiment.log({"generated_images": [wandb.Image(some_img, caption="...")]})
self.loggers[logger_index].experiment.log({"generated_images": [wandb.Image(some_img, caption="...")]})
# Option 2 for specifically logging images
self.logger.log_image(key="generated_images", images=[some_img])
self.loggers[logger_index].log_image(key="generated_images", images=[some_img])

.. seealso::
- :class:`~pytorch_lightning.loggers.WandbLogger` docs.
Expand Down Expand Up @@ -246,7 +246,4 @@ The loggers are available as a list anywhere except ``__init__`` in your
class MyModule(LightningModule):
def any_lightning_module_function_or_hook(self):
some_img = fake_image()
# Option 1
self.logger.experiment[0].add_image("generated_images", some_img, 0)
# Option 2
self.logger[0].experiment.add_image("generated_images", some_img, 0)
self.loggers[logger_index].experiment.add_image("generated_images", some_img, 0)
18 changes: 15 additions & 3 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1741,9 +1741,21 @@ The current logger being used. Here's an example using tensorboard

.. code-block:: python

def training_step(self, batch, batch_idx):
logger = self.trainer.logger
tensorboard = logger.experiment
logger = trainer.logger
tensorboard = logger.experiment


loggers (p)
***********

The list of loggers currently being used.

.. code-block:: python

# List of LightningLoggerBase objects
loggers = trainer.loggers
for logger in loggers:
logger.log_metrics({"foo": 1.0})


logged_metrics
Expand Down
5 changes: 3 additions & 2 deletions docs/source/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ If you want to log anything that is not a scalar, like histograms, text, images,
def training_step(self):
...
# the logger you used (in this case tensorboard)
tensorboard = self.logger.experiment
tensorboard = self.loggers[logger_index].experiment
tensorboard.add_image()
tensorboard.add_histogram(...)
tensorboard.add_figure(...)
Expand Down Expand Up @@ -378,7 +378,8 @@ in the `hparams tab <https://pytorch.org/docs/stable/tensorboard.html#torch.util

# Using custom or multiple metrics (default_hp_metric=False)
def on_train_start(self):
self.logger.log_hyperparams(self.hparams, {"hp/metric_1": 0, "hp/metric_2": 0})
for logger in self.loggers:
logger.log_hyperparams(self.hparams, {"hp/metric_1": 0, "hp/metric_2": 0})


def validation_step(self, batch, batch_idx):
Expand Down
2 changes: 1 addition & 1 deletion docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ You can also use any method of your logger directly:
.. code-block:: python

def training_step(self, batch, batch_idx):
tensorboard = self.logger.experiment
tensorboard = self.loggers[logger_index].experiment
tensorboard.any_summary_writer_method_you_want()

Once your training starts, you can view the logs by using your favorite logger or booting up the Tensorboard logs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def on_train_epoch_end(self):
# log sampled images
sample_imgs = self(z)
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
for logger in self.loggers:
logger.experiment.add_image("generated_images", grid, self.current_epoch)


def main(args: Namespace) -> None:
Expand Down
20 changes: 11 additions & 9 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DeviceStatsMonitor(Callback):
"""

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")

def on_train_batch_start(
Expand All @@ -55,17 +55,18 @@ def on_train_batch_start(
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")

if not trainer.logger_connector.should_update_logs:
return

device = trainer.strategy.root_device
device_stats = trainer.accelerator.get_device_stats(device)
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)

def on_train_batch_end(
self,
Expand All @@ -76,17 +77,18 @@ def on_train_batch_end(
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")

if not trainer.logger_connector.should_update_logs:
return

device = trainer.strategy.root_device
device_stats = trainer.accelerator.get_device_stats(device)
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
for logger in trainer.loggers:
separator = logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)


def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/callbacks/gpu_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
self._gpu_ids: List[str] = [] # will be assigned later in setup()

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")

if trainer._device_type != _AcceleratorType.GPU:
Expand Down Expand Up @@ -162,8 +162,9 @@ def on_train_batch_start(
# First log at beginning of second step
logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)
assert trainer.loggers
for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)

@rank_zero_only
def on_train_batch_end(
Expand All @@ -187,8 +188,9 @@ def on_train_batch_end(
if self._log_stats.intra_step_time and self._snap_intra_step_time:
logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000

assert trainer.logger is not None
trainer.logger.log_metrics(logs, step=trainer.global_step)
assert trainer.loggers
for logger in trainer.loggers:
logger.log_metrics(logs, step=trainer.global_step)

@staticmethod
def _get_gpu_ids(device_ids: List[int]) -> List[str]:
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No
MisconfigurationException:
If ``Trainer`` has no ``logger``.
"""
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException(
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
)
Expand Down Expand Up @@ -149,7 +149,7 @@ def _check_no_key(key: str) -> bool:
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}

def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
assert trainer.loggers
if not trainer.logger_connector.should_update_logs:
return

Expand All @@ -158,16 +158,18 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any)
latest_stat = self._extract_stats(trainer, interval)

if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)

def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
assert trainer.loggers
if self.logging_interval != "step":
interval = "epoch" if self.logging_interval is None else "any"
latest_stat = self._extract_stats(trainer, interval)

if latest_stat:
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
for logger in trainer.loggers:
logger.log_metrics(latest_stat, step=trainer.global_step)

def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]:
latest_stat = {}
Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,9 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None:
self._save_last_checkpoint(trainer, monitor_candidates)

# notify loggers
if trainer.is_global_zero and trainer.logger:
trainer.logger.after_save_checkpoint(proxy(self))
if trainer.is_global_zero:
for logger in trainer.loggers:
logger.after_save_checkpoint(proxy(self))

def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
from pytorch_lightning.trainer.states import TrainerFn
Expand Down Expand Up @@ -580,18 +581,22 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
if self.dirpath is not None:
return # short circuit

if trainer.logger is not None:
if trainer.loggers:
if trainer.weights_save_path != trainer.default_root_dir:
# the user has changed weights_save_path, it overrides anything
save_dir = trainer.weights_save_path
else:
save_dir = trainer.logger.save_dir or trainer.default_root_dir

if len(trainer.loggers) == 1 and trainer.loggers[0].save_dir:
save_dir = trainer.loggers[0].save_dir
else:
save_dir = trainer.default_root_dir
# TODO: Find out we handle trainer.logger.version without LoggerCollection
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should we handle trainer.logger.version and trainer.logger.name here? LoggerCollection will concatenate the values together from each logger, but with trainer.loggers we no longer have that logic and that logic might not be the best approach anyways.

version = (
trainer.logger.version
if isinstance(trainer.logger.version, str)
else f"version_{trainer.logger.version}"
)
# TODO: Find out we handle trainer.logger.name without LoggerCollection
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints")
else:
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
if pl_module.truncated_bptt_steps > 0:
items_dict["split_idx"] = trainer.fit_loop.split_idx

# TODO: Find out we handle trainer.logger.version without LoggerCollection
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if trainer.logger is not None and trainer.logger.version is not None:
version = trainer.logger.version
if isinstance(version, str):
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/callbacks/xla_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, verbose: bool = True) -> None:
self._verbose = verbose

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

if trainer._device_type != _AcceleratorType.TPU:
Expand All @@ -87,7 +87,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
self._start_time = time.time()

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.logger:
if not trainer.loggers:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

device = trainer.strategy.root_device
Expand All @@ -101,10 +101,11 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
epoch_time = trainer.strategy.reduce(epoch_time)

trainer.logger.log_metrics(
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
step=trainer.current_epoch,
)
for logger in trainer.loggers:
logger.log_metrics(
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
step=trainer.current_epoch,
)

if self._verbose:
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def on_before_optimizer_step(self, optimizer, optimizer_idx):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
for k, v in self.named_parameters():
self.logger.experiment.add_histogram(
self.loggers[logger_index].experiment.add_histogram(
tag=k, values=v.grad, global_step=self.trainer.global_step
)
"""
Expand Down