Skip to content

Commit

Permalink
Update on "Prevent module full_backward_hook from erroring in double …
Browse files Browse the repository at this point in the history
…backward"


See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312


[ghstack-poisoned]
  • Loading branch information
soulitzer committed Nov 16, 2022
2 parents eb2dfbd + 00ad317 commit 9918e2e
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torch/nn/modules/module.py
Expand Up @@ -307,8 +307,10 @@ def register_module_full_backward_hook(
This adds global state to the `nn.module` module
and it is only intended for debugging/profiling purposes.
The hook will be called every time the gradients with respect to module
inputs are computed. The hook should have the following signature::
The hook will be called every time the gradients with respect to a module
are computed. Note that the hook will not execute if the gradients with
respect to module outputs are not computed. The hook should have the following
signature::
hook(module, grad_input, grad_output) -> Tensor or None
Expand Down Expand Up @@ -1197,8 +1199,10 @@ def register_full_backward_hook(
) -> RemovableHandle:
r"""Registers a backward hook on the module.
The hook will be called every time the gradients with respect to module
inputs are computed. The hook should have the following signature::
The hook will be called every time the gradients with respect to a module
are computed. Note that the hook will not execute if the gradients with
respect to module outputs are not computed. The hook should have the following
signature::
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
Expand Down

0 comments on commit 9918e2e

Please sign in to comment.