Skip to content

Commit

Permalink
make torch.amp.autocast more generic
Browse files Browse the repository at this point in the history
ghstack-source-id: 4e580070ec0ec666d1a2929ac588f91c6ac7af70
Pull Request resolved: #125103
  • Loading branch information
guangyey committed Apr 29, 2024
1 parent 63ea78a commit fc055f1
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 45 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -227,6 +227,7 @@ namespace c10 {
_(aten, is_autocast_enabled) \
_(aten, is_autocast_cpu_enabled) \
_(aten, is_autocast_xla_enabled) \
_(aten, get_autocast_dtype) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \
Expand Down
49 changes: 24 additions & 25 deletions benchmarks/dynamo/common.py
Expand Up @@ -2041,31 +2041,30 @@ def setup_amp(self, current_device=None):

devices = [current_device] if current_device else self.args.devices
if self.args.amp:
if devices == ["cuda"]:
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = torch.cuda.amp.autocast
if devices == ["cpu"]:
self.autocast = torch.cpu.amp.autocast
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = functools.partial(
torch.amp.autocast, device_type=devices[0]
)
if self.args.amp_dtype:
amp_dtype = (
torch.float16
Expand Down
16 changes: 4 additions & 12 deletions torch/_dynamo/output_graph.py
Expand Up @@ -598,28 +598,20 @@ def save_global_state(self, out=None):
)
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())

def autocast_specific_backend(
device_type: str, func: Callable[[str, Any], None]
):
def decorator(value):
return func(device_type, value)

return decorator

global_state["autocast_enabled"] = (
autocast_specific_backend("cuda", torch.set_autocast_enabled),
functools.partial(torch.set_autocast_enabled, "cuda"),
torch.is_autocast_enabled("cuda"),
)
global_state["autocast_cpu_enabled"] = (
autocast_specific_backend("cpu", torch.set_autocast_enabled),
functools.partial(torch.set_autocast_enabled, "cpu"),
torch.is_autocast_enabled("cpu"),
)
global_state["autocast_gpu_dtype"] = (
autocast_specific_backend("cuda", torch.set_autocast_dtype),
functools.partial(torch.set_autocast_dtype, "cuda"),
torch.get_autocast_dtype("cuda"),
)
global_state["autocast_cpu_dtype"] = (
autocast_specific_backend("cpu", torch.set_autocast_dtype),
functools.partial(torch.set_autocast_dtype, "cpu"),
torch.get_autocast_dtype("cpu"),
)
global_state["autocast_cache_enabled"] = (
Expand Down
2 changes: 2 additions & 0 deletions torch/amp/autocast_mode.py
Expand Up @@ -203,6 +203,8 @@ def __init__(
enabled: bool = True,
cache_enabled: Optional[bool] = None,
):
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Expand Up @@ -815,6 +815,16 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, enabled);
},
aliasAnalysisConservative()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::get_autocast_dtype(str device_type) -> ScalarType"),
[](Stack& stack) {
at::DeviceType device_type =
at::Device(pop(stack).toStringRef()).type();
at::ScalarType dtype = at::autocast::get_autocast_dtype(device_type);
push(stack, dtype);
},
aliasAnalysisConservative()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"),
unInitialized,
Expand Down
14 changes: 6 additions & 8 deletions torch/utils/checkpoint.py
Expand Up @@ -287,11 +287,10 @@ def backward(ctx, *args):
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
detached_inputs = detach_variable(tuple(inputs))

device_autocast_ctx = device_module.amp.autocast(
**ctx.device_autocast_kwargs
device_autocast_ctx = torch.amp.autocast(
device_type=ctx.device, **ctx.device_autocast_kwargs
) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext()
with torch.enable_grad(), device_autocast_ctx, \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
Expand Down Expand Up @@ -1394,11 +1393,10 @@ def recompute_fn(*inputs):
if had_device_in_fwd:
set_device_states(fwd_devices, fwd_device_states)

device_autocast_ctx = device_module.amp.autocast(
**device_autocast_kwargs
device_autocast_ctx = torch.amp.autocast(
device_type=device, **device_autocast_kwargs
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
fn(*args, **kwargs)

new_frame = _CheckpointFrame(
Expand Down

0 comments on commit fc055f1

Please sign in to comment.