Skip to content

Commit

Permalink
Prevent module full_backward_hook from erroring in double backward (#…
Browse files Browse the repository at this point in the history
…88357) (#89928)

Also clarifies documentation to say "execute if and only if gradients wrt outputs are computed" (previously, "execute every time gradients wrt inputs are computed")

See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312

Pull Request resolved: #88357
Approved by: https://github.com/albanD

Co-authored-by: soulitzer <soulitzer@gmail.com>
  • Loading branch information
weiwangmeta and soulitzer committed Dec 6, 2022
1 parent a81f9b3 commit 9c90070
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
19 changes: 19 additions & 0 deletions test/test_autograd.py
Expand Up @@ -6488,6 +6488,25 @@ def forward(self, x):
gc.collect()
self.assertIsNone(ref_())

def test_full_backward_hook_double_backward(self):
x = torch.rand(1, requires_grad=True)
y = torch.rand_like(x)

func = torch.nn.MSELoss()
counter = [0]

def hook(module, grad_input, grad_output):
counter[0] += 1

func.register_full_backward_hook(hook)

f = func(x, y)

(gradx_f,) = torch.autograd.grad(f, x, create_graph=True)
self.assertEqual(counter[0], 1)
_ = torch.autograd.grad(gradx_f, x)
# We should not error, and counter should not be incremented
self.assertEqual(counter[0], 1)

def test_input_buffer_accum(self):
leaf = torch.rand(2, 2, requires_grad=True)
Expand Down
12 changes: 8 additions & 4 deletions torch/nn/modules/module.py
Expand Up @@ -190,8 +190,10 @@ def register_module_full_backward_hook(
This adds global state to the `nn.module` module
and it is only intended for debugging/profiling purposes.
The hook will be called every time the gradients with respect to module
inputs are computed. The hook should have the following signature::
The hook will be called every time the gradients with respect to a module
are computed, i.e. the hook will execute if and only if the gradients with
respect to module outputs are computed. The hook should have the following
signature::
hook(module, grad_input, grad_output) -> Tensor or None
Expand Down Expand Up @@ -1015,8 +1017,10 @@ def register_full_backward_hook(
) -> RemovableHandle:
r"""Registers a backward hook on the module.
The hook will be called every time the gradients with respect to module
inputs are computed. The hook should have the following signature::
The hook will be called every time the gradients with respect to a module
are computed, i.e. the hook will execute if and only if the gradients with
respect to module outputs are computed. The hook should have the following
signature::
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
Expand Down
11 changes: 4 additions & 7 deletions torch/utils/hooks.py
Expand Up @@ -98,13 +98,10 @@ def _unpack_none(self, indices, values):
def _set_user_hook(self, grad_fn):
def hook(grad_input, _):
if self.grad_outputs is None:
raise RuntimeError("Module backward hook for grad_input is called before "
"the grad_output one. This happens because the gradient "
"in your nn.Module flows to the Module's input without "
"passing through the Module's output. Make sure that the "
"output depends on the input and that the loss is computed "
"based on the output.")

# This happens because the gradient in your nn.Module flows to
# the Module's input without " passing through the Module's
# output, e.g. when you're doing double backward.
return
res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs)

for hook in self.user_hooks:
Expand Down

0 comments on commit 9c90070

Please sign in to comment.