Skip to content

Commit

Permalink
Enable back inference mode support with hpu & update links (#15918)
Browse files Browse the repository at this point in the history
* Enable back inference mode support with hpu
* Remove unused
* Update document link and address comment

Signed-off-by: Jerome <janand@habana.ai>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

(cherry picked from commit 6aaac8b)
  • Loading branch information
jerome-habana authored and Borda committed Dec 7, 2022
1 parent a528d56 commit d90f624
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 13 deletions.
1 change: 0 additions & 1 deletion docs/source-pytorch/accelerators/hpu_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,3 @@ Known limitations
-----------------

* `Habana dataloader <https://docs.habana.ai/en/latest/PyTorch_User_Guide/PyTorch_User_Guide.html#habana-data-loader>`__ is not supported.
* :func:`torch.inference_mode` is not supported
2 changes: 1 addition & 1 deletion docs/source-pytorch/accelerators/hpu_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,4 @@ The below snippet shows how DeviceStatsMonitor can be enabled.
device_stats = DeviceStatsMonitor()
trainer = Trainer(accelerator="hpu", callbacks=[device_stats])
For more details, please refer to `Memory Stats APIs <https://docs.habana.ai/en/v1.5.0/PyTorch/PyTorch_User_Guide/Python_Packages.html#memory-stats-apis>`__.
For more details, please refer to `Memory Stats APIs <https://docs.habana.ai/en/latest/PyTorch/PyTorch_User_Guide/Python_Packages.html#memory-stats-apis>`__.
1 change: 1 addition & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [1.8.4] - 2022-12-06

- Direct support for compiled models ([#15922](https://github.com/Lightning-AI/lightning/pull/15922))
- Fixed issue with unsupported torch.inference_mode() on hpu backends ([#15918](https://github.com/Lightning-AI/lightning/pull/15918))


## [1.8.3] - 2022-11-22
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,7 @@
from pytorch_lightning.strategies.ddp_spawn import _DDP_FORK_ALIASES
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import (
_HOROVOD_AVAILABLE,
_HPU_AVAILABLE,
_IPU_AVAILABLE,
_TORCH_GREATER_EQUAL_1_11,
)
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _IPU_AVAILABLE, _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -499,7 +494,7 @@ def _choose_auto_accelerator(self) -> str:
return "tpu"
if _IPU_AVAILABLE:
return "ipu"
if _HPU_AVAILABLE:
if HPUAccelerator.is_available():
return "hpu"
if MPSAccelerator.is_available():
return "mps"
Expand Down
6 changes: 2 additions & 4 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from lightning_lite.utilities.data import _auto_add_worker_init_fn
from lightning_lite.utilities.types import _PATH
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.accelerators import Accelerator, HPUAccelerator, TPUAccelerator
from pytorch_lightning.accelerators import Accelerator, TPUAccelerator
from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.datamodule import LightningDataModule
Expand Down Expand Up @@ -2265,13 +2265,11 @@ def configure_optimizers(self):

@contextmanager
def _evaluation_context(accelerator: Accelerator, inference_mode: bool = True) -> Generator:
# inference mode is not supported with gloo backend (#9431),
# and HPU & TPU accelerators.
# inference mode is not supported with gloo backend (#9431) and TPU accelerators.
context_manager_class = (
torch.inference_mode
if inference_mode
and not (dist.is_available() and dist.is_initialized() and dist.get_backend() == "gloo")
and not isinstance(accelerator, HPUAccelerator)
and not isinstance(accelerator, TPUAccelerator)
else torch.no_grad
)
Expand Down

0 comments on commit d90f624

Please sign in to comment.