diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 658e047f5f2a0..000fd134911da 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `configure_layout` method to the `LightningWork` which can be used to control how the work is handled in the layout of a parent flow ([#15926](https://github.com/Lightning-AI/lightning/pull/15926)) +- Added automatic conversion of list and dict of works and flows to structures ([#15961](https://github.com/Lightning-AI/lightning/pull/15961)) + ### Changed diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py index de079238658a2..a79794bac3d20 100644 --- a/src/lightning_app/core/flow.py +++ b/src/lightning_app/core/flow.py @@ -142,6 +142,14 @@ def __setattr__(self, name: str, value: Any) -> None: if name in self._works and value != getattr(self, name): raise AttributeError(f"Cannot set attributes as the work can't be changed once defined: {name}") + if isinstance(value, (list, dict)) and value: + _type = (LightningFlow, LightningWork, List, Dict) + if isinstance(value, list) and all(isinstance(va, _type) for va in value): + value = List(*value) + + if isinstance(value, dict) and all(isinstance(va, _type) for va in value.values()): + value = Dict(**value) + if isinstance(value, LightningFlow): self._flows.add(name) _set_child_name(self, value, name) @@ -163,10 +171,10 @@ def __setattr__(self, name: str, value: Any) -> None: value._register_cloud_compute() elif isinstance(value, (Dict, List)): - value._backend = self._backend self._structures.add(name) _set_child_name(self, value, name) - if self._backend: + if getattr(self, "_backend", None) is not None: + value._backend = self._backend for flow in value.flows: LightningFlow._attach_backend(flow, self._backend) for work in value.works: diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py index 9c7f492370635..3346da5a858fc 100644 --- a/tests/tests_app/structures/test_structures.py +++ b/tests/tests_app/structures/test_structures.py @@ -518,3 +518,29 @@ def __init__(self): LightningApp(flow) # wrap in app to init all component names assert flow.list_structure[0].name == "root.list_structure.0" assert flow.dict_structure["dict_child"].name == "root.dict_structure.dict_child" + + +class FlowWiStructures(LightningFlow): + def __init__(self): + super().__init__() + + self.ws = [EmptyFlow(), EmptyFlow()] + + self.ws1 = {"a": EmptyFlow(), "b": EmptyFlow()} + + self.ws2 = { + "a": EmptyFlow(), + "b": EmptyFlow(), + "c": List(EmptyFlow(), EmptyFlow()), + "d": Dict(**{"a": EmptyFlow()}), + } + + def run(self): + pass + + +def test_flow_without_structures(): + + flow = FlowWiStructures() + assert isinstance(flow.ws, List) + assert isinstance(flow.ws1, Dict)