diff --git a/src/lightning_app/components/training.py b/src/lightning_app/components/training.py index 4618b5aa9e9cb..6d3c86eb50374 100644 --- a/src/lightning_app/components/training.py +++ b/src/lightning_app/components/training.py @@ -147,33 +147,28 @@ def __init__( the ServableModule API """ super().__init__() - self.ws = structures.List() - self.has_initialized = False self.script_path = script_path self.script_args = script_args self.num_nodes = num_nodes - self._cloud_compute = cloud_compute # TODO: Add support for cloudCompute self.sanity_serving = sanity_serving self._script_runner = script_runner self._script_runner_kwargs = script_runner_kwargs - def run(self, **run_kwargs): - if not self.has_initialized: - for node_rank in range(self.num_nodes): - self.ws.append( - self._script_runner( - script_path=self.script_path, - script_args=self.script_args, - cloud_compute=self._cloud_compute, - node_rank=node_rank, - sanity_serving=self.sanity_serving, - num_nodes=self.num_nodes, - **self._script_runner_kwargs, - ) + self.ws = structures.List() + for node_rank in range(self.num_nodes): + self.ws.append( + self._script_runner( + script_path=self.script_path, + script_args=self.script_args, + cloud_compute=cloud_compute, + node_rank=node_rank, + sanity_serving=self.sanity_serving, + num_nodes=self.num_nodes, + **self._script_runner_kwargs, ) + ) - self.has_initialized = True - + def run(self, **run_kwargs): for work in self.ws: if all(w.internal_ip for w in self.ws): internal_urls = [(w.internal_ip, w.port) for w in self.ws]