diff --git a/.drone.jsonnet b/.drone.jsonnet new file mode 100644 index 00000000000000..f156881d751502 --- /dev/null +++ b/.drone.jsonnet @@ -0,0 +1,63 @@ +/* +Copyright The PyTorch Lightning team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// https://github.com/drone/drone-jsonnet-config/blob/master/.drone.jsonnet + +local pipeline(name, image) = { + kind: "pipeline", + type: "docker", + name: name, + steps: [ + { + name: "testing", + image: image, + environment: { + "CODECOV_TOKEN": { + from_secret: "codecov_token" + }, + "MKL_THREADING_LAYER": "GNU", + }, + commands: [ + "python --version", + "pip --version", + "nvidia-smi", + "pip install -r ./requirements/devel.txt --upgrade-strategy only-if-needed -v --no-cache-dir", + "pip list", + "coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v -ra --color=yes --durations=25", + "python -m pytest benchmarks pl_examples -v -ra --color=yes --maxfail=2 --durations=0", + "coverage report", + "codecov --token $CODECOV_TOKEN --flags=gpu,pytest --name='GPU-coverage' --env=linux --build $DRONE_BUILD_NUMBER --commit $DRONE_COMMIT", + "python tests/collect_env_details.py" + ], + }, + ], + trigger: { + branch: [ + "master", + "release/*" + ], + event: [ + "push", + "pull_request" + ] + }, + depends_on: if name == "torch-GPU-nightly" then ["torch-GPU"] +}; + +[ + pipeline("torch-GPU", "pytorchlightning/pytorch_lightning:base-cuda-py3.7-torch1.6"), + pipeline("torch-GPU-nightly", "pytorchlightning/pytorch_lightning:base-cuda-py3.7-torch1.7"), +] diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index e1a36d3f576d51..099f5702929880 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -5,23 +5,48 @@ # the repo. Unless a later match takes precedence, # @global-owner1 and @global-owner2 will be requested for # review when someone opens a pull request. -* @williamfalcon @borda @teddykoker @awaelchli @nateraw @justusschock @tchaton @SeanNaren @ananyahjha93 +* @williamfalcon @borda @tchaton @SeanNaren @awaelchli @justusschock # Metrics -/pytorch_lightning/metrics/* @teddykoker @ananyahjha93 @justusschock -/tests/metrics/* @teddykoker @ananyahjha93 @justusschock +/pytorch_lightning/metrics/ @teddykoker @ananyahjha93 @justusschock +/tests/metrics/ @teddykoker @ananyahjha93 @justusschock /docs/source/metrics.rst @teddykoker @ananyahjha93 @justusschock # API -/pytorch_lightning/callbacks/base.py @williamfalcon -/pytorch_lightning/core/datamodule.py @williamfalcon -/pytorch_lightning/trainer/trainer.py @williamfalcon -/pytorch_lightning/core/hooks.py @williamfalcon -/pytorch_lightning/core/lightning.py @williamfalcon +/pytorch_lightning/callbacks/base.py @williamfalcon +/pytorch_lightning/core/datamodule.py @williamfalcon +/pytorch_lightning/trainer/trainer.py @williamfalcon @tchaton +/pytorch_lightning/core/hooks.py @williamfalcon +/pytorch_lightning/core/lightning.py @williamfalcon @tchaton +/pytorch_lightning/core/optimizer.py @tchaton +/pytorch_lightning/trainer/training_loop.py @tchaton @SeanNaren +/pytorch_lightning/trainer/evaluation_loop.py @tchaton @SeanNaren +# Connectors +/pytorch_lightning/trainer/connectors/ @tchaton @SeanNaren # accelerators -/pytorch_lightning/accelerators/* @williamfalcon +/pytorch_lightning/accelerators/ @williamfalcon @tchaton @SeanNaren @awaelchli @justusschock # owners -/pytorch_lightning/.github/CODEOWNERS @williamfalcon +/.github/CODEOWNERS @williamfalcon +# main +/README.md @williamfalcon @edenlightning +# installation +/setup.py @borda @williamfalcon + +# CI/CD +/.github/workflows/ @borda @tchaton +/.github/*.py @borda @tchaton +/dockers/ @borda @tchaton +# configs in root +/*.yml @borda @tchaton + +# Docs +/docs/ @edenlightning @tchaton @borda @awaelchli +/.github/*.md @edenlightning @williamfalcon @borda +/.github/ISSUE_TEMPLATE/*.md @edenlightning @borda @tchaton +/docs/source/conf.py @borda @awaelchli + +# Testing +/tests/base/boring_model.py @williamfalcon diff --git a/CHANGELOG.md b/CHANGELOG.md index d950b24ebd14c3..4ba46ebdc85203 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,8 +4,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## Unreleased -## [1.1.0rc1] - 2020-12-02 +### Fixed + +- Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138)) + + +## [1.1.0rc] - 2020-12-02 ### Added @@ -50,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added optimizer refactors ([#4658](https://github.com/PyTorchLightning/pytorch-lightning/pull/4658)) +- Added `PrecisionRecallCurve, ROC, AveragePrecision` class metric ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549)) + + - Added custom `Apex` and `NativeAMP` as `Precision plugins` ([#4355](https://github.com/PyTorchLightning/pytorch-lightning/pull/4355)) @@ -72,9 +81,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Removed `multiclass_roc` and `multiclass_precision_recall_curve`, use `roc` and `precision_recall_curve` instead ([#4549](https://github.com/PyTorchLightning/pytorch-lightning/pull/4549)) + + + - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) + - WandbLogger does not force wandb `reinit` arg to True anymore and creates a run only when needed ([#4648](https://github.com/PyTorchLightning/pytorch-lightning/pull/4648)) @@ -89,6 +103,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `prefix` argument in `ModelCheckpoint` ([#4765](https://github.com/PyTorchLightning/pytorch-lightning/pull/4765)) +- Deprecated the old way of assigning hyper-parameters through `self.hparams = ...` ([#4813](https://github.com/PyTorchLightning/pytorch-lightning/pull/4813)) + + +- Deprecated `mode='auto'` from `ModelCheckpoint` and `EarlyStopping` ([#4695](https://github.com/PyTorchLightning/pytorch-lightning/pull/4695)) + + ### Removed @@ -97,6 +117,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added feature to move tensors to CPU before saving ([#4309](https://github.com/PyTorchLightning/pytorch-lightning/pull/4309)) +- Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138)) + + +- Auto convert tensors to contiguous format when `gather_all` ([#4907](https://github.com/PyTorchLightning/pytorch-lightning/pull/4907)) ## [1.0.8] - 2020-11-24 diff --git a/MANIFEST.in b/MANIFEST.in index 926827faa89254..c88d81243d9df9 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -54,6 +54,7 @@ include pyproject.toml # Exclude build configs exclude *.yml exclude *.yaml +exclude *.jsonnet # Exclude pyright config exclude .pyrightconfig.json diff --git a/docs/source/_images/lightning_icon.svg b/docs/source/_images/lightning_icon.svg new file mode 100644 index 00000000000000..c2213e4f9e0b7e --- /dev/null +++ b/docs/source/_images/lightning_icon.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/docs/source/hyperparameters.rst b/docs/source/hyperparameters.rst index 7e5f349ba0ca8f..fb27636b9bc1b1 100644 --- a/docs/source/hyperparameters.rst +++ b/docs/source/hyperparameters.rst @@ -167,8 +167,8 @@ improve readability and reproducibility. def train_dataloader(self): return DataLoader(mnist_train, batch_size=self.hparams.batch_size) - .. warning:: Deprecated. This method of assigning hyperparameters to the LightningModule is no longer - recommended and will not be supported in future versions of Lightning. + .. warning:: Deprecated since v1.1.0. This method of assigning hyperparameters to the LightningModule + will no longer be supported from v1.3.0. Use the ``self.save_hyperparameters()`` method from above instead. 4. You can also save full objects such as `dict` or `Namespace` to the checkpoint. diff --git a/docs/source/logging.rst b/docs/source/logging.rst index 906240ce6e2efc..79452b0ca87884 100644 --- a/docs/source/logging.rst +++ b/docs/source/logging.rst @@ -6,7 +6,7 @@ .. role:: hidden :class: hidden-section - + .. _logging: @@ -57,9 +57,11 @@ Logging from a LightningModule Lightning offers automatic log functionalities for logging scalars, or manual logging for anything else. -Automatic logging +Automatic Logging ================= -Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method to log from anywhere in a :ref:`lightning_module`. +Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` +method to log from anywhere in a :ref:`lightning_module` and :ref:`callbacks` +except functions with `batch_start` in their names. .. code-block:: python @@ -95,6 +97,9 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a argument of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` or in the graphs plotted to the logger of your choice. +If your work requires to log in an unsupported function, please open an issue with a clear description of why it is blocking you. + + Manual logging ============== If you want to log anything that is not a scalar, like histograms, text, images, etc... you may need to use the logger object directly. @@ -144,8 +149,8 @@ Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:` def experiment(self): # Return the experiment object associated with this logger. pass - - @property + + @property def version(self): # Return the experiment version, int or str. return '0.1' @@ -238,7 +243,7 @@ if you are using a logger. These defaults can be customized by overriding the :func:`~pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict` hook in your module. .. code-block:: python - + def get_progress_bar_dict(self): # don't show the version number items = super().get_progress_bar_dict() diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index d47c872f350474..6cc5a9387a8efd 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -239,6 +239,24 @@ ConfusionMatrix .. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix :noindex: +PrecisionRecallCurve +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve + :noindex: + +AveragePrecision +~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision + :noindex: + +ROC +~~~ + +.. autoclass:: pytorch_lightning.metrics.classification.ROC + :noindex: + Regression Metrics ------------------ @@ -326,7 +344,7 @@ multiclass_auroc [func] average_precision [func] ~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.average_precision +.. autofunction:: pytorch_lightning.metrics.functional.average_precision :noindex: @@ -365,10 +383,10 @@ iou [func] :noindex: -multiclass_roc [func] +roc [func] ~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.multiclass_roc +.. autofunction:: pytorch_lightning.metrics.functional.roc :noindex: @@ -389,7 +407,7 @@ precision_recall [func] precision_recall_curve [func] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall_curve +.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve :noindex: @@ -400,13 +418,6 @@ recall [func] :noindex: -roc [func] -~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.metrics.functional.classification.roc - :noindex: - - stat_scores [func] ~~~~~~~~~~~~~~~~~~ @@ -424,14 +435,14 @@ stat_scores_multiple_classes [func] to_categorical [func] ~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.to_categorical +.. autofunction:: pytorch_lightning.metrics.utils.to_categorical :noindex: to_onehot [func] ~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.to_onehot +.. autofunction:: pytorch_lightning.metrics.utils.to_onehot :noindex: diff --git a/docs/source/weights_loading.rst b/docs/source/weights_loading.rst index cddee7e1feb3fe..5dc80b51e5b898 100644 --- a/docs/source/weights_loading.rst +++ b/docs/source/weights_loading.rst @@ -14,6 +14,7 @@ Lightning automates saving and loading checkpoints. Checkpoints capture the exac Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model. + ***************** Checkpoint saving ***************** @@ -68,7 +69,7 @@ You can customize the checkpointing behavior to monitor any quantity of your tra # 4. Add your callback to the callbacks list trainer = Trainer(callbacks=[checkpoint_callback]) -You can also control more advanced options, like `save_top_k`, to save the best k models and the mode of the monitored quantity (min/max/auto, where the mode is automatically inferred from the name of the monitored quantity), `save_weights_only` or `period` to set the interval of epochs between checkpoints, to avoid slowdowns. +You can also control more advanced options, like `save_top_k`, to save the best k models and the `mode` of the monitored quantity (min/max), `save_weights_only` or `period` to set the interval of epochs between checkpoints, to avoid slowdowns. .. code-block:: python @@ -84,10 +85,11 @@ You can also control more advanced options, like `save_top_k`, to save the best # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt checkpoint_callback = ModelCheckpoint( monitor='val_loss', - dirpath='my/path/, + dirpath='my/path/', filename='sample-mnist-{epoch:02d}-{val_loss:.2f}', save_top_k=3, - mode='min') + mode='min', + ) trainer = Trainer(callbacks=[checkpoint_callback]) @@ -137,6 +139,23 @@ You can manually save checkpoints and restore your model from the checkpointed s trainer.save_checkpoint("example.ckpt") new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt") +Manual saving with accelerators +=============================== + +Lightning also handles accelerators where multiple processes are running, such as DDP. For example, when using the DDP accelerator our training script is running across multiple devices at the same time. +Lightning automatically ensures that the model is saved only on the main process, whilst other processes do not interfere with saving checkpoints. This requires no code changes as seen below. + +.. code-block:: python + + trainer = Trainer(accelerator="ddp") + model = MyLightningModule(hparams) + trainer.fit(model) + # Saves only on the main process + trainer.save_checkpoint("example.ckpt") + +Not using `trainer.save_checkpoint` can lead to unexpected behaviour and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the trainer's save functionality. +If using custom saving functions cannot be avoided, we recommend using :func:`~pytorch_lightning.loggers.base.rank_zero_only` to ensure saving occurs only on the main process. + ****************** Checkpoint loading ****************** diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index 371b62fbb9dcaf..a341728554d31a 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -96,7 +96,9 @@ def cli_main(): # ------------ # testing # ------------ - result = trainer.test(datamodule=dm) + # todo: without passing model it fails for missing best weights + # MisconfigurationException, 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' + result = trainer.test(model, datamodule=dm) pprint(result) diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index f43483984e6004..ba4b292fb4af0a 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -19,7 +19,6 @@ """ import os from argparse import ArgumentParser, Namespace -from collections import OrderedDict import torch import torch.nn.functional as F @@ -37,7 +36,6 @@ class ImageNetLightningModel(LightningModule): - # pull out resnet names from torchvision models MODEL_NAMES = sorted( name for name in models.__dict__ @@ -45,16 +43,16 @@ class ImageNetLightningModel(LightningModule): ) def __init__( - self, - arch: str, - pretrained: bool, - lr: float, - momentum: float, - weight_decay: int, - data_path: str, - batch_size: int, - workers: int, - **kwargs, + self, + arch: str, + pretrained: bool, + lr: float, + momentum: float, + weight_decay: int, + data_path: str, + batch_size: int, + workers: int, + **kwargs, ): super().__init__() self.save_hyperparameters() @@ -74,39 +72,21 @@ def forward(self, x): def training_step(self, batch, batch_idx): images, target = batch output = self(images) - loss_val = F.cross_entropy(output, target) + loss_train = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) - - tqdm_dict = {'train_loss': loss_val} - output = OrderedDict({ - 'loss': loss_val, - 'acc1': acc1, - 'acc5': acc5, - 'progress_bar': tqdm_dict, - 'log': tqdm_dict - }) - return output + self.log('train_loss', loss_train, on_step=True, on_epoch=True, logger=True) + self.log('train_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True, logger=True) + self.log('train_acc5', acc5, on_step=True, on_epoch=True, logger=True) + return loss_train def validation_step(self, batch, batch_idx): images, target = batch output = self(images) loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) - - output = OrderedDict({ - 'val_loss': loss_val, - 'val_acc1': acc1, - 'val_acc5': acc5, - }) - return output - - def validation_epoch_end(self, outputs): - tqdm_dict = {} - for metric_name in ["val_loss", "val_acc1", "val_acc5"]: - tqdm_dict[metric_name] = torch.stack([output[metric_name] for output in outputs]).mean() - - result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': tqdm_dict["val_loss"]} - return result + self.log('val_loss', loss_val, on_step=True, on_epoch=True) + self.log('val_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True) + self.log('val_acc5', acc5, on_step=True, on_epoch=True) @staticmethod def __accuracy(output, target, topk=(1,)): @@ -121,7 +101,7 @@ def __accuracy(output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index 502f60942029be..c0599dc74c5a95 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -1,19 +1,22 @@ """ Deep Reinforcement Learning: Deep Q-network (DQN) -This example is based on https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On- -Second-Edition/blob/master/Chapter06/02_dqn_pong.py - The template illustrates using Lightning for Reinforcement Learning. The example builds a basic DQN using the classic CartPole environment. To run the template, just run: -python reinforce_learn_Qnet.py +`python reinforce_learn_Qnet.py` + +After ~1500 steps, you will see the total_reward hitting the max score of 200. +Open up TensorBoard to see the metrics: -After ~1500 steps, you will see the total_reward hitting the max score of 200. Open up TensorBoard to -see the metrics: +`tensorboard --logdir default` -tensorboard --logdir default +References +---------- + +[1] https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On- +Second-Edition/blob/master/Chapter06/02_dqn_pong.py """ import argparse diff --git a/pl_examples/test_examples.py b/pl_examples/test_examples.py index 3b9613a1320303..da213841901637 100644 --- a/pl_examples/test_examples.py +++ b/pl_examples/test_examples.py @@ -8,6 +8,7 @@ from pl_examples import DALI_AVAILABLE ARGS_DEFAULT = """ +--default_root_dir %(tmpdir)s \ --max_epochs 1 \ --batch_size 32 \ --limit_train_batches 2 \ @@ -18,33 +19,41 @@ --gpus 1 \ """ -ARGS_DP_AMP = ARGS_DEFAULT + """ +ARGS_DP = ARGS_DEFAULT + """ --gpus 2 \ ---distributed_backend dp \ +--accelerator dp \ +""" + +ARGS_DP_AMP = ARGS_DP + """ --precision 16 \ """ -ARGS_DDP_AMP = ARGS_DEFAULT + """ +ARGS_DDP = ARGS_DEFAULT + """ --gpus 2 \ ---distributed_backend ddp \ +--accelerator ddp \ --precision 16 \ """ +ARGS_DDP_AMP = ARGS_DEFAULT + """ +--precision 16 \ +""" -# ToDo: fix this failing example -# @pytest.mark.parametrize('import_cli', [ -# 'pl_examples.basic_examples.simple_image_classifier', -# 'pl_examples.basic_examples.backbone_image_classifier', -# 'pl_examples.basic_examples.autoencoder', -# ]) -# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -# @pytest.mark.parametrize('cli_args', [ARGS_DP_AMP]) -# def test_examples_dp(import_cli, cli_args): -# -# module = importlib.import_module(import_cli) -# -# with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): -# module.cli_main() + +@pytest.mark.parametrize('import_cli', [ + 'pl_examples.basic_examples.simple_image_classifier', + 'pl_examples.basic_examples.backbone_image_classifier', + 'pl_examples.basic_examples.autoencoder', +]) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.parametrize('cli_args', [ARGS_DP, ARGS_DP_AMP]) +def test_examples_dp(tmpdir, import_cli, cli_args): + + module = importlib.import_module(import_cli) + # update the temp dir + cli_args = cli_args % {'tmpdir': tmpdir} + + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): + module.cli_main() # ToDo: fix this failing example @@ -54,10 +63,12 @@ # 'pl_examples.basic_examples.autoencoder', # ]) # @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -# @pytest.mark.parametrize('cli_args', [ARGS_DDP_AMP]) -# def test_examples_ddp(import_cli, cli_args): +# @pytest.mark.parametrize('cli_args', [ARGS_DDP, ARGS_DDP_AMP]) +# def test_examples_ddp(tmpdir, import_cli, cli_args): # # module = importlib.import_module(import_cli) +# # update the temp dir +# cli_args = cli_args % {'tmpdir': tmpdir} # # with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): # module.cli_main() @@ -69,9 +80,11 @@ 'pl_examples.basic_examples.autoencoder', ]) @pytest.mark.parametrize('cli_args', [ARGS_DEFAULT]) -def test_examples_cpu(import_cli, cli_args): +def test_examples_cpu(tmpdir, import_cli, cli_args): module = importlib.import_module(import_cli) + # update the temp dir + cli_args = cli_args % {'tmpdir': tmpdir} with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): module.cli_main() @@ -81,8 +94,10 @@ def test_examples_cpu(import_cli, cli_args): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.skipif(platform.system() != 'Linux', reason='Only applies to Linux platform.') @pytest.mark.parametrize('cli_args', [ARGS_GPU]) -def test_examples_mnist_dali(cli_args): +def test_examples_mnist_dali(tmpdir, cli_args): from pl_examples.basic_examples.dali_image_classifier import cli_main + # update the temp dir + cli_args = cli_args % {'tmpdir': tmpdir} with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()): cli_main() diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index a22a8fb3702eec..9d36f76876a08d 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -15,21 +15,18 @@ import torch +from pytorch_lightning.utilities import HOROVOD_AVAILABLE from pytorch_lightning import _logger as log from pytorch_lightning import accelerators from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment -from pytorch_lightning.utilities import XLA_AVAILABLE, device_parser, rank_zero_only, TPU_AVAILABLE +from pytorch_lightning.utilities import device_parser, rank_zero_only, TPU_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -try: +if HOROVOD_AVAILABLE: import horovod.torch as hvd -except (ModuleNotFoundError, ImportError): - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True class AcceleratorConnector: diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index b2cec906178f92..460f5a83d25827 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -18,15 +18,11 @@ from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp -from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import rank_zero_only -try: +if HOROVOD_AVAILABLE: import horovod.torch as hvd -except (ModuleNotFoundError, ImportError): - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True class HorovodAccelerator(Accelerator): diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 005a3f8cde4adc..88f1881643c9aa 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -16,7 +16,7 @@ Early Stopping ^^^^^^^^^^^^^^ -Monitor a validation metric and stop training when it stops improving. +Monitor a metric and stop training when it stops improving. """ import os @@ -26,14 +26,12 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE - -torch_inf = torch.tensor(np.Inf) +from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, TPU_AVAILABLE class EarlyStopping(Callback): r""" - Monitor a validation metric and stop training when it stops improving. + Monitor a metric and stop training when it stops improving. Args: monitor: quantity to be monitored. Default: ``'early_stop_on'``. @@ -50,7 +48,11 @@ class EarlyStopping(Callback): mode it will stop when the quantity monitored has stopped increasing; in `auto` mode, the direction is automatically inferred - from the name of the monitored quantity. Default: ``'auto'``. + from the name of the monitored quantity. + + .. warning:: + Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3. + strict: whether to crash the training if `monitor` is not found in the validation metrics. Default: ``True``. @@ -66,8 +68,15 @@ class EarlyStopping(Callback): 'max': torch.gt, } - def __init__(self, monitor: str = 'early_stop_on', min_delta: float = 0.0, patience: int = 3, - verbose: bool = False, mode: str = 'auto', strict: bool = True): + def __init__( + self, + monitor: str = 'early_stop_on', + min_delta: float = 0.0, + patience: int = 3, + verbose: bool = False, + mode: str = 'auto', + strict: bool = True, + ): super().__init__() self.monitor = monitor self.patience = patience @@ -82,21 +91,36 @@ def __init__(self, monitor: str = 'early_stop_on', min_delta: float = 0.0, patie # It is set to False initially and overwritten, if eval results have been validated self.based_on_eval_results = False - if mode not in self.mode_dict and mode != 'auto': + self.__init_monitor_mode() + + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + torch_inf = torch.tensor(np.Inf) + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + + def __init_monitor_mode(self): + # TODO: Update with MisconfigurationException when auto mode is removed in v1.3 + if self.mode not in self.mode_dict and self.mode != 'auto': if self.verbose > 0: - log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.') + rank_zero_warn( + f'EarlyStopping mode={self.mode} is unknown, fallback to auto mode.', + RuntimeWarning, + ) self.mode = 'auto' if self.mode == 'auto': - if self.monitor == 'acc': + rank_zero_warn( + "mode='auto' is deprecated in v1.1 and will be removed in v1.3." + " Default value for mode with be 'min' in v1.3.", + DeprecationWarning + ) + + if "acc" in self.monitor or self.monitor.startswith("fmeasure"): self.mode = 'max' else: self.mode = 'min' - if self.verbose > 0: - log.info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.') - self.min_delta *= 1 if self.monitor_op == torch.gt else -1 - self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + if self.verbose > 0: + rank_zero_info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.') def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d41928cd55aea6..eb669736ada3a1 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -22,16 +22,17 @@ import os import re -import yaml from copy import deepcopy -from typing import Any, Dict, Optional, Union from pathlib import Path +from typing import Any, Dict, Optional, Union import numpy as np import torch +import yaml + from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, rank_zero_info +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -71,6 +72,10 @@ class ModelCheckpoint(Callback): this should be `max`, for `val_loss` this should be `min`, etc. In `auto` mode, the direction is automatically inferred from the name of the monitored quantity. + + .. warning:: + Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3. + save_weights_only: if ``True``, then only the model's weights will be saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). @@ -85,7 +90,7 @@ class ModelCheckpoint(Callback): Example:: # custom path - # saves a file like: my/path/epoch=0.ckpt + # saves a file like: my/path/epoch=0-step=10.ckpt >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') By default, dirpath is ``None`` and will be set at runtime to the location @@ -135,6 +140,7 @@ class ModelCheckpoint(Callback): CHECKPOINT_JOIN_CHAR = "-" CHECKPOINT_NAME_LAST = "last" + FILE_EXTENSION = ".ckpt" def __init__( self, @@ -311,18 +317,29 @@ def __init_monitor_mode(self, monitor, mode): mode_dict = { "min": (torch_inf, "min"), "max": (-torch_inf, "max"), - "auto": (-torch_inf, "max") - if monitor is not None and ("acc" in monitor or monitor.startswith("fmeasure")) - else (torch_inf, "min"), } - if mode not in mode_dict: + # TODO: Update with MisconfigurationException when auto mode is removed in v1.3 + if mode not in mode_dict and mode != 'auto': rank_zero_warn( f"ModelCheckpoint mode {mode} is unknown, fallback to auto mode", RuntimeWarning, ) mode = "auto" + if mode == 'auto': + rank_zero_warn( + "mode='auto' is deprecated in v1.1 and will be removed in v1.3." + " Default value for mode with be 'min' in v1.3.", + DeprecationWarning + ) + + mode_dict['auto'] = ( + (-torch_inf, "max") + if monitor is not None and ("acc" in monitor or monitor.startswith("fmeasure")) + else (torch_inf, "min") + ) + self.kth_value, self.mode = mode_dict[mode] @rank_zero_only @@ -426,7 +443,7 @@ def format_checkpoint_name( ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) - ckpt_name = f"{filename}.ckpt" + ckpt_name = f"{filename}{self.FILE_EXTENSION}" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name def __resolve_ckpt_dir(self, trainer, pl_module): @@ -529,7 +546,7 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath) ckpt_name_metrics, prefix=self.prefix ) - last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt") + last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}") self._save_model(last_filepath, trainer, pl_module) if ( diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 78fc740e389aa1..c33297934eed16 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1651,6 +1651,13 @@ def hparams_initial(self) -> AttributeDict: @hparams.setter def hparams(self, hp: Union[dict, Namespace, Any]): + # TODO: remove this method in v1.3.0. + rank_zero_warn( + "The setter for self.hparams in LightningModule is deprecated since v1.1.0 and will be" + " removed in v1.3.0. Replace the assignment `self.hparams = hparams` with " + " `self.save_hyperparameters()`.", + DeprecationWarning + ) hparams_assignment_name = self.__get_hparams_assignment_variable() self._hparams_name = hparams_assignment_name self._set_hparams(hp) @@ -1670,7 +1677,7 @@ def __get_hparams_assignment_variable(self): line = re.sub(r"\s+", "", line, flags=re.UNICODE) if ".hparams=" in line: return line.split("=")[1] - except Exception as e: + except Exception: return "hparams" return None diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index a397fd7bcb8160..142fe9048cb0ea 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -400,11 +400,15 @@ def detach(self): if isinstance(v, torch.Tensor): self.__setitem__(k, v.detach()) - def cpu(self): - """Move all self attributes to CPU.""" + def to(self, *args, **kwargs): + """Move all self attributes to the given device.""" for k, v in self.items(): if isinstance(v, torch.Tensor): - self.__setitem__(k, v.cpu()) + self.__setitem__(k, v.to(*args, **kwargs)) + + def cpu(self): + """Move all self attributes to CPU.""" + self.to(torch.device("cpu")) def __repr__(self): self_copy = self.copy() diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index fc40db4e69b16d..a27998366b6712 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -409,6 +409,11 @@ def nop(*args, **kw): def __getattr__(self, _): return self.nop + def __getitem__(self, idx): + # enables self.logger[0].experiment.add_image + # and self.logger.experiment[0].add_image(...) + return self + class DummyLogger(LightningLoggerBase): """ Dummy logger for internal use. Is usefull if we want to disable users @@ -437,6 +442,9 @@ def name(self): def version(self): pass + def __getitem__(self, idx): + return self + def merge_dicts( dicts: Sequence[Mapping], diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 653ad23c68f7e4..59f3a9cec06cf9 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -17,9 +17,12 @@ Accuracy, Precision, Recall, + ConfusionMatrix, + PrecisionRecallCurve, + AveragePrecision, + ROC, FBeta, F1, - ConfusionMatrix ) from pytorch_lightning.metrics.regression import ( diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index db643c227abedc..13cb705f30b179 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.metrics.classification.accuracy import Accuracy -from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall -from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 +from pytorch_lightning.metrics.classification.average_precision import AveragePrecision from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix +from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 +from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall +from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve +from pytorch_lightning.metrics.classification.roc import ROC diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 7ba695d71899b5..330691a379574e 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -11,15 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -import functools -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Union -from collections.abc import Mapping, Sequence -from collections import namedtuple +from typing import Any, Callable, Optional import torch -from torch import nn + from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.utils import _input_format_classification diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py new file mode 100644 index 00000000000000..0a8a952470dbc8 --- /dev/null +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -0,0 +1,130 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Any, Union, List + +import torch + +from pytorch_lightning.metrics import Metric +from pytorch_lightning.metrics.functional.average_precision import ( + _average_precision_update, + _average_precision_compute +) +from pytorch_lightning.utilities import rank_zero_warn + + +class AveragePrecision(Metric): + """ + Computes the average precision score, which summarises the precision recall + curve into one number. Works for both binary and multiclass problems. + In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) + where C is the number of classes + + - ``target`` (long tensor): ``(N, ...)`` + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example (binary case): + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> average_precision = AveragePrecision(pos_label=1) + >>> average_precision(pred, target) + tensor(1.) + + Example (multiclass case): + + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> average_precision = AveragePrecision(num_classes=5) + >>> average_precision(pred, target) + [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] + + """ + def __init__( + self, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + + self.num_classes = num_classes + self.pos_label = pos_label + + self.add_state("preds", default=[], dist_reduce_fx=None) + self.add_state("target", default=[], dist_reduce_fx=None) + + rank_zero_warn( + 'Metric `AveragePrecision` will save all targets and' + ' predictions in buffer. For large datasets this may lead' + ' to large memory footprint.' + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target, num_classes, pos_label = _average_precision_update( + preds, + target, + self.num_classes, + self.pos_label + ) + self.preds.append(preds) + self.target.append(target) + self.num_classes = num_classes + self.pos_label = pos_label + + def compute(self) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Compute the average precision score + + Returns: + tensor with average precision. If multiclass will return list + of such tensors, one for each class + + """ + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) + return _average_precision_compute(preds, target, self.num_classes, self.pos_label) diff --git a/pytorch_lightning/metrics/classification/confusion_matrix.py b/pytorch_lightning/metrics/classification/confusion_matrix.py index e62fd37880de0a..b9b0c20e9e30eb 100644 --- a/pytorch_lightning/metrics/classification/confusion_matrix.py +++ b/pytorch_lightning/metrics/classification/confusion_matrix.py @@ -15,11 +15,11 @@ import torch -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.functional.confusion_matrix import ( _confusion_matrix_update, _confusion_matrix_compute ) +from pytorch_lightning.metrics.metric import Metric class ConfusionMatrix(Metric): diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index 92cc22228befa5..56cc00f9a5dce2 100755 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -15,11 +15,11 @@ import torch -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.functional.f_beta import ( _fbeta_update, _fbeta_compute ) +from pytorch_lightning.metrics.metric import Metric class FBeta(Metric): diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 9ab05891852d67..7e1f843b9c331c 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -11,17 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -import functools -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Union -from collections.abc import Mapping, Sequence -from collections import namedtuple +from typing import Any, Optional import torch -from torch import nn + from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.metrics.utils import to_onehot, METRIC_EPS, _input_format_classification_one_hot +from pytorch_lightning.metrics.utils import METRIC_EPS, _input_format_classification_one_hot class Precision(Metric): diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py new file mode 100644 index 00000000000000..052a25a7a977de --- /dev/null +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -0,0 +1,150 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Any, Union, Tuple, List + +import torch + +from pytorch_lightning.metrics import Metric +from pytorch_lightning.metrics.functional.precision_recall_curve import ( + _precision_recall_curve_update, + _precision_recall_curve_compute +) +from pytorch_lightning.utilities import rank_zero_warn + + +class PrecisionRecallCurve(Metric): + """ + Computes precision-recall pairs for different thresholds. Works for both + binary and multiclass problems. In the case of multiclass, the values will + be calculated based on a one-vs-the-rest approach. + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) + where C is the number of classes + + - ``target`` (long tensor): ``(N, ...)`` + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example (binary case): + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> pr_curve = PrecisionRecallCurve(pos_label=1) + >>> precision, recall, thresholds = pr_curve(pred, target) + >>> precision + tensor([0.6667, 0.5000, 0.0000, 1.0000]) + >>> recall + tensor([1.0000, 0.5000, 0.0000, 0.0000]) + >>> thresholds + tensor([1, 2, 3]) + + Example (multiclass case): + + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> pr_curve = PrecisionRecallCurve(num_classes=5) + >>> precision, recall, thresholds = pr_curve(pred, target) + >>> precision + [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] + >>> recall + [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] + + """ + def __init__( + self, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + + self.num_classes = num_classes + self.pos_label = pos_label + + self.add_state("preds", default=[], dist_reduce_fx=None) + self.add_state("target", default=[], dist_reduce_fx=None) + + rank_zero_warn( + 'Metric `PrecisionRecallCurve` will save all targets and' + ' predictions in buffer. For large datasets this may lead' + ' to large memory footprint.' + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target, num_classes, pos_label = _precision_recall_curve_update( + preds, + target, + self.num_classes, + self.pos_label + ) + self.preds.append(preds) + self.target.append(target) + self.num_classes = num_classes + self.pos_label = pos_label + + def compute(self) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + """ + Compute the precision-recall curve + + Returns: 3-element tuple containing + + precision: + tensor where element i is the precision of predictions with + score >= thresholds[i] and the last element is 1. + If multiclass, this is a list of such tensors, one for each class. + recall: + tensor where element i is the recall of predictions with + score >= thresholds[i] and the last element is 0. + If multiclass, this is a list of such tensors, one for each class. + thresholds: + Thresholds used for computing precision/recall scores + + """ + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) + return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label) diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py new file mode 100644 index 00000000000000..89e8265b19fc14 --- /dev/null +++ b/pytorch_lightning/metrics/classification/roc.py @@ -0,0 +1,151 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Any, Union, List, Tuple + +import torch + +from pytorch_lightning.metrics import Metric +from pytorch_lightning.metrics.functional.roc import ( + _roc_update, + _roc_compute +) +from pytorch_lightning.utilities import rank_zero_warn + + +class ROC(Metric): + """ + Computes the Receiver Operating Characteristic (ROC). Works for both + binary and multiclass problems. In the case of multiclass, the values will + be calculated based on a one-vs-the-rest approach. + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) + where C is the number of classes + + - ``target`` (long tensor): ``(N, ...)`` + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example (binary case): + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> roc = ROC(pos_label=1) + >>> fpr, tpr, thresholds = roc(pred, target) + >>> fpr + tensor([0., 0., 0., 0., 1.]) + >>> tpr + tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) + >>> thresholds + tensor([4, 3, 2, 1, 0]) + + Example (multiclass case): + + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05], + ... [0.05, 0.05, 0.05, 0.75]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> roc = ROC(num_classes=4) + >>> fpr, tpr, thresholds = roc(pred, target) + >>> fpr + [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] + >>> tpr + [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500])] + + """ + def __init__( + self, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + + self.num_classes = num_classes + self.pos_label = pos_label + + self.add_state("preds", default=[], dist_reduce_fx=None) + self.add_state("target", default=[], dist_reduce_fx=None) + + rank_zero_warn( + 'Metric `ROC` will save all targets and' + ' predictions in buffer. For large datasets this may lead' + ' to large memory footprint.' + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target, num_classes, pos_label = _roc_update( + preds, + target, + self.num_classes, + self.pos_label + ) + self.preds.append(preds) + self.target.append(target) + self.num_classes = num_classes + self.pos_label = pos_label + + def compute(self) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + """ + Compute the receiver operating characteristic + + Returns: 3-element tuple containing + + fpr: + tensor with false positive rates. + If multiclass, this is a list of such tensors, one for each class. + tpr: + tensor with true positive rates. + If multiclass, this is a list of such tensors, one for each class. + thresholds: + thresholds used for computing false- and true postive rates + + """ + preds = torch.cat(self.preds, dim=0) + target = torch.cat(self.target, dim=0) + return _roc_compute(preds, target, self.num_classes, self.pos_label) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 3bb5313db7b271..e13242e40b0acf 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -11,36 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.metrics.functional.average_precision import average_precision from pytorch_lightning.metrics.functional.classification import ( accuracy, auc, auroc, - average_precision, dice_score, - multiclass_precision_recall_curve, - multiclass_roc, multiclass_auroc, precision, precision_recall, - precision_recall_curve, recall, - roc, stat_scores, stat_scores_multiple_classes, - to_categorical, - to_onehot, iou, ) -from pytorch_lightning.metrics.functional.nlp import bleu_score -from pytorch_lightning.metrics.functional.self_supervised import ( - embedding_similarity -) +from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # TODO: unify metrics between class and functional, add below from pytorch_lightning.metrics.functional.explained_variance import explained_variance +from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error +from pytorch_lightning.metrics.functional.nlp import bleu_score +from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve from pytorch_lightning.metrics.functional.psnr import psnr +from pytorch_lightning.metrics.functional.roc import roc +from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity from pytorch_lightning.metrics.functional.ssim import ssim -from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix -from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 diff --git a/pytorch_lightning/metrics/functional/average_precision.py b/pytorch_lightning/metrics/functional/average_precision.py new file mode 100644 index 00000000000000..da4f37b0732066 --- /dev/null +++ b/pytorch_lightning/metrics/functional/average_precision.py @@ -0,0 +1,95 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Sequence, Tuple, Union, List + +import torch + +from pytorch_lightning.metrics.functional.precision_recall_curve import ( + _precision_recall_curve_update, + _precision_recall_curve_compute +) + + +def _average_precision_update( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, int, int]: + return _precision_recall_curve_update(preds, target, num_classes, pos_label) + + +def _average_precision_compute( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + pos_label: int, + sample_weights: Optional[Sequence] = None +) -> Union[List[torch.Tensor], torch.Tensor]: + precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) + # Return the step function integral + # The following works because the last entry of precision is + # guaranteed to be 1, as returned by precision_recall_curve + if num_classes == 1: + return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) + + res = [] + for p, r in zip(precision, recall): + res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) + return res + + +def average_precision( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + sample_weights: Optional[Sequence] = None, +) -> Union[List[torch.Tensor], torch.Tensor]: + """ + Computes the average precision score. + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + sample_weight: sample weights for each data point + + Returns: + tensor with average precision. If multiclass will return list + of such tensors, one for each class + + Example (binary case): + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> average_precision(pred, target, pos_label=1) + tensor(1.) + + Example (multiclass case): + + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> average_precision(pred, target, num_classes=5) + [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] + + """ + preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes, pos_label) + return _average_precision_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 75eeeca3b8e171..1b407f2a7ec9e4 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -15,98 +15,12 @@ from typing import Callable, Optional, Sequence, Tuple import torch -from pytorch_lightning.metrics.functional.reduction import class_reduce, reduce from torch.nn import functional as F +from pytorch_lightning.metrics.utils import to_categorical, get_num_classes, reduce, class_reduce from pytorch_lightning.utilities import rank_zero_warn -def to_onehot( - tensor: torch.Tensor, - num_classes: Optional[int] = None, -) -> torch.Tensor: - """ - Converts a dense label tensor to one-hot format - - Args: - tensor: dense label tensor, with shape [N, d1, d2, ...] - num_classes: number of classes C - - Output: - A sparse label tensor with shape [N, C, d1, d2, ...] - - Example: - - >>> x = torch.tensor([1, 2, 3]) - >>> to_onehot(x) - tensor([[0, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1]]) - - """ - if num_classes is None: - num_classes = int(tensor.max().detach().item() + 1) - dtype, device, shape = tensor.dtype, tensor.device, tensor.shape - tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], - dtype=dtype, device=device) - index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) - return tensor_onehot.scatter_(1, index, 1.0) - - -def to_categorical( - tensor: torch.Tensor, - argmax_dim: int = 1 -) -> torch.Tensor: - """ - Converts a tensor of probabilities to a dense label tensor - - Args: - tensor: probabilities to get the categorical label [N, d1, d2, ...] - argmax_dim: dimension to apply - - Return: - A tensor with categorical labels [N, d2, ...] - - Example: - - >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) - >>> to_categorical(x) - tensor([1, 0]) - - """ - return torch.argmax(tensor, dim=argmax_dim) - - -def get_num_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, -) -> int: - """ - Calculates the number of classes for a given prediction and target tensor. - - Args: - pred: predicted values - target: true labels - num_classes: number of classes if known - - Return: - An integer that represents the number of classes. - """ - num_target_classes = int(target.max().detach().item() + 1) - num_pred_classes = int(pred.max().detach().item() + 1) - num_all_classes = max(num_target_classes, num_pred_classes) - - if num_classes is None: - num_classes = num_all_classes - elif num_classes != num_all_classes: - rank_zero_warn(f'You have set {num_classes} number of classes which is' - f' different from predicted ({num_pred_classes}) and' - f' target ({num_target_classes}) number of classes', - RuntimeWarning) - return num_classes - - def stat_scores( pred: torch.Tensor, target: torch.Tensor, @@ -462,7 +376,8 @@ def _binary_clf_curve( return fps, tps, pred[threshold_idxs] -def roc( +# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py +def __roc( pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, @@ -471,6 +386,8 @@ def roc( """ Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. + .. warning:: Deprecated + Args: pred: estimated probabilities target: ground-truth labels @@ -484,7 +401,7 @@ def roc( >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 1, 1]) - >>> fpr, tpr, thresholds = roc(x, y) + >>> fpr, tpr, thresholds = __roc(x, y) >>> fpr tensor([0., 0., 0., 0., 1.]) >>> tpr @@ -516,7 +433,8 @@ def roc( return fpr, tpr, thresholds -def multiclass_roc( +# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py +def __multiclass_roc( pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, @@ -525,6 +443,8 @@ def multiclass_roc( """ Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. + .. warning:: Deprecated + Args: pred: estimated probabilities target: ground-truth labels @@ -542,7 +462,7 @@ def multiclass_roc( ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE + >>> __multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), @@ -554,119 +474,11 @@ def multiclass_roc( for c in range(num_classes): pred_c = pred[:, c] - class_roc_vals.append(roc(pred=pred_c, target=target, - sample_weight=sample_weight, pos_label=c)) + class_roc_vals.append(__roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c)) return tuple(class_roc_vals) -def precision_recall_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Computes precision-recall pairs for different thresholds. - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - precision, recall, thresholds - - Example: - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 0]) - >>> precision, recall, thresholds = precision_recall_curve(pred, target) - >>> precision - tensor([0.6667, 0.5000, 0.0000, 1.0000]) - >>> recall - tensor([1.0000, 0.5000, 0.0000, 0.0000]) - >>> thresholds - tensor([1, 2, 3]) - - """ - fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) - - precision = tps / (tps + fps) - recall = tps / tps[-1] - - # stop when full recall attained - # and reverse the outputs so recall is decreasing - last_ind = torch.where(tps == tps[-1])[0][0] - sl = slice(0, last_ind.item() + 1) - - # need to call reversed explicitly, since including that to slice would - # introduce negative strides that are not yet supported in pytorch - precision = torch.cat([reversed(precision[sl]), - torch.ones(1, dtype=precision.dtype, - device=precision.device)]) - - recall = torch.cat([reversed(recall[sl]), - torch.zeros(1, dtype=recall.dtype, - device=recall.device)]) - - thresholds = torch.tensor(reversed(thresholds[sl])) - - return precision, recall, thresholds - - -def multiclass_precision_recall_curve( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - num_classes: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Computes precision-recall pairs for different thresholds given a multiclass scores. - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weight - num_classes: number of classes - - Return: - number of classes, precision, recall, thresholds - - Example: - - >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - ... [0.05, 0.85, 0.05, 0.05], - ... [0.05, 0.05, 0.85, 0.05], - ... [0.05, 0.05, 0.05, 0.85]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> nb_classes, precision, recall, thresholds = multiclass_precision_recall_curve(pred, target) - >>> nb_classes - (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) - >>> precision - (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) - >>> recall - (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) - >>> thresholds # doctest: +NORMALIZE_WHITESPACE - (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) - """ - num_classes = get_num_classes(pred, target, num_classes) - - class_pr_vals = [] - for c in range(num_classes): - pred_c = pred[:, c] - - class_pr_vals.append(precision_recall_curve( - pred=pred_c, - target=target, - sample_weight=sample_weight, pos_label=c)) - - return tuple(class_pr_vals) - - def auc( x: torch.Tensor, y: torch.Tensor, @@ -777,7 +589,7 @@ def auroc( @auc_decorator(reorder=True) def _auroc(pred, target, sample_weight, pos_label): - return roc(pred, target, sample_weight, pos_label) + return __roc(pred, target, sample_weight, pos_label) return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) @@ -808,7 +620,7 @@ def multiclass_auroc( ... [0.05, 0.05, 0.85, 0.05], ... [0.05, 0.05, 0.05, 0.85]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> multiclass_auroc(pred, target) # doctest: +NORMALIZE_WHITESPACE + >>> multiclass_auroc(pred, target, num_classes=4) tensor(0.6667) """ if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)): @@ -830,7 +642,7 @@ def multiclass_auroc( @multiclass_auc_decorator(reorder=False) def _multiclass_auroc(pred, target, sample_weight, num_classes): - return multiclass_roc(pred, target, sample_weight, num_classes) + return __multiclass_roc(pred, target, sample_weight, num_classes) class_aurocs = _multiclass_auroc(pred=pred, target=target, sample_weight=sample_weight, @@ -838,40 +650,6 @@ def _multiclass_auroc(pred, target, sample_weight, num_classes): return torch.mean(class_aurocs) -def average_precision( - pred: torch.Tensor, - target: torch.Tensor, - sample_weight: Optional[Sequence] = None, - pos_label: int = 1., -) -> torch.Tensor: - """ - Compute average precision from prediction scores - - Args: - pred: estimated probabilities - target: ground-truth labels - sample_weight: sample weights - pos_label: the label for the positive class - - Return: - Tensor containing average precision score - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> average_precision(x, y) - tensor(0.3333) - """ - precision, recall, _ = precision_recall_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) - # Return the step function integral - # The following works because the last entry of precision is - # guaranteed to be 1, as returned by precision_recall_curve - return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) - - def dice_score( pred: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/confusion_matrix.py b/pytorch_lightning/metrics/functional/confusion_matrix.py index 143d237b3b2c6c..3370e24215123f 100644 --- a/pytorch_lightning/metrics/functional/confusion_matrix.py +++ b/pytorch_lightning/metrics/functional/confusion_matrix.py @@ -15,8 +15,8 @@ import torch -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.metrics.utils import _input_format_classification +from pytorch_lightning.utilities import rank_zero_warn def _confusion_matrix_update(preds: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index cab15201b4d018..012e1486ebb1f9 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -14,6 +14,7 @@ from typing import Union, Tuple, Sequence import torch + from pytorch_lightning.metrics.utils import _check_same_shape diff --git a/pytorch_lightning/metrics/functional/f_beta.py b/pytorch_lightning/metrics/functional/f_beta.py index 9700545019a41c..3f0a7a04493257 100755 --- a/pytorch_lightning/metrics/functional/f_beta.py +++ b/pytorch_lightning/metrics/functional/f_beta.py @@ -15,8 +15,7 @@ import torch -from pytorch_lightning.metrics.utils import _input_format_classification_one_hot -from pytorch_lightning.metrics.functional.reduction import class_reduce +from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce def _fbeta_update( diff --git a/pytorch_lightning/metrics/functional/mean_absolute_error.py b/pytorch_lightning/metrics/functional/mean_absolute_error.py index 4645499fb79aac..359eadc389038f 100644 --- a/pytorch_lightning/metrics/functional/mean_absolute_error.py +++ b/pytorch_lightning/metrics/functional/mean_absolute_error.py @@ -14,6 +14,7 @@ from typing import Tuple import torch + from pytorch_lightning.metrics.utils import _check_same_shape diff --git a/pytorch_lightning/metrics/functional/mean_squared_error.py b/pytorch_lightning/metrics/functional/mean_squared_error.py index 0223e7b9e1f369..e418536b26973f 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_error.py @@ -14,6 +14,7 @@ from typing import Tuple import torch + from pytorch_lightning.metrics.utils import _check_same_shape diff --git a/pytorch_lightning/metrics/functional/mean_squared_log_error.py b/pytorch_lightning/metrics/functional/mean_squared_log_error.py index aac2fbd0bc0d11..1b96e1a7abc10c 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_log_error.py @@ -14,6 +14,7 @@ from typing import Tuple import torch + from pytorch_lightning.metrics.utils import _check_same_shape diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py new file mode 100644 index 00000000000000..6c112fe0103780 --- /dev/null +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -0,0 +1,221 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Sequence, Tuple, List, Union + +import torch +import torch.nn.functional as F + +from pytorch_lightning.utilities import rank_zero_warn + + +def _binary_clf_curve( + preds: torch.Tensor, + target: torch.Tensor, + sample_weights: Optional[Sequence] = None, + pos_label: int = 1., +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py + """ + if sample_weights is not None and not isinstance(sample_weights, torch.Tensor): + sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float) + + # remove class dimension if necessary + if preds.ndim > target.ndim: + preds = preds[:, 0] + desc_score_indices = torch.argsort(preds, descending=True) + + preds = preds[desc_score_indices] + target = target[desc_score_indices] + + if sample_weights is not None: + weight = sample_weights[desc_score_indices] + else: + weight = 1. + + # pred typically has many tied values. Here we extract + # the indices associated with the distinct values. We also + # concatenate a value for the end of the curve. + distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0] + threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) + target = (target == pos_label).to(torch.long) + tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] + + if sample_weights is not None: + # express fps as a cumsum to ensure fps is increasing even in + # the presence of floating point errors + fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] + else: + fps = 1 + threshold_idxs - tps + + return fps, tps, preds[threshold_idxs] + + +def _precision_recall_curve_update( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, int, int]: + if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): + raise ValueError( + "preds and target must have same number of dimensions, or one additional dimension for preds" + ) + # single class evaluation + if len(preds.shape) == len(target.shape): + if num_classes is not None and num_classes != 1: + raise ValueError('Preds and target have equal shape, but number of classes is different from 1') + num_classes = 1 + if pos_label is None: + rank_zero_warn('`pos_label` automatically set 1.') + pos_label = 1 + preds = preds.flatten() + target = target.flatten() + + # multi class evaluation + if len(preds.shape) == len(target.shape) + 1: + if pos_label is not None: + rank_zero_warn('Argument `pos_label` should be `None` when running' + f'multiclass precision recall curve. Got {pos_label}') + if num_classes != preds.shape[1]: + raise ValueError(f'Argument `num_classes` was set to {num_classes} in' + f'metric `precision_recall_curve` but detected {preds.shape[1]}' + 'number of classes from predictions') + preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) + target = target.flatten() + + return preds, target, num_classes, pos_label + + +def _precision_recall_curve_compute( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + pos_label: int, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + + if num_classes == 1: + fps, tps, thresholds = _binary_clf_curve( + preds=preds, + target=target, + sample_weights=sample_weights, + pos_label=pos_label + ) + + precision = tps / (tps + fps) + recall = tps / tps[-1] + + # stop when full recall attained + # and reverse the outputs so recall is decreasing + last_ind = torch.where(tps == tps[-1])[0][0] + sl = slice(0, last_ind.item() + 1) + + # need to call reversed explicitly, since including that to slice would + # introduce negative strides that are not yet supported in pytorch + precision = torch.cat([reversed(precision[sl]), + torch.ones(1, dtype=precision.dtype, + device=precision.device)]) + + recall = torch.cat([reversed(recall[sl]), + torch.zeros(1, dtype=recall.dtype, + device=recall.device)]) + + thresholds = reversed(thresholds[sl]).clone() + + return precision, recall, thresholds + + # Recursively call per class + precision, recall, thresholds = [], [], [] + for c in range(num_classes): + preds_c = preds[:, c] + res = precision_recall_curve( + preds=preds_c, + target=target, + num_classes=1, + pos_label=c, + sample_weights=sample_weights, + ) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) + + return precision, recall, thresholds + + +def precision_recall_curve( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + """ + Computes precision-recall pairs for different thresholds. + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + sample_weight: sample weights for each data point + + Returns: 3-element tuple containing + + precision: + tensor where element i is the precision of predictions with + score >= thresholds[i] and the last element is 1. + If multiclass, this is a list of such tensors, one for each class. + recall: + tensor where element i is the recall of predictions with + score >= thresholds[i] and the last element is 0. + If multiclass, this is a list of such tensors, one for each class. + thresholds: + Thresholds used for computing precision/recall scores + + Example (binary case): + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1) + >>> precision + tensor([0.6667, 0.5000, 0.0000, 1.0000]) + >>> recall + tensor([1.0000, 0.5000, 0.0000, 0.0000]) + >>> thresholds + tensor([1, 2, 3]) + + Example (multiclass case): + + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5) + >>> precision + [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] + >>> recall + [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] + >>> thresholds + [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] + + """ + preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, + num_classes, pos_label) + return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py deleted file mode 100644 index 197b1dd7097a32..00000000000000 --- a/pytorch_lightning/metrics/functional/reduction.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - - -def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: - """ - Reduces a given tensor by a given reduction method - - Args: - to_reduce : the tensor, which shall be reduced - reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum') - - Return: - reduced Tensor - - Raise: - ValueError if an invalid reduction parameter was given - """ - if reduction == 'elementwise_mean': - return torch.mean(to_reduce) - if reduction == 'none': - return to_reduce - if reduction == 'sum': - return torch.sum(to_reduce) - raise ValueError('Reduction parameter unknown.') - - -def class_reduce(num: torch.Tensor, - denom: torch.Tensor, - weights: torch.Tensor, - class_reduction: str = 'none') -> torch.Tensor: - """ - Function used to reduce classification metrics of the form `num / denom * weights`. - For example for calculating standard accuracy the num would be number of - true positives per class, denom would be the support per class, and weights - would be a tensor of 1s - - Args: - num: numerator tensor - decom: denominator tensor - weights: weights for each class - class_reduction: reduction method for multiclass problems - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'`` or ``None``: returns calculated metric per class - - """ - valid_reduction = ('micro', 'macro', 'weighted', 'none', None) - if class_reduction == 'micro': - fraction = torch.sum(num) / torch.sum(denom) - else: - fraction = num / denom - - # We need to take care of instances where the denom can be 0 - # for some (or all) classes which will produce nans - fraction[fraction != fraction] = 0 - - if class_reduction == 'micro': - return fraction - elif class_reduction == 'macro': - return torch.mean(fraction) - elif class_reduction == 'weighted': - return torch.sum(fraction * (weights.float() / torch.sum(weights))) - elif class_reduction == 'none' or class_reduction is None: - return fraction - - raise ValueError(f'Reduction parameter {class_reduction} unknown.' - f' Choose between one of these: {valid_reduction}') diff --git a/pytorch_lightning/metrics/functional/roc.py b/pytorch_lightning/metrics/functional/roc.py new file mode 100644 index 00000000000000..ffd5f9f0ac79cf --- /dev/null +++ b/pytorch_lightning/metrics/functional/roc.py @@ -0,0 +1,146 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Sequence, Tuple, List, Union + +import torch + +from pytorch_lightning.metrics.functional.precision_recall_curve import ( + _precision_recall_curve_update, + _binary_clf_curve +) + + +def _roc_update( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, int, int]: + return _precision_recall_curve_update(preds, target, num_classes, pos_label) + + +def _roc_compute( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + pos_label: int, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + + if num_classes == 1: + fps, tps, thresholds = _binary_clf_curve( + preds=preds, + target=target, + sample_weights=sample_weights, + pos_label=pos_label + ) + # Add an extra threshold position + # to make sure that the curve starts at (0, 0) + tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) + fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) + thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) + + if fps[-1] <= 0: + raise ValueError("No negative samples in targets, false positive value should be meaningless") + fpr = fps / fps[-1] + + if tps[-1] <= 0: + raise ValueError("No positive samples in targets, true positive value should be meaningless") + tpr = tps / tps[-1] + + return fpr, tpr, thresholds + + # Recursively call per class + fpr, tpr, thresholds = [], [], [] + for c in range(num_classes): + preds_c = preds[:, c] + res = roc( + preds=preds_c, + target=target, + num_classes=1, + pos_label=c, + sample_weights=sample_weights, + ) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + + return fpr, tpr, thresholds + + +def roc( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + sample_weights: Optional[Sequence] = None, +) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]: + """ + Computes the Receiver Operating Characteristic (ROC). + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + pos_label: integer determining the positive class. Default is ``None`` + which for binary problem is translate to 1. For multiclass problems + this argument should not be set as we iteratively change it in the + range [0,num_classes-1] + sample_weight: sample weights for each data point + + Returns: 3-element tuple containing + + fpr: + tensor with false positive rates. + If multiclass, this is a list of such tensors, one for each class. + tpr: + tensor with true positive rates. + If multiclass, this is a list of such tensors, one for each class. + thresholds: + thresholds used for computing false- and true postive rates + + Example (binary case): + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> fpr, tpr, thresholds = roc(pred, target, pos_label=1) + >>> fpr + tensor([0., 0., 0., 0., 1.]) + >>> tpr + tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) + >>> thresholds + tensor([4, 3, 2, 1, 0]) + + Example (multiclass case): + + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05], + ... [0.05, 0.05, 0.05, 0.75]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> fpr, tpr, thresholds = roc(pred, target, num_classes=4) + >>> fpr + [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] + >>> tpr + [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500]), + tensor([1.7500, 0.7500, 0.0500])] + + """ + preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label) + return _roc_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/pytorch_lightning/metrics/functional/ssim.py b/pytorch_lightning/metrics/functional/ssim.py index 05f9f00a88b273..b52744421aef2a 100644 --- a/pytorch_lightning/metrics/functional/ssim.py +++ b/pytorch_lightning/metrics/functional/ssim.py @@ -14,10 +14,10 @@ from typing import Optional, Sequence, Tuple import torch -from pytorch_lightning.metrics.functional.reduction import reduce -from pytorch_lightning.metrics.utils import _check_same_shape from torch.nn import functional as F +from pytorch_lightning.metrics.utils import _check_same_shape, reduce + def _gaussian(kernel_size: int, sigma: int, dtype: torch.dtype, device: torch.device): dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 2bc7977be25cff..0f61b94c551396 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -13,19 +13,16 @@ # limitations under the License. import functools from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Union -from collections.abc import Mapping, Sequence -from collections import namedtuple +from collections.abc import Sequence from copy import deepcopy -from distutils.version import LooseVersion +from typing import Any, Callable, Optional, Union -import os import torch from torch import nn +from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import gather_all_tensors -from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum class Metric(nn.Module, ABC): diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index e1ff95b94f4719..9aaa5578edb803 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Tuple, Optional import torch +from pytorch_lightning.utilities import rank_zero_warn METRIC_EPS = 1e-6 @@ -35,34 +36,6 @@ def _flatten(x): return [item for sublist in x for item in sublist] -def to_onehot( - tensor: torch.Tensor, - num_classes: int, -) -> torch.Tensor: - """ - Converts a dense label tensor to one-hot format - - Args: - tensor: dense label tensor, with shape [N, d1, d2, ...] - num_classes: number of classes C - - Output: - A sparse label tensor with shape [N, C, d1, d2, ...] - - Example: - >>> x = torch.tensor([1, 2, 3]) - >>> to_onehot(x, num_classes=4) - tensor([[0, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1]]) - """ - dtype, device, shape = tensor.dtype, tensor.device, tensor.shape - tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], - dtype=dtype, device=device) - index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) - return tensor_onehot.scatter_(1, index, 1.0) - - def _check_same_shape(pred: torch.Tensor, target: torch.Tensor): """ Check that predictions and target have the same shape, else raise error """ if pred.shape != target.shape: @@ -146,3 +119,157 @@ def _input_format_classification_one_hot( target = target.transpose(1, 0) return preds.reshape(num_classes, -1), target.reshape(num_classes, -1) + + +def to_onehot( + tensor: torch.Tensor, + num_classes: Optional[int] = None, +) -> torch.Tensor: + """ + Converts a dense label tensor to one-hot format + + Args: + tensor: dense label tensor, with shape [N, d1, d2, ...] + num_classes: number of classes C + + Output: + A sparse label tensor with shape [N, C, d1, d2, ...] + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> to_onehot(x) + tensor([[0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + + """ + if num_classes is None: + num_classes = int(tensor.max().detach().item() + 1) + dtype, device, shape = tensor.dtype, tensor.device, tensor.shape + tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], + dtype=dtype, device=device) + index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) + return tensor_onehot.scatter_(1, index, 1.0) + + +def to_categorical( + tensor: torch.Tensor, + argmax_dim: int = 1 +) -> torch.Tensor: + """ + Converts a tensor of probabilities to a dense label tensor + + Args: + tensor: probabilities to get the categorical label [N, d1, d2, ...] + argmax_dim: dimension to apply + + Return: + A tensor with categorical labels [N, d2, ...] + + Example: + + >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) + >>> to_categorical(x) + tensor([1, 0]) + + """ + return torch.argmax(tensor, dim=argmax_dim) + + +def get_num_classes( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, +) -> int: + """ + Calculates the number of classes for a given prediction and target tensor. + + Args: + pred: predicted values + target: true labels + num_classes: number of classes if known + + Return: + An integer that represents the number of classes. + """ + num_target_classes = int(target.max().detach().item() + 1) + num_pred_classes = int(pred.max().detach().item() + 1) + num_all_classes = max(num_target_classes, num_pred_classes) + + if num_classes is None: + num_classes = num_all_classes + elif num_classes != num_all_classes: + rank_zero_warn(f'You have set {num_classes} number of classes which is' + f' different from predicted ({num_pred_classes}) and' + f' target ({num_target_classes}) number of classes', + RuntimeWarning) + return num_classes + + +def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: + """ + Reduces a given tensor by a given reduction method + + Args: + to_reduce : the tensor, which shall be reduced + reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum') + + Return: + reduced Tensor + + Raise: + ValueError if an invalid reduction parameter was given + """ + if reduction == 'elementwise_mean': + return torch.mean(to_reduce) + if reduction == 'none': + return to_reduce + if reduction == 'sum': + return torch.sum(to_reduce) + raise ValueError('Reduction parameter unknown.') + + +def class_reduce(num: torch.Tensor, + denom: torch.Tensor, + weights: torch.Tensor, + class_reduction: str = 'none') -> torch.Tensor: + """ + Function used to reduce classification metrics of the form `num / denom * weights`. + For example for calculating standard accuracy the num would be number of + true positives per class, denom would be the support per class, and weights + would be a tensor of 1s + + Args: + num: numerator tensor + decom: denominator tensor + weights: weights for each class + class_reduction: reduction method for multiclass problems + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'`` or ``None``: returns calculated metric per class + + """ + valid_reduction = ('micro', 'macro', 'weighted', 'none', None) + if class_reduction == 'micro': + fraction = torch.sum(num) / torch.sum(denom) + else: + fraction = num / denom + + # We need to take care of instances where the denom can be 0 + # for some (or all) classes which will produce nans + fraction[fraction != fraction] = 0 + + if class_reduction == 'micro': + return fraction + elif class_reduction == 'macro': + return torch.mean(fraction) + elif class_reduction == 'weighted': + return torch.sum(fraction * (weights.float() / torch.sum(weights))) + elif class_reduction == 'none' or class_reduction is None: + return fraction + + raise ValueError(f'Reduction parameter {class_reduction} unknown.' + f' Choose between one of these: {valid_reduction}') diff --git a/pytorch_lightning/plugins/plugin_connector.py b/pytorch_lightning/plugins/plugin_connector.py index a7ec7932740d83..b6ede3ab7c7a6d 100644 --- a/pytorch_lightning/plugins/plugin_connector.py +++ b/pytorch_lightning/plugins/plugin_connector.py @@ -120,7 +120,7 @@ def _convert_str_to_plugin(self, plugin): f" {plugin} is not a supported lightning custom plugin." " If you're trying to pass a custom plugin, please pass this as an object to" " Trainer(plugins=[MyPlugin()]." - f" Supported plugins as string input: {(e.name for e in LightningCustomPlugins)}." + f" Supported plugins as string input: {[e.name for e in LightningCustomPlugins]}." ) plugin_cls = LightningCustomPlugins[plugin].value return plugin_cls(trainer=self.trainer) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 394a65bf7e9f5b..9a8e12c9419ab4 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -71,7 +71,7 @@ def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpo ) if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True: - self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None)) + self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min')) def configure_progress_bar(self, refresh_rate=1, process_position=0): # smaller refresh rate on colab causes crashes, warn user about this diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index a8a035e8132c4d..8dc993df7dedb3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -11,47 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from collections import ChainMap, defaultdict -from copy import deepcopy +from collections import defaultdict from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union -from pytorch_lightning.core.step_result import Result -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -# used to map boolean to right LoggerStage values -class FrozenDict(dict): - def __init__(self, *args, **kwargs): - self._hash = None - super(FrozenDict, self).__init__(*args, **kwargs) - - def __hash__(self): - if self._hash is None: - self._hash = hash(tuple(sorted(self.items()))) # iteritems() on py2 - return self._hash - - def _immutable(self, *args, **kws): - raise TypeError('cannot change object - object is immutable') - - __setitem__ = _immutable - __delitem__ = _immutable - pop = _immutable - popitem = _immutable - clear = _immutable - update = _immutable - setdefault = _immutable +import torch - -LOOKUP_TABLE = FrozenDict({"1": "test", "0": "validation", "True": "test", "False": "validation"}) +from pytorch_lightning.core.step_result import Result -class LoggerStages(Enum): +class LoggerStages(str, Enum): TRAIN = "train" VAL = "validation" TEST = "test" + @staticmethod + def determine_stage(stage_or_testing: Union[str, bool]) -> 'LoggerStages': + if isinstance(stage_or_testing, str) and stage_or_testing in list(LoggerStages): + return LoggerStages(stage_or_testing) + if isinstance(stage_or_testing, (bool, int)): + # stage_or_testing is trainer.testing + return LoggerStages.TEST if bool(stage_or_testing) else LoggerStages.VAL + raise RuntimeError(f"Invalid stage {stage_or_testing} of type {type(stage_or_testing)} given") + class ResultStoreType(Enum): INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop" @@ -64,7 +46,7 @@ class HookResultStore: It holds all metrics logged using the self.log function in the scope of ModelHooks or Callback functions. - We need to differiante 3 different scenarios: + We need to differentiate 3 different scenarios: - (1): We are outside of a batch loop * It means no dataloader_idx, no optimizer idx, etc.. - (2): We are inside the training batch loop @@ -74,19 +56,14 @@ class HookResultStore: The data store `Result` objects for those 3 scenarios in `self._internals`. - (1): self._internals = {"dataloader_idx": [Result(), ..., Result()]} + (1): self._internals = {dataloader_idx: [Result(), ..., Result()]} * dataloader_idx not being defined, it is set to 0 b default - (2): self._internals = {"dataloader_idx": - {"optimizer_idx": - {"batch_idx": - [Result(), Result()] - } - } - } + (2): self._internals = {dataloader_idx: {optimizer_idx: {batch_idx: [Result(), ..., Result()]}}} (3): Same as (1) for simplicity Those data structures enables us to reduce properly Result object when batch loop is finished. """ + def __init__(self, fx_name): self._fx_name = fx_name self._internals = {} @@ -101,22 +78,21 @@ def has_several_dataloaders(self) -> bool: @property def num_dataloaders(self) -> int: - _inter = self._internals_reduced if self.has_reduced else self._internals - return len(_inter) + inter = self._internals_reduced if self.has_reduced else self._internals + return len(inter) def check_dataloader_idx(self, result: Result) -> bool: - random_key = [*result.keys()][-1] - add_dataloader_idx = result["meta"][random_key]["dataloader_idx"] is not None - return add_dataloader_idx + random_key = list(result.keys())[-1] + return result["meta"][random_key]["dataloader_idx"] is not None - def get_lastest_from_func_name(self, latest_result, func_name: str, *args, **kwargs) -> Dict: + def get_latest_from_func_name(self, latest_result, func_name: str, *args, **kwargs) -> Dict: results = {} add_dataloader_idx = self.check_dataloader_idx(latest_result) func = getattr(latest_result, func_name) results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs)) return results - def run_lastest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]: + def run_latest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) -> List[Dict]: """ This function used cache_ref and cache_result to optimize loading metrics @@ -126,39 +102,27 @@ def run_lastest_batch_metrics_with_func_name(self, func_name, *args, **kwargs) - HookResultStore keeps track of its latest added result object, and cache its pbar and log metrics if already called on, """ - results = [] - for dl_idx in range(self.num_dataloaders): - dl_idx = str(dl_idx) - latest_result = self._latest_ref[dl_idx] - result = self.get_lastest_from_func_name(latest_result, func_name, *args, **kwargs) - results.append(result) - return results + return [ + self.get_latest_from_func_name(self._latest_ref[dl_idx], func_name, *args, **kwargs) + for dl_idx in range(self.num_dataloaders) + ] def get_batch_pbar_metrics(self, *args, **kwargs): - return self.run_lastest_batch_metrics_with_func_name("get_batch_pbar_metrics", - *args, - **kwargs) + return self.run_latest_batch_metrics_with_func_name("get_batch_pbar_metrics", *args, **kwargs) def get_batch_log_metrics(self, *args, **kwargs): - return self.run_lastest_batch_metrics_with_func_name("get_batch_log_metrics", - *args, - **kwargs) + return self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics", *args, **kwargs) def run_epoch_func(self, results, opt_metric, func_name, *args, **kwargs) -> None: - if isinstance(opt_metric, Result): - func = getattr(opt_metric, func_name) - metrics_to_log = func( - *args, - add_dataloader_idx=self.has_several_dataloaders, - **kwargs) - results.append(metrics_to_log) - else: + if not isinstance(opt_metric, Result): raise Exception("The provided opt_metric should be a Result Object. Something is wrong") + func = getattr(opt_metric, func_name) + metrics_to_log = func(*args, add_dataloader_idx=self.has_several_dataloaders, **kwargs) + results.append(metrics_to_log) def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> List[Dict]: results = [] for dl_idx in range(self.num_dataloaders): - dl_idx = str(dl_idx) opt_metrics = self._internals_reduced[dl_idx] if isinstance(opt_metrics, defaultdict): for opt_metric in opt_metrics.values(): @@ -167,58 +131,51 @@ def get_epoch_from_func_name(self, func_name, *args, **kwargs) -> List[Dict]: self.run_epoch_func(results, opt_metrics, func_name, *args, **kwargs) return results - def get_epoch_pbar_metrics(self, *args, **kwargs) -> List[Dict]: + def get_epoch_pbar_metrics(self, *_, **__) -> List[Dict]: return self.get_epoch_from_func_name("get_epoch_pbar_metrics") - def get_epoch_log_metrics(self, *args, **kwargs) -> List[Dict]: + def get_epoch_log_metrics(self, *_, **__) -> List[Dict]: return self.get_epoch_from_func_name("get_epoch_log_metrics") - def get_forked_metrics(self, *args, **kwargs) -> List[Dict]: + def get_forked_metrics(self, *_, **__) -> List[Dict]: return self.get_epoch_from_func_name("get_forked_metrics") @staticmethod def _append_to_structure(primary_dict, opt_idx, batch_idx, result) -> None: - if opt_idx not in primary_dict: - primary_dict[opt_idx] = {} - - if batch_idx not in primary_dict[opt_idx]: - primary_dict[opt_idx][batch_idx] = [] - + primary_dict.setdefault(opt_idx, {}) + primary_dict[opt_idx].setdefault(batch_idx, []) primary_dict[opt_idx][batch_idx].append(result) - def append(self, result, dataloader_idx=None, extra_info: dict = {}) -> None: - + def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optional[dict] = None) -> None: assert isinstance(result, Result) - if dataloader_idx is None: dataloader_idx = 0 - - primary_key = f"{dataloader_idx}" + if extra_info is None: + extra_info = {} # [dataloader_idx][optimizer_idx][training_step_idx] is a list if len(extra_info) > 0: self._internal_type = ResultStoreType.INSIDE_BATCH_TRAIN_LOOP # initialize dictionary - if primary_key not in self._internals: - self._internals[primary_key] = {} - self._internals_reduced[primary_key] = defaultdict(dict) + if dataloader_idx not in self._internals: + self._internals[dataloader_idx] = {} + self._internals_reduced[dataloader_idx] = defaultdict(dict) # extract infos - opt_idx = str(extra_info["opt_idx"]) - batch_idx = str(extra_info["batch_idx"]) + opt_idx = extra_info["opt_idx"] + batch_idx = extra_info["batch_idx"] - self._append_to_structure(self._internals[primary_key], opt_idx, batch_idx, result) + self._append_to_structure(self._internals[dataloader_idx], opt_idx, batch_idx, result) - self._latest_ref[primary_key] = result + self._latest_ref[dataloader_idx] = result # [dataloader_idx] is a list else: self._internal_type = ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP - if primary_key not in self._internals: - self._internals[primary_key] = [] - self._internals[primary_key].append(result) + self._internals.setdefault(dataloader_idx, []) + self._internals[dataloader_idx].append(result) - self._latest_ref[primary_key] = result + self._latest_ref[dataloader_idx] = result def auto_reduce_results_on_epoch_end(self) -> None: """ @@ -226,75 +183,65 @@ def auto_reduce_results_on_epoch_end(self) -> None: The reduced Result object will be saved into `self._internals_reduced` The `self._internals` stored Result objects will be deleted to save memory. """ - if not self.has_reduced: - epoch_log_metrics = {} - epoch_progress_bar_metrics = {} - - for dl_idx in range(self.num_dataloaders): - dl_idx = str(dl_idx) - epoch_metrics = self._internals[dl_idx] + if self.has_reduced: + return + for dl_idx in range(self.num_dataloaders): + epoch_metrics = self._internals[dl_idx] - if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: + if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: - num_opt_idx = len(self._internals[dl_idx]) - 1 + num_opt_idx = len(self._internals[dl_idx]) - 1 - # Make sure we didn't create key - assert num_opt_idx >= 0 + # Make sure we didn't create key + assert num_opt_idx >= 0 - for opt_idx in range(num_opt_idx + 1): - opt_idx = str(opt_idx) - # TODO: Figure out to reduce memory - # TODO: How to start training in middle of epoch - opt_outputs = epoch_metrics[opt_idx] + for opt_idx in range(num_opt_idx + 1): + # TODO: Figure out to reduce memory + # TODO: How to start training in middle of epoch + opt_outputs = epoch_metrics[opt_idx] - num_batch_idx = len(self._internals[dl_idx][str(num_opt_idx)]) - 1 - assert num_batch_idx >= 0 - batch_indexes = self._internals[dl_idx][str(num_opt_idx)].keys() + num_batch_idx = len(self._internals[dl_idx][num_opt_idx]) - 1 + assert num_batch_idx >= 0 + batch_indexes = self._internals[dl_idx][num_opt_idx].keys() - # reduce across time first - time_reduced_outputs = [] - for batch_idx in batch_indexes: - batch_idx = str(batch_idx) - tbptt_outs = opt_outputs[str(batch_idx)] - tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) - if len(tbptt_outs) > 1: - time_reduced_outputs.append(tbptt_outs) + # reduce across time first + time_reduced_outputs = [] + for batch_idx in batch_indexes: + tbptt_outs = opt_outputs[batch_idx] + tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs) + if len(tbptt_outs) > 1: + time_reduced_outputs.append(tbptt_outs) - if len(time_reduced_outputs) == 0: - continue + if len(time_reduced_outputs) == 0: + continue - # reduce across training steps - opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) + # reduce across training steps + opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs) - # with manual opt need 1 + metrics because meta is always there - if opt_outputs.minimize is not None: - opt_outputs.minimize = opt_outputs.minimize.mean() + # with manual opt need 1 + metrics because meta is always there + if opt_outputs.minimize is not None: + opt_outputs.minimize = opt_outputs.minimize.mean() - self._internals_reduced[dl_idx][str(opt_idx)] = opt_outputs + self._internals_reduced[dl_idx][opt_idx] = opt_outputs - # free memory - del self._internals[dl_idx][opt_idx] + # free memory + del self._internals[dl_idx][opt_idx] + else: + # no need to reduce as called only once + if len(epoch_metrics) == 1: + reduced_epoch_metrics = epoch_metrics[0] else: - # no need to reduce as called only once - if len(epoch_metrics) == 1: - reduced_epoch_metrics = epoch_metrics[0] - else: - reduced_epoch_metrics = epoch_metrics[0].__class__.reduce_on_epoch_end(epoch_metrics) + reduced_epoch_metrics = epoch_metrics[0].__class__.reduce_on_epoch_end(epoch_metrics) - self._internals_reduced[dl_idx] = reduced_epoch_metrics + self._internals_reduced[dl_idx] = reduced_epoch_metrics - # free memory - del self._internals[dl_idx] + # free memory + del self._internals[dl_idx] - self.has_reduced = True + self.has_reduced = True def __getitem__(self, key: str) -> Any: - try: - if key in self._internals: - return self._internals[key] - return self[key] - except KeyError: - return None + return self._internals.get(key, None) def __repr__(self): return self._internals.__repr__() @@ -314,36 +261,28 @@ class EpochResultStore: epoch_result_store.cache_result() ``` """ + def __init__(self, trainer, stage): self.trainer = trainer self._stage = stage self.reset() def __getitem__(self, key: str) -> Any: - try: - if key in self._internals: - return self._internals[key] - return None - except KeyError: - return None + return self._internals.get(key, None) @property def has_split_and_opt_idx(self): """ This function informs if we are running within training batch loop """ - if self._split_idx is not None and self._opt_idx is not None: - return True - return False + return self._split_idx is not None and self._opt_idx is not None @property def extra_info(self): """ This function provides necessary parameters to properly configure HookResultStore obj """ - return {"batch_idx": self.trainer.batch_idx, - "split_idx": self._split_idx, - "opt_idx": self._opt_idx} + return {"batch_idx": self.trainer.batch_idx, "split_idx": self._split_idx, "opt_idx": self._opt_idx} def reset_model(self): """ @@ -361,9 +300,7 @@ def current_model_info(self): """ model_ref = self.trainer.get_model() # extract hook information - fx_name = model_ref._current_hook_fx_name - if fx_name is None: - fx_name = model_ref._current_fx_name + fx_name = model_ref._current_hook_fx_name or model_ref._current_fx_name dataloader_idx = model_ref._current_dataloader_idx return fx_name, dataloader_idx @@ -386,15 +323,9 @@ def cache_result(self) -> None: # extract model information fx_name, dataloader_idx = self.current_model_info() - # add only if anything as been logged - # default len is 1 due to _internals + self._internals.setdefault(fx_name, HookResultStore(fx_name)) - if fx_name not in self._internals: - self._internals[fx_name] = HookResultStore(fx_name) - - extra_info = {} - if self.has_split_and_opt_idx: - extra_info = self.extra_info + extra_info = self.extra_info if self.has_split_and_opt_idx else {} # attach capture batch_size Result.attach_batch_size(self._batch_size, hook_result) @@ -402,18 +333,19 @@ def cache_result(self) -> None: hook_result.detach() if self.trainer.move_metrics_to_cpu: hook_result.cpu() + elif self.trainer.use_dp: + hook_result.to(torch.device("cuda", self.trainer.root_gpu)) - self._internals[fx_name].append( - hook_result, - dataloader_idx=dataloader_idx, - extra_info=extra_info) + self._internals[fx_name].append(hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info) # update logged_metrics, progress_bar_metrics, callback_metrics - self.update_logger_connector(fx_name) + + if "epoch_end" in fx_name: + self.update_logger_connector() self.reset_model() - def update_logger_connector(self, fx_name: str = None) -> None: + def update_logger_connector(self) -> None: """ This function is called every time we capture a hook It automatically updates the logger_connector followings: @@ -425,23 +357,22 @@ def update_logger_connector(self, fx_name: str = None) -> None: logger_connector = self.trainer.logger_connector callback_metrics = {} + batch_pbar_metrics = {} + batch_log_metrics = {} is_train = self._stage in LoggerStages.TRAIN.value if not self._has_batch_loop_finished: # get pbar batch_pbar_metrics = self.get_latest_batch_pbar_metrics() logger_connector.add_progress_bar_metrics(batch_pbar_metrics) + batch_log_metrics = self.get_latest_batch_log_metrics() if is_train: # Only log and add to callback epoch step during evaluation, test. - batch_log_metrics = self.get_latest_batch_log_metrics() logger_connector.logged_metrics.update(batch_log_metrics) - callback_metrics.update(batch_pbar_metrics) callback_metrics.update(batch_log_metrics) else: - epoch_dict = {"epoch": self.trainer.current_epoch} - # get pbar epoch_pbar_metrics = self.get_epoch_pbar_metrics() logger_connector.add_progress_bar_metrics(epoch_pbar_metrics) @@ -449,7 +380,7 @@ def update_logger_connector(self, fx_name: str = None) -> None: # get logged_metrics epoch_log_metrics = self.get_epoch_log_metrics() logger_connector.logged_metrics.update(epoch_log_metrics) - logger_connector.logged_metrics.update(epoch_dict) + logger_connector.logged_metrics.update(epoch=self.trainer.current_epoch) # get forked_metrics forked_metrics = self.get_forked_metrics() @@ -465,12 +396,13 @@ def update_logger_connector(self, fx_name: str = None) -> None: logger_connector.callback_metrics.update(callback_metrics) logger_connector.callback_metrics.pop("epoch", None) + batch_pbar_metrics.pop("debug_epoch", None) + return batch_pbar_metrics, batch_log_metrics + def run_batch_from_func_name(self, func_name) -> Dict: - results = [] - for fx_name, hook_result in self._internals.items(): - func = getattr(hook_result, func_name) - results.append(func(include_forked_originals=False)) - return dict(ChainMap(*sum(results, []))) + results = [getattr(hook_result, func_name) for hook_result in self._internals.values()] + results = [func(include_forked_originals=False) for func in results] + return {k: v for d in sum(results, []) for k, v in d.items()} # List[List[dict]] -> dict def get_latest_batch_log_metrics(self) -> Dict: batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics") @@ -485,11 +417,11 @@ def get_latest_batch_pbar_metrics(self) -> Dict: @property def has_reduced(self) -> bool: hook_results = self._internals.values() - return len(hook_results) == sum([h.has_reduced for h in hook_results]) + return len(hook_results) == sum(h.has_reduced for h in hook_results) def auto_reduce_results_on_epoch_end(self) -> None: if not self.has_reduced: - for fx_name, hook_result in self._internals.items(): + for hook_result in self._internals.values(): hook_result.auto_reduce_results_on_epoch_end() @property @@ -511,11 +443,9 @@ def has_batch_loop_finished(self, has_batch_loop_finished): def run_epoch_by_func_name(self, func_name) -> Dict: if not self.has_reduced: self.auto_reduce_results_on_epoch_end() - results = [] - for fx_name, hook_result in self._internals.items(): - func = getattr(hook_result, func_name) - results.append(func()) - return dict(ChainMap(*sum(results, []))) + results = [getattr(hook_result, func_name) for hook_result in self._internals.values()] + results = [func() for func in results] + return {k: v for d in sum(results, []) for k, v in d.items()} # List[List[dict]] -> dict def get_epoch_pbar_metrics(self) -> Dict: return self.run_epoch_by_func_name("get_epoch_pbar_metrics") @@ -528,22 +458,22 @@ def get_forked_metrics(self) -> Dict: def reset(self): self._internals = {} - self._dataloader_idx: Union[int, None] = None - self._split_idx: Union[int, None] = None - self._opt_idx: Union[int, None] = None - self._batch_size: Union[int, None] = None + self._dataloader_idx: Optional[int] = None + self._split_idx: Optional[int] = None + self._opt_idx: Optional[int] = None + self._batch_size: Optional[int] = None self._has_batch_loop_finished = False self.legacy_batch_log_metrics = {} self.legacy_batch_pbar_metrics = {} def __call__( - self, - fx_name: Optional[Union[str, int]] = None, - dl_idx: Optional[Union[str, int]] = None, - opt_idx: Optional[Union[str, int]] = None, - batch_idx: Optional[Union[str, int]] = None, - split_idx: Optional[Union[str, int]] = None, - reduced: bool = False, + self, + fx_name: str, + dl_idx: Optional[int] = None, + opt_idx: Optional[int] = None, + batch_idx: Optional[int] = None, + split_idx: Optional[int] = None, + reduced: bool = False, ): """ This function is an helper to access stored data @@ -586,42 +516,21 @@ def __call__( reduced: Data are being aggregated on on_epoch_end. Indicates if we want to access aggregated Result or not. """ - - hook_result = self[str(fx_name)] - - dl_idx = str(dl_idx) if dl_idx is not None else None - opt_idx = str(opt_idx) if opt_idx is not None else None - batch_idx = str(batch_idx) if batch_idx is not None else None - split_idx = int(split_idx) if split_idx is not None else None - + hook_result = self[fx_name] internal_type = hook_result._internal_type - - if reduced: - result = hook_result._internals_reduced - else: - result = hook_result._internals - - if internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: - if not reduced: - if dl_idx is not None: - result = result[dl_idx] - if opt_idx is not None: - result = result[opt_idx] - if batch_idx is not None: - result = result[batch_idx] - if split_idx is not None: - result = result[split_idx] - else: - if dl_idx is not None: - result = result[dl_idx] - if opt_idx is not None: - result = result[opt_idx] - else: - if dl_idx is not None: - result = result[dl_idx] - if batch_idx and not reduced: - result = result[batch_idx] - + result = hook_result._internals_reduced if reduced else hook_result._internals + + if dl_idx is not None: + result = result[dl_idx] + if internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP: + if opt_idx is not None: + result = result[opt_idx] + if not reduced and batch_idx is not None: + result = result[batch_idx] + if split_idx is not None: + result = result[split_idx] + elif not reduced and batch_idx is not None: + result = result[batch_idx] return result def __repr__(self): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index cab08edd58531d..851a48e01434d3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from collections import ChainMap from copy import deepcopy from pprint import pprint -from typing import Iterable, Union, cast +from typing import Iterable, Union import torch @@ -23,18 +22,13 @@ from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator -from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import ( - LOOKUP_TABLE, - EpochResultStore, - LoggerStages, -) +from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages from pytorch_lightning.utilities import flatten_dict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_utils import is_overridden class LoggerConnector: - def __init__(self, trainer): self.trainer = trainer self.callback_metrics = {} @@ -42,24 +36,23 @@ def __init__(self, trainer): self.logged_metrics = {} self.progress_bar_metrics = {} self.eval_loop_results = [] - self._stages = sorted([s.value for s in LoggerStages]) - self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in self._stages} + self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in LoggerStages} self._callback_hook_validator = CallbackHookNameValidator() self._current_stage = None @property def cached_results(self) -> Union[EpochResultStore, None]: - return self._cached_results[self._current_stage] + return self._cached_results.get(self._current_stage) - def set_stage(self, stage_or_testing: str, reset:bool = False) -> None: - self._current_stage = self._determine_stage(stage_or_testing) + def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None: + self._current_stage = LoggerStages.determine_stage(stage_or_testing) if reset: self.cached_results.reset() def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None: - self._callback_hook_validator.check_logging_in_callbacks(current_hook_fx_name=hook_fx_name, - on_step=on_step, - on_epoch=on_epoch) + self._callback_hook_validator.check_logging_in_callbacks( + current_hook_fx_name=hook_fx_name, on_step=on_step, on_epoch=on_epoch + ) def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders): model = self.trainer.get_model() @@ -78,25 +71,11 @@ def on_train_batch_end(self) -> None: self.cached_results._opt_idx = None self.cached_results._batch_size = None - def _determine_stage(self, stage_or_testing: Union[str, bool]) -> str: - stage_or_testing = str(stage_or_testing) - stages = self._stages - if stage_or_testing in stages: - return stage_or_testing - if stage_or_testing in LOOKUP_TABLE: - # Acces using trainer.testing - return LOOKUP_TABLE[stage_or_testing] - raise MisconfigurationException( - f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {stages}" - f" or {LOOKUP_TABLE.keys()}" - ) - - def cache_logged_metrics(self) -> Union[EpochResultStore, None]: + def cache_logged_metrics(self): if self._current_stage is not None: self._cached_results[self._current_stage].cache_result() - def on_trainer_init(self, logger, flush_logs_every_n_steps: int, - log_every_n_steps: int, move_metrics_to_cpu: bool): + def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): # logging self.configure_logger(logger) # todo: IDE is complaining, these shall be initialized in the Trainer init at leas as placeholders @@ -122,9 +101,7 @@ def configure_logger(self, logger): # default logger self.trainer.logger = TensorBoardLogger( - save_dir=self.trainer.default_root_dir, - version=version, - name='lightning_logs' + save_dir=self.trainer.default_root_dir, version=version, name='lightning_logs' ) elif logger is False: self.trainer.logger = None @@ -289,10 +266,7 @@ def get_evaluate_epoch_results(self, test_mode): return results def _track_callback_metrics(self, eval_results, using_eval_result): - if ( - len(eval_results) > 0 and - (eval_results[0] is None or not isinstance(eval_results[0], Result)) - ): + if len(eval_results) > 0 and (eval_results[0] is None or not isinstance(eval_results[0], Result)): return if using_eval_result: @@ -379,20 +353,22 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result) if num_loaders > 1: - self.__process_eval_epoch_end_results_and_log_legacy_update(prog_bar_metrics, log_metrics, callback_metrics) + self.__process_eval_epoch_end_results_and_log_legacy_update( + prog_bar_metrics, log_metrics, callback_metrics + ) if num_loaders == 1: - self.__process_eval_epoch_end_results_and_log_legacy_update(prog_bar_metrics, log_metrics, callback_metrics) + self.__process_eval_epoch_end_results_and_log_legacy_update( + prog_bar_metrics, log_metrics, callback_metrics + ) def on_train_epoch_end(self): # inform cached logger connector epoch finished self.cached_results.has_batch_loop_finished = True - def log_train_epoch_end_metrics(self, - epoch_output, - checkpoint_accumulator, - early_stopping_accumulator, - num_optimizers): + def log_train_epoch_end_metrics( + self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers + ): # epoch output is a list. Each item in that list has all the outputs per optimizer # epoch_output[optimizer_idx][training_step_idx][tbptt_index] # remember that not using truncated backprop is equivalent with truncated back prop of len(1) @@ -438,11 +414,7 @@ def log_train_epoch_end_metrics(self, # TODO: deprecate 1.0 else: out = self.__run_legacy_training_epoch_end( - num_optimizers, - epoch_output, - model, - is_result_obj, - epoch_callback_metrics + num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics ) epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out @@ -489,18 +461,15 @@ def training_epoch_end(self, model, epoch_output, num_optimizers): epoch_output = model.training_epoch_end(epoch_output) if epoch_output is not None: - raise MisconfigurationException('training_epoch_end expects a return of None. ' - 'HINT: remove the return statement in training_epoch_end') + raise MisconfigurationException( + 'training_epoch_end expects a return of None. ' + 'HINT: remove the return statement in training_epoch_end' + ) # capture logging self.trainer.logger_connector.cache_logged_metrics() def __run_legacy_training_epoch_end( - self, - num_optimizers, - epoch_output, - model, - is_result_obj, - epoch_callback_metrics + self, num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics ): epoch_log_metrics = {} @@ -618,11 +587,13 @@ def __gather_result_across_time_and_optimizers(self, epoch_output): return gathered_epoch_outputs def log_train_step_metrics(self, batch_output): + _, batch_log_metrics = self.cached_results.update_logger_connector() # when metrics should be logged if self.should_update_logs or self.trainer.fast_dev_run: # logs user requested information to logger - metrics = self.cached_results.get_latest_batch_log_metrics() grad_norm_dic = batch_output.grad_norm_dic - if len(metrics) > 0 or len(grad_norm_dic) > 0: - self.log_metrics(metrics, grad_norm_dic, log_train_step_metrics=True) - self.callback_metrics.update(metrics) + if grad_norm_dic is None: + grad_norm_dic = {} + if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0: + self.log_metrics(batch_log_metrics, grad_norm_dic, log_train_step_metrics=True) + self.callback_metrics.update(batch_log_metrics) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index a15e9bba2af631..4a7b14d0b1fe97 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -16,14 +16,14 @@ import platform from abc import ABC from copy import deepcopy -from typing import Callable, Iterable, List, Optional, Tuple, Union +from typing import Union, List, Tuple, Callable, Optional, Iterable from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE, HOROVOD_AVAILABLE from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -32,12 +32,8 @@ if TPU_AVAILABLE: import torch_xla.core.xla_model as xm -try: +if HOROVOD_AVAILABLE: import horovod.torch as hvd -except (ModuleNotFoundError, ImportError): - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True class TrainerDataLoadingMixin(ABC): diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 097727a6bed786..4b70917c8c43db 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -106,9 +106,9 @@ def on_evaluation_model_train(self, *args, **kwargs): def on_evaluation_end(self, *args, **kwargs): if self.testing: - self.trainer.call_hook('on_test_end', *args, capture=True, **kwargs) + self.trainer.call_hook('on_test_end', *args, **kwargs) else: - self.trainer.call_hook('on_validation_end', *args, capture=True, **kwargs) + self.trainer.call_hook('on_validation_end', *args, **kwargs) def reload_evaluation_dataloaders(self): model = self.trainer.get_model() @@ -329,9 +329,9 @@ def store_predictions(self, output, batch_idx, dataloader_idx): def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook if self.testing: - self.trainer.call_hook('on_test_epoch_end', *args, capture=True, **kwargs) + self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: - self.trainer.call_hook('on_validation_epoch_end', *args, capture=True, **kwargs) + self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.running_sanity_check: @@ -346,10 +346,8 @@ def log_evaluation_step_metrics(self, output, batch_idx): self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx) def __log_result_step_metrics(self, step_log_metrics, step_pbar_metrics, batch_idx): - cached_batch_log_metrics = \ - self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics() - cached_batch_pbar_metrics = \ - self.trainer.logger_connector.cached_results.get_latest_batch_pbar_metrics() + cached_results = self.trainer.logger_connector.cached_results + cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector() step_log_metrics.update(cached_batch_log_metrics) step_pbar_metrics.update(cached_batch_pbar_metrics) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 012d9b3a6fd5e9..92e3b6af2e1fd9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -206,10 +206,9 @@ def __init__( log_every_n_steps: How often to log within steps (defaults to every 50 steps). - automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad. - If False you are responsible for calling .backward, .step, zero_grad in LightningModule. - This argument has been moved to LightningModule. It is deprecated here in v1.1 and - will be removed in v1.3. + automatic_optimization: If False you are responsible for calling .backward, .step, zero_grad + in LightningModule. This argument has been moved to LightningModule. It is deprecated + here in v1.1 and will be removed in v1.3. prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data @@ -505,11 +504,9 @@ def train(self): # hook self.train_loop.on_train_start() - if self.train_loop.should_skip_training(): - self.train_loop.on_train_end() - return - try: + if self.train_loop.should_skip_training(): + return # run all epochs for epoch in range(self.current_epoch, self.max_epochs): @@ -521,9 +518,6 @@ def train(self): self.train_loop.run_training_epoch() if self.max_steps and self.max_steps <= self.global_step: - - # hook - self.train_loop.on_train_end() return # update LR schedulers @@ -535,7 +529,6 @@ def train(self): if self.should_stop: if met_min_epochs and met_min_steps: - self.train_loop.on_train_end() return log.info( 'Trainer was signaled to stop but required minimum epochs' @@ -543,9 +536,6 @@ def train(self): ' not been met. Training will continue...' ) - # hook - self.train_loop.on_train_end() - except KeyboardInterrupt: rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') @@ -554,9 +544,9 @@ def train(self): self.interrupted = True self._state = TrainerState.INTERRUPTED self.on_keyboard_interrupt() - - # hook - self.train_loop.on_train_end() + finally: + # hook + self.train_loop.on_train_end() def run_evaluation(self, test_mode: bool = False, max_batches=None): @@ -865,6 +855,8 @@ def call_setup_hook(self, model): model.setup(stage_name) def _reset_result_and_set_hook_fx_name(self, hook_name): + if "batch_start" in hook_name: + return True model_ref = self.get_model() if model_ref is not None: # used to track current hook name called @@ -878,10 +870,9 @@ def _cache_logged_metrics(self): # capture logging for this hook self.logger_connector.cache_logged_metrics() - def call_hook(self, hook_name, *args, capture=False, **kwargs): + def call_hook(self, hook_name, *args, **kwargs): # set hook_name to model + reset Result obj - if capture: - self._reset_result_and_set_hook_fx_name(hook_name) + skip = self._reset_result_and_set_hook_fx_name(hook_name) # always profile hooks with self.profiler.profile(hook_name): @@ -904,7 +895,7 @@ def call_hook(self, hook_name, *args, capture=False, **kwargs): accelerator_hook = getattr(self.accelerator_backend, hook_name) output = accelerator_hook(*args, **kwargs) - if capture: + if not skip: self._cache_logged_metrics() return output diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9a4f324033d399..679f59c05e7c43 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -825,8 +825,8 @@ def run_on_epoch_end_hook(self, epoch_output): # inform logger the batch loop has finished self.trainer.logger_connector.on_train_epoch_end() - self.trainer.call_hook('on_epoch_end', capture=True) - self.trainer.call_hook('on_train_epoch_end', epoch_output, capture=True) + self.trainer.call_hook('on_epoch_end') + self.trainer.call_hook('on_train_epoch_end', epoch_output) def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 916e434e5ff06d..1e2eeea9f456c5 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -47,6 +47,7 @@ def _module_available(module_path: str) -> bool: NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") OMEGACONF_AVAILABLE = _module_available("omegaconf") HYDRA_AVAILABLE = _module_available("hydra") +HOROVOD_AVAILABLE = _module_available("horovod.torch") TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 7dc9c90e16dbd4..e94934020107d4 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -25,8 +25,8 @@ def load(path_or_url: Union[str, IO, Path], map_location=None): if not isinstance(path_or_url, (str, Path)): # any sort of BytesIO or similiar return torch.load(path_or_url, map_location=map_location) - if path_or_url.startswith("http"): - return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) + if str(path_or_url).startswith("http"): + return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location) fs = get_filesystem(path_or_url) with fs.open(path_or_url, "rb") as f: return torch.load(f, map_location=map_location) diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index f7b9e79b7f9325..9264e2a49810d7 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -87,7 +87,7 @@ def track_load_dataloader_call(self, name, dataloaders): for dl in dataloaders: try: length = len(dl) - except Exception as e: + except Exception: length = -1 lengths.append(length) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 98d322ce0a3a2f..ffa1be87cd3ca6 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -89,6 +89,9 @@ def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None) if group is None: group = torch.distributed.group.WORLD + # convert tensors to contiguous format + result = result.contiguous() + world_size = torch.distributed.get_world_size(group) gathered_result = [torch.zeros_like(result) for _ in range(world_size)] diff --git a/tests/backends/ddp_model.py b/tests/backends/ddp_model.py index b625d8cc985fc1..32b30c05538be7 100644 --- a/tests/backends/ddp_model.py +++ b/tests/backends/ddp_model.py @@ -14,20 +14,24 @@ """ Runs either `.fit()` or `.test()` on a single node across multiple gpus. """ +import os from argparse import ArgumentParser +import tests as pl_tests from pytorch_lightning import Trainer, seed_everything from tests.base import EvalModelTemplate -import os + import torch def main(): seed_everything(1234) + parser = ArgumentParser(add_help=False) parser = Trainer.add_argparse_args(parser) parser.add_argument('--trainer_method', default='fit') parser.add_argument('--tmpdir') + parser.add_argument('--workdir') parser.set_defaults(gpus=2) parser.set_defaults(distributed_backend="ddp") args = parser.parse_args() @@ -38,14 +42,26 @@ def main(): result = {} if args.trainer_method == 'fit': trainer.fit(model) - result = {'status': 'complete', 'method': args.trainer_method, 'result': None} + result = { + 'status': 'complete', + 'method': args.trainer_method, + 'result': None + } if args.trainer_method == 'test': result = trainer.test(model) - result = {'status': 'complete', 'method': args.trainer_method, 'result': result} + result = { + 'status': 'complete', + 'method': args.trainer_method, + 'result': result + } if args.trainer_method == 'fit_test': trainer.fit(model) result = trainer.test(model) - result = {'status': 'complete', 'method': args.trainer_method, 'result': result} + result = { + 'status': 'complete', + 'method': args.trainer_method, + 'result': result + } if len(result) > 0: file_path = os.path.join(args.tmpdir, 'ddp.result') diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index abe21b9d28e25e..b7711e8aae3feb 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -13,15 +13,17 @@ # limitations under the License. import os import pickle -from unittest import mock import cloudpickle +import numpy as np import pytest import torch +from unittest import mock +from pytorch_lightning import _logger from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from tests.base import EvalModelTemplate +from tests.base import EvalModelTemplate, BoringModel from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -204,3 +206,92 @@ def validation_epoch_end(self, outputs): ) trainer.fit(model) assert trainer.current_epoch >= 5, 'early_stopping failed' + + +@pytest.mark.parametrize('step_freeze, min_steps, min_epochs',[(5, 1, 1), (5, 1, 3), (3, 15, 1)]) +def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, min_steps, min_epochs): + """Excepted Behaviour: + IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being triggered, + THEN the trainer should continue until reaching `trainer.global_step` == `min_steps`, and stop. + + IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is being triggered, + THEN the trainer should continue until reaching `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop. + This test validate this expected behaviour + + IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when `early_stopping` is being triggered, + THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached. + + Caviat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader) + + This test validate those expected behaviours + """ + + _logger.disabled = True + + original_loss_value = 10 + limit_train_batches = 3 + patience = 3 + + class Model(BoringModel): + + def __init__(self, step_freeze): + super(Model, self).__init__() + + self._step_freeze = step_freeze + + self._loss_value = 10.0 + self._eps = 1e-1 + self._count_decrease = 0 + self._values = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + return {"test_val_loss": self._loss_value} + + def validation_epoch_end(self, outputs): + _mean = np.mean([x['test_val_loss'] for x in outputs]) + if self.trainer.global_step <= self._step_freeze: + self._count_decrease += 1 + self._loss_value -= self._eps + self._values.append(_mean) + return {"test_val_loss": _mean} + + model = Model(step_freeze) + model.training_step_end = None + model.test_dataloader = None + early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[early_stop_callback], + limit_train_batches=limit_train_batches, + limit_val_batches=2, + min_steps=min_steps, + min_epochs=min_epochs + ) + trainer.fit(model) + + # Make sure loss was properly decreased + assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6 + + pos_diff = (np.diff(model._values) == 0).nonzero()[0][0] + + # Compute when the latest validation epoch end happened + latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches + if pos_diff % limit_train_batches == 0: + latest_validation_epoch_end += limit_train_batches + + # Compute early stopping latest step + by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience + + # Compute min_epochs latest step + by_min_epochs = min_epochs * limit_train_batches + + # Make sure the trainer stops for the max of all minimun requirements + assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \ + (trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs) + + _logger.disabled = False diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 33bc19a894d8f9..6d1d3edea5be98 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -261,6 +261,29 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt' +class ModelCheckpointExtensionTest(ModelCheckpoint): + FILE_EXTENSION = '.tpkc' + + +def test_model_checkpoint_file_extension(tmpdir): + """ + Test ModelCheckpoint with different file extension. + """ + + model = LogInTwoMethods() + model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[model_checkpoint], + max_steps=1, + logger=False, + ) + trainer.fit(model) + + expected = ['epoch=0-step=0.tpkc', 'last.tpkc'] + assert set(expected) == set(os.listdir(tmpdir)) + + def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" seed_everything() diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 4b270927a43d39..3d89c8cd85310b 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -20,6 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection +from pytorch_lightning.loggers.base import DummyLogger, DummyExperiment from pytorch_lightning.utilities import rank_zero_only from tests.base import EvalModelTemplate @@ -215,6 +216,16 @@ def log_metrics(self, metrics, step): assert logger.history == {0: {'loss': 0.5623850983416314}, 1: {'loss': 0.4778883735637184}} +def test_dummyexperiment_support_indexing(): + experiment = DummyExperiment() + assert experiment[0] == experiment + + +def test_dummylogger_support_indexing(): + logger = DummyLogger() + assert logger[0] == logger + + def test_np_sanitization(): class CustomParamsLogger(CustomLogger): def __init__(self): diff --git a/tests/metrics/classification/test_average_precision.py b/tests/metrics/classification/test_average_precision.py new file mode 100644 index 00000000000000..e4492349f3272f --- /dev/null +++ b/tests/metrics/classification/test_average_precision.py @@ -0,0 +1,102 @@ +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import average_precision_score as _sk_average_precision_score + +from pytorch_lightning.metrics.classification.average_precision import AveragePrecision +from pytorch_lightning.metrics.functional.average_precision import average_precision +from tests.metrics.classification.inputs import ( + _binary_prob_inputs, + _multiclass_prob_inputs, + _multidim_multiclass_prob_inputs, +) +from tests.metrics.utils import NUM_CLASSES, MetricTester + +torch.manual_seed(42) + + +def sk_average_precision_score(y_true, probas_pred, num_classes=1): + if num_classes == 1: + return _sk_average_precision_score(y_true, probas_pred) + + res = [] + for i in range(num_classes): + y_true_temp = np.zeros_like(y_true) + y_true_temp[y_true == i] = 1 + res.append(_sk_average_precision_score(y_true_temp, probas_pred[:, i])) + return res + + +def _binary_prob_sk_metric(preds, target, num_classes=1): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _multiclass_prob_sk_metric(preds, target, num_classes=1): + sk_preds = preds.reshape(-1, num_classes).numpy() + sk_target = target.view(-1).numpy() + + return sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _multidim_multiclass_prob_sk_metric(preds, target, num_classes=1): + sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() + sk_target = target.view(-1).numpy() + return sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1), + ( + _multiclass_prob_inputs.preds, + _multiclass_prob_inputs.target, + _multiclass_prob_sk_metric, + NUM_CLASSES), + ( + _multidim_multiclass_prob_inputs.preds, + _multidim_multiclass_prob_inputs.target, + _multidim_multiclass_prob_sk_metric, + NUM_CLASSES + ), +]) +class TestAveragePrecision(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=AveragePrecision, + sk_metric=partial(sk_metric, num_classes=num_classes), + dist_sync_on_step=dist_sync_on_step, + metric_args={"num_classes": num_classes} + ) + + def test_average_precision_functional(self, preds, target, sk_metric, num_classes): + self.run_functional_metric_test( + preds, + target, + metric_functional=average_precision, + sk_metric=partial(sk_metric, num_classes=num_classes), + metric_args={"num_classes": num_classes}, + ) + + +@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [ + # Check the average_precision_score of a constant predictor is + # the TPR + # Generate a dataset with 25% of positives + # And a constant score + # The precision is then the fraction of positive whatever the recall + # is, as there is only one threshold: + pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25), + # With threshold 0.8 : 1 TP and 2 TN and one FN + pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75), +]) +def test_average_precision(scores, target, expected_score): + assert average_precision(scores, target) == expected_score diff --git a/tests/metrics/classification/test_precision_recall_curve.py b/tests/metrics/classification/test_precision_recall_curve.py new file mode 100644 index 00000000000000..07e942c5b10f4f --- /dev/null +++ b/tests/metrics/classification/test_precision_recall_curve.py @@ -0,0 +1,104 @@ +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import precision_recall_curve as _sk_precision_recall_curve + +from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve +from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve +from tests.metrics.classification.inputs import ( + _binary_prob_inputs, + _multiclass_prob_inputs, + _multidim_multiclass_prob_inputs, +) +from tests.metrics.utils import NUM_CLASSES, MetricTester + +torch.manual_seed(42) + + +def sk_precision_recall_curve(y_true, probas_pred, num_classes=1): + """ Adjusted comparison function that can also handles multiclass """ + if num_classes == 1: + return _sk_precision_recall_curve(y_true, probas_pred) + + precision, recall, thresholds = [], [], [] + for i in range(num_classes): + y_true_temp = np.zeros_like(y_true) + y_true_temp[y_true == i] = 1 + res = _sk_precision_recall_curve(y_true_temp, probas_pred[:, i]) + precision.append(res[0]) + recall.append(res[1]) + thresholds.append(res[2]) + return precision, recall, thresholds + + +def _binary_prob_sk_metric(preds, target, num_classes=1): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _multiclass_prob_sk_metric(preds, target, num_classes=1): + sk_preds = preds.reshape(-1, num_classes).numpy() + sk_target = target.view(-1).numpy() + + return sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _multidim_multiclass_prob_sk_metric(preds, target, num_classes=1): + sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() + sk_target = target.view(-1).numpy() + return sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1), + ( + _multiclass_prob_inputs.preds, + _multiclass_prob_inputs.target, + _multiclass_prob_sk_metric, + NUM_CLASSES), + ( + _multidim_multiclass_prob_inputs.preds, + _multidim_multiclass_prob_inputs.target, + _multidim_multiclass_prob_sk_metric, + NUM_CLASSES + ), +]) +class TestPrecisionRecallCurve(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=PrecisionRecallCurve, + sk_metric=partial(sk_metric, num_classes=num_classes), + dist_sync_on_step=dist_sync_on_step, + metric_args={"num_classes": num_classes} + ) + + def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes): + self.run_functional_metric_test( + preds, + target, + metric_functional=precision_recall_curve, + sk_metric=partial(sk_metric, num_classes=num_classes), + metric_args={"num_classes": num_classes}, + ) + + +@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [ + pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4]) +]) +def test_pr_curve(pred, target, expected_p, expected_r, expected_t): + p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target)) + assert p.size() == r.size() + assert p.size(0) == t.size(0) + 1 + + assert torch.allclose(p, torch.tensor(expected_p).to(p)) + assert torch.allclose(r, torch.tensor(expected_r).to(r)) + assert torch.allclose(t, torch.tensor(expected_t).to(t)) diff --git a/tests/metrics/classification/test_roc.py b/tests/metrics/classification/test_roc.py new file mode 100644 index 00000000000000..c3db4ed769221f --- /dev/null +++ b/tests/metrics/classification/test_roc.py @@ -0,0 +1,107 @@ +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import roc_curve as _sk_roc_curve + +from pytorch_lightning.metrics.classification.roc import ROC +from pytorch_lightning.metrics.functional.roc import roc +from tests.metrics.classification.inputs import ( + _binary_prob_inputs, + _multiclass_prob_inputs, + _multidim_multiclass_prob_inputs, +) +from tests.metrics.utils import NUM_CLASSES, MetricTester + +torch.manual_seed(42) + + +def sk_roc_curve(y_true, probas_pred, num_classes=1): + """ Adjusted comparison function that can also handles multiclass """ + if num_classes == 1: + return _sk_roc_curve(y_true, probas_pred, drop_intermediate=False) + + fpr, tpr, thresholds = [], [], [] + for i in range(num_classes): + y_true_temp = np.zeros_like(y_true) + y_true_temp[y_true == i] = 1 + res = _sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False) + fpr.append(res[0]) + tpr.append(res[1]) + thresholds.append(res[2]) + return fpr, tpr, thresholds + + +def _binary_prob_sk_metric(preds, target, num_classes=1): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _multiclass_prob_sk_metric(preds, target, num_classes=1): + sk_preds = preds.reshape(-1, num_classes).numpy() + sk_target = target.view(-1).numpy() + + return sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +def _multidim_multiclass_prob_sk_metric(preds, target, num_classes=1): + sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() + sk_target = target.view(-1).numpy() + return sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) + + +@pytest.mark.parametrize("preds, target, sk_metric, num_classes", [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1), + ( + _multiclass_prob_inputs.preds, + _multiclass_prob_inputs.target, + _multiclass_prob_sk_metric, + NUM_CLASSES), + ( + _multidim_multiclass_prob_inputs.preds, + _multidim_multiclass_prob_inputs.target, + _multidim_multiclass_prob_sk_metric, + NUM_CLASSES + ), +]) +class TestROC(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=ROC, + sk_metric=partial(sk_metric, num_classes=num_classes), + dist_sync_on_step=dist_sync_on_step, + metric_args={"num_classes": num_classes} + ) + + def test_roc_functional(self, preds, target, sk_metric, num_classes): + self.run_functional_metric_test( + preds, + target, + metric_functional=roc, + sk_metric=partial(sk_metric, num_classes=num_classes), + metric_args={"num_classes": num_classes}, + ) + + +@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [ + pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), + pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), + pytest.param([1, 1], [1, 0], [0, 1], [0, 1]), + pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), + pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]), +]) +def test_roc_curve(pred, target, expected_tpr, expected_fpr): + fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target)) + + assert fpr.shape == tpr.shape + assert fpr.size(0) == thresh.size(0) + assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr)) + assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr)) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 307aeea1f9ac14..f7bd7d558f5b4f 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -7,18 +7,11 @@ jaccard_score as sk_jaccard_score, precision_score as sk_precision, recall_score as sk_recall, - f1_score as sk_f1_score, - fbeta_score as sk_fbeta_score, - roc_curve as sk_roc_curve, roc_auc_score as sk_roc_auc_score, - precision_recall_curve as sk_precision_recall_curve ) from pytorch_lightning import seed_everything from pytorch_lightning.metrics.functional.classification import ( - to_onehot, - to_categorical, - get_num_classes, stat_scores, stat_scores_multiple_classes, accuracy, @@ -26,14 +19,12 @@ recall, _binary_clf_curve, dice_score, - average_precision, auroc, multiclass_auroc, - precision_recall_curve, - roc, auc, iou, ) +from pytorch_lightning.metrics.utils import to_onehot, get_num_classes, to_categorical @pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [ @@ -41,8 +32,6 @@ pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'), pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'), pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'), - pytest.param(sk_roc_curve, roc, True, id='roc'), - pytest.param(sk_precision_recall_curve, precision_recall_curve, True, id='precision_recall_curve'), pytest.param(sk_roc_auc_score, auroc, True, id='auroc') ]) def test_against_sklearn(sklearn_metric, torch_metric, only_binary): @@ -243,35 +232,6 @@ def test_binary_clf_curve(sample_weight, pos_label, exp_shape): assert thresh.shape == (exp_shape,) -@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [ - pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4]) -]) -def test_pr_curve(pred, target, expected_p, expected_r, expected_t): - p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target)) - assert p.size() == r.size() - assert p.size(0) == t.size(0) + 1 - - assert torch.allclose(p, torch.tensor(expected_p).to(p)) - assert torch.allclose(r, torch.tensor(expected_r).to(r)) - assert torch.allclose(t, torch.tensor(expected_t).to(t)) - - -@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [ - pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), - pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), - pytest.param([1, 1], [1, 0], [0, 1], [0, 1]), - pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), - pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]), -]) -def test_roc_curve(pred, target, expected_tpr, expected_fpr): - fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target)) - - assert fpr.shape == tpr.shape - assert fpr.size(0) == thresh.size(0) - assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr)) - assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr)) - - @pytest.mark.parametrize(['pred', 'target', 'expected'], [ pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.), pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.), @@ -337,21 +297,6 @@ def test_auc(x, y, expected): assert auc(torch.tensor(x), torch.tensor(y)) == expected -@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [ - # Check the average_precision_score of a constant predictor is - # the TPR - # Generate a dataset with 25% of positives - # And a constant score - # The precision is then the fraction of positive whatever the recall - # is, as there is only one threshold: - pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25), - # With threshold 0.8 : 1 TP and 2 TN and one FN - pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75), -]) -def test_average_precision(scores, target, expected_score): - assert average_precision(scores, target) == expected_score - - @pytest.mark.parametrize(['pred', 'target', 'expected'], [ pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.), pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.), diff --git a/tests/metrics/functional/test_reduction.py b/tests/metrics/functional/test_reduction.py index aec54c18067155..99fecf65b534f5 100644 --- a/tests/metrics/functional/test_reduction.py +++ b/tests/metrics/functional/test_reduction.py @@ -1,7 +1,7 @@ import pytest import torch -from pytorch_lightning.metrics.functional.reduction import reduce, class_reduce +from pytorch_lightning.metrics.utils import reduce, class_reduce def test_reduce(): diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py index 21931a365efe32..4cac03cc16e2be 100644 --- a/tests/metrics/test_ddp.py +++ b/tests/metrics/test_ddp.py @@ -3,6 +3,7 @@ import pytest import torch +from pytorch_lightning.metrics import Metric from tests.metrics.test_metric import Dummy from tests.metrics.utils import setup_ddp @@ -43,3 +44,28 @@ def _test_ddp_sum_cat(rank, worldsize): @pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat]) def test_ddp(process): torch.multiprocessing.spawn(process, args=(2,), nprocs=2) + + +def _test_non_contiguous_tensors(rank, worldsize): + setup_ddp(rank, worldsize) + + class DummyMetric(Metric): + def __init__(self): + super().__init__() + self.add_state("x", default=[], dist_reduce_fx=None) + + def update(self, x): + self.x.append(x) + + def compute(self): + x = torch.cat(self.x, dim=0) + return x.sum() + + metric = DummyMetric() + metric.update(torch.randn(10, 5)[:, 0]) + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_non_contiguous_tensors(): + """ Test that gather_all operation works for non contiguous tensors """ + torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2,), nprocs=2) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 34abee84738630..5c00384da1e14b 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -13,7 +13,7 @@ NUM_PROCESSES = 2 NUM_BATCHES = 10 -BATCH_SIZE = 16 +BATCH_SIZE = 32 NUM_CLASSES = 5 EXTRA_DIM = 3 THRESHOLD = 0.5 @@ -28,6 +28,32 @@ def setup_ddp(rank, world_size): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) +def _assert_allclose(pl_result, sk_result, atol: float = 1e-8): + """ Utility function for recursively asserting that two results are within + a certain tolerance + """ + # single output compare + if isinstance(pl_result, torch.Tensor): + assert np.allclose(pl_result.numpy(), sk_result, atol=atol, equal_nan=True) + # multi output compare + elif isinstance(pl_result, (tuple, list)): + for pl_res, sk_res in zip(pl_result, sk_result): + _assert_allclose(pl_res, sk_res, atol=atol) + else: + raise ValueError('Unknown format for comparison') + + +def _assert_tensor(pl_result): + """ Utility function for recursively checking that some input only consist of + torch tensors + """ + if isinstance(pl_result, (list, tuple)): + for plr in pl_result: + _assert_tensor(plr) + else: + assert isinstance(pl_result, torch.Tensor) + + def _class_test( rank: int, worldsize: int, @@ -71,28 +97,28 @@ def _class_test( if metric.dist_sync_on_step: if rank == 0: - ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)]) - ddp_target = torch.stack([target[i + r] for r in range(worldsize)]) + ddp_preds = torch.cat([preds[i + r] for r in range(worldsize)]) + ddp_target = torch.cat([target[i + r] for r in range(worldsize)]) sk_batch_result = sk_metric(ddp_preds, ddp_target) # assert for dist_sync_on_step if check_dist_sync_on_step: - assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) + _assert_allclose(batch_result, sk_batch_result, atol=atol) else: sk_batch_result = sk_metric(preds[i], target[i]) # assert for batch if check_batch: - assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) + _assert_allclose(batch_result, sk_batch_result, atol=atol) # check on all batches on all ranks result = metric.compute() - assert isinstance(result, torch.Tensor) + _assert_tensor(result) - total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)]) - total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) + total_preds = torch.cat([preds[i] for i in range(NUM_BATCHES)]) + total_target = torch.cat([target[i] for i in range(NUM_BATCHES)]) sk_result = sk_metric(total_preds, total_target) # assert after aggregation - assert np.allclose(result.numpy(), sk_result, atol=atol) + _assert_allclose(result, sk_result, atol=atol) def _functional_test( @@ -120,7 +146,7 @@ def _functional_test( sk_result = sk_metric(preds[i], target[i]) # assert its the same - assert np.allclose(lightning_result.numpy(), sk_result, atol=atol) + _assert_allclose(lightning_result, sk_result, atol=atol) class MetricTester: diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 5b31c678177e4a..94daaedb4fa638 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -21,18 +21,19 @@ import os import sys - -try: - import horovod.torch as hvd -except ImportError: - print('You requested to import Horovod which is missing or not supported for your OS.') - PATH_HERE = os.path.abspath(os.path.dirname(__file__)) PATH_ROOT = os.path.abspath(os.path.join(PATH_HERE, '..', '..', '..', '..')) sys.path.insert(0, os.path.abspath(PATH_ROOT)) from pytorch_lightning import Trainer # noqa: E402 from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402 +from pytorch_lightning.utilities import HOROVOD_AVAILABLE # noqa: E402 + +if HOROVOD_AVAILABLE: + import horovod.torch as hvd # noqa: E402 +else: + print('You requested to import Horovod which is missing or not supported for your OS.') + # Move project root to the front of the search path, as some imports may have reordered things idx = sys.path.index(PATH_ROOT) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 0fc68a226eae61..1a38b12d37ba02 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -29,33 +29,25 @@ from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator from pytorch_lightning.core.step_result import EvalResult, Result, TrainResult from pytorch_lightning.metrics.classification.accuracy import Accuracy -from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE +from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE, HOROVOD_AVAILABLE, _module_available from tests.base import EvalModelTemplate +from tests.base.boring_model import BoringModel from tests.base.models import BasicGAN -try: +if HOROVOD_AVAILABLE: import horovod - from horovod.common.util import nccl_built -except ImportError: - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - + import horovod.torch as hvd # This script will run the actual test model training in parallel TEST_SCRIPT = os.path.join(os.path.dirname(__file__), 'data', 'horovod', 'train_default_model.py') - -def _nccl_available(): - if not HOROVOD_AVAILABLE: - return False - - try: - return nccl_built() - except AttributeError: - # Horovod 0.19.1 nccl_built() does not yet work with Python 3.8: - # See: https://github.com/horovod/horovod/issues/1891 - return False +try: + from horovod.common.util import nccl_built + nccl_built() +except (ImportError, ModuleNotFoundError, AttributeError): + HOROVOD_NCCL_AVAILABLE = False +finally: + HOROVOD_NCCL_AVAILABLE = True def _run_horovod(trainer_options, on_gpu=False): @@ -114,7 +106,7 @@ def test_horovod_cpu_implicit(enable_pl_optimizer, tmpdir): @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_horovod_multi_gpu(tmpdir): """Test Horovod with multi-GPU support.""" @@ -134,7 +126,7 @@ def test_horovod_multi_gpu(tmpdir): @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex") def test_horovod_apex(tmpdir): @@ -158,7 +150,7 @@ def test_horovod_apex(tmpdir): @pytest.mark.skip(reason="Skip till Horovod fixes integration with Native torch.cuda.amp") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp") def test_horovod_amp(tmpdir): @@ -181,7 +173,7 @@ def test_horovod_amp(tmpdir): @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") -@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support") @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_horovod_transfer_batch_to_gpu(tmpdir): class TestTrainingStepModel(EvalModelTemplate): @@ -263,10 +255,6 @@ def hvd_test_fn(): path_root = os.path.abspath(os.path.join(path_here, '..', '..')) sys.path.insert(0, os.path.abspath(path_root)) - import horovod.torch as hvd - - from tests.base.boring_model import BoringModel - class TestModel(BoringModel): def training_step(self, batch, batch_idx): self.training_step_called = True @@ -318,8 +306,6 @@ def sk_metric(preds, target): target = torch.randint(high=2, size=(num_batches, batch_size)) def _compute_batch(): - import horovod.torch as hvd - trainer = Trainer( fast_dev_run=True, distributed_backend='horovod', diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 2369922c31a7cb..20d1c6fdd5cc02 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Test deprecated functionality which will be removed in vX.Y.Z""" import sys from argparse import ArgumentParser @@ -6,8 +19,8 @@ import pytest import torch -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.metrics.functional.classification import auc from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -21,7 +34,21 @@ def test_tbd_remove_in_v1_3_0(tmpdir): # Deprecate prefix with pytest.deprecated_call(match='will be removed in v1.3'): - callback = ModelCheckpoint(prefix='temp') + ModelCheckpoint(prefix='temp') + + # Deprecate auto mode + with pytest.deprecated_call(match='will be removed in v1.3'): + ModelCheckpoint(mode='auto') + + with pytest.deprecated_call(match='will be removed in v1.3'): + EarlyStopping(mode='auto') + + with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"): + class DeprecatedHparamsModel(LightningModule): + def __init__(self, hparams): + super().__init__() + self.hparams = hparams + DeprecatedHparamsModel({}) def test_tbd_remove_in_v1_2_0(): @@ -110,6 +137,6 @@ def test_end(self, outputs): return {'test_loss': torch.tensor(0.7)} -def test_auc_reorder_remove_in_v1_1_0(): +def test_reorder_remove_in_v1_1(): with pytest.deprecated_call(match='The `reorder` parameter to `auc` has been deprecated'): _ = auc(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 2]), reorder=True) diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index 75a48b9a92d5b5..56e5765c7f4b86 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -14,54 +14,51 @@ """ Tests to ensure that the training loop works with a dict (1.0) """ -import os from copy import deepcopy -from unittest import mock import pytest import torch +from torch.utils.data import DataLoader from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer import Trainer -from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator -from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel, RandomDataset -class Helper: - def decorator_with_arguments(fx_name='', hook_fx_name=None): - def decorator(func): - def wrapper(self, *args, **kwargs): - # Set information - self._current_fx_name = fx_name - self._current_hook_fx_name = hook_fx_name - self._results = Result() +def decorator_with_arguments(fx_name='', hook_fx_name=None): + def decorator(func): + def wrapper(self, *args, **kwargs): + # Set information + self._current_fx_name = fx_name + self._current_hook_fx_name = hook_fx_name + self._results = Result() - result = func(self, *args, **kwargs) + result = func(self, *args, **kwargs) - # cache metrics - self.trainer.logger_connector.cache_logged_metrics() - return result - return wrapper + # cache metrics + self.trainer.logger_connector.cache_logged_metrics() + return result - return decorator + return wrapper + return decorator -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test__logger_connector__epoch_result_store__train(tmpdir): + +def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch): """ Tests that LoggerConnector will properly capture logged information and reduce them """ + monkeypatch.setenv("PL_DEV_DEBUG", "1") class TestModel(BoringModel): train_losses = [] - @Helper.decorator_with_arguments(fx_name="training_step") + @decorator_with_arguments(fx_name="training_step") def training_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) @@ -91,18 +88,10 @@ def training_step_end(self, *_): train_results = model.train_results - assert len(train_results(fx_name="training_step", dl_idx="0", opt_idx="0")) == 2 - generated = train_results(fx_name="training_step", - dl_idx="0", - opt_idx="0", - batch_idx="0", - split_idx="0")["train_loss"] + assert len(train_results(fx_name="training_step", dl_idx=0, opt_idx=0)) == 2 + generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=0, split_idx=0)["train_loss"] assert generated == model.train_losses[0] - generated = train_results(fx_name="training_step", - dl_idx="0", - opt_idx="0", - batch_idx="1", - split_idx="0")["train_loss"] + generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=1, split_idx=0)["train_loss"] assert generated == model.train_losses[1] assert train_results.has_reduced is not True @@ -111,7 +100,7 @@ def training_step_end(self, *_): assert train_results.has_reduced is True - generated = train_results(fx_name="training_step", dl_idx="0", opt_idx="0", reduced=True)['train_loss_epoch'].item() + generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['train_loss_epoch'].item() excepted = torch.stack(model.train_losses).mean().item() assert generated == excepted @@ -144,7 +133,7 @@ def __init__(self): self.test_hidden = None self.layer = torch.nn.Linear(2, 2) - @Helper.decorator_with_arguments(fx_name="training_step") + @decorator_with_arguments(fx_name="training_step") def training_step(self, batch, batch_idx, hiddens): self.test_hidden = torch.rand(1) @@ -155,8 +144,7 @@ def training_step(self, batch, batch_idx, hiddens): assert y_tensor.shape[1] == truncated_bptt_steps, "tbptt split list failed" pred = self(x_tensor.view(batch_size, truncated_bptt_steps)) - loss = torch.nn.functional.mse_loss( - pred, y_tensor.view(batch_size, truncated_bptt_steps)) + loss = torch.nn.functional.mse_loss(pred, y_tensor.view(batch_size, truncated_bptt_steps)) self.train_losses.append(loss) @@ -195,7 +183,7 @@ def training_step_end(self, *_): train_results = model.train_results - generated = train_results(fx_name="training_step", dl_idx="0", opt_idx="0", batch_idx="0") + generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, batch_idx=0) assert len(generated) == len(model.train_losses) # assert reduction didn't happen yet @@ -207,31 +195,28 @@ def training_step_end(self, *_): # assert reduction did happen assert train_results.has_reduced is True - generated = train_results(fx_name="training_step", dl_idx="0", opt_idx="0", reduced=True)['a_epoch'].item() + generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['a_epoch'].item() assert generated == torch.stack(model.train_losses).mean().item() -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize('num_dataloaders', [1, 2]) -def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, num_dataloaders): +def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, monkeypatch, num_dataloaders): """ Tests that LoggerConnector will properly capture logged information in multi_dataloaders scenario """ + monkeypatch.setenv("PL_DEV_DEBUG", "1") class TestModel(BoringModel): test_losses = {} - @Helper.decorator_with_arguments(fx_name="test_step") + @decorator_with_arguments(fx_name="test_step") def test_step(self, batch, batch_idx, dl_idx=0): output = self.layer(batch) loss = self.loss(batch, output) - primary_key = str(dl_idx) - if primary_key not in self.test_losses: - self.test_losses[primary_key] = [] - - self.test_losses[primary_key].append(loss) + self.test_losses.setdefault(dl_idx, []) + self.test_losses[dl_idx].append(loss) self.log("test_loss", loss, on_step=True, on_epoch=True) return {"test_loss": loss} @@ -270,14 +255,14 @@ def test_dataloader(self): assert len(generated) == num_dataloaders for dl_idx in range(num_dataloaders): - generated = len(test_results(fx_name="test_step", dl_idx=str(dl_idx))) + generated = len(test_results(fx_name="test_step", dl_idx=dl_idx)) assert generated == limit_test_batches test_results = model.reduce_results for dl_idx in range(num_dataloaders): - expected = torch.stack(model.test_losses[str(dl_idx)]).mean() - generated = test_results(fx_name="test_step", dl_idx=str(dl_idx), reduced=True)["test_loss_epoch"] + expected = torch.stack(model.test_losses[dl_idx]).mean() + generated = test_results(fx_name="test_step", dl_idx=dl_idx, reduced=True)["test_loss_epoch"] assert abs(expected.item() - generated.item()) < 1e-6 @@ -294,7 +279,8 @@ def test_call_back_validator(tmpdir): 'on_epoch_start', 'on_fit_end', 'on_fit_start', - 'on_init_end', 'on_init_start', + 'on_init_end', + 'on_init_start', 'on_keyboard_interrupt', 'on_load_checkpoint', 'on_pretrain_routine_end', @@ -343,24 +329,22 @@ def test_call_back_validator(tmpdir): "teardown", ] - assert funcs_name == callbacks_func, """Detected new callback function. + assert ( + funcs_name == callbacks_func + ), """Detected new callback function. Need to add its logging permission to CallbackHookNameValidator and update this test""" validator = CallbackHookNameValidator() for func_name in funcs_name: - # This summurize where and what is currently possible to log using `self.log` function. + # This summarizes where and what is currently possible to log using `self.log` is_stage = "train" in func_name or "test" in func_name or "validation" in func_name is_start = "start" in func_name or "batch" in func_name on_step = is_stage and is_start on_epoch = True # creating allowed condition allowed = ( - is_stage - or "batch" in func_name - or "epoch" in func_name - or "grad" in func_name - or "backward" in func_name + is_stage or "batch" in func_name or "epoch" in func_name or "grad" in func_name or "backward" in func_name ) allowed = ( allowed @@ -368,22 +352,76 @@ def test_call_back_validator(tmpdir): and func_name not in ["on_train_end", "on_test_end", "on_validation_end"] ) if allowed: - validator.check_logging_in_callbacks(current_hook_fx_name=func_name, - on_step=on_step, - on_epoch=on_epoch) + validator.check_logging_in_callbacks(current_hook_fx_name=func_name, on_step=on_step, on_epoch=on_epoch) if not is_start and is_stage: with pytest.raises(MisconfigurationException, match="function supports only"): - validator.check_logging_in_callbacks(current_hook_fx_name=func_name, - on_step=True, - on_epoch=on_epoch) + validator.check_logging_in_callbacks( + current_hook_fx_name=func_name, on_step=True, on_epoch=on_epoch + ) else: assert func_name in not_supported with pytest.raises(MisconfigurationException, match="function doesn't support"): - validator.check_logging_in_callbacks(current_hook_fx_name=func_name, - on_step=on_step, - on_epoch=on_epoch) - - result = validator.check_logging_in_callbacks(current_hook_fx_name=None, - on_step=None, - on_epoch=None) - assert result is None + validator.check_logging_in_callbacks(current_hook_fx_name=func_name, on_step=on_step, on_epoch=on_epoch) + + # should not fail + validator.check_logging_in_callbacks(current_hook_fx_name=None, on_step=None, on_epoch=None) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires two GPUs") +def test_epoch_results_cache_dp(tmpdir): + + root_device = torch.device("cuda", 0) + + class TestModel(BoringModel): + + def training_step(self, *args, **kwargs): + result = super().training_step(*args, **kwargs) + self.log("train_loss_epoch", result["loss"], on_step=False, on_epoch=True) + return result + + def training_step_end(self, training_step_outputs): # required for dp + loss = training_step_outputs["loss"].mean() + return loss + + def training_epoch_end(self, outputs): + assert all(out["loss"].device == root_device for out in outputs) + assert self.trainer.callback_metrics["train_loss_epoch"].device == root_device + + def validation_step(self, *args, **kwargs): + val_loss = torch.rand(1, device=torch.device("cuda", 1)) + self.log("val_loss_epoch", val_loss, on_step=False, on_epoch=True) + return val_loss + + def validation_epoch_end(self, outputs): + assert all(loss.device == root_device for loss in outputs) + assert self.trainer.callback_metrics["val_loss_epoch"].device == root_device + + def test_step(self, *args, **kwargs): + test_loss = torch.rand(1, device=torch.device("cuda", 1)) + self.log("test_loss_epoch", test_loss, on_step=False, on_epoch=True) + return test_loss + + def test_epoch_end(self, outputs): + assert all(loss.device == root_device for loss in outputs) + assert self.trainer.callback_metrics["test_loss_epoch"].device == root_device + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=4) + + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=4) + + def test_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=4) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + accelerator="dp", + gpus=2, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + trainer.test(model, ckpt_path=None) diff --git a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py index 022796c275d368..1a928913228f22 100644 --- a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py @@ -472,8 +472,6 @@ def make_logging(self, pl_module, func_name, "forked": False, "func_name": func_name} - """ - def on_validation_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_validation_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) @@ -486,6 +484,7 @@ def on_validation_epoch_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_validation_epoch_start', 3, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) + """ def on_batch_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) @@ -493,6 +492,7 @@ def on_batch_start(self, trainer, pl_module): def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): self.make_logging(pl_module, 'on_validation_batch_start', 5, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) + """ def on_batch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices, @@ -510,8 +510,6 @@ def on_epoch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) - """ - def on_validation_epoch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_validation_epoch_end', 9, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) @@ -541,16 +539,14 @@ def validation_step(self, batch, batch_idx): trainer.fit(model) trainer.test() - """ assert test_callback.funcs_called_count["on_epoch_start"] == 1 - assert test_callback.funcs_called_count["on_batch_start"] == 1 + # assert test_callback.funcs_called_count["on_batch_start"] == 1 assert test_callback.funcs_called_count["on_batch_end"] == 1 assert test_callback.funcs_called_count["on_validation_start"] == 1 assert test_callback.funcs_called_count["on_validation_epoch_start"] == 1 - assert test_callback.funcs_called_count["on_validation_batch_start"] == 4 + # assert test_callback.funcs_called_count["on_validation_batch_start"] == 4 assert test_callback.funcs_called_count["on_validation_batch_end"] == 4 assert test_callback.funcs_called_count["on_epoch_end"] == 1 - """ assert test_callback.funcs_called_count["on_validation_epoch_end"] == 1 # Make sure the func_name exists within callback_metrics. If not, we missed some @@ -662,7 +658,6 @@ def make_logging(self, pl_module, func_name, "forked": False, "func_name": func_name} - """ def on_test_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_test_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) @@ -675,11 +670,8 @@ def on_test_epoch_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_test_epoch_start', 3, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) - def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - self.make_logging(pl_module, 'on_test_batch_start', 4, on_steps=self.choices, - on_epochs=self.choices, prob_bars=self.choices) - def on_test_step_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self.make_logging(pl_module, 'on_test_step_end', 5, on_steps=self.choices, + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.make_logging(pl_module, 'on_test_batch_end', 5, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) # used to make sure aggregation works fine. @@ -690,7 +682,6 @@ def on_test_step_end(self, trainer, pl_module, outputs, batch, batch_idx, datalo def on_epoch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_epoch_end', 6, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) - """ def on_test_epoch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_test_epoch_end', 7, on_steps=[False], @@ -728,13 +719,11 @@ def test_dataloader(self): ) trainer.fit(model) trainer.test() - """ + assert test_callback.funcs_called_count["on_test_start"] == 1 assert test_callback.funcs_called_count["on_epoch_start"] == 2 assert test_callback.funcs_called_count["on_test_epoch_start"] == 1 - assert test_callback.funcs_called_count["on_test_batch_start"] == 4 - assert test_callback.funcs_called_count["on_test_step_end"] == 4 - """ + assert test_callback.funcs_called_count["on_test_batch_end"] == 4 assert test_callback.funcs_called_count["on_test_epoch_end"] == 1 # Make sure the func_name exists within callback_metrics. If not, we missed some diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 9be44c68fa812b..c148748888af48 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -558,7 +558,7 @@ def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, "prog_bar": prog_bar, "forked": False, "func_name": func_name} - """ + def on_train_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_train_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) @@ -571,15 +571,6 @@ def on_train_epoch_start(self, trainer, pl_module): self.make_logging(pl_module, 'on_train_epoch_start', 3, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) - def on_batch_start(self, trainer, pl_module): - self.make_logging(pl_module, 'on_batch_start', 4, on_steps=self.choices, - on_epochs=self.choices, prob_bars=self.choices) - - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - self.make_logging(pl_module, 'on_train_batch_start', 5, on_steps=self.choices, - on_epochs=self.choices, prob_bars=self.choices) - - def on_batch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_batch_end', 6, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices) @@ -592,7 +583,6 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data # with func = np.mean if on_epoch else func = np.max self.count += 1 - """ def on_epoch_end(self, trainer, pl_module): self.make_logging(pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices) @@ -629,17 +619,12 @@ def training_step(self, batch, batch_idx): ) trainer.fit(model) - """ assert test_callback.funcs_called_count["on_train_start"] == 1 assert test_callback.funcs_called_count["on_epoch_start"] == 2 assert test_callback.funcs_called_count["on_train_epoch_start"] == 2 - assert test_callback.funcs_called_count["on_batch_start"] == 4 - assert test_callback.funcs_called_count["on_train_batch_start"] == 4 assert test_callback.funcs_called_count["on_batch_end"] == 4 assert test_callback.funcs_called_count["on_epoch_end"] == 2 assert test_callback.funcs_called_count["on_train_batch_end"] == 4 - - """ assert test_callback.funcs_called_count["on_epoch_end"] == 2 assert test_callback.funcs_called_count["on_train_epoch_end"] == 2 diff --git a/tests/utilities/distributed.py b/tests/utilities/distributed.py index f6b9a686b21bf1..80c0246ce6c577 100644 --- a/tests/utilities/distributed.py +++ b/tests/utilities/distributed.py @@ -29,11 +29,10 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60): # need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment env = os.environ.copy() - env['PYTHONPATH'] = f'{pytorch_lightning.__file__}:' + env.get('PYTHONPATH', '') + env['PYTHONPATH'] = env.get('PYTHONPATH', '') + f'{pytorch_lightning.__file__}:' # for running in ddp mode, we need to lauch it's own process or pytest will get stuck p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) - try: std, err = p.communicate(timeout=timeout) err = str(err.decode("utf-8")) @@ -42,5 +41,4 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60): except TimeoutExpired: p.kill() std, err = p.communicate() - return std, err