From 99edeff12714f8c609b65d7957081870dda42106 Mon Sep 17 00:00:00 2001 From: Changyu Gao Date: Mon, 5 Dec 2022 15:48:25 -0800 Subject: [PATCH] Implement _compute_intra_grad_corr_mean for gradient computation (#1095) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix gradient accumulation Add ``is_scaled_loss`` flag to support both scaled / unscaled loss Fix ``test_grad_accum`` and``test_set_num_gradients_to_accumulate`` * Add a method to scale grad for grad_accum using unscaled loss - Revert the changes in `step` method - Add a method `scale_grad_by_num_grads_to_accum`to handle gradient accumulation using unscaled loss more explicitly - Add gradient tests * Implement _compute_corr_mean_between_grads * Improve tests and comments * Use ubuntu-20.04 instead of latest Use ubuntu-20.04 to fix the `arch x64 not found` issue [Version 3.10 with arch x64 not found actions/setup-python#401](https://github.com/actions/setup-python/issues/401) * Switch flake8 from gitlab to github Flake8 was moved to Github See discussions https://www.reddit.com/r/Python/comments/yvfww8/flake8_took_down_the_gitlab_repository_in_favor/ * Fix scikit-learn package * Update PyTorch versions * Resolve comments from Min * Minor fix * Disable broken tests for new versions of PyTorch --- .circleci/config.yml | 14 ++--- .github/workflows/pre-commit.yml | 2 +- .pre-commit-config.yaml | 2 +- .../fair_dev/testing/golden_testing_data.py | 14 +++++ fairscale/optim/adascale.py | 44 ++++++++++++++ requirements-dev.txt | 3 +- .../checkpoint/test_checkpoint_activations.py | 4 ++ tests/nn/data_parallel/test_fsdp_memory.py | 4 ++ .../test_fsdp_multiple_forward_checkpoint.py | 4 ++ tests/optim/test_ddp_adascale.py | 57 +++++++++++++++++++ 10 files changed, 137 insertions(+), 11 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 9a46c9107..2ffdc716a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -100,16 +100,16 @@ install_dep_pytorch_lts: &install_dep_pytorch_lts # most recent stable version install_dep_pytorch_stable: &install_dep_pytorch_stable - run: - name: Install Dependencies with torch 1.12.0 + name: Install Dependencies with torch 1.13.0 command: | # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip - if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.11 && exit 0; fi + if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.13 && exit 0; fi # start installing - pip install --progress-bar off torch==1.12.0+cu113 torchvision==0.13.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 + pip install --progress-bar off torch==1.13.0 torchvision==0.14.0 --extra-index-url https://download.pytorch.org/whl/cu113 pip install --progress-bar off -r requirements-dev.txt pip install --progress-bar off -r requirements-benchmarks.txt python -c 'import torch; print("Torch version:", torch.__version__)' - python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "12"], f"wrong torch version {torch.__version__}"' + python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "13"], f"wrong torch version {torch.__version__}"' python -m torch.utils.collect_env wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py @@ -118,13 +118,13 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly name: Install Dependencies with a torch nightly preview build command: | # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip - if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.12 && exit 0; fi + if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.14 && exit 0; fi # start installing - pip install --pre torch==1.13.0.dev20220825+cu113 torchvision==0.14.0.dev20220825+cu113 --extra-index-url https://download.pytorch.org/whl/nightly/cu113 + pip install --pre torch==1.14.0.dev20221121+cu117 torchvision==0.15.0.dev20221121+cu117 --extra-index-url https://download.pytorch.org/whl/nightly/cu117 pip install --progress-bar off -r requirements-dev.txt pip install --progress-bar off -r requirements-benchmarks.txt python -c 'import torch; print("Torch version:", torch.__version__)' - python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "13"], f"wrong torch version {torch.__version__}"' + python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "14"], f"wrong torch version {torch.__version__}"' python -m torch.utils.collect_env wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 3c6c1f286..4341b9286 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -7,7 +7,7 @@ on: jobs: pre-commit: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 strategy: matrix: # make sure python versions are consistent with those used in .circleci/config.yml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5409d581f..c7238e5b8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: hooks: - id: black -- repo: https://gitlab.com/pycqa/flake8 +- repo: https://github.com/PyCQA/flake8 rev: 4.0.1 hooks: - id: flake8 diff --git a/fairscale/fair_dev/testing/golden_testing_data.py b/fairscale/fair_dev/testing/golden_testing_data.py index eb5c06a40..3c0dbb5d0 100644 --- a/fairscale/fair_dev/testing/golden_testing_data.py +++ b/fairscale/fair_dev/testing/golden_testing_data.py @@ -47,3 +47,17 @@ "expected_bias_grad": [1.0, 1.0], }, ] + +corr_mean_test_data = [ + { + "inputs": [ + [[1.0, 0.0, 2.0], [2.0, 0.0, 1.0]], + [[0.0, 1.0, 2.0], [2.0, 1.0, 0]], + [[3.0, 1.0, 2.0], [2.0, 1.0, -1.0]], + ], + "expected_grad": [[1.5, 0.0, 1.5], [1.0, 1.0, 1.0], [2.5, 1.0, 0.5]], + # expected pearson correlation of two micro-batches + "expected_corr": [0.5, -1.0, 0.327327], + "expected_cos_similarity": [float("nan"), 0.8165, 0.8433], + } +] diff --git a/fairscale/optim/adascale.py b/fairscale/optim/adascale.py index e20bbd20e..523630c64 100644 --- a/fairscale/optim/adascale.py +++ b/fairscale/optim/adascale.py @@ -387,6 +387,49 @@ def _update_avg(self, name: str, value: np.ndarray, factor: float) -> None: else: self._state[name] = factor * self._state[name] + (1.0 - factor) * value + def _gather_flat_grad(self) -> torch.Tensor: + """ + Helper function for gathering all gradients into a single vector. + Duplicated from torch.optim.lbfgs. + """ + + def _to_flat_view(p: torch.Tensor) -> torch.Tensor: + """ + Local helper function for _gather_flat_grad. + Returns a flattened view of the input tensor. + """ + if p.grad is None: + return p.new(p.numel()).zero_() # type: ignore + elif p.grad.is_sparse: # type: ignore + return p.grad.to_dense().view(-1) + else: + return p.grad.view(-1) + + views = [_to_flat_view(p) for param_group in self._optimizer.param_groups for p in param_group["params"]] + return torch.cat(views, 0) + + def _compute_intra_grad_corr_mean(self) -> torch.Tensor: + """ + Helper function for computing average intra correlation among gradients on different GPUs. + This should be called under `model.no_sync()` context. + """ + assert self._world_size > 1, "Only for distributed training" + flat_grad = self._gather_flat_grad() + corr_mean = torch.tensor(0.0).cuda() + if dist.get_rank() == 0: + size = flat_grad.numel() + gathered_tensors = [torch.zeros(size, device=0) for _ in range(self._world_size)] + dist.gather(flat_grad, gather_list=gathered_tensors, dst=0) + # the following requires torch 1.10+ + corr = torch.stack(gathered_tensors).corrcoef() # type: ignore + # pick out the upper triangular part of the correlation matrix + corr = corr[torch.triu(torch.ones_like(corr), diagonal=1) == 1] + corr_mean = corr.mean() + else: + dist.gather(flat_grad, gather_list=None, dst=0) + dist.broadcast(corr_mean, src=0) + return corr_mean + def _backward_hook(self, pg_idx: int, grad: torch.Tensor) -> None: # This method should be invoked once for each parameter during the # backward pass, before gradients are synchronized between world_size. @@ -449,6 +492,7 @@ def _final_callback(self) -> None: return # Since self._local_grad_sqr is FP32, sum shouldn't overflow. + # This vector has length of # of param_groups, so it is small, but we # use async to hide the all_reduce latency, esp when # of nodes is large. work = None diff --git a/requirements-dev.txt b/requirements-dev.txt index 31c9183ec..93c02565e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -28,11 +28,10 @@ pynvml == 8.0.4 # For mypy typing. It is important to have a fixed version. Otherwise, you # may run into mypy errors out differently for different versions. -# Using 1.21.5 for now because py3.7 only has up to 1.21.5, not 1.22.x. numpy == 1.22.0 # For layerwise gradient scaler -sklearn >= 0.0 +scikit-learn == 1.1.3 # For weigit. These are actually user requirements, not developer requirements. # However, due to the experimental nature of weigit, we don't expose to the diff --git a/tests/nn/checkpoint/test_checkpoint_activations.py b/tests/nn/checkpoint/test_checkpoint_activations.py index ef04612b4..b5491a1a4 100644 --- a/tests/nn/checkpoint/test_checkpoint_activations.py +++ b/tests/nn/checkpoint/test_checkpoint_activations.py @@ -83,6 +83,10 @@ def forward(self, x): @pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.skipif( + torch_version() >= (1, 13, 0), + reason="mem_peak behavior changed for torch 1.13 and above", +) def test_basic(device): if "cuda" in device and not torch.cuda.is_available(): pytest.skip("test requires a GPU") diff --git a/tests/nn/data_parallel/test_fsdp_memory.py b/tests/nn/data_parallel/test_fsdp_memory.py index b465e2500..bb5fc27b1 100644 --- a/tests/nn/data_parallel/test_fsdp_memory.py +++ b/tests/nn/data_parallel/test_fsdp_memory.py @@ -162,6 +162,10 @@ def cmp(results, expected): @pytest.mark.timeout(120) @pytest.mark.parametrize("ckpt", ["no_ckpt", "ckpt"]) @pytest.mark.parametrize("fsdp", ["ddp", "fsdp", "fsdp_amp_default", "fsdp_amp_compute_dtype32"]) +@pytest.mark.skipif( + torch_version() >= (1, 14, 0), + reason="Tests broke in Pytorch pre-release version 1.14", +) def test_fsdp_memory(fsdp, ckpt): expected = { ("ddp", "no_ckpt"): { diff --git a/tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py b/tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py index 1f30f3c73..801218a48 100644 --- a/tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py +++ b/tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py @@ -303,6 +303,10 @@ def _get_cached_results( @pytest.mark.parametrize("wrap_bn", ["auto_wrap_bn", "no_auto_wrap_bn"]) @pytest.mark.parametrize("model_type", ["model1", "model2"]) @pytest.mark.parametrize("bn_type", ["bn", "sync_bn"]) +@pytest.mark.skipif( + torch_version() >= (1, 14, 0), + reason="Tests broke in Pytorch pre-release version 1.14", +) def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type, bn_type): mixed_precision = precision == "mixed" flatten = flatten == "flatten" diff --git a/tests/optim/test_ddp_adascale.py b/tests/optim/test_ddp_adascale.py index 561827a04..8e234fdc5 100644 --- a/tests/optim/test_ddp_adascale.py +++ b/tests/optim/test_ddp_adascale.py @@ -35,6 +35,7 @@ from fairscale.fair_dev.testing.golden_testing_data import adascale_test_data from fairscale.fair_dev.testing.testing import skip_if_single_gpu +from fairscale.internal import torch_version from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import ShardedDataParallel as SDP from fairscale.optim import OSS, AdaScale @@ -152,3 +153,59 @@ def test_grad_accum(): temp_file_name = tempfile.mkstemp()[1] mp.spawn(_test_grad_accum_func, args=(world_size, temp_file_name), nprocs=world_size, join=True) + + +def _test_corr_mean_func(rank, world_size, tempfile_name, test_case): + _dist_init(rank, world_size, tempfile_name, backend="gloo") # Covers gloo + + model = Linear(3, 1, bias=False) + model.to("cuda") + model = DDP(model, device_ids=[rank]) + optim = AdaScale(SGD(model.parameters(), lr=0.1)) + results = [] + last_grad = None + for i, in_data in enumerate(test_case["inputs"]): + # use no_sync so we can access nonreduced gradients + with model.no_sync(): + in_data = Tensor(in_data[rank]).cuda() + out = model(in_data) + out.sum().backward() + results.append(optim._compute_intra_grad_corr_mean().item()) + # sync gradients manually + for p in model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) + # divide by world size + p.grad.data.div_(world_size) + grad = optim._gather_flat_grad() + assert np.allclose(grad.cpu(), test_case["expected_grad"][i]) + optim.step() + if last_grad is not None: + # compute cosine similarity + cos_similarity = torch.dot(grad, last_grad) / (grad.norm() * last_grad.norm()) + np.allclose(cos_similarity.cpu(), test_case["expected_cos_similarity"][i]) + last_grad = grad + optim.zero_grad() + assert np.allclose(results, test_case["expected_corr"]), results + + dist.destroy_process_group() + + +@skip_if_single_gpu +@pytest.mark.skipif( + torch_version() < (1, 10, 0), + reason="torch.corrcoef available only for torch 1.10 or higher", +) +def test_corr_mean(): + """ + Test _compute_intra_grad_corr_mean and _gather_flat_grad using ddp.no_sync() + We also demonstrate how cosine similarity between consecutive gradients can be computed using _gather_flat_grad + """ + world_size = 2 + temp_file_name = tempfile.mkstemp()[1] + + from fairscale.fair_dev.testing.golden_testing_data import corr_mean_test_data + + test_case = corr_mean_test_data[0] + + mp.spawn(_test_corr_mean_func, args=(world_size, temp_file_name, test_case), nprocs=world_size, join=True)