Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningCLI support for external accelerators #55

Open
ankitgola005 opened this issue Jul 19, 2023 · 3 comments
Open

LightningCLI support for external accelerators #55

ankitgola005 opened this issue Jul 19, 2023 · 3 comments
Labels
enhancement New feature or request help wanted Extra attention is needed won't fix

Comments

@ankitgola005
Copy link
Contributor

馃殌 Feature

LightningCLI support for external accelerators

Motivation

LightningCLI helps avoid boilerplate code for command line tools. The current implementation does not seem to support external accelerators, and it only accepts the accelerators present in lightning source.

Pitch

Extend support for external accelerators in LightningCLI.

Alternatives

Additional context

First mentioned in #54

To reproduce:

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.cli import LightningCLI
from lightning_habana import HPUAccelerator

class BMAccelerator(BoringModel):
    def on_fit_start(self):
        assert isinstance(self.trainer.accelerator, HPUAccelerator), self.trainer.accelerator

model = BMAccelerator
accelerator = HPUAccelerator()

if __name__ == "__main__":

    # Method 1, Passing supported accelerator class instance from an external library
    cli = LightningCLI(model, trainer_defaults={'accelerator': accelerator}

    # Method 2, passing accelerator as string
    cli = LightningCLI(model, trainer_defaults={'accelerator': 'hpu'}

Gives the following tracebacks:

Method 1, passing supported accelerator class instance from an external library

Traceback (most recent call last):
  File "temp.py", line 34, in <module>
    cli = LightningCLI(model, trainer_defaults={'accelerator': HPUAccelerator()})
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 353, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 642, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 893, in _run
    self.strategy.setup_environment()
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 127, in setup_environment
    self.accelerator.setup_device(self.root_device)
  File "/home/agola/lightning-habana-fork/src/lightning_habana/pytorch/accelerator.py", line 50, in setup_device
    raise MisconfigurationException(f"Device should be HPU, got {device} instead.")
lightning.fabric.utilities.exceptions.MisconfigurationException: Device should be HPU, got cpu instead.

Method 2, passing accelerator as string

Traceback (most recent call last):
  File "temp.py", line 33, in <module>
    cli = LightningCLI(model, trainer_defaults={'accelerator': "hpu"})
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 353, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 642, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 916, in _run
    call._call_lightning_module_hook(self, "on_fit_start")
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 142, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "temp.py", line 15, in on_fit_start
    assert isinstance(self.trainer.accelerator,
AssertionError: <lightning.pytorch.accelerators.hpu.HPUAccelerator object at 0x7f37f62917c0>

Env

lightning                     2.0.0
lightning-fabric              2.0.3
lightning-habana              1.0.0
lightning-utilities           0.9.0
pytorch-lightning             2.0.5
@ankitgola005 ankitgola005 added enhancement New feature or request help wanted Extra attention is needed labels Jul 19, 2023
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@jerome-habana
Copy link
Collaborator

cc @Borda

@stale
Copy link

stale bot commented Sep 17, 2023

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix label Sep 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed won't fix
Projects
None yet
Development

No branches or pull requests

2 participants