Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update save_on_cpu and checkpointing to work with functorch wrapped tensors #89166

Closed
wants to merge 4 commits into from

Commits on Nov 16, 2022

  1. Update save_on_cpu to dedupe tensors and work with functorch wrapped …

    …tensors
    
    [ghstack-poisoned]
    soulitzer committed Nov 16, 2022
    Configuration menu
    Copy the full SHA
    4a23049 View commit details
    Browse the repository at this point in the history

Commits on Nov 17, 2022

  1. Update on "Update save_on_cpu to dedupe tensors and work with functor…

    …ch wrapped tensors"
    
    
    This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers.
    
    Other prototypes:
    - #89159 (alternate approach that saves the outer-most tensor instead and unwraps the necessary number of layers during unpack - the issue is that we cannot tell when we are saving the outer-most tensor)
    - #88976 (same approach as this PR, but in cpp, unfinished)
    
    TODO:
    - verify in tests that we are actually saving the correct amount of tensors
    - try a non-zero bdim
    - make that assert more precise
    
    
    [ghstack-poisoned]
    soulitzer committed Nov 17, 2022
    Configuration menu
    Copy the full SHA
    62eb14a View commit details
    Browse the repository at this point in the history
  2. Update on "Update save_on_cpu to dedupe tensors and work with functor…

    …ch wrapped tensors"
    
    
    This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers.
    
    Other prototypes:
    - #89159 (alternate approach that saves the outer-most tensor instead and unwraps the necessary number of layers during unpack - the issue is that we cannot tell when we are saving the outer-most tensor)
    - #88976 (same approach as this PR, but in cpp, unfinished)
    
    TODO:
    - verify in tests that we are actually saving the correct amount of tensors
    - try a non-zero bdim
    - make that assert more precise
    
    
    [ghstack-poisoned]
    soulitzer committed Nov 17, 2022
    Configuration menu
    Copy the full SHA
    91d4539 View commit details
    Browse the repository at this point in the history

Commits on Nov 28, 2022

  1. Update on "Update save_on_cpu and checkpointing to work with functorc…

    …h wrapped tensors"
    
    
    Design doc: https://docs.google.com/document/d/1OX5__xKsZP-natgEnsRrK4gfD0wSN9WS6j4kekfn-IA/edit
    
    This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers. This PR tries to do most things in Python, but there are probably some APIs that could exist (or maybe already exist) that could simplify this PR.
    
    - This PR does very weird things to stash autograd metadata:
      - The rewrap function needs to capture autograd metadata so that the higher order graphs don't get disconnected, we reuse TensorWrapper to do this, but in a way that is careful not to save the original TensorWrapper's data
      - During packing, we do a clone on the original TensorWrapper, then replace the value_ with an empty tensor, so this new dataless TensorWrapper gets captured instead by rewrap fn
      - During unpacking, when we run the rewrap fn, we just replace the value_ with the value we desire (this could either be the 
        recomputed value or value that was previously offloaded)
      - The API exposed to replace value_ is set_data!
    - There doesn't seem to be a reliable way to uniquely identify a tensor since id() gets reused, using data_ptr helps but it is 
      also not enough sometimes. In this PR, I'm also using the first element of the Tensor to get a test to pass.
    
    Unanswered questions:
    - Why did we need to enable grad mode while packing (where was it disabled)
    
    Other prototypes:
    - #89159 (alternate approach that saves the outer-most tensor instead and unwraps the necessary number of layers during unpack - the issue is that we cannot tell when we are saving the outer-most tensor)
    - #88976 (same approach as this PR, but in cpp, unfinished)
    
    TODO:
    - verify in tests that we are actually saving the correct amount of tensors
    - try a non-zero bdim
    - make that assert more precise
    
    
    [ghstack-poisoned]
    soulitzer committed Nov 28, 2022
    Configuration menu
    Copy the full SHA
    a02b774 View commit details
    Browse the repository at this point in the history