Skip to content

Commit

Permalink
make torch.amp.autocast more generic
Browse files Browse the repository at this point in the history
ghstack-source-id: 1cbc46833c8bcb946e0eff23f4311e766631d9c2
Pull Request resolved: #125103
  • Loading branch information
guangyey committed Apr 28, 2024
1 parent 63ea78a commit 345c080
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 45 deletions.
47 changes: 22 additions & 25 deletions benchmarks/dynamo/common.py
Expand Up @@ -2041,31 +2041,28 @@ def setup_amp(self, current_device=None):

devices = [current_device] if current_device else self.args.devices
if self.args.amp:
if devices == ["cuda"]:
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = torch.cuda.amp.autocast
if devices == ["cpu"]:
self.autocast = torch.cpu.amp.autocast
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = functools.partial(torch.amp.autocast, device_type=devices[0])
if self.args.amp_dtype:
amp_dtype = (
torch.float16
Expand Down
16 changes: 4 additions & 12 deletions torch/_dynamo/output_graph.py
Expand Up @@ -598,28 +598,20 @@ def save_global_state(self, out=None):
)
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())

def autocast_specific_backend(
device_type: str, func: Callable[[str, Any], None]
):
def decorator(value):
return func(device_type, value)

return decorator

global_state["autocast_enabled"] = (
autocast_specific_backend("cuda", torch.set_autocast_enabled),
functools.partial(torch.set_autocast_enabled, "cuda"),
torch.is_autocast_enabled("cuda"),
)
global_state["autocast_cpu_enabled"] = (
autocast_specific_backend("cpu", torch.set_autocast_enabled),
functools.partial(torch.set_autocast_enabled, "cpu"),
torch.is_autocast_enabled("cpu"),
)
global_state["autocast_gpu_dtype"] = (
autocast_specific_backend("cuda", torch.set_autocast_dtype),
functools.partial(torch.set_autocast_dtype, "cuda"),
torch.get_autocast_dtype("cuda"),
)
global_state["autocast_cpu_dtype"] = (
autocast_specific_backend("cpu", torch.set_autocast_dtype),
functools.partial(torch.set_autocast_dtype, "cpu"),
torch.get_autocast_dtype("cpu"),
)
global_state["autocast_cache_enabled"] = (
Expand Down
2 changes: 2 additions & 0 deletions torch/amp/autocast_mode.py
Expand Up @@ -203,6 +203,8 @@ def __init__(
enabled: bool = True,
cache_enabled: Optional[bool] = None,
):
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
Expand Down
14 changes: 6 additions & 8 deletions torch/utils/checkpoint.py
Expand Up @@ -287,11 +287,10 @@ def backward(ctx, *args):
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
detached_inputs = detach_variable(tuple(inputs))

device_autocast_ctx = device_module.amp.autocast(
**ctx.device_autocast_kwargs
device_autocast_ctx = torch.amp.autocast(
device_type=ctx.device, **ctx.device_autocast_kwargs
) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext()
with torch.enable_grad(), device_autocast_ctx, \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
Expand Down Expand Up @@ -1394,11 +1393,10 @@ def recompute_fn(*inputs):
if had_device_in_fwd:
set_device_states(fwd_devices, fwd_device_states)

device_autocast_ctx = device_module.amp.autocast(
**device_autocast_kwargs
device_autocast_ctx = torch.amp.autocast(
device_type=device, **device_autocast_kwargs
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
fn(*args, **kwargs)

new_frame = _CheckpointFrame(
Expand Down

0 comments on commit 345c080

Please sign in to comment.