Skip to content

Commit

Permalink
[App] Resolve bi-directional queue bug (#15642)
Browse files Browse the repository at this point in the history
(cherry picked from commit 0250c19)
  • Loading branch information
tchaton authored and Borda committed Nov 16, 2022
1 parent fafe429 commit 7981c93
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 15 deletions.
3 changes: 3 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed race condition to over-write the frontend with app infos ([#15398](https://github.com/Lightning-AI/lightning/pull/15398))


- Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642))



## [1.8.0] - 2022-11-01

Expand Down
1 change: 1 addition & 0 deletions src/lightning_app/cli/commands/logs.py
Expand Up @@ -71,6 +71,7 @@ def _show_logs(app_name: str, components: List[str], follow: bool) -> None:
works = client.lightningwork_service_list_lightningwork(
project_id=project.project_id, app_id=apps[app_name].id
).lightningworks

app_component_names = ["flow"] + [f.name for f in apps[app_name].spec.flow_servers] + [w.name for w in works]

if not components:
Expand Down
10 changes: 3 additions & 7 deletions src/lightning_app/components/serve/python_server.py
@@ -1,5 +1,6 @@
import abc
import base64
import os
from pathlib import Path
from typing import Any, Dict, Optional

Expand All @@ -14,12 +15,6 @@
logger = Logger(__name__)


def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode("UTF-8")


class _DefaultInputData(BaseModel):
payload: str

Expand All @@ -33,7 +28,8 @@ class Image(BaseModel):

@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
imagepath = Path(__file__).absolute().parent / "catimage.png"
name = "lightning" + "_" + "app"
imagepath = Path(__file__.replace(f"lightning{os.sep}app", name)).absolute().parent / "catimage.png"
with open(imagepath, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return {"image": encoded_string.decode("UTF-8")}
Expand Down
18 changes: 11 additions & 7 deletions src/lightning_app/core/app.py
Expand Up @@ -24,7 +24,7 @@
from lightning_app.core.queues import BaseQueue, SingleProcessQueue
from lightning_app.core.work import LightningWork
from lightning_app.frontend import Frontend
from lightning_app.storage import Drive, Path
from lightning_app.storage import Drive, Path, Payload
from lightning_app.storage.path import _storage_root_dir
from lightning_app.utilities import frontend
from lightning_app.utilities.app_helpers import (
Expand Down Expand Up @@ -630,8 +630,16 @@ def _extract_vars_from_component_name(component_name: str, state):
else:
return None

# Note: Remove private keys
return {k: v for k, v in child["vars"].items() if not k.startswith("_")}
# Filter private keys and drives
return {
k: v
for k, v in child["vars"].items()
if (
not k.startswith("_")
and not (isinstance(v, dict) and v.get("type", None) == "__drive__")
and not (isinstance(v, (Payload, Path)))
)
}

def _send_flow_to_work_deltas(self, state) -> None:
if not self.flow_to_work_delta_queues:
Expand All @@ -652,10 +660,6 @@ def _send_flow_to_work_deltas(self, state) -> None:
if state_work is None or last_state_work is None:
continue

# Note: The flow shouldn't update path or drive manually.
last_state_work = apply_to_collection(last_state_work, (Path, Drive), lambda x: None)
state_work = apply_to_collection(state_work, (Path, Drive), lambda x: None)

deep_diff = DeepDiff(last_state_work, state_work, verbose_level=2).to_dict()

if "unprocessed" in deep_diff:
Expand Down
11 changes: 11 additions & 0 deletions src/lightning_app/testing/testing.py
Expand Up @@ -431,7 +431,18 @@ def fetch_logs(component_names: Optional[List[str]] = None) -> Generator:
project_id=project.project_id,
app_id=app_id,
).lightningworks

component_names = ["flow"] + [w.name for w in works]
else:

def add_prefix(c: str) -> str:
if c == "flow":
return c
if not c.startswith("root."):
return "root." + c
return c

component_names = [add_prefix(c) for c in component_names]

gen = _app_logs_reader(
logs_api_client=logs_api_client,
Expand Down
30 changes: 29 additions & 1 deletion tests/tests_app/utilities/test_proxies.py
Expand Up @@ -14,7 +14,7 @@

from lightning_app import LightningApp, LightningFlow, LightningWork
from lightning_app.runners import MultiProcessRuntime
from lightning_app.storage import Path
from lightning_app.storage import Drive, Path
from lightning_app.storage.path import _artifacts_path
from lightning_app.storage.requests import _GetRequest
from lightning_app.testing.helpers import _MockQueue, EmptyFlow
Expand Down Expand Up @@ -761,3 +761,31 @@ def test_bi_directional_proxy_forbidden(monkeypatch):
MultiProcessRuntime(app, start_server=False).dispatch()
assert app.stage == AppStage.FAILED
assert "A forbidden operation to update the work" in str(app.exception)


class WorkDrive(LightningFlow):
def __init__(self, drive):
super().__init__()
self.drive = drive
self.path = Path("data")

def run(self):
pass


class FlowDrive(LightningFlow):
def __init__(self):
super().__init__()
self.data = Drive("lit://data")
self.counter = 0

def run(self):
if not hasattr(self, "w"):
self.w = WorkDrive(self.data)
self.counter += 1


def test_bi_directional_proxy_filtering():
app = LightningApp(FlowDrive())
app.root.run()
assert app._extract_vars_from_component_name(app.root.w.name, app.state) == {}

0 comments on commit 7981c93

Please sign in to comment.