Skip to content

Commit

Permalink
make torch.amp.autocast more generic (#125103)
Browse files Browse the repository at this point in the history
Summary:
# Motivation
As discussed in [#124479](pytorch/pytorch#124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend.

# Solution
When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC.

# Additional Context
With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`.
Add two new UTs to cover this change in eager and jit path respectively.

X-link: pytorch/pytorch#125103
Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui

Reviewed By: izaitsevfb

Differential Revision: D57138276

fbshipit-source-id: 17f883924e43f68dd6836d99b06fe8a47cfccbf6
  • Loading branch information
guangyey authored and facebook-github-bot committed May 9, 2024
1 parent 4c7ec3a commit 798f69b
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions userbenchmark/dynamo/dynamobench/common.py
Expand Up @@ -2086,31 +2086,30 @@ def setup_amp(self, current_device=None):

devices = [current_device] if current_device else self.args.devices
if self.args.amp:
if devices == ["cuda"]:
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = torch.cuda.amp.autocast
if devices == ["cpu"]:
self.autocast = torch.cpu.amp.autocast
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = functools.partial(
torch.amp.autocast, device_type=devices[0]
)
if self.args.amp_dtype:
amp_dtype = (
torch.float16
Expand Down

0 comments on commit 798f69b

Please sign in to comment.