diff --git a/test/test_autograd.py b/test/test_autograd.py index 53ddb8f67dccf7f..6e26f67f6dc34da 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6795,6 +6795,20 @@ def inplace_double(x): # not leaf, not output test(lambda: (1 + torch.randn(5, requires_grad=True)), False) + def test_saved_variable_saved_original_inplace_detach(self): + # Detaching a tensor that is saved input raises + a = torch.tensor(1., requires_grad=True).clone() + b = a.sin() + a.detach_() + with self.assertRaisesRegex(RuntimeError, "Trying to use a saved tensor that has been detached"): + b.backward() + + # Detaching a tensor that is saved as output is OK + a = torch.tensor(1., requires_grad=True).clone() + b = a.exp() + a.detach_() + b.backward() + def test_saved_variable_packing_unpacking_did_not_save_original_with_hooks(self): # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks # The saved_original / did_not_save_original distinction corresponds to the `save_original` diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index a2e0f05b63943ba..d438205e8947fc8 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -144,7 +144,16 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { : grad_fn_; if (!is_leaf_ && !grad_fn) { - TORCH_INTERNAL_ASSERT(saved_for, "No grad_fn for non-leaf saved tensor"); + // This issue was introduced when we added logic to save the original + // because now we rely on data_.grad_fn(), but can be unreliable if the + // autograd_meta of that saved tensor is cleared with an in-place detach. + // As a simple fix, we choose to disallow that behavior here even though + // it makes behavior inconsistent depending on whether you are saving + // input or output. + TORCH_CHECK( + saved_for, + "Trying to use a saved tensor that has been detached in-place, i.e. with .detach_()." + "This is not supported, please use out-of-place `.detach()` instead"); grad_fn = std::move(saved_for); }