Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update save_on_cpu and checkpointing to work with functorch wrapped tensors #89166

Closed
wants to merge 4 commits into from

Conversation

soulitzer
Copy link
Contributor

@soulitzer soulitzer commented Nov 16, 2022

Stack from ghstack (oldest at bottom):

Design doc: https://docs.google.com/document/d/1OX5__xKsZP-natgEnsRrK4gfD0wSN9WS6j4kekfn-IA/edit

This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers. This PR tries to do most things in Python, but there are probably some APIs that could exist (or maybe already exist) that could simplify this PR.

  • This PR does very weird things to stash autograd metadata:
    • The rewrap function needs to capture autograd metadata so that the higher order graphs don't get disconnected, we reuse TensorWrapper to do this, but in a way that is careful not to save the original TensorWrapper's data
    • During packing, we do a clone on the original TensorWrapper, then replace the value_ with an empty tensor, so this new dataless TensorWrapper gets captured instead by rewrap fn
    • During unpacking, when we run the rewrap fn, we just replace the value_ with the value we desire (this could either be the
      recomputed value or value that was previously offloaded)
    • The API exposed to replace value_ is set_data!
  • There doesn't seem to be a reliable way to uniquely identify a tensor since id() gets reused, using data_ptr helps but it is
    also not enough sometimes. In this PR, I'm also using the first element of the Tensor to get a test to pass.

Unanswered questions:

  • Why did we need to enable grad mode while packing (where was it disabled)

Other prototypes:

TODO:

  • verify in tests that we are actually saving the correct amount of tensors
  • try a non-zero bdim
  • make that assert more precise

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 16, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89166

Note: Links to docs will display an error until the docs builds have been completed.

❌ 7 Failures

As of commit a02b774:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

soulitzer added a commit that referenced this pull request Nov 16, 2022
…tensors

ghstack-source-id: d20523b6ba4f05d236621450d38e2686d10c24e4
Pull Request resolved: #89166
…ch wrapped tensors"


This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers.

Other prototypes:
- #89159 (alternate approach that saves the outer-most tensor instead and unwraps the necessary number of layers during unpack - the issue is that we cannot tell when we are saving the outer-most tensor)
- #88976 (same approach as this PR, but in cpp, unfinished)

TODO:
- verify in tests that we are actually saving the correct amount of tensors
- try a non-zero bdim
- make that assert more precise


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Nov 17, 2022
…tensors

ghstack-source-id: 72ef77660f0f6546efe0093cc52a2e00d50e598f
Pull Request resolved: #89166
@soulitzer soulitzer added the topic: not user facing topic category label Nov 17, 2022
@soulitzer soulitzer marked this pull request as draft November 17, 2022 06:41
…ch wrapped tensors"


This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers.

Other prototypes:
- #89159 (alternate approach that saves the outer-most tensor instead and unwraps the necessary number of layers during unpack - the issue is that we cannot tell when we are saving the outer-most tensor)
- #88976 (same approach as this PR, but in cpp, unfinished)

TODO:
- verify in tests that we are actually saving the correct amount of tensors
- try a non-zero bdim
- make that assert more precise


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Nov 17, 2022
…tensors

ghstack-source-id: b6fa3b4d555ac0c2d6e5a285532927403c4a0824
Pull Request resolved: #89166
@soulitzer soulitzer changed the title Update save_on_cpu to dedupe tensors and work with functorch wrapped tensors Update save_on_cpu and checkpointing to work with functorch wrapped tensors Nov 17, 2022
inner.size(),
dtype=inner.dtype,
layout=inner.layout,
pin_memory=(torch.cuda.is_available() and not inner.is_sparse))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inner -> unwrapped

# If saved tensor logic properly detached, we shouldn't have to unwrap and rewrap
# here at all. The rewrapping here is implicit due to lifting.
assert torch._C._functorch.maybe_get_level(ret) > original_level or original_level == -1, (
"lifting should happen unless we've left transforms entirely")
Copy link
Contributor Author

