From 61bb90e47547e8d2c096c0ca8a06f4fc2e632efc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 16 Dec 2022 13:52:17 +0100 Subject: [PATCH] fix can't instantiate abstract class [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix --- .../precision/test_native_amp_integration.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/tests_lite/plugins/precision/test_native_amp_integration.py b/tests/tests_lite/plugins/precision/test_native_amp_integration.py index cd927229cd8f8..cc111ca1aaa7a 100644 --- a/tests/tests_lite/plugins/precision/test_native_amp_integration.py +++ b/tests/tests_lite/plugins/precision/test_native_amp_integration.py @@ -74,17 +74,15 @@ def test_native_mixed_precision(accelerator, precision, expected_dtype): lite.run() -@RunIf(min_torch="1.13", min_cuda_gpus=1) -def test_native_mixed_precision_fused_optimizer_parity(): - def run(fused=False): +class FusedTest(LightningLite): + def run(self, fused=False): seed_everything(1234) - lite = LightningLite(accelerator="cuda", precision=16, devices=1) - model = nn.Linear(10, 10).to(lite.device) # TODO: replace with individual setup_model call + model = nn.Linear(10, 10).to(self.device) # TODO: replace with individual setup_model call optimizer = torch.optim.Adam(model.parameters(), lr=1.0, fused=fused) - model, optimizer = lite.setup(model, optimizer) - assert isinstance(lite._precision.scaler, torch.cuda.amp.GradScaler) + model, optimizer = self.setup(model, optimizer) + assert isinstance(self._precision.scaler, torch.cuda.amp.GradScaler) data = torch.randn(10, 10, device="cuda") target = torch.randn(10, 10, device="cuda") @@ -94,13 +92,18 @@ def run(fused=False): optimizer.zero_grad() output = model(data) loss = (output - target).abs().sum() - lite.backward(loss) + self.backward(loss) optimizer.step() losses.append(loss.detach()) return torch.stack(losses), model.parameters() - losses, params = run(fused=False) - losses_fused, params_fused = run(fused=True) + +@RunIf(min_torch="1.13", min_cuda_gpus=1) +def test_native_mixed_precision_fused_optimizer_parity(): + lite = FusedTest(accelerator="cuda", precision=16, devices=1) + losses, params = lite.run(fused=False) + lite = FusedTest(accelerator="cuda", precision=16, devices=1) + losses_fused, params_fused = lite.run(fused=True) # Both the regular and the fused version of Adam produce the same losses and model weights torch.testing.assert_close(losses, losses_fused)