Skip to content

Commit

Permalink
Remove redundant find_unused_parameters=False in Lite (#16026)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Dec 19, 2022
1 parent 9421e87 commit f407c9c
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/lightning_fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _check_strategy_and_fallback(self) -> None:
# TODO this logic should apply to both str and object config
strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag

if strategy_flag in ("ddp_spawn", "ddp_spawn_find_unused_parameters_false") and (
if strategy_flag == "ddp_spawn" and (
TorchElasticEnvironment.detect()
or KubeflowEnvironment.detect()
or SLURMEnvironment.detect()
Expand Down
17 changes: 0 additions & 17 deletions src/lightning_fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@

_DDP_FORK_ALIASES = (
"ddp_fork",
"ddp_fork_find_unused_parameters_false",
"ddp_notebook",
"ddp_notebook_find_unused_parameters_false",
)


Expand Down Expand Up @@ -177,21 +175,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
start_method=start_method,
)

entries = (
("ddp_find_unused_parameters_false", "popen"),
("ddp_spawn_find_unused_parameters_false", "spawn"),
("ddp_fork_find_unused_parameters_false", "fork"),
("ddp_notebook_find_unused_parameters_false", "fork"),
)
for name, start_method in entries:
strategy_registry.register(
name,
cls,
description=f"DDP strategy with `find_unused_parameters` as False and `start_method={start_method!r}`",
find_unused_parameters=False,
start_method=start_method,
)

def _setup_distributed(self) -> None:
self._set_world_ranks()
rank_zero_only.rank = self.global_rank
Expand Down
13 changes: 0 additions & 13 deletions src/lightning_fabric/strategies/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,24 +103,11 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"ddp_sharded_find_unused_parameters_false",
cls,
description="DDP Sharded Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
strategy_registry.register(
"ddp_sharded",
cls,
description=cls.__class__.__name__,
)
strategy_registry.register(
"ddp_sharded_spawn_find_unused_parameters_false",
cls,
description="DDP Spawn Sharded Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
start_method="spawn",
)
strategy_registry.register("ddp_sharded_spawn", cls, description=cls.__class__.__name__, start_method="spawn")


Expand Down
17 changes: 17 additions & 0 deletions tests/tests_fabric/strategies/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import MagicMock, Mock

import pytest
Expand Down Expand Up @@ -62,3 +63,19 @@ def test_ddp_no_backward_sync():
pass

module.no_sync.assert_called_once()


@mock.patch("lightning_lite.strategies.ddp.DistributedDataParallel")
def test_ddp_extra_kwargs(ddp_mock):
"""Test that additional kwargs passed to the DDPStrategy get passed down to the DistributedDataParallel
wrapper."""
module = torch.nn.Linear(1, 1)
strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")])
strategy.setup_module(module)
ddp_mock.assert_called_with(module=module, device_ids=None)

ddp_mock.reset_mock()

strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")], find_unused_parameters=True)
strategy.setup_module(module)
ddp_mock.assert_called_with(module=module, device_ids=None, find_unused_parameters=True)
15 changes: 0 additions & 15 deletions tests/tests_fabric/strategies/test_fairscale_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,3 @@ def test_fairscale_multi_process_checkpoint_state_consolidation(with_fairscale_o
weights is identical to the saved one."""
lite = ShardedSaveAndLoad(strategy=strategy, accelerator=accelerator, devices=2)
lite.run(tmpdir, with_fairscale_oss=with_fairscale_oss)


@pytest.mark.parametrize(
"strategy, expected_find_unused_parameters",
[
("ddp_sharded", None),
("ddp_sharded_find_unused_parameters_false", False),
("ddp_sharded_spawn", None),
("ddp_sharded_spawn_find_unused_parameters_false", False),
],
)
def test_fairscale_find_unused_parameters_from_registry(strategy, expected_find_unused_parameters):
lite = BoringLite(strategy=strategy)
if expected_find_unused_parameters is not None:
assert lite._strategy._ddp_kwargs["find_unused_parameters"] is False
6 changes: 0 additions & 6 deletions tests/tests_fabric/strategies/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def __init__(self, param1, param2):

def test_available_strategies_in_registry():
expected = {
"ddp_sharded_find_unused_parameters_false",
"ddp_sharded",
"ddp_find_unused_parameters_false",
"ddp",
"deepspeed",
"deepspeed_stage_1",
Expand All @@ -54,14 +52,10 @@ def test_available_strategies_in_registry():
"deepspeed_stage_3",
"deepspeed_stage_3_offload",
"deepspeed_stage_3_offload_nvme",
"ddp_sharded_spawn_find_unused_parameters_false",
"ddp_sharded_spawn",
"ddp_spawn",
"ddp_fork",
"ddp_notebook",
"ddp_spawn_find_unused_parameters_false",
"ddp_fork_find_unused_parameters_false",
"ddp_notebook_find_unused_parameters_false",
"single_tpu",
"tpu_spawn",
"xla",
Expand Down
10 changes: 2 additions & 8 deletions tests/tests_fabric/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_strategy_choice_ddp_on_cpu():

def _test_strategy_choice_ddp_and_cpu(ddp_strategy_class):
connector = _Connector(
strategy=ddp_strategy_class(find_unused_parameters=True),
strategy=ddp_strategy_class(),
accelerator="cpu",
devices=2,
)
Expand Down Expand Up @@ -379,9 +379,7 @@ def test_invalid_strategy_choice():
["strategy", "strategy_class"],
[
("ddp_spawn", DDPStrategy),
("ddp_spawn_find_unused_parameters_false", DDPStrategy),
("ddp", DDPStrategy),
("ddp_find_unused_parameters_false", DDPStrategy),
],
)
def test_strategy_choice_cpu_str(strategy, strategy_class):
Expand All @@ -394,9 +392,7 @@ def test_strategy_choice_cpu_str(strategy, strategy_class):
["strategy", "strategy_class"],
[
("ddp_spawn", DDPStrategy),
("ddp_spawn_find_unused_parameters_false", DDPStrategy),
("ddp", DDPStrategy),
("ddp_find_unused_parameters_false", DDPStrategy),
("dp", DataParallelStrategy),
("ddp_sharded", DDPShardedStrategy),
("ddp_sharded_spawn", DDPShardedStrategy),
Expand Down Expand Up @@ -780,9 +776,7 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin
assert isinstance(connector.precision, plugin_cls)


@pytest.mark.parametrize(
["strategy", "strategy_cls"], [("DDP", DDPStrategy), ("DDP_FIND_UNUSED_PARAMETERS_FALSE", DDPStrategy)]
)
@pytest.mark.parametrize(["strategy", "strategy_cls"], [("DDP", DDPStrategy), ("Ddp", DDPStrategy)])
def test_strategy_str_passed_being_case_insensitive(strategy, strategy_cls):
connector = _Connector(strategy=strategy)
assert isinstance(connector.strategy, strategy_cls)
Expand Down

0 comments on commit f407c9c

Please sign in to comment.