Skip to content

Commit

Permalink
Remove Trainer._device_type
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 18, 2022
1 parent 2c75a7b commit e00a14b
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 147 deletions.
5 changes: 3 additions & 2 deletions pytorch_lightning/callbacks/xla_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import time

import pytorch_lightning as pl
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info

Expand Down Expand Up @@ -72,7 +73,7 @@ def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule")
if not trainer.logger:
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")

if trainer._device_type != _AcceleratorType.TPU:
if isinstance(trainer.accelerator, TPUAccelerator):
raise MisconfigurationException(
"You are using XLAStatsMonitor but are not running on TPU"
f" since `tpu_cores` attribute in Trainer is set to {trainer.tpu_cores}."
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# pointer to the trainer object
self.trainer = None

self._device_type = None

# true if using amp
self.use_amp: bool = False

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import Tensor
from torch.optim import Optimizer

from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult
Expand All @@ -29,10 +30,9 @@
check_finite_loss,
)
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.utilities import _AcceleratorType, AMPType
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -369,7 +369,7 @@ def _optimizer_step(
optimizer,
opt_idx,
train_step_and_backward_closure,
on_tpu=(self.trainer._device_type == _AcceleratorType.TPU and _TPU_AVAILABLE),
on_tpu=isinstance(self.trainer.accelerator, TPUAccelerator),
using_native_amp=(self.trainer.amp_backend == AMPType.NATIVE),
using_lbfgs=is_lbfgs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
import torch

import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import _AcceleratorType, memory
from pytorch_lightning.utilities import memory
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -305,7 +306,7 @@ def gpus_metrics(self) -> Dict[str, float]:
.. deprecated:: v1.5
Will be removed in v1.7.
"""
if self.trainer._device_type == _AcceleratorType.GPU and self.log_gpu_memory:
if isinstance(self.trainer.accelerator, GPUAccelerator) and self.log_gpu_memory:
mem_map = memory.get_memory_profile(self.log_gpu_memory)
self._gpus_metrics.update(mem_map)
return self._gpus_metrics
Expand Down
23 changes: 8 additions & 15 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
from pytorch_lightning.accelerators import Accelerator, GPUAccelerator, IPUAccelerator, TPUAccelerator
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.datamodule import LightningDataModule
Expand Down Expand Up @@ -75,7 +75,6 @@
from pytorch_lightning.tuner.lr_finder import _LRFinder
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import (
_AcceleratorType,
_IPU_AVAILABLE,
_TPU_AVAILABLE,
AMPType,
Expand Down Expand Up @@ -1716,33 +1715,31 @@ def __setup_profiler(self) -> None:
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)

def _log_device_info(self) -> None:
rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type == _AcceleratorType.GPU}")
rank_zero_info(
f"GPU available: {torch.cuda.is_available()}, used: {isinstance(self.accelerator, GPUAccelerator)}"
)

num_tpu_cores = (
self.tpu_cores if self.tpu_cores is not None and self._device_type == _AcceleratorType.TPU else 0
self.tpu_cores if self.tpu_cores is not None and isinstance(self.accelerator, TPUAccelerator) else 0
)
rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores")

num_ipus = self.ipus if self.ipus is not None else 0
rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs")

if torch.cuda.is_available() and self._device_type != _AcceleratorType.GPU:
if torch.cuda.is_available() and not isinstance(self.accelerator, GPUAccelerator):
rank_zero_warn(
"GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.",
category=PossibleUserWarning,
)

if _TPU_AVAILABLE and self._device_type != _AcceleratorType.TPU:
if _TPU_AVAILABLE and not isinstance(self.accelerator, TPUAccelerator):
rank_zero_warn(
"TPU available but not used. Set the `tpu_cores` flag in your trainer"
" `Trainer(tpu_cores=8)` or script `--tpu_cores=8`."
)

if (
_IPU_AVAILABLE
and self._device_type != _AcceleratorType.IPU
and not isinstance(self.accelerator, IPUAccelerator)
):
if _IPU_AVAILABLE and not isinstance(self.accelerator, IPUAccelerator):
rank_zero_warn(
"IPU available but not used. Set the `ipus` flag in your trainer"
" `Trainer(ipus=8)` or script `--ipus=8`."
Expand Down Expand Up @@ -1962,10 +1959,6 @@ def should_rank_save_checkpoint(self) -> bool:
isinstance(strategy, pl.strategies.TPUSpawnStrategy) and strategy.local_rank == 0 or strategy.is_global_zero
)

@property
def _device_type(self) -> _AcceleratorType:
return self._accelerator_connector.device_type

@property
def num_nodes(self) -> int:
return self._accelerator_connector.num_nodes
Expand Down
21 changes: 1 addition & 20 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
ParallelStrategy,
SingleDeviceStrategy,
)
from pytorch_lightning.utilities import _AcceleratorType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -447,10 +446,7 @@ def test_accelerator_choice_multi_node_gpu(

@mock.patch("torch.cuda.is_available", return_value=False)
def test_accelerator_cpu(_):

trainer = Trainer(accelerator="cpu")

assert trainer._device_type == "cpu"
assert isinstance(trainer.accelerator, CPUAccelerator)

with pytest.raises(MisconfigurationException, match="You requested gpu:"):
Expand All @@ -464,36 +460,25 @@ def test_accelerator_cpu(_):

@RunIf(min_gpus=1)
def test_accelerator_gpu():

trainer = Trainer(accelerator="gpu", gpus=1)

assert trainer._device_type == "gpu"
assert isinstance(trainer.accelerator, GPUAccelerator)

trainer = Trainer(accelerator="gpu")
assert isinstance(trainer.accelerator, GPUAccelerator)

trainer = Trainer(accelerator="auto", gpus=1)

assert trainer._device_type == "gpu"
assert isinstance(trainer.accelerator, GPUAccelerator)


@RunIf(min_gpus=1)
def test_accelerator_cpu_with_gpus_flag():

trainer = Trainer(accelerator="cpu", gpus=1)

assert trainer._device_type == "cpu"
assert isinstance(trainer.accelerator, CPUAccelerator)


@RunIf(min_gpus=2)
def test_accelerator_cpu_with_multiple_gpus():

trainer = Trainer(accelerator="cpu", gpus=2)

assert trainer._device_type == "cpu"
assert isinstance(trainer.accelerator, CPUAccelerator)


Expand Down Expand Up @@ -532,10 +517,8 @@ def test_accelerator_gpu_with_devices(devices, plugin):

@RunIf(min_gpus=1)
def test_accelerator_auto_with_devices_gpu():

trainer = Trainer(accelerator="auto", devices=1)

assert trainer._device_type == "gpu"
assert isinstance(trainer.accelerator, GPUAccelerator)
assert trainer.gpus == 1


Expand Down Expand Up @@ -662,10 +645,8 @@ def test_strategy_choice_gpu_plugin(tmpdir, plugin):
@RunIf(min_gpus=2)
@pytest.mark.parametrize("plugin", [DDPSpawnStrategy, DDPStrategy])
def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin):

trainer = Trainer(strategy=plugin(), gpus=2)
assert isinstance(trainer.strategy, plugin)
assert trainer._device_type == _AcceleratorType.GPU
assert isinstance(trainer.accelerator, GPUAccelerator)


Expand Down
16 changes: 2 additions & 14 deletions tests/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytorch_lightning.strategies.ipu import IPUStrategy
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import _AcceleratorType, _IPU_AVAILABLE
from pytorch_lightning.utilities import _IPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.datamodules import ClassifDataModule
Expand Down Expand Up @@ -499,27 +499,19 @@ def test_precision_plugin(tmpdir):

@RunIf(ipu=True)
def test_accelerator_ipu():

trainer = Trainer(accelerator="ipu", ipus=1)

assert trainer._device_type == "ipu"
assert isinstance(trainer.accelerator, IPUAccelerator)

trainer = Trainer(accelerator="ipu")
assert isinstance(trainer.accelerator, IPUAccelerator)

trainer = Trainer(accelerator="auto", ipus=8)

assert trainer._device_type == "ipu"
assert isinstance(trainer.accelerator, IPUAccelerator)


@RunIf(ipu=True)
def test_accelerator_cpu_with_ipus_flag():

trainer = Trainer(accelerator="cpu", ipus=1)

assert trainer._device_type == "cpu"
assert isinstance(trainer.accelerator, CPUAccelerator)


Expand All @@ -535,10 +527,8 @@ def test_accelerator_ipu_with_devices():

@RunIf(ipu=True)
def test_accelerator_auto_with_devices_ipu():

trainer = Trainer(accelerator="auto", devices=8)

assert trainer._device_type == "ipu"
assert isinstance(trainer.accelerator, IPUAccelerator)
assert trainer.ipus == 8


Expand Down Expand Up @@ -568,10 +558,8 @@ def test_strategy_choice_ipu_plugin(tmpdir):

@RunIf(ipu=True)
def test_device_type_when_training_plugin_ipu_passed(tmpdir):

trainer = Trainer(strategy=IPUStrategy(), ipus=8)
assert isinstance(trainer.strategy, IPUStrategy)
assert trainer._device_type == _AcceleratorType.IPU
assert isinstance(trainer.accelerator, IPUAccelerator)


Expand Down
44 changes: 11 additions & 33 deletions tests/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.plugins import PrecisionPlugin, TPUPrecisionPlugin, XLACheckpointIO
from pytorch_lightning.strategies import DDPStrategy, TPUSpawnStrategy
from pytorch_lightning.strategies import DDPStrategy, SingleTPUStrategy, TPUSpawnStrategy
from pytorch_lightning.utilities import find_shared_parameters
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
Expand Down Expand Up @@ -86,48 +86,33 @@ def test_accelerator_tpu():
assert TPUAccelerator.is_available()

trainer = Trainer(accelerator="tpu", tpu_cores=8)

assert trainer._device_type == "tpu"
assert isinstance(trainer.accelerator, TPUAccelerator)
assert isinstance(trainer.strategy, TPUSpawnStrategy)

trainer = Trainer(accelerator="tpu")
assert isinstance(trainer.accelerator, TPUAccelerator)
assert isinstance(trainer.strategy, SingleTPUStrategy)


@RunIf(tpu=True)
def test_accelerator_cpu_with_tpu_cores_flag():

trainer = Trainer(accelerator="cpu", tpu_cores=8)

assert trainer._device_type == "cpu"
assert isinstance(trainer.accelerator, CPUAccelerator)


@RunIf(tpu=True)
def test_accelerator_tpu_with_auto():

trainer = Trainer(accelerator="auto", tpu_cores=8)

assert trainer._device_type == "tpu"
assert isinstance(trainer.accelerator, TPUAccelerator)


@RunIf(tpu=True)
def test_accelerator_tpu_with_devices():

trainer = Trainer(accelerator="tpu", devices=8)

assert trainer.tpu_cores == 8
assert isinstance(trainer.strategy, TPUSpawnStrategy)
assert isinstance(trainer.accelerator, TPUAccelerator)


@RunIf(tpu=True)
def test_accelerator_auto_with_devices_tpu():
assert TPUAccelerator.is_available()

trainer = Trainer(accelerator="auto", devices=8)
assert isinstance(trainer.accelerator, TPUAccelerator)
assert isinstance(trainer.strategy, TPUSpawnStrategy)
assert trainer.tpu_cores == 8

assert trainer._device_type == "tpu"
trainer = Trainer(accelerator="auto", devices="auto")
assert isinstance(trainer.accelerator, TPUAccelerator)
assert isinstance(trainer.strategy, TPUSpawnStrategy)
assert trainer.devices == 8
assert trainer.tpu_cores == 8


Expand Down Expand Up @@ -328,10 +313,3 @@ def test_mp_device_dataloader_attribute(_):
dataset = RandomDataset(32, 64)
dataloader = TPUSpawnStrategy().process_dataloader(DataLoader(dataset))
assert dataloader.dataset == dataset


@RunIf(tpu=True)
def test_devices_auto_choice_tpu():
trainer = Trainer(accelerator="auto", devices="auto")
assert trainer.devices == 8
assert trainer.tpu_cores == 8
4 changes: 1 addition & 3 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.strategies import TPUSpawnStrategy
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync
from pytorch_lightning.utilities import _AcceleratorType, _TPU_AVAILABLE
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
Expand Down Expand Up @@ -469,8 +469,6 @@ def teardown(self, stage):
@RunIf(tpu=True)
@pl_multi_process_test
def test_device_type_when_training_plugin_tpu_passed(tmpdir):

trainer = Trainer(strategy=TPUSpawnStrategy(), tpu_cores=8)
assert isinstance(trainer.strategy, TPUSpawnStrategy)
assert trainer._device_type == _AcceleratorType.TPU
assert isinstance(trainer.accelerator, TPUAccelerator)

0 comments on commit e00a14b

Please sign in to comment.