From 25a2284cb95cbcefd4c4bc9105e2caa1916fb930 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 21 Nov 2022 14:07:13 +0000 Subject: [PATCH 1/3] Improve `LightningTrainerScript` start-up time --- src/lightning_app/components/training.py | 31 ++++++++++-------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/src/lightning_app/components/training.py b/src/lightning_app/components/training.py index 4618b5aa9e9cb..6d3c86eb50374 100644 --- a/src/lightning_app/components/training.py +++ b/src/lightning_app/components/training.py @@ -147,33 +147,28 @@ def __init__( the ServableModule API """ super().__init__() - self.ws = structures.List() - self.has_initialized = False self.script_path = script_path self.script_args = script_args self.num_nodes = num_nodes - self._cloud_compute = cloud_compute # TODO: Add support for cloudCompute self.sanity_serving = sanity_serving self._script_runner = script_runner self._script_runner_kwargs = script_runner_kwargs - def run(self, **run_kwargs): - if not self.has_initialized: - for node_rank in range(self.num_nodes): - self.ws.append( - self._script_runner( - script_path=self.script_path, - script_args=self.script_args, - cloud_compute=self._cloud_compute, - node_rank=node_rank, - sanity_serving=self.sanity_serving, - num_nodes=self.num_nodes, - **self._script_runner_kwargs, - ) + self.ws = structures.List() + for node_rank in range(self.num_nodes): + self.ws.append( + self._script_runner( + script_path=self.script_path, + script_args=self.script_args, + cloud_compute=cloud_compute, + node_rank=node_rank, + sanity_serving=self.sanity_serving, + num_nodes=self.num_nodes, + **self._script_runner_kwargs, ) + ) - self.has_initialized = True - + def run(self, **run_kwargs): for work in self.ws: if all(w.internal_ip for w in self.ws): internal_urls = [(w.internal_ip, w.port) for w in self.ws] From 4c9832759d788d80e05cf0abdcb59a698524db44 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 21 Nov 2022 14:22:17 +0000 Subject: [PATCH 2/3] Fix example --- examples/app_multi_node/train_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/app_multi_node/train_pytorch.py b/examples/app_multi_node/train_pytorch.py index 9599bce5bbd85..6beeac0f04b2b 100644 --- a/examples/app_multi_node/train_pytorch.py +++ b/examples/app_multi_node/train_pytorch.py @@ -23,7 +23,7 @@ def distributed_train(local_rank: int, main_address: str, main_port: int, num_no # 2. PREPARE DISTRIBUTED MODEL model = torch.nn.Linear(32, 2) device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu") - model = DistributedDataParallel(model, device_ids=[local_rank]).to(device) + model = DistributedDataParallel(model, device_ids=[local_rank] if torch.cuda.is_available() else None).to(device) # 3. SETUP LOSS AND OPTIMIZER criterion = torch.nn.MSELoss() @@ -55,7 +55,7 @@ def run(self, main_address: str, main_port: int, num_nodes: int, node_rank: int) ) -# 32 GPUs: (8 nodes x 4 v 100) +# 32 GPUs: (2 nodes x 4 v 100) compute = L.CloudCompute("gpu-fast-multi") # 4xV100 component = MultiNode(PyTorchDistributed, num_nodes=2, cloud_compute=compute) app = L.LightningApp(component) From c1a012472cc0935e1e3bbbfd5d5c714ae4ee2a54 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 21 Nov 2022 14:31:17 +0000 Subject: [PATCH 3/3] Revert --- examples/app_multi_node/train_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/app_multi_node/train_pytorch.py b/examples/app_multi_node/train_pytorch.py index 6beeac0f04b2b..9599bce5bbd85 100644 --- a/examples/app_multi_node/train_pytorch.py +++ b/examples/app_multi_node/train_pytorch.py @@ -23,7 +23,7 @@ def distributed_train(local_rank: int, main_address: str, main_port: int, num_no # 2. PREPARE DISTRIBUTED MODEL model = torch.nn.Linear(32, 2) device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu") - model = DistributedDataParallel(model, device_ids=[local_rank] if torch.cuda.is_available() else None).to(device) + model = DistributedDataParallel(model, device_ids=[local_rank]).to(device) # 3. SETUP LOSS AND OPTIMIZER criterion = torch.nn.MSELoss() @@ -55,7 +55,7 @@ def run(self, main_address: str, main_port: int, num_nodes: int, node_rank: int) ) -# 32 GPUs: (2 nodes x 4 v 100) +# 32 GPUs: (8 nodes x 4 v 100) compute = L.CloudCompute("gpu-fast-multi") # 4xV100 component = MultiNode(PyTorchDistributed, num_nodes=2, cloud_compute=compute) app = L.LightningApp(component)