Skip to content

Commit

Permalink
Update on "Prevent module full_backward_hook from erroring in double …
Browse files Browse the repository at this point in the history
…backward"


Fixes #88312


[ghstack-poisoned]
  • Loading branch information
soulitzer committed Nov 2, 2022
1 parent 5952a19 commit 011f8ca
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions test/test_autograd.py
Expand Up @@ -6640,6 +6640,25 @@ def forward(self, x):
gc.collect()
self.assertIsNone(ref_())

def test_full_backward_hook_double_backward(self):
x = torch.rand(1, requires_grad=True)
y = torch.rand_like(x)

func = torch.nn.MSELoss()
counter = [0]

def hook(module, grad_input, grad_output):
counter[0] += 1

func.register_full_backward_hook(hook)

f = func(x, y)

(gradx_f,) = torch.autograd.grad(f, x, create_graph=True)
self.assertEqual(counter[0], 1)
_ = torch.autograd.grad(gradx_f, x)
# We should not error, and counter should not be incremented
self.assertEqual(counter[0], 1)

def test_input_buffer_accum(self):
leaf = torch.rand(2, 2, requires_grad=True)
Expand Down

0 comments on commit 011f8ca

Please sign in to comment.