@soulitzer soulitzer Nov 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this is true actually, if we consider the case;

grad(
   forward
        grad(
               forward  (tensor packed here)
               backward (or here)
   backward             (but unpacked here)

Lifting shouldn't happen, but I haven't seen the assert triggered.

@soulitzer
Copy link
Contributor Author

@zou3519 In case you started to read this, I plan to rewrite the checkpointing logic actually, which actually may turn out to be a lot simpler than save_on_cpu (which I had mostly been thinking about before this).

From Alban's comments offline:

  • the checkpoint application may be able to be simplified (unlike save_on_cpu, we don't really need to worry about recomputing both grad(t) and t), this should get rid of most of the unwrapping + wrapping logic.
  • need to consider the case where we checkpoint a function that calls backward inside (just need to early stop when when the size of the storage reaches the size of what we expect to be saved) or else it would infinite loop
  • agrees that reenabling functorch modes if backward is done later should be done, and should be straightforward

@zou3519 zou3519 self-requested a review November 17, 2022 20:14
@zou3519
Copy link
Contributor

zou3519 commented Nov 18, 2022

need to consider the case where we checkpoint a function that calls backward inside (just need to early stop when when the size of the storage reaches the size of what we expect to be saved) or else it would infinite loop

For this case, let's say we're doing

def f(x):
  y = x.sin()
  z = y.sin()
  return z

def g(x):
  return checkpoint(grad_1(f), x)

grad_0(g)(x)

grad_1 executes something like:

y = x.sin()
z = y.sin()
# --- 
recomputed_y = x.sin()
gy = recomputed_y.cos()
gx = gy * x.cos()

If we "early stop checkpointing", what happens in grad_0? Do gy, recomputed_y get recomputed, or are they saved?

…h wrapped tensors"


Design doc: https://docs.google.com/document/d/1OX5__xKsZP-natgEnsRrK4gfD0wSN9WS6j4kekfn-IA/edit

This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers. This PR tries to do most things in Python, but there are probably some APIs that could exist (or maybe already exist) that could simplify this PR.

- This PR does very weird things to stash autograd metadata:
  - The rewrap function needs to capture autograd metadata so that the higher order graphs don't get disconnected, we reuse TensorWrapper to do this, but in a way that is careful not to save the original TensorWrapper's data
  - During packing, we do a clone on the original TensorWrapper, then replace the value_ with an empty tensor, so this new dataless TensorWrapper gets captured instead by rewrap fn
  - During unpacking, when we run the rewrap fn, we just replace the value_ with the value we desire (this could either be the 
    recomputed value or value that was previously offloaded)
  - The API exposed to replace value_ is set_data!
- There doesn't seem to be a reliable way to uniquely identify a tensor since id() gets reused, using data_ptr helps but it is 
  also not enough sometimes. In this PR, I'm also using the first element of the Tensor to get a test to pass.

Unanswered questions:
- Why did we need to enable grad mode while packing (where was it disabled)

Other prototypes:
- #89159 (alternate approach that saves the outer-most tensor instead and unwraps the necessary number of layers during unpack - the issue is that we cannot tell when we are saving the outer-most tensor)
- #88976 (same approach as this PR, but in cpp, unfinished)

TODO:
- verify in tests that we are actually saving the correct amount of tensors
- try a non-zero bdim
- make that assert more precise


[ghstack-poisoned]
soulitzer added a commit that referenced this pull request Nov 28, 2022
…tensors

ghstack-source-id: f863257dc87892fd4ab9cc2351a72b87a46a1231
Pull Request resolved: #89166
@github-actions
Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jan 27, 2023
@github-actions github-actions bot closed this Feb 26, 2023
@facebook-github-bot facebook-github-bot deleted the gh/soulitzer/151/head branch June 8, 2023 18:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants