Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Centralize rank_zero_only utilities into their own module #11747

Merged
merged 15 commits into from Feb 7, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -92,6 +92,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for `Bagua` training strategy ([#11146](https://github.com/PyTorchLightning/pytorch-lightning/pull/11146))


- Added `rank_zero` module to centralize utilities ([#11747](https://github.com/PyTorchLightning/pytorch-lightning/pull/11747))
ananthsub marked this conversation as resolved.
Show resolved Hide resolved


### Changed

- Implemented a new native and rich format in `_print_results` method of the `EvaluationLoop` ([#11332](https://github.com/PyTorchLightning/pytorch-lightning/pull/11332))
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_references.rst
Expand Up @@ -289,5 +289,6 @@ Utilities API
memory
model_summary
parsing
rank_zero
seed
warnings
2 changes: 1 addition & 1 deletion docs/source/common/checkpointing.rst
Expand Up @@ -98,7 +98,7 @@ Lightning automatically ensures that the model is saved only on the main process
trainer.save_checkpoint("example.ckpt")

Not using :meth:`~pytorch_lightning.trainer.trainer.Trainer.save_checkpoint` can lead to unexpected behavior and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the Trainer's save functionality.
If using custom saving functions cannot be avoided, we recommend using the :func:`~pytorch_lightning.utilities.distributed.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using
If using custom saving functions cannot be avoided, we recommend using the :func:`~pytorch_lightning.utilities.rank_zero.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using
model parallel distributed strategies such as deepspeed or sharded training.


Expand Down
2 changes: 1 addition & 1 deletion docs/source/extensions/logging.rst
Expand Up @@ -205,7 +205,7 @@ Make a Custom Logger
********************

You can implement your own logger by writing a class that inherits from :class:`~pytorch_lightning.loggers.base.LightningLoggerBase`.
Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`~pytorch_lightning.utilities.distributed.rank_zero_only` decorators to make sure that only the first process in DDP training creates the experiment and logs the data respectively.
Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`~pytorch_lightning.utilities.rank_zero.rank_zero_only` decorators to make sure that only the first process in DDP training creates the experiment and logs the data respectively.

.. testcode::

Expand Down
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/autoencoder.py
Expand Up @@ -25,9 +25,9 @@
import pytorch_lightning as pl
from pl_examples import _DATASETS_PATH, cli_lightning_logo
from pl_examples.basic_examples.mnist_datamodule import MNIST
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _TORCHVISION_AVAILABLE:
import torchvision
Expand Down
Expand Up @@ -58,8 +58,8 @@
from pl_examples import cli_lightning_logo
from pytorch_lightning import LightningDataModule
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.rank_zero import rank_zero_info

log = logging.getLogger(__name__)
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Expand Up @@ -26,8 +26,8 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/finetuning.py
Expand Up @@ -26,8 +26,8 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

log = logging.getLogger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/gpu_stats_monitor.py
Expand Up @@ -29,9 +29,10 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _AcceleratorType, rank_zero_deprecation, rank_zero_only
from pytorch_lightning.utilities import _AcceleratorType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
from pytorch_lightning.utilities.types import STEP_OUTPUT


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/lr_monitor.py
Expand Up @@ -27,8 +27,8 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Expand Up @@ -33,9 +33,9 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress/base.py
Expand Up @@ -15,7 +15,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_warn


class ProgressBarBase(Callback):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress/tqdm_progress.py
Expand Up @@ -27,7 +27,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities.distributed import rank_zero_debug
from pytorch_lightning.utilities.rank_zero import rank_zero_debug

_PAD_SIZE = 5

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/pruning.py
Expand Up @@ -30,8 +30,8 @@
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_only

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/stochastic_weight_avg.py
Expand Up @@ -24,8 +24,8 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerConfig

_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/timer.py
Expand Up @@ -24,8 +24,8 @@
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info

log = logging.getLogger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/xla_stats_monitor.py
Expand Up @@ -22,8 +22,9 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE, rank_zero_deprecation, rank_zero_info
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/core/decorators.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn

rank_zero_deprecation(
"Using `pytorch_lightning.core.decorators.parameter_validation` is deprecated in v1.5, "
Expand All @@ -22,8 +22,6 @@
from functools import wraps # noqa: E402
from typing import Callable # noqa: E402

from pytorch_lightning.utilities import rank_zero_warn # noqa: E402


def parameter_validation(fn: Callable) -> Callable:
"""Validates that the module parameter lengths match after moving to the device. It is useful when tying
Expand Down
11 changes: 3 additions & 8 deletions pytorch_lightning/core/lightning.py
Expand Up @@ -38,20 +38,15 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import (
_IS_WINDOWS,
_TORCH_GREATER_EQUAL_1_10,
GradClipAlgorithmType,
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import distributed_available, rank_zero_debug, sync_ddp
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import get_model_size_mb
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.parsing import collect_init_args
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/optimizer.py
Expand Up @@ -21,9 +21,9 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import _Stateful, LRSchedulerConfig, LRSchedulerTypeTuple, ReduceLROnPlateau


Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/saving.py
Expand Up @@ -26,12 +26,13 @@
import torch
import yaml

from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict, rank_zero_warn
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.parsing import parse_class_init_keys
from pytorch_lightning.utilities.rank_zero import rank_zero_warn

log = logging.getLogger(__name__)
PRIMITIVE_TYPES = (bool, int, float, str)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/loggers/base.py
Expand Up @@ -26,8 +26,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only


def rank_zero_experiment(fn: Callable) -> Callable:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loggers/comet.py
Expand Up @@ -26,9 +26,10 @@

import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _module_available
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from pytorch_lightning.utilities.rank_zero import rank_zero_only

log = logging.getLogger(__name__)
_COMET_AVAILABLE = _module_available("comet_ml")
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/loggers/csv_logs.py
Expand Up @@ -28,9 +28,8 @@

from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn

log = logging.getLogger(__name__)

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loggers/mlflow.py
Expand Up @@ -23,8 +23,9 @@
from typing import Any, Dict, Optional, Union

from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.imports import _module_available
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn

log = logging.getLogger(__name__)
LOCAL_FILE_URI_PREFIX = "file:"
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loggers/neptune.py
Expand Up @@ -32,10 +32,10 @@
from pytorch_lightning import __version__
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _NEPTUNE_AVAILABLE and _NEPTUNE_GREATER_EQUAL_0_9:
try:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loggers/tensorboard.py
Expand Up @@ -29,10 +29,11 @@
import pytorch_lightning as pl
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from pytorch_lightning.utilities.logger import _sanitize_params as _utils_sanitize_params
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn

log = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/test_tube.py
Expand Up @@ -20,9 +20,9 @@

import pytorch_lightning as pl
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import _module_available
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn

_TESTTUBE_AVAILABLE = _module_available("test_tube")

Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/loggers/wandb.py
Expand Up @@ -26,11 +26,10 @@

from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version
from pytorch_lightning.utilities.imports import _compare_version, _module_available
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
from pytorch_lightning.utilities.warnings import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn

_WANDB_AVAILABLE = _module_available("wandb")
_WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22")
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Expand Up @@ -23,14 +23,14 @@
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache
from pytorch_lightning.utilities.warnings import WarningCache

_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/loops/fit_loop.py
Expand Up @@ -22,11 +22,10 @@
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/overrides/data_parallel.py
Expand Up @@ -19,8 +19,8 @@

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.rank_zero import rank_zero_warn


def _ignore_scalar_return_in_dp() -> None:
Expand Down
Expand Up @@ -16,7 +16,7 @@
import socket

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.rank_zero import rank_zero_only


class LightningEnvironment(ClusterEnvironment):
Expand Down
Expand Up @@ -16,7 +16,7 @@
import os

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/io/torch_plugin.py
Expand Up @@ -17,9 +17,9 @@

import pytorch_lightning as pl
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import _PATH

log = logging.getLogger(__name__)
Expand Down