diff --git a/test/test_autograd.py b/test/test_autograd.py index 7f548ff4630d846..373a57c4a885a4d 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6488,6 +6488,25 @@ def forward(self, x): gc.collect() self.assertIsNone(ref_()) + def test_full_backward_hook_double_backward(self): + x = torch.rand(1, requires_grad=True) + y = torch.rand_like(x) + + func = torch.nn.MSELoss() + counter = [0] + + def hook(module, grad_input, grad_output): + counter[0] += 1 + + func.register_full_backward_hook(hook) + + f = func(x, y) + + (gradx_f,) = torch.autograd.grad(f, x, create_graph=True) + self.assertEqual(counter[0], 1) + _ = torch.autograd.grad(gradx_f, x) + # We should not error, and counter should not be incremented + self.assertEqual(counter[0], 1) def test_input_buffer_accum(self): leaf = torch.rand(2, 2, requires_grad=True) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index b150c1f1eb2a6d6..1508b6d15108538 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -190,8 +190,10 @@ def register_module_full_backward_hook( This adds global state to the `nn.module` module and it is only intended for debugging/profiling purposes. - The hook will be called every time the gradients with respect to module - inputs are computed. The hook should have the following signature:: + The hook will be called every time the gradients with respect to a module + are computed, i.e. the hook will execute if and only if the gradients with + respect to module outputs are computed. The hook should have the following + signature:: hook(module, grad_input, grad_output) -> Tensor or None @@ -1015,8 +1017,10 @@ def register_full_backward_hook( ) -> RemovableHandle: r"""Registers a backward hook on the module. - The hook will be called every time the gradients with respect to module - inputs are computed. The hook should have the following signature:: + The hook will be called every time the gradients with respect to a module + are computed, i.e. the hook will execute if and only if the gradients with + respect to module outputs are computed. The hook should have the following + signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index 2179e7a5a7d6a7a..14f369ec92385cd 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -98,13 +98,10 @@ def _unpack_none(self, indices, values): def _set_user_hook(self, grad_fn): def hook(grad_input, _): if self.grad_outputs is None: - raise RuntimeError("Module backward hook for grad_input is called before " - "the grad_output one. This happens because the gradient " - "in your nn.Module flows to the Module's input without " - "passing through the Module's output. Make sure that the " - "output depends on the input and that the loss is computed " - "based on the output.") - + # This happens because the gradient in your nn.Module flows to + # the Module's input without " passing through the Module's + # output, e.g. when you're doing double backward. + return res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs) for hook in self.user_hooks: