New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make torch.amp.autocast more generic #125103
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125103
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e11d24b with merge base 5007312 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 5dbe659705ae42831596b9eb4126d126b261f01d Pull Request resolved: #125103
ghstack-source-id: 1cbc46833c8bcb946e0eff23f4311e766631d9c2 Pull Request resolved: #125103
[ghstack-poisoned]
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
ghstack-source-id: 0066d18809b6652072307eb5727fa10c9b14d564 Pull Request resolved: #125103
ghstack-source-id: 4e580070ec0ec666d1a2929ac588f91c6ac7af70 Pull Request resolved: #125103
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
ghstack-source-id: cd7a600e63177d5d4a7590a8c62c51cff7e10243 Pull Request resolved: #125103
ghstack-source-id: 31ac5e57dfcbf817bdbe1e45af5c21cb08d43050 Pull Request resolved: #125103
@@ -600,28 +600,20 @@ def save_global_state(self, out=None): | |||
) | |||
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled()) | |||
|
|||
def autocast_specific_backend( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code improvements.
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
if dtype is None: | ||
dtype = torch.get_autocast_dtype(device_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it covered by existing UTs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add two new UTs to cover this change in eager and jit path respectively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should update the doc to mention the new default value for this arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
ghstack-source-id: 62a264b2a9033119fee7ebd49a93b6252ba2c89f Pull Request resolved: #125103
ghstack-source-id: 62b35fb066c0ee740b0220e56bc153b861ff0c6e Pull Request resolved: #125103
ghstack-source-id: f4516ff524822c7994a009c6973ae7d118642b02 Pull Request resolved: #125103
# Motivation As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
# Motivation As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
@albanD This PR intends to make torch.amp.autocast to be more generic. Developers can use it to write device-agnostic code instead of using |
if dtype is None: | ||
dtype = torch.get_autocast_dtype(device_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should update the doc to mention the new default value for this arg?
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext() | ||
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ | ||
recompute_context: | ||
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho we gather and restore both the cpu context and another device's context here?
This makes this code a bit weird. But sounds fair. We definitely don't want to change the behavior here.
cc @soulitzer in case this is something you want to clean up for AC in general in a follow upnow that we have the nice API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't change the behavior here, just use torch.amp.autocast
to be more generic code and leave the logic as it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep perfect!
ghstack-source-id: 9f6516b793ad060ba57b52b1ba3bfcbe3077cd60 Pull Request resolved: #125103
# Motivation As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit in doc, sounds good otherwise.
torch/amp/autocast_mode.py
Outdated
@@ -191,7 +191,9 @@ def forward(self, x): | |||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`. | |||
enabled(bool, optional): Whether autocasting should be enabled in the region. | |||
Default: ``True`` | |||
dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16. | |||
dtype(torch_dtype, optional): Data type for ops run in autocast. It uses the default value | |||
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU, by default), given by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU, by default), given by | |
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated.
# Motivation As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
ghstack-source-id: ccfcbf2a1d43c618076ca9d883fab4d462dd2632 Pull Request resolved: #125103
ghstack-source-id: c05a9e0166e9fdfc6fc3284f876f96ba236e3939 Pull Request resolved: #125103
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
# Motivation As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
# Motivation As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng [ghstack-poisoned]
Summary: # Motivation As discussed in [#124479](pytorch/pytorch#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. X-link: pytorch/pytorch#125103 Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui Reviewed By: izaitsevfb Differential Revision: D57138276 fbshipit-source-id: 17f883924e43f68dd6836d99b06fe8a47cfccbf6
Stack from ghstack (oldest at bottom):
Motivation
As discussed in #124479,
torch.amp.autocast
can NOT be completely equivalent totorch.cuda.amp.autocast
andtorch.cpu.amp.autocast
sincetorch.amp.autocast
has NOT the defaultdtype
for CPU (torch.bfloat16
by default) and CUDA (torch.float16
by default) respectively. We would liketorch.amp.autocast
to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocasttorch.xxx.amp.autocast
for each device backend.Solution
When
None
is passed todtype
, we should usetorch.get_autocast_dtype
to get the related dtype for each backend. Meanwhile,torch.get_autocast_dtype
is necessary to be supported in JIT path for BC.Additional Context
With this PR,
torch.amp.autocast(device_type='cuda')
is equivalent totorch.cuda.amp.autocast
.Add two new UTs to cover this change in eager and jit path respectively.
cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng