Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Logging in train callbacks #4258

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ac40d1b
release only training callback
tchaton Oct 20, 2020
8231ae3
update for flake8
tchaton Oct 20, 2020
d6eb8d8
resolve `test_multiple_optimizers_manual_return_and_log`
tchaton Oct 20, 2020
7ae6f79
resolve `test_multiple_optimizers_manual_return`
tchaton Oct 20, 2020
9434d11
release only training callback
tchaton Oct 20, 2020
4741258
update for flake8
tchaton Oct 20, 2020
77fae0e
resolve `test_multiple_optimizers_manual_return_and_log`
tchaton Oct 20, 2020
b47b390
resolve `test_multiple_optimizers_manual_return`
tchaton Oct 20, 2020
2a3c72d
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
542d8d3
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
290a160
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
0dfe8c9
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
104c9f5
remove mixin
tchaton Oct 21, 2020
abe57d3
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 21, 2020
f4e2477
remove explicit mixin
tchaton Oct 21, 2020
f59d10d
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
bddd61d
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
0740f1e
resolve logging bug
tchaton Oct 22, 2020
82fc4fe
repair bug
tchaton Oct 22, 2020
075d5bf
resolve pep8
tchaton Oct 22, 2020
f71e588
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
36dbc96
resolve formatting bug
tchaton Oct 22, 2020
81ca911
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 22, 2020
25242fb
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
1a6b172
check if metric and grad_norm_dic is defined
tchaton Oct 22, 2020
4690ff5
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 22, 2020
c7c1e7d
resolve pep8
tchaton Oct 22, 2020
1dbe60c
resolve typo
tchaton Oct 22, 2020
54e2799
convert metris and grad_norm_dic to dict when None
tchaton Oct 22, 2020
8a8b54a
resolve pep8
tchaton Oct 22, 2020
43da42a
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
e1652bf
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
536f85f
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 23, 2020
18247fa
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
1058e5e
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
6504731
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 24, 2020
e90f2c0
move files ar
tchaton Oct 25, 2020
292af7d
create connector_logger_utils
tchaton Oct 25, 2020
abdbd9f
resolve flake8
tchaton Oct 25, 2020
8d1c924
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 25, 2020
4aa2317
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 25, 2020
e16a358
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 21 additions & 8 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import collections
import copy
import inspect
import os
import re
import tempfile
from abc import ABC
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import (
AttributeDict,
collect_init_args,
get_init_args,
)
from pytorch_lightning.callbacks import Callback
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -111,6 +111,8 @@ def __init__(self, *args, **kwargs):
self._datamodule = None
self._results: Optional[Result] = None
self._current_fx_name = ''
self._current_hook_fx_name = ''
self._current_dataloader_idx = None

