Skip to content

Commit

Permalink
Prevent module full_backward_hook from erroring in double backward (#…
Browse files Browse the repository at this point in the history
…88357)

Also clarifies documentation to say "execute if and only if gradients wrt outputs are computed" (previously, "execute every time gradients wrt inputs are computed")

See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312

Pull Request resolved: #88357
Approved by: https://github.com/albanD
  • Loading branch information
soulitzer authored and pytorchmergebot committed Nov 16, 2022
1 parent 0581331 commit 6b521bb
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
19 changes: 19 additions & 0 deletions test/test_autograd.py
Expand Up @@ -6638,6 +6638,25 @@ def forward(self, x):
gc.collect()
self.assertIsNone(ref_())

def test_full_backward_hook_double_backward(self):
x = torch.rand(1, requires_grad=True)
y = torch.rand_like(x)

func = torch.nn.MSELoss()
counter = [0]

def hook(module, grad_input, grad_output):
counter[0] += 1

func.register_full_backward_hook(hook)

f = func(x, y)

(gradx_f,) = torch.autograd.grad(f, x, create_graph=True)
self.assertEqual(counter[0], 1)
_ = torch.autograd.grad(gradx_f, x)
# We should not error, and counter should not be incremented
self.assertEqual(counter[0], 1)

def test_input_buffer_accum(self):
leaf = torch.rand(2, 2, requires_grad=True)
Expand Down
12 changes: 8 additions & 4 deletions torch/nn/modules/module.py
Expand Up @@ -307,8 +307,10 @@ def register_module_full_backward_hook(
This adds global state to the `nn.module` module
and it is only intended for debugging/profiling purposes.
The hook will be called every time the gradients with respect to module
inputs are computed. The hook should have the following signature::
The hook will be called every time the gradients with respect to a module
are computed, i.e. the hook will execute if and only if the gradients with
respect to module outputs are computed. The hook should have the following
signature::
hook(module, grad_input, grad_output) -> Tensor or None
Expand Down Expand Up @@ -1197,8 +1199,10 @@ def register_full_backward_hook(
) -> RemovableHandle:
r"""Registers a backward hook on the module.
The hook will be called every time the gradients with respect to module
inputs are computed. The hook should have the following signature::
The hook will be called every time the gradients with respect to a module
are computed, i.e. the hook will execute if and only if the gradients with
respect to module outputs are computed. The hook should have the following
signature::
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
Expand Down
11 changes: 4 additions & 7 deletions torch/utils/hooks.py
Expand Up @@ -99,13 +99,10 @@ def _unpack_none(self, indices, values):
def _set_user_hook(self, grad_fn):
def hook(grad_input, _):
if self.grad_outputs is None:
raise RuntimeError("Module backward hook for grad_input is called before "
"the grad_output one. This happens because the gradient "
"in your nn.Module flows to the Module's input without "
"passing through the Module's output. Make sure that the "
"output depends on the input and that the loss is computed "
"based on the output.")

# This happens because the gradient in your nn.Module flows to
# the Module's input without " passing through the Module's
# output, e.g. when you're doing double backward.
return
res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs)

for hook in self.user_hooks:
Expand Down

0 comments on commit 6b521bb

Please sign in to comment.