diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 7439d6a4becba..5dc5ca769c0b3 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed Registration for CloudComputes of Works in `L.app.structures` ([#15964](https://github.com/Lightning-AI/lightning/pull/15964)) +- Fixed `AutoScaler` raising an exception when non-default cloud compute is specified ([#15991](https://github.com/Lightning-AI/lightning/pull/15991)) + + ## [1.8.4] - 2022-12-08 ### Added diff --git a/src/lightning_app/components/auto_scaler.py b/src/lightning_app/components/auto_scaler.py index fc6a1a873769b..13948ba50af89 100644 --- a/src/lightning_app/components/auto_scaler.py +++ b/src/lightning_app/components/auto_scaler.py @@ -449,8 +449,15 @@ def workers(self) -> List[LightningWork]: def create_work(self) -> LightningWork: """Replicates a LightningWork instance with args and kwargs provided via ``__init__``.""" - # TODO: Remove `start_with_flow=False` for faster initialization on the cloud - self._work_kwargs.update(dict(start_with_flow=False)) + cloud_compute = self._work_kwargs.get("cloud_compute", None) + self._work_kwargs.update( + dict( + # TODO: Remove `start_with_flow=False` for faster initialization on the cloud + start_with_flow=False, + # don't try to create multiple works in a single machine + cloud_compute=cloud_compute.clone() if cloud_compute else None, + ) + ) return self._work_cls(*self._work_args, **self._work_kwargs) def add_work(self, work) -> str: diff --git a/tests/tests_app/components/test_auto_scaler.py b/tests/tests_app/components/test_auto_scaler.py index 436c3517d01ca..672b05bbc9a15 100644 --- a/tests/tests_app/components/test_auto_scaler.py +++ b/tests/tests_app/components/test_auto_scaler.py @@ -3,7 +3,7 @@ import pytest -from lightning_app import LightningWork +from lightning_app import CloudCompute, LightningWork from lightning_app.components import AutoScaler @@ -90,3 +90,11 @@ def test_scale(replicas, metrics, expected_replicas): ) assert auto_scaler.scale(replicas, metrics) == expected_replicas + + +def test_create_work_cloud_compute_cloned(): + """Test CloudCompute is cloned to avoid creating multiple works in a single machine.""" + cloud_compute = CloudCompute("gpu") + auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute) + _ = auto_scaler.create_work() + assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute