-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update save_on_cpu and checkpointing to work with functorch wrapped tensors #89166
Closed
Closed
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
4a23049
Update save_on_cpu to dedupe tensors and work with functorch wrapped …
soulitzer 62eb14a
Update on "Update save_on_cpu to dedupe tensors and work with functor…
soulitzer 91d4539
Update on "Update save_on_cpu to dedupe tensors and work with functor…
soulitzer a02b774
Update on "Update save_on_cpu and checkpointing to work with functorc…
soulitzer File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,31 @@ | |
"allow_mutation_on_saved_tensors", | ||
] | ||
|
||
|
||
def _get_tid(t) -> Tuple[int, int]: | ||
# Returns a unique identifier for a particular version of a tensor. This is | ||
# currently used in saved_tensor_hook context managers such as save_on_cpu() | ||
# and allow_mutation_on_saved_tensors(). | ||
# | ||
# id() and t._version are not sufficient because sometimes id() get reused | ||
# so we also compare t.data_ptr() and the first value in the tensor. | ||
# techincally this does not guarantee correctness either, but makes it | ||
# somewhat safer at least. | ||
# | ||
# TODO: support tensors that don't have storage | ||
item = t.item() if t.ndim == 0 else t[(0,) * t.ndim].item() | ||
return (id(t), t.data_ptr(), t._version, item) | ||
|
||
def _get_sid(t) -> Tuple[int, int]: | ||
# Returns a tuple that uniquely identifies a tensor's storage | ||
# | ||
# NB: two tensors that share the same storage may have different | ||
# sid if their storage offset is different | ||
return (t.data_ptr(), t._version) | ||
|
||
class _Handle(): | ||
pass | ||
|
||
class saved_tensors_hooks(): | ||
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors. | ||
|
||
|
@@ -84,6 +109,80 @@ def __enter__(self): | |
def __exit__(self, *args: Any): | ||
torch._C._autograd._pop_saved_tensors_default_hooks() | ||
|
||
def _is_grad_wrapper(wrapped): | ||
# There's probably a better way lol | ||
current_level = torch._C._functorch.maybe_get_level(wrapped) | ||
if current_level == -1: | ||
return False | ||
try: | ||
unwrapped = torch._C._functorch._unwrap_for_grad(wrapped, current_level) | ||
assert wrapped is not unwrapped | ||
return True | ||
except: | ||
return False | ||
|
||
def _unwrap(tensor): | ||
current_level = torch._C._functorch.maybe_get_level(tensor) | ||
assert current_level >= 1 | ||
|
||
# Assume either is grad wrapper or batched wrapper | ||
if _is_grad_wrapper(tensor): | ||
unwrapped = torch._C._functorch._unwrap_for_grad(tensor, current_level) | ||
assert tensor is not unwrapped | ||
inner_level = torch._C._functorch.maybe_get_level(unwrapped) | ||
|
||
# Clone and replace the inner tensor of the TensorWrapper with an empty tensor | ||
# This is so that we can capture the autograd metadata without capturing the | ||
# data as well. | ||
# TODO: figure out a way to do this without actually cloning the data. | ||
with torch.autograd.enable_grad(): | ||
# Why do we need to reenable grad? | ||
captured_wrapper = tensor.clone() | ||
new_inner = torch.empty_like(unwrapped) | ||
# Undo lifting | ||
new_inner = _functorch_unwrap_to_level_no_rewrap(new_inner, inner_level) | ||
captured_wrapper = _functorch_unwrap_to_level_no_rewrap(captured_wrapper, current_level) | ||
captured_wrapper.data = new_inner | ||
|
||
def rewrap_fn(new_value): | ||
nonlocal captured_wrapper | ||
captured_wrapper.data = new_value | ||
return captured_wrapper | ||
else: | ||
bdim = torch._C._functorch.maybe_get_bdim(tensor) | ||
unwrapped = torch._C._functorch.get_unwrapped(tensor) | ||
|
||
def rewrap_fn(new_value): | ||
return torch._C._functorch._add_batch_dim(new_value, bdim, current_level) | ||
|
||
return unwrapped, rewrap_fn | ||
|
||
def _functorch_unwrap_to_level_no_rewrap(tensor: torch.Tensor, target_level: int) -> torch.Tensor: | ||
current_level = torch._C._functorch.maybe_get_level(tensor) | ||
while current_level > target_level: | ||
tensor = torch._C._functorch._unwrap_for_grad(tensor, current_level) | ||
current_level = torch._C._functorch.maybe_get_level(tensor) | ||
assert current_level == target_level, (current_level, target_level) | ||
return tensor | ||
|
||
# It might be better to do more things in cpp: | ||
# https://github.com/pytorch/pytorch/pull/88976 | ||
def _functorch_unwrap_to_level(tensor: torch.Tensor, target_level: int) -> torch.Tensor: | ||
assert target_level != 0, "level 0 is not supported, you should pass -1 instead" | ||
current_level = torch._C._functorch.maybe_get_level(tensor) | ||
assert current_level >= target_level, (current_level, target_level) | ||
rewrap_fns = [] | ||
for _ in range(max(current_level, 0), max(target_level, 0), -1): | ||
current_level = torch._C._functorch.maybe_get_level(tensor) | ||
if current_level == target_level: | ||
# Sometimes wrappers can skip levels | ||
break | ||
tensor, rewrap_fn = _unwrap(tensor) | ||
rewrap_fns.append(rewrap_fn) | ||
|
||
result_level = torch._C._functorch.maybe_get_level(tensor) | ||
assert result_level == target_level, (result_level, target_level) | ||
return tensor, rewrap_fns | ||
|
||
class save_on_cpu(saved_tensors_hooks): | ||
"""Context-manager under which tensors saved by the forward pass will be | ||
|
@@ -127,24 +226,82 @@ class save_on_cpu(saved_tensors_hooks): | |
|
||
""" | ||
def __init__(self, pin_memory=False): | ||
# We use weak references here to makes sure that the only owning reference | ||
# to any tensors that are offloaded have lifetimes that are linked to that | ||
# of the graph | ||
self.unwrapped_copies: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() | ||
self.tid_to_weakhandle: weakref.WeakValueDictionary = weakref.WeakValueDictionary() | ||
|
||
def pack_to_cpu(tensor): | ||
if not pin_memory: | ||
return (tensor.device, tensor.cpu()) | ||
device = tensor.device | ||
|
||
# NB: Always unwrap the outer layer, which is guaranteed to be a TensorWrapper | ||
# or a vanilla tensor. Unwrapping if the level is -1 is a no-op (I think). | ||
level = torch._C._functorch.maybe_get_level(tensor) | ||
inner = torch._C._functorch._unwrap_for_grad(tensor, level) | ||
# NB: level after unwrapping isn't always level - 1, so query again | ||
level = torch._C._functorch.maybe_get_level(inner) | ||
|
||
# Unwrap all the way, but remember how to restore the wrappers | ||
# including their autograd metadata | ||
unwrapped, rewrap_fns = _functorch_unwrap_to_level(inner, -1) | ||
tid = _get_tid(unwrapped) | ||
|
||
if tid in self.tid_to_weakhandle: | ||
handle = self.tid_to_weakhandle[tid] | ||
else: | ||
if not pin_memory: | ||
unwrapped_copy = unwrapped.cpu() | ||
else: | ||
unwrapped_copy = torch.empty( | ||
inner.size(), | ||
dtype=inner.dtype, | ||
layout=inner.layout, | ||
pin_memory=(torch.cuda.is_available() and not inner.is_sparse)) | ||
unwrapped_copy.copy_(inner) | ||
|
||
handle = _Handle() | ||
self.unwrapped_copies[handle] = unwrapped_copy | ||
self.tid_to_weakhandle[tid] = handle | ||
|
||
packed = torch.empty( | ||
tensor.size(), | ||
dtype=tensor.dtype, | ||
layout=tensor.layout, | ||
pin_memory=(torch.cuda.is_available() and not tensor.is_sparse)) | ||
packed.copy_(tensor) | ||
return (tensor.device, packed) | ||
return (device, handle, rewrap_fns, level) | ||
|
||
def unpack_from_cpu(packed): | ||
device, tensor = packed | ||
return tensor.to(device, non_blocking=pin_memory) | ||
device, handle, rewrap_fns, original_level = packed | ||
ret = self.unwrapped_copies[handle] | ||
for rewrap_fn in reversed(rewrap_fns): | ||
ret = rewrap_fn(ret) | ||
|
||
new_level = torch._C._functorch.maybe_get_level(ret) | ||
assert new_level == original_level, (new_level, original_level) | ||
|
||
ret = ret.to(device, non_blocking=pin_memory) | ||
# Ideally we would be completely working with unwrapped tensors during backward, | ||
# but for a grad transform the wrapping level is the same as that of forward. | ||
# I think that should not be the case. (Can we fix this?) | ||
# | ||
# If saved tensor logic properly detached, we shouldn't have to unwrap and rewrap | ||
# here at all. The rewrapping here is implicit due to lifting. | ||
assert torch._C._functorch.maybe_get_level(ret) > original_level or original_level == -1, ( | ||
"lifting should happen unless we've left transforms entirely") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure this is true actually, if we consider the case;
Lifting shouldn't happen, but I haven't seen the assert triggered. |
||
|
||
return ret | ||
|
||
super().__init__(pack_to_cpu, unpack_from_cpu) | ||
|
||
def __enter__(self): | ||
torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook) | ||
return self | ||
|
||
def __exit__(self, *args: Any): | ||
# We cannot clear here because that would be bc-breaking: people sometimes run | ||
# forward using save_on_cpu, but exit the ctx before running backward. | ||
# Note that this behavior is inconsistent with that of allow_mutation_on_saved ctx | ||
# which requires that backward also be run inside the ctx. | ||
# | ||
# self.inner_copies.clear() | ||
# self.tid_to_weakhandle.clear() | ||
torch._C._autograd._pop_saved_tensors_default_hooks() | ||
|
||
@contextlib.contextmanager | ||
def disable_saved_tensors_hooks(error_message): | ||
|
@@ -289,15 +446,6 @@ def __setstate__(self, state): | |
# - if the clone exists, the tensor must've been modified in-place | ||
_allow_mutation_on_saved_tensors_enabled = False | ||
|
||
def _get_tid(t) -> Tuple[int, int, int]: | ||
return (id(t), t.data_ptr(), t._version) | ||
|
||
def _get_sid(t) -> Tuple[int, int]: | ||
return (t.data_ptr(), t._version) | ||
|
||
class _Handle(): | ||
pass | ||
|
||
class _swap_with_cloned(saved_tensors_hooks): | ||
def __init__(self, ctx): | ||
def pack_hook(t): | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inner -> unwrapped