Skip to content

Commit

Permalink
rename verbose -> detail and extra logging
Browse files Browse the repository at this point in the history
  • Loading branch information
edward-io committed Dec 23, 2021
1 parent 63910a5 commit bb72db5
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 20 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Add new `VERBOSE` log level to provide useful logs for production use case monitoring and debugging
- Add new `DETAIL` log level to provide useful logs for improving monitoring and debugging of batch jobs


- Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601))
Expand Down
15 changes: 8 additions & 7 deletions pytorch_lightning/__init__.py
@@ -1,20 +1,21 @@
"""Root package info."""

import logging
from typing import Any

from pytorch_lightning.__about__ import * # noqa: F401, F403

VERBOSE = 15 # between logging.INFO and logging.DEBUG, used for logging in production use cases
DETAIL = 15 # between logging.INFO and logging.DEBUG, used for logging in production use cases


def verbose(self, message, *args, **kws):
if self.isEnabledFor(VERBOSE):
self._log(VERBOSE, message, args, **kws)
def detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None:
if self.isEnabledFor(DETAIL):
self._log(DETAIL, message, *args, **kwargs)


logging.addLevelName(VERBOSE, "VERBOSE")
logging.verbose = verbose
logging.Logger.verbose = verbose
logging.addLevelName(DETAIL, "DETAIL")
logging.detail = detail
logging.Logger.detail = detail

_root_logger = logging.getLogger()
_logger = logging.getLogger(__name__)
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/loops/fit_loop.py
Expand Up @@ -204,6 +204,7 @@ def on_advance_start(self) -> None: # type: ignore[override]

# reset train dataloader
if not self._is_fresh_start_epoch and self.trainer._should_reload_dl_epoch:
log.detail("resetting train dataloader")
self.trainer.reset_train_dataloader(model)
self._is_fresh_start_epoch = False

Expand All @@ -223,6 +224,7 @@ def on_advance_start(self) -> None: # type: ignore[override]

def advance(self) -> None: # type: ignore[override]
"""Runs one whole epoch."""
log.detail("advancing fit loop")
dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader)
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)

Expand All @@ -243,6 +245,7 @@ def on_advance_end(self) -> None:

def on_run_end(self) -> None:
"""Calls the ``on_train_end`` hook."""
log.detail("fit loop ended")
# NOTE: the current_epoch is already incremented
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
# To simulate that current behavior, we decrement here.
Expand Down
18 changes: 11 additions & 7 deletions pytorch_lightning/strategies/ddp.py
Expand Up @@ -102,7 +102,7 @@ def __init__(
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
log.verbose(f"Initializing DDP: {self.__class__.__name__}")
log.detail(f"{self.__class__.__name__}: initializing DDP plugin")
self.interactive_ddp_procs = []
self._num_nodes = 1
self.sync_batchnorm = False
Expand Down Expand Up @@ -172,7 +172,9 @@ def setup(self, trainer: "pl.Trainer") -> None:

def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
device_ids = self.determine_ddp_device_ids()
log.detail(f"setting up DDP model with device ids: {device_ids}, kwargs: {self._ddp_kwargs}")
return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)

def _call_children_scripts(self):
# bookkeeping of spawned processes
Expand Down Expand Up @@ -244,7 +246,7 @@ def _call_children_scripts(self):
self._rank_0_has_called_call_children_scripts = True

def setup_distributed(self):
log.verbose(f"{self.__class__.__name__}: setting up distributed...")
log.detail(f"{self.__class__.__name__}: setting up distributed...")
reset_seed()

# determine which process we are and world size
Expand Down Expand Up @@ -290,6 +292,7 @@ def pre_configure_ddp(self):
self._ddp_kwargs["find_unused_parameters"] = True

def _register_ddp_hooks(self) -> None:
log.detail(f"registering ddp hooks")
# In 1.8, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# Since 1.9, DDP communication hooks can work on all backends.
if _TORCH_GREATER_EQUAL_1_9 or (
Expand All @@ -309,6 +312,7 @@ def _register_ddp_hooks(self) -> None:
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)

def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
log.detail("reinitializing optimizers with post localSGD")
optimizers = self.lightning_module.trainer.optimizers
if self._model_averaging_period is None:
raise ValueError(
Expand Down Expand Up @@ -352,7 +356,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
trainer.convert_to_lightning_optimizers()

def configure_ddp(self) -> None:
log.verbose(f"{self.__class__.__name__}: configuring DDP...")
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
self.pre_configure_ddp()
self.model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()
Expand Down Expand Up @@ -383,7 +387,7 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None:
prepare_for_backward(self.model, closure_loss)

def model_to_device(self):
log.verbose(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
log.detail(f"{self.__class__.__qualname__}: moving model to device [{self.root_device}]...")
self.model.to(self.root_device)

def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:
Expand Down Expand Up @@ -504,7 +508,7 @@ def reconciliate_processes(self, trace: str) -> None:
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")

def teardown(self) -> None:
log.verbose(f"{self.__class__.__name__}: tearing down plugin...")
log.detail(f"{self.__class__.__qualname__}: tearing down DDP plugin")
super().teardown()
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module
Expand All @@ -514,7 +518,7 @@ def teardown(self) -> None:

if self.on_gpu:
# GPU teardown
log.verbose(f"{self.__class__.__name__}: moving model to CPU...")
log.detail(f"{self.__class__.__qualname__}: moving model to CPU")
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
8 changes: 4 additions & 4 deletions pytorch_lightning/strategies/fully_sharded.py
Expand Up @@ -147,7 +147,7 @@ def setup(self, trainer: "pl.Trainer") -> None:

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
log.verbose(f"{self.__class__.__name__}: entered model_sharded_context.")
log.detail(f"{self.__class__.__name__}: entered model_sharded_context.")
precision = self.precision_plugin.precision

def wrap_policy(*args, **kwargs):
Expand All @@ -169,10 +169,10 @@ def wrap_policy(*args, **kwargs):
):
yield

log.verbose(f"{self.__class__.__name__}: exiting model_sharded_context.")
log.detail(f"{self.__class__.__name__}: exiting model_sharded_context.")

def configure_ddp(self) -> None:
log.verbose(f"{self.__class__.__name__}: configuring DDP... (cpu_offload: [{self.cpu_offload}])")
log.detail(f"{self.__class__.__name__}: configuring DDP... (cpu_offload: [{self.cpu_offload}])")
if not self.cpu_offload:
# When using CPU Offload, FSDP will manage the CUDA movement for us.
# Note: this would be problematic for large model (which could not fit in one GPU)
Expand All @@ -184,7 +184,7 @@ def configure_ddp(self) -> None:
self.setup_optimizers(self.lightning_module.trainer)

def model_to_device(self) -> None:
log.verbose(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...")
# ensure we update the device type in the lightning module
self.lightning_module.to(self.root_device)

Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from abc import ABC
from copy import deepcopy
from typing import Any, Dict, List, Optional, Type, Union
Expand All @@ -25,6 +26,9 @@
from pytorch_lightning.utilities.types import STEP_OUTPUT


log = logging.getLogger(__name__)


class TrainerCallbackHookMixin(ABC):
r"""
.. deprecated:: v1.6
Expand Down Expand Up @@ -384,6 +388,7 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0):
"`TrainerCallbackHookMixin.on_train_batch_start` was deprecated in v1.6 and will be removed in v1.8."
)
for callback in self.callbacks:
log.detail(f"calling callback {callback.__class__.__name__}.on_train_batch_start")
if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True):
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0)
else:
Expand All @@ -401,6 +406,7 @@ def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_
"`TrainerCallbackHookMixin.on_train_batch_end` was deprecated in v1.6 and will be removed in v1.8."
)
for callback in self.callbacks:
log.detail(f"calling callback {callback.__class__.__name__}.on_train_batch_end")
if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True):
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0)
else:
Expand Down Expand Up @@ -597,6 +603,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
)
callback_states = {}
for callback in self.callbacks:
log.detail(f"calling callback {callback.__class__.__name__}.on_save_checkpoint")
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
callback_states[callback.state_key] = state
Expand Down Expand Up @@ -631,6 +638,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
)

for callback in self.callbacks:
log.detail(f"calling callback {callback.__class__.__name__}.on_load_checkpoint")
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
if state:
state = deepcopy(state)
Expand Down
Expand Up @@ -78,7 +78,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
self.resume_checkpoint_path = self._hpc_resume_path or self._fault_tolerant_auto_resume_path or checkpoint_path
checkpoint_path = self.resume_checkpoint_path
if not checkpoint_path:
log.verbose("`checkpoint_path` not specified. Skipping checkpoint loading.")
log.detail("`checkpoint_path` not specified. Skipping checkpoint loading.")
return

rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
Expand Down
13 changes: 13 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -429,6 +429,7 @@ def __init__(
"""
super().__init__()
Trainer._log_api_event("init")
log.detail(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}")
self.state = TrainerState()

gpu_ids, tpu_cores = self._parse_devices(gpus, auto_select_gpus, tpu_cores)
Expand Down Expand Up @@ -740,6 +741,7 @@ def _fit_impl(
ckpt_path: Optional[str] = None,
) -> None:
Trainer._log_api_event("fit")
log.detail("fitting")

self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
Expand Down Expand Up @@ -815,6 +817,7 @@ def _validate_impl(
# SETUP HOOK
# --------------------
Trainer._log_api_event("validate")
log.detail("validating")

self.state.fn = TrainerFn.VALIDATING
self.state.status = TrainerStatus.RUNNING
Expand Down Expand Up @@ -900,6 +903,7 @@ def _test_impl(
# SETUP HOOK
# --------------------
Trainer._log_api_event("test")
log.detail("testing")

self.state.fn = TrainerFn.TESTING
self.state.status = TrainerStatus.RUNNING
Expand Down Expand Up @@ -986,6 +990,7 @@ def _predict_impl(
# SETUP HOOK
# --------------------
Trainer._log_api_event("predict")
log.detail("predicting")

self.state.fn = TrainerFn.PREDICTING
self.state.status = TrainerStatus.RUNNING
Expand Down Expand Up @@ -1101,19 +1106,23 @@ def _run(
verify_loop_configurations(self)

# hook
log.detail("preparing data")
self._data_connector.prepare_data()

# ----------------------------
# SET UP TRAINING
# ----------------------------
self._call_callback_hooks("on_before_accelerator_backend_setup")
log.detail("setting up strategy environment")
self.strategy.setup_environment()
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment

# check if we should delay restoring checkpoint till later
if not self.strategy.restore_checkpoint_after_setup:
log.detail(f"restoring module and callbacks from checkpoint path: {ckpt_path}")
self._restore_modules_and_callbacks(ckpt_path)

log.detail("configuring sharded model")
self._call_configure_sharded_model() # allow user to setup in model sharded environment

# ----------------------------
Expand Down Expand Up @@ -1159,14 +1168,17 @@ def _run(
self._log_hyperparams()

if self.strategy.restore_checkpoint_after_setup:
log.detail(f"restoring module and callbacks from checkpoint path: {ckpt_path}")
self._restore_modules_and_callbacks(ckpt_path)

# restore optimizers, etc.
log.detail("restoring training state")
self.checkpoint_connector.restore_training_state()

self.checkpoint_connector.resume_end()

results = self._run_stage()
log.detail("tearing down trainer")
self._teardown()

# ----------------------------
Expand All @@ -1177,6 +1189,7 @@ def _run(
self._call_callback_hooks("on_fit_end")
self._call_lightning_module_hook("on_fit_end")

log.detail("calling teardown hooks")
self._call_teardown_hook()

if self.state.status != TrainerStatus.INTERRUPTED:
Expand Down

0 comments on commit bb72db5

Please sign in to comment.