From 011f8cae2b90a2b6d7f7dcb10d491b05d00ef012 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 2 Nov 2022 16:08:25 -0400 Subject: [PATCH] Update on "Prevent module full_backward_hook from erroring in double backward" Fixes https://github.com/pytorch/pytorch/issues/88312 [ghstack-poisoned] --- test/test_autograd.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test/test_autograd.py b/test/test_autograd.py index e19ea54cd423..34bb15fed76b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6640,6 +6640,25 @@ def forward(self, x): gc.collect() self.assertIsNone(ref_()) + def test_full_backward_hook_double_backward(self): + x = torch.rand(1, requires_grad=True) + y = torch.rand_like(x) + + func = torch.nn.MSELoss() + counter = [0] + + def hook(module, grad_input, grad_output): + counter[0] += 1 + + func.register_full_backward_hook(hook) + + f = func(x, y) + + (gradx_f,) = torch.autograd.grad(f, x, create_graph=True) + self.assertEqual(counter[0], 1) + _ = torch.autograd.grad(gradx_f, x) + # We should not error, and counter should not be incremented + self.assertEqual(counter[0], 1) def test_input_buffer_accum(self): leaf = torch.rand(2, 2, requires_grad=True)