Skip to content

Commit

Permalink
Update on "Prevent module full_backward_hook from erroring in double …
Browse files Browse the repository at this point in the history
…backward"


Fixes #88312


[ghstack-poisoned]
  • Loading branch information
soulitzer committed Nov 11, 2022
2 parents 59fe05b + dffe0fd commit 33b0406
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
14 changes: 14 additions & 0 deletions test/test_autograd.py
Expand Up @@ -6795,6 +6795,20 @@ def inplace_double(x):
# not leaf, not output
test(lambda: (1 + torch.randn(5, requires_grad=True)), False)

def test_saved_variable_saved_original_inplace_detach(self):
# Detaching a tensor that is saved input raises
a = torch.tensor(1., requires_grad=True).clone()
b = a.sin()
a.detach_()
with self.assertRaisesRegex(RuntimeError, "Trying to use a saved tensor that has been detached"):
b.backward()

# Detaching a tensor that is saved as output is OK
a = torch.tensor(1., requires_grad=True).clone()
b = a.exp()
a.detach_()
b.backward()

def test_saved_variable_packing_unpacking_did_not_save_original_with_hooks(self):
# Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks
# The saved_original / did_not_save_original distinction corresponds to the `save_original`
Expand Down
11 changes: 10 additions & 1 deletion torch/csrc/autograd/saved_variable.cpp
Expand Up @@ -144,7 +144,16 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
: grad_fn_;

if (!is_leaf_ && !grad_fn) {
TORCH_INTERNAL_ASSERT(saved_for, "No grad_fn for non-leaf saved tensor");
// This issue was introduced when we added logic to save the original
// because now we rely on data_.grad_fn(), but can be unreliable if the
// autograd_meta of that saved tensor is cleared with an in-place detach.
// As a simple fix, we choose to disallow that behavior here even though
// it makes behavior inconsistent depending on whether you are saving
// input or output.
TORCH_CHECK(
saved_for,
"Trying to use a saved tensor that has been detached in-place, i.e. with .detach_()."
"This is not supported, please use out-of-place `.detach()` instead");
grad_fn = std::move(saved_for);
}

Expand Down

0 comments on commit 33b0406

Please sign in to comment.