def optimizers(self):
opts = self.trainer.optimizers
Expand Down Expand Up @@ -244,6 +246,16 @@ def log(
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

if self._current_hook_fx_name != '':
self.trainer.callback_connector.validate_callback_logging_arguments(self._current_hook_fx_name,
on_step=on_step,
on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
raise MisconfigurationException(
f"Logged key: {name} should not contain information about dataloader_idx.")

self._results.log(
name,
value,
Expand All @@ -257,7 +269,8 @@ def log(
enable_graph,
sync_dist,
sync_dist_op,
sync_dist_group
sync_dist_group,
self._current_dataloader_idx,
)

def log_dict(
Expand Down Expand Up @@ -1273,11 +1286,11 @@ def tbptt_split_batch(self, batch, split_size):
batch_split = []
for i, x in enumerate(batch):
if isinstance(x, torch.Tensor):
split_x = x[:, t : t + split_size]
split_x = x[:, t: t + split_size]
elif isinstance(x, collections.Sequence):
split_x = [None] * len(x)
for batch_idx in range(len(x)):
split_x[batch_idx] = x[batch_idx][t : t + split_size]
split_x[batch_idx] = x[batch_idx][t: t + split_size]

batch_split.append(split_x)

Expand Down
21 changes: 19 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], add
m += f' {additional_err}'
assert x.grad_fn is not None, m

def add_dl_idx(self, name: str, dl_idx: Union[None, int]) -> str:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
This function add dl_idx logic to logged key automatically if we have multiple dataloders
"""
if dl_idx is not None:
name += f"/dataloader_idx_{dl_idx}"
return name

def log(
self,
name: str,
Expand All @@ -124,6 +132,7 @@ def log(
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
current_dataloader_idx: Optional[int] = None,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
Expand All @@ -143,7 +152,9 @@ def log(
was_forked = True

# set step version
step_name = f'{name}_step'
# add possibly dataloader_idx
step_name = self.add_dl_idx(f'{name}_step', current_dataloader_idx)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

self.__set_meta(
step_name,
value,
Expand All @@ -156,10 +167,13 @@ def log(
tbptt_pad_token=tbptt_pad_token,
forked=False
)

self.__setitem__(step_name, value)

# set epoch version
epoch_name = f'{name}_epoch'
# add possibly dataloader_idx
epoch_name = self.add_dl_idx(f'{name}_epoch', current_dataloader_idx)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

self.__set_meta(
epoch_name,
value,
Expand All @@ -174,6 +188,9 @@ def log(
)
self.__setitem__(epoch_name, value)

# add possibly dataloader_idx
name = self.add_dl_idx(name, current_dataloader_idx)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# always log the original metric
self.__set_meta(
name,
Expand Down
199 changes: 198 additions & 1 deletion pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,209 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class CallbackConnector:
class CallbackConnectorLoggingMixin(ABC):
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# This Mixin class helps to summarize logging logic for Pytorch LightningModule using self.log functions for `pytorch_lightning.callbacks.CallBack`

@staticmethod
def validate_callback_logging_arguments(current_hook_fx_name:str = None, on_step:bool = None, on_epoch: bool = None) -> None:
current_callback_hook_auth_args = getattr(CallbackConnectorLoggingMixin, f"_{current_hook_fx_name}_log")()

if current_callback_hook_auth_args is not None:
m = "{} function supports only {} in {}. Provided {}"
if on_step not in current_callback_hook_auth_args["on_step"]:
msg = m.format(current_hook_fx_name, "on_step", current_callback_hook_auth_args["on_step"], on_step)
raise MisconfigurationException(msg)

if on_epoch not in current_callback_hook_auth_args["on_epoch"]:
msg = m.format(current_hook_fx_name, "on_epoch", current_callback_hook_auth_args["on_epoch"], on_epoch)
raise MisconfigurationException(msg)
else:
raise MisconfigurationException(f"{current_hook_fx_name} function doesn't support logging using self.log() yet.")

@staticmethod
def _setup_log():
"""Called when fit or test begins"""
return None

@staticmethod
def _teardown_log():
"""Called at the end of fit and test"""
return None

@staticmethod
def _on_init_start_log():
"""Called when the trainer initialization begins, model has not yet been set."""
return None

@staticmethod
def _on_init_end_log():
"""Called when the trainer initialization ends, model has not yet been set."""
return None

@staticmethod
def _on_fit_start_log():
"""Called when the trainer initialization begins, model has not yet been set."""
return None

@staticmethod
def _on_fit_end_log():
"""Called when the trainer initialization begins, model has not yet been set."""
return None

@staticmethod
def _on_sanity_check_start_log():
"""Called when the validation sanity check starts."""
return None

@staticmethod
def _on_sanity_check_end_log():
"""Called when the validation sanity check ends."""
return None

@staticmethod
def _on_train_epoch_start_log():
"""Called when the epoch begins."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_train_epoch_end_log():
"""Called when the epoch ends."""
return {"on_step" : [False], "on_epoch" : [False, True]}

@staticmethod
def _on_validation_epoch_start_log():
"""Called when the epoch begins."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_validation_epoch_end_log():
"""Called when the epoch ends."""
return {"on_step" : [False], "on_epoch" : [False, True]}

@staticmethod
def _on_test_epoch_start_log():
"""Called when the epoch begins."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_test_epoch_end_log():
"""Called when the epoch ends."""
return {"on_step" : [False], "on_epoch" : [False, True]}

@staticmethod
def _on_epoch_start_log():
"""Called when the epoch begins."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_epoch_end_log():
"""Called when the epoch ends."""
return {"on_step" : [False], "on_epoch" : [False, True]}

@staticmethod
def _on_train_start_log():
"""Called when the train begins."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_train_end_log():
"""Called when the train ends."""
return None

@staticmethod
def _on_pretrain_routine_start_log():
"""Called when the train begins."""
return None

@staticmethod
def _on_pretrain_routine_end_log():
"""Called when the train ends."""
return None

@staticmethod
def _on_batch_start_log():
"""Called when the training batch begins."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_batch_end_log():
"""Called when the training batch ends."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_train_batch_start_log():
"""Called when the training batch begins."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_train_batch_end_log():
"""Called when the training batch ends."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_validation_batch_start_log():
"""Called when the validation batch begins."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_validation_batch_end_log():
"""Called when the validation batch ends."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_test_batch_start_log():
"""Called when the test batch begins."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_test_batch_end_log():
"""Called when the test batch ends."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_validation_start_log():
"""Called when the validation loop begins."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_validation_end_log():
"""Called when the validation loop ends."""
return {"on_step" : [False], "on_epoch" : [False, True]}

@staticmethod
def _on_test_start_log():
"""Called when the test begins."""
return {"on_step" : [False, True], "on_epoch" : [False, True]}

@staticmethod
def _on_test_end_log():
"""Called when the test ends."""
return None

@staticmethod
def _on_keyboard_interrupt_log():
"""Called when the training is interrupted by KeyboardInterrupt."""
return None

@staticmethod
def _on_save_checkpoint_log():
"""Called when saving a model checkpoint."""
return None

@staticmethod
def _on_load_checkpoint_log():
"""Called when loading a model checkpoint."""
return None


class CallbackConnector(CallbackConnectorLoggingMixin):

def __init__(self, trainer):
self.trainer = trainer
Expand Down