Skip to content

Commit

Permalink
Update on "Prevent module full_backward_hook from erroring in double …
Browse files Browse the repository at this point in the history
…backward"


See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?'

Fixes #88312


[ghstack-poisoned]
  • Loading branch information
soulitzer committed Nov 16, 2022
2 parents 9918e2e + 19e11f3 commit 84c05bc
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch/nn/modules/module.py
Expand Up @@ -308,8 +308,8 @@ def register_module_full_backward_hook(
and it is only intended for debugging/profiling purposes.
The hook will be called every time the gradients with respect to a module
are computed. Note that the hook will not execute if the gradients with
respect to module outputs are not computed. The hook should have the following
are computed, i.e. the hook will execute if and only if the gradients with
respect to module outputs are computed. The hook should have the following
signature::
hook(module, grad_input, grad_output) -> Tensor or None
Expand Down Expand Up @@ -1200,8 +1200,8 @@ def register_full_backward_hook(
r"""Registers a backward hook on the module.
The hook will be called every time the gradients with respect to a module
are computed. Note that the hook will not execute if the gradients with
respect to module outputs are not computed. The hook should have the following
are computed, i.e. the hook will execute if and only if the gradients with
respect to module outputs are computed. The hook should have the following
signature::
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
Expand Down

0 comments on commit 84c05bc

Please sign in to comment.