Skip to content

Commit

Permalink
fix can't instantiate abstract class
Browse files Browse the repository at this point in the history
[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

fix
  • Loading branch information
awaelchli committed Dec 16, 2022
1 parent dd8c8d5 commit 61bb90e
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions tests/tests_lite/plugins/precision/test_native_amp_integration.py
Expand Up @@ -74,17 +74,15 @@ def test_native_mixed_precision(accelerator, precision, expected_dtype):
lite.run()


@RunIf(min_torch="1.13", min_cuda_gpus=1)
def test_native_mixed_precision_fused_optimizer_parity():
def run(fused=False):
class FusedTest(LightningLite):
def run(self, fused=False):
seed_everything(1234)
lite = LightningLite(accelerator="cuda", precision=16, devices=1)

model = nn.Linear(10, 10).to(lite.device) # TODO: replace with individual setup_model call
model = nn.Linear(10, 10).to(self.device) # TODO: replace with individual setup_model call
optimizer = torch.optim.Adam(model.parameters(), lr=1.0, fused=fused)

model, optimizer = lite.setup(model, optimizer)
assert isinstance(lite._precision.scaler, torch.cuda.amp.GradScaler)
model, optimizer = self.setup(model, optimizer)
assert isinstance(self._precision.scaler, torch.cuda.amp.GradScaler)

data = torch.randn(10, 10, device="cuda")
target = torch.randn(10, 10, device="cuda")
Expand All @@ -94,13 +92,18 @@ def run(fused=False):
optimizer.zero_grad()
output = model(data)
loss = (output - target).abs().sum()
lite.backward(loss)
self.backward(loss)
optimizer.step()
losses.append(loss.detach())
return torch.stack(losses), model.parameters()

losses, params = run(fused=False)
losses_fused, params_fused = run(fused=True)

@RunIf(min_torch="1.13", min_cuda_gpus=1)
def test_native_mixed_precision_fused_optimizer_parity():
lite = FusedTest(accelerator="cuda", precision=16, devices=1)
losses, params = lite.run(fused=False)
lite = FusedTest(accelerator="cuda", precision=16, devices=1)
losses_fused, params_fused = lite.run(fused=True)

# Both the regular and the fused version of Adam produce the same losses and model weights
torch.testing.assert_close(losses, losses_fused)
Expand Down

0 comments on commit 61bb90e

Please sign in to comment.