New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update save_on_cpu and checkpointing to work with functorch wrapped tensors #89166
Conversation
…tensors [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89166
Note: Links to docs will display an error until the docs builds have been completed. ❌ 7 FailuresAs of commit a02b774: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…tensors ghstack-source-id: d20523b6ba4f05d236621450d38e2686d10c24e4 Pull Request resolved: #89166
…ch wrapped tensors" This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers. Other prototypes: - #89159 (alternate approach that saves the outer-most tensor instead and unwraps the necessary number of layers during unpack - the issue is that we cannot tell when we are saving the outer-most tensor) - #88976 (same approach as this PR, but in cpp, unfinished) TODO: - verify in tests that we are actually saving the correct amount of tensors - try a non-zero bdim - make that assert more precise [ghstack-poisoned]
…tensors ghstack-source-id: 72ef77660f0f6546efe0093cc52a2e00d50e598f Pull Request resolved: #89166
…ch wrapped tensors" This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers. Other prototypes: - #89159 (alternate approach that saves the outer-most tensor instead and unwraps the necessary number of layers during unpack - the issue is that we cannot tell when we are saving the outer-most tensor) - #88976 (same approach as this PR, but in cpp, unfinished) TODO: - verify in tests that we are actually saving the correct amount of tensors - try a non-zero bdim - make that assert more precise [ghstack-poisoned]
…tensors ghstack-source-id: b6fa3b4d555ac0c2d6e5a285532927403c4a0824 Pull Request resolved: #89166
inner.size(), | ||
dtype=inner.dtype, | ||
layout=inner.layout, | ||
pin_memory=(torch.cuda.is_available() and not inner.is_sparse)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inner -> unwrapped
# If saved tensor logic properly detached, we shouldn't have to unwrap and rewrap | ||
# here at all. The rewrapping here is implicit due to lifting. | ||
assert torch._C._functorch.maybe_get_level(ret) > original_level or original_level == -1, ( | ||
"lifting should happen unless we've left transforms entirely") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure this is true actually, if we consider the case;
grad(
forward
grad(
forward (tensor packed here)
backward (or here)
backward (but unpacked here)
Lifting shouldn't happen, but I haven't seen the assert triggered.
@zou3519 In case you started to read this, I plan to rewrite the checkpointing logic actually, which actually may turn out to be a lot simpler than save_on_cpu (which I had mostly been thinking about before this). From Alban's comments offline:
|
For this case, let's say we're doing def f(x):
y = x.sin()
z = y.sin()
return z
def g(x):
return checkpoint(grad_1(f), x)
grad_0(g)(x) grad_1 executes something like:
If we "early stop checkpointing", what happens in grad_0? Do gy, recomputed_y get recomputed, or are they saved? |
…h wrapped tensors" Design doc: https://docs.google.com/document/d/1OX5__xKsZP-natgEnsRrK4gfD0wSN9WS6j4kekfn-IA/edit This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers. This PR tries to do most things in Python, but there are probably some APIs that could exist (or maybe already exist) that could simplify this PR. - This PR does very weird things to stash autograd metadata: - The rewrap function needs to capture autograd metadata so that the higher order graphs don't get disconnected, we reuse TensorWrapper to do this, but in a way that is careful not to save the original TensorWrapper's data - During packing, we do a clone on the original TensorWrapper, then replace the value_ with an empty tensor, so this new dataless TensorWrapper gets captured instead by rewrap fn - During unpacking, when we run the rewrap fn, we just replace the value_ with the value we desire (this could either be the recomputed value or value that was previously offloaded) - The API exposed to replace value_ is set_data! - There doesn't seem to be a reliable way to uniquely identify a tensor since id() gets reused, using data_ptr helps but it is also not enough sometimes. In this PR, I'm also using the first element of the Tensor to get a test to pass. Unanswered questions: - Why did we need to enable grad mode while packing (where was it disabled) Other prototypes: - #89159 (alternate approach that saves the outer-most tensor instead and unwraps the necessary number of layers during unpack - the issue is that we cannot tell when we are saving the outer-most tensor) - #88976 (same approach as this PR, but in cpp, unfinished) TODO: - verify in tests that we are actually saving the correct amount of tensors - try a non-zero bdim - make that assert more precise [ghstack-poisoned]
…tensors ghstack-source-id: f863257dc87892fd4ab9cc2351a72b87a46a1231 Pull Request resolved: #89166
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
Design doc: https://docs.google.com/document/d/1OX5__xKsZP-natgEnsRrK4gfD0wSN9WS6j4kekfn-IA/edit
This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers. This PR tries to do most things in Python, but there are probably some APIs that could exist (or maybe already exist) that could simplify this PR.
recomputed value or value that was previously offloaded)
also not enough sometimes. In this PR, I'm also using the first element of the Tensor to get a test to pass.
Unanswered questions:
Other prototypes:
TODO: