Skip to content

Commit

Permalink
Moving HPU availability check to an utility function
Browse files Browse the repository at this point in the history
Signed-off-by: jyothi kumar sambolu <jsambolu@habana.ai>
  • Loading branch information
jyothisambolu committed Apr 15, 2024
1 parent f4a22b3 commit 519a480
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
19 changes: 10 additions & 9 deletions src/lightning/pytorch/trainer/connectors/accelerator_connector.py
Expand Up @@ -447,25 +447,26 @@ def _choose_strategy(self) -> Union[Strategy, str]:
return "ddp_fork"
return "ddp"

def _is_hpu_accelerator(self) -> bool:
if _habana_available_and_importable():
from lightning_habana import HPUAccelerator

if isinstance(self._accelerator_flag, HPUAccelerator):
return True
return False

def _check_strategy_and_fallback(self) -> None:
"""Checks edge cases when the strategy selection was a string input, and we need to fall back to a different
choice depending on other parameters or the environment."""
# current fallback and check logic only apply to user pass in str config and object config
# TODO this logic should apply to both str and object config
strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag

if _habana_available_and_importable():
from lightning_habana import HPUAccelerator
if isinstance(self._accelerator_flag, HPUAccelerator):
if strategy_flag:
self._strategy_flag = strategy_flag
return

if (
strategy_flag in FSDPStrategy.get_registered_strategies() or isinstance(self._strategy_flag, FSDPStrategy)
) and self._accelerator_flag not in ("cuda", "gpu"):
) and self._accelerator_flag not in ("cuda", "gpu") and not self._is_hpu_accelerator():
raise MisconfigurationException(
f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used."
f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU or HPU accelerator is not used."
)
if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods():
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/strategies/test_fsdp.py
Expand Up @@ -218,7 +218,7 @@ def test_invalid_on_cpu(tmp_path, cuda_count_0):
"""Test to ensure that we raise Misconfiguration for FSDP on CPU."""
with pytest.raises(
MisconfigurationException,
match=f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used.",
match=f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU or HPU accelerator is not used.",
):
trainer = Trainer(accelerator="cpu", default_root_dir=tmp_path, fast_dev_run=True, strategy="fsdp")
assert isinstance(trainer.strategy, FSDPStrategy)
Expand Down
Expand Up @@ -566,7 +566,7 @@ def test_strategy_choice_ddp_cpu_slurm(cuda_count_0, strategy):
def test_check_fsdp_strategy_and_fallback():
with pytest.raises(
MisconfigurationException,
match=f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used.",
match=f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU or HPU accelerator is not used.",
):
Trainer(accelerator="cpu", strategy="fsdp")

Expand Down

0 comments on commit 519a480

Please sign in to comment.