Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make torch.amp.autocast more generic #125103

Closed
wants to merge 12 commits into from

Conversation

guangyey
Copy link
Collaborator

@guangyey guangyey commented Apr 27, 2024

Stack from ghstack (oldest at bottom):

Motivation

As discussed in #124479, torch.amp.autocast can NOT be completely equivalent to torch.cuda.amp.autocast and torch.cpu.amp.autocast since torch.amp.autocast has NOT the default dtype for CPU (torch.bfloat16 by default) and CUDA (torch.float16 by default) respectively. We would like torch.amp.autocast to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast torch.xxx.amp.autocast for each device backend.

Solution

When None is passed to dtype, we should use torch.get_autocast_dtype to get the related dtype for each backend. Meanwhile, torch.get_autocast_dtype is necessary to be supported in JIT path for BC.

Additional Context

With this PR, torch.amp.autocast(device_type='cuda') is equivalent to torch.cuda.amp.autocast.
Add two new UTs to cover this change in eager and jit path respectively.

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng

Copy link

pytorch-bot bot commented Apr 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125103

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e11d24b with merge base 5007312 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

guangyey added a commit that referenced this pull request Apr 27, 2024
ghstack-source-id: 5dbe659705ae42831596b9eb4126d126b261f01d
Pull Request resolved: #125103
@guangyey guangyey changed the title make torch.amp.autocast more generic [WIP] make torch.amp.autocast more generic Apr 27, 2024
@guangyey guangyey marked this pull request as draft April 27, 2024 15:42
@guangyey guangyey added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 27, 2024
guangyey added a commit that referenced this pull request Apr 27, 2024
ghstack-source-id: 1cbc46833c8bcb946e0eff23f4311e766631d9c2
Pull Request resolved: #125103
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
guangyey added a commit that referenced this pull request Apr 28, 2024
ghstack-source-id: 0066d18809b6652072307eb5727fa10c9b14d564
Pull Request resolved: #125103
@pytorch-bot pytorch-bot bot added the release notes: jit release notes category label Apr 28, 2024
guangyey added a commit that referenced this pull request Apr 28, 2024
ghstack-source-id: 4e580070ec0ec666d1a2929ac588f91c6ac7af70
Pull Request resolved: #125103
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
guangyey added a commit that referenced this pull request May 6, 2024
ghstack-source-id: cd7a600e63177d5d4a7590a8c62c51cff7e10243
Pull Request resolved: #125103
@guangyey guangyey added the topic: improvements topic category label May 6, 2024
guangyey added a commit that referenced this pull request May 6, 2024
ghstack-source-id: 31ac5e57dfcbf817bdbe1e45af5c21cb08d43050
Pull Request resolved: #125103
@guangyey guangyey changed the title [WIP] make torch.amp.autocast more generic make torch.amp.autocast more generic May 6, 2024
@guangyey guangyey marked this pull request as ready for review May 6, 2024 05:52
@@ -600,28 +600,20 @@ def save_global_state(self, out=None):
)
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())

def autocast_specific_backend(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code improvements.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
Comment on lines +210 to +211
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it covered by existing UTs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add two new UTs to cover this change in eager and jit path respectively.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update the doc to mention the new default value for this arg?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

guangyey added a commit that referenced this pull request May 7, 2024
ghstack-source-id: 62a264b2a9033119fee7ebd49a93b6252ba2c89f
Pull Request resolved: #125103
guangyey added a commit that referenced this pull request May 7, 2024
ghstack-source-id: 62b35fb066c0ee740b0220e56bc153b861ff0c6e
Pull Request resolved: #125103
@guangyey guangyey requested a review from jgong5 May 7, 2024 07:24
guangyey added a commit that referenced this pull request May 7, 2024
ghstack-source-id: f4516ff524822c7994a009c6973ae7d118642b02
Pull Request resolved: #125103
# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
@guangyey
Copy link
Collaborator Author

guangyey commented May 7, 2024

@albanD This PR intends to make torch.amp.autocast to be more generic. Developers can use it to write device-agnostic code instead of using torch.cuda.amp.autocast or torch.cpu.amp.autocast. Is it reasonable?

Comment on lines +210 to +211
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update the doc to mention the new default value for this arg?

) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho we gather and restore both the cpu context and another device's context here?
This makes this code a bit weird. But sounds fair. We definitely don't want to change the behavior here.

cc @soulitzer in case this is something you want to clean up for AC in general in a follow upnow that we have the nice API

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't change the behavior here, just use torch.amp.autocast to be more generic code and leave the logic as it is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep perfect!

guangyey added a commit that referenced this pull request May 7, 2024
ghstack-source-id: 9f6516b793ad060ba57b52b1ba3bfcbe3077cd60
Pull Request resolved: #125103
# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit in doc, sounds good otherwise.

@@ -191,7 +191,9 @@ def forward(self, x):
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
enabled(bool, optional): Whether autocasting should be enabled in the region.
Default: ``True``
dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16.
dtype(torch_dtype, optional): Data type for ops run in autocast. It uses the default value
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU, by default), given by
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU, by default), given by
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated.

# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
guangyey added a commit that referenced this pull request May 8, 2024
ghstack-source-id: ccfcbf2a1d43c618076ca9d883fab4d462dd2632
Pull Request resolved: #125103
guangyey added a commit that referenced this pull request May 8, 2024
ghstack-source-id: c05a9e0166e9fdfc6fc3284f876f96ba236e3939
Pull Request resolved: #125103
@guangyey
Copy link
Collaborator Author

guangyey commented May 8, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
# Motivation
As discussed in [#124479](#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

cc mcarilli ptrblck leslie-fang-intel jgong5 ezyang msaroufim bdhirsh anijain2305 chauhang voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng

[ghstack-poisoned]
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request May 9, 2024
Summary:
# Motivation
As discussed in [#124479](pytorch/pytorch#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

X-link: pytorch/pytorch#125103
Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui

Reviewed By: izaitsevfb

Differential Revision: D57138276

fbshipit-source-id: 17f883924e43f68dd6836d99b06fe8a47cfccbf6
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

None yet

6 participants