Skip to content

Commit

Permalink
make torch.amp.autocast more generic
Browse files Browse the repository at this point in the history
ghstack-source-id: 31ac5e57dfcbf817bdbe1e45af5c21cb08d43050
Pull Request resolved: #125103
  • Loading branch information
guangyey committed May 6, 2024
1 parent 68a1f78 commit a4f128e
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 45 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -227,6 +227,7 @@ namespace c10 {
_(aten, is_autocast_enabled) \
_(aten, is_autocast_cpu_enabled) \
_(aten, is_autocast_xla_enabled) \
_(aten, get_autocast_dtype) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \
Expand Down
49 changes: 24 additions & 25 deletions benchmarks/dynamo/common.py
Expand Up @@ -2052,31 +2052,30 @@ def setup_amp(self, current_device=None):

devices = [current_device] if current_device else self.args.devices
if self.args.amp:
if devices == ["cuda"]:
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = torch.cuda.amp.autocast
if devices == ["cpu"]:
self.autocast = torch.cpu.amp.autocast
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = functools.partial(
torch.amp.autocast, device_type=devices[0]
)
if self.args.amp_dtype:
amp_dtype = (
torch.float16
Expand Down
16 changes: 4 additions & 12 deletions torch/_dynamo/output_graph.py
Expand Up @@ -600,28 +600,20 @@ def save_global_state(self, out=None):
)
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())

def autocast_specific_backend(
device_type: str, func: Callable[[str, Any], None]
):
def decorator(value):
return func(device_type, value)

return decorator

global_state["autocast_enabled"] = (
autocast_specific_backend("cuda", torch.set_autocast_enabled),
functools.partial(torch.set_autocast_enabled, "cuda"),
torch.is_autocast_enabled("cuda"),
)
global_state["autocast_cpu_enabled"] = (
autocast_specific_backend("cpu", torch.set_autocast_enabled),
functools.partial(torch.set_autocast_enabled, "cpu"),
torch.is_autocast_enabled("cpu"),
)
global_state["autocast_gpu_dtype"] = (
autocast_specific_backend("cuda", torch.set_autocast_dtype),
functools.partial(torch.set_autocast_dtype, "cuda"),
torch.get_autocast_dtype("cuda"),
)
global_state["autocast_cpu_dtype"] = (
autocast_specific_backend("cpu", torch.set_autocast_dtype),
functools.partial(torch.set_autocast_dtype, "cpu"),
torch.get_autocast_dtype("cpu"),
)
global_state["autocast_cache_enabled"] = (
Expand Down
2 changes: 2 additions & 0 deletions torch/amp/autocast_mode.py
Expand Up @@ -207,6 +207,8 @@ def __init__(
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Expand Up @@ -815,6 +815,21 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, enabled);
},
aliasAnalysisConservative()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::get_autocast_dtype(str device_type) -> ScalarType"),
[](Stack& stack) {
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
// autocast is not supported.
at::ScalarType dtype = at::ScalarType::Undefined;
#else
at::DeviceType device_type =
at::Device(pop(stack).toStringRef()).type();
at::ScalarType dtype = at::autocast::get_autocast_dtype(device_type);
#endif
push(stack, dtype);
},
aliasAnalysisConservative()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"),
unInitialized,
Expand Down
14 changes: 6 additions & 8 deletions torch/utils/checkpoint.py
Expand Up @@ -288,11 +288,10 @@ def backward(ctx, *args):
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
detached_inputs = detach_variable(tuple(inputs))

device_autocast_ctx = device_module.amp.autocast(
**ctx.device_autocast_kwargs
device_autocast_ctx = torch.amp.autocast(
device_type=ctx.device, **ctx.device_autocast_kwargs
) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext()
with torch.enable_grad(), device_autocast_ctx, \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
Expand Down Expand Up @@ -1395,11 +1394,10 @@ def recompute_fn(*inputs):
if had_device_in_fwd:
set_device_states(fwd_devices, fwd_device_states)

device_autocast_ctx = device_module.amp.autocast(
**device_autocast_kwargs
device_autocast_ctx = torch.amp.autocast(
device_type=device, **device_autocast_kwargs
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
fn(*args, **kwargs)

new_frame = _CheckpointFrame(
Expand Down

0 comments on commit a4f128e

Please sign in to comment.