Skip to content

Commit

Permalink
[App] Fix AutoScaler trying to replicate multiple works in a single…
Browse files Browse the repository at this point in the history
… machine (#15991)

* dont try to replicate new works in the existing machine

* update chglog

* Update comment

* Update src/lightning_app/components/auto_scaler.py

* add test

(cherry picked from commit c1d0156)
  • Loading branch information
akihironitta authored and Borda committed Dec 14, 2022
1 parent f01e4fc commit ed223e8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed


- Fixed `AutoScaler` raising an exception when non-default cloud compute is specified ([#15991](https://github.com/Lightning-AI/lightning/pull/15991))


## [1.8.4] - 2022-12-08

### Added
Expand Down
11 changes: 9 additions & 2 deletions src/lightning_app/components/auto_scaler.py
Expand Up @@ -449,8 +449,15 @@ def workers(self) -> List[LightningWork]:

def create_work(self) -> LightningWork:
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud
self._work_kwargs.update(dict(start_with_flow=False))
cloud_compute = self._work_kwargs.get("cloud_compute", None)
self._work_kwargs.update(
dict(
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud
start_with_flow=False,
# don't try to create multiple works in a single machine
cloud_compute=cloud_compute.clone() if cloud_compute else None,
)
)
return self._work_cls(*self._work_args, **self._work_kwargs)

def add_work(self, work) -> str:
Expand Down
10 changes: 9 additions & 1 deletion tests/tests_app/components/test_auto_scaler.py
Expand Up @@ -3,7 +3,7 @@

import pytest

from lightning_app import LightningWork
from lightning_app import CloudCompute, LightningWork
from lightning_app.components import AutoScaler


Expand Down Expand Up @@ -90,3 +90,11 @@ def test_scale(replicas, metrics, expected_replicas):
)

assert auto_scaler.scale(replicas, metrics) == expected_replicas


def test_create_work_cloud_compute_cloned():
"""Test CloudCompute is cloned to avoid creating multiple works in a single machine."""
cloud_compute = CloudCompute("gpu")
auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute)
_ = auto_scaler.create_work()
assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute

0 comments on commit ed223e8

Please sign in to comment.