Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore num_nodes when running MultiNode components locally #15806

Merged
merged 9 commits into from Nov 24, 2022
1 change: 1 addition & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- `lightning add ssh-key` CLI command has been transitioned to `lightning create ssh-key` with the same calling signature ([#15761](https://github.com/Lightning-AI/lightning/pull/15761))
- `lightning remove ssh-key` CLI command has been transitioned to `lightning delete ssh-key` with the same calling signature ([#15761](https://github.com/Lightning-AI/lightning/pull/15761))
- The `MultiNode` components now warn the user when running with `num_nodes > 1` locally ([#15806](https://github.com/Lightning-AI/lightning/pull/15806))


### Deprecated
Expand Down
15 changes: 13 additions & 2 deletions src/lightning_app/components/multi_node/base.py
@@ -1,8 +1,10 @@
import warnings
from typing import Any, Type

from lightning_app import structures
from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.utilities.cloud import is_running_in_cloud
from lightning_app.utilities.packaging.cloud_compute import CloudCompute


Expand Down Expand Up @@ -45,12 +47,21 @@ def run(

Arguments:
work_cls: The work to be executed
num_nodes: Number of nodes.
cloud_compute: The cloud compute object used in the cloud.
num_nodes: Number of nodes. Gets ignored when running locally. Launch the app with --cloud to run on
multiple cloud machines.
cloud_compute: The cloud compute object used in the cloud. The value provided here gets ignored when
running locally.
work_args: Arguments to be provided to the work on instantiation.
work_kwargs: Keywords arguments to be provided to the work on instantiation.
"""
super().__init__()
if num_nodes > 1 and not is_running_in_cloud():
num_nodes = 1
warnings.warn(
f"You set {type(self).__name__}(num_nodes={num_nodes}, ...)` but this app is running locally. "
" We assume you are debugging and will ignore the `num_nodes` argument. "
" To run on multiple nodes in the cloud, launch your app with `--cloud`."
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
)
self.ws = structures.List(
*[
work_cls(
Expand Down
19 changes: 19 additions & 0 deletions tests/tests_app/components/multi_node/test_base.py
@@ -0,0 +1,19 @@
from re import escape
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

import pytest
from tests_app.helpers.utils import no_warning_call

from lightning_app import CloudCompute, LightningWork
from lightning_app.components import MultiNode


def test_multi_node_warn_running_locally():
class Work(LightningWork):
def run(self):
pass

with pytest.warns(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
MultiNode(Work, num_nodes=2, cloud_compute=CloudCompute("gpu"))

with no_warning_call(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
MultiNode(Work, num_nodes=1, cloud_compute=CloudCompute("gpu"))
Empty file.
30 changes: 30 additions & 0 deletions tests/tests_app/helpers/utils.py
@@ -0,0 +1,30 @@
import re
from contextlib import contextmanager
from typing import Optional, Type

import pytest


@contextmanager
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
# TODO: Replace with `lightning_utilities.test.warning.no_warning_call`
# https://github.com/Lightning-AI/utilities/issues/57

with pytest.warns(None) as record:
yield

if match is None:
try:
w = record.pop(expected_warning)
except AssertionError:
# no warning raised
return
else:
for w in record.list:
if w.category is expected_warning and re.compile(match).search(w.message.args[0]):
break
else:
return

msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
raise AssertionError(f"{msg} was raised: {w}")
7 changes: 5 additions & 2 deletions tests/tests_examples_app/public/test_multi_node.py
@@ -1,5 +1,6 @@
import os
import sys
from unittest import mock

import pytest
from tests_examples_app.public import _PATH_EXAMPLES
Expand All @@ -17,7 +18,8 @@ def on_before_run_once(self):


@pytest.mark.skip(reason="flaky")
def test_multi_node_example(monkeypatch):
@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", return_value=True)
def test_multi_node_example(_, monkeypatch):
monkeypatch.chdir(os.path.join(_PATH_EXAMPLES, "app_multi_node"))
command_line = [
"app.py",
Expand Down Expand Up @@ -50,7 +52,8 @@ def on_before_run_once(self):
],
)
@pytest.mark.skipif(sys.platform == "win32", reason="flaky")
def test_multi_node_examples(app_name, monkeypatch):
@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", return_value=True)
def test_multi_node_examples(_, app_name, monkeypatch):
monkeypatch.chdir(os.path.join(_PATH_EXAMPLES, "app_multi_node"))
command_line = [
app_name,
Expand Down