New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make torch.amp.autocast more generic #125103
Changes from 6 commits
054a451
2b6089e
c6e6fac
58cabab
a49c133
7e129ca
b00a0bd
73d9468
d20ca90
4e211cd
63d4d7e
e11d24b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -207,6 +207,8 @@ def __init__( | |
raise ValueError( | ||
f"Expected `device_type` of type `str`, got: `{type(device_type)}`" | ||
) | ||
if dtype is None: | ||
dtype = torch.get_autocast_dtype(device_type) | ||
Comment on lines
+213
to
+214
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it covered by existing UTs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add two new UTs to cover this change in eager and jit path respectively. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should update the doc to mention the new default value for this arg? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. |
||
if torch._jit_internal.is_scripting(): | ||
self._enabled = enabled | ||
self.device = device_type | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -288,11 +288,10 @@ def backward(ctx, *args): | |
set_device_states(ctx.fwd_devices, ctx.fwd_device_states) | ||
detached_inputs = detach_variable(tuple(inputs)) | ||
|
||
device_autocast_ctx = device_module.amp.autocast( | ||
**ctx.device_autocast_kwargs | ||
device_autocast_ctx = torch.amp.autocast( | ||
device_type=ctx.device, **ctx.device_autocast_kwargs | ||
) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext() | ||
with torch.enable_grad(), device_autocast_ctx, \ | ||
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): | ||
with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined] | ||
outputs = ctx.run_function(*detached_inputs) | ||
|
||
if isinstance(outputs, torch.Tensor): | ||
|
@@ -1395,11 +1394,10 @@ def recompute_fn(*inputs): | |
if had_device_in_fwd: | ||
set_device_states(fwd_devices, fwd_device_states) | ||
|
||
device_autocast_ctx = device_module.amp.autocast( | ||
**device_autocast_kwargs | ||
device_autocast_ctx = torch.amp.autocast( | ||
device_type=device, **device_autocast_kwargs | ||
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext() | ||
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ | ||
recompute_context: | ||
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ho we gather and restore both the cpu context and another device's context here? cc @soulitzer in case this is something you want to clean up for AC in general in a follow upnow that we have the nice API There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't change the behavior here, just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep perfect! |
||
fn(*args, **kwargs) | ||
|
||
new_frame = _CheckpointFrame( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code improvements.