Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make torch.amp.autocast more generic #125103

Closed
wants to merge 12 commits into from
Closed
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -227,6 +227,7 @@ namespace c10 {
_(aten, is_autocast_enabled) \
_(aten, is_autocast_cpu_enabled) \
_(aten, is_autocast_xla_enabled) \
_(aten, get_autocast_dtype) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \
Expand Down
49 changes: 24 additions & 25 deletions benchmarks/dynamo/common.py
Expand Up @@ -2052,31 +2052,30 @@ def setup_amp(self, current_device=None):

devices = [current_device] if current_device else self.args.devices
if self.args.amp:
if devices == ["cuda"]:
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = torch.cuda.amp.autocast
if devices == ["cpu"]:
self.autocast = torch.cpu.amp.autocast
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = functools.partial(
torch.amp.autocast, device_type=devices[0]
)
if self.args.amp_dtype:
amp_dtype = (
torch.float16
Expand Down
16 changes: 4 additions & 12 deletions torch/_dynamo/output_graph.py
Expand Up @@ -600,28 +600,20 @@ def save_global_state(self, out=None):
)
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())

def autocast_specific_backend(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code improvements.

device_type: str, func: Callable[[str, Any], None]
):
def decorator(value):
return func(device_type, value)

return decorator

global_state["autocast_enabled"] = (
autocast_specific_backend("cuda", torch.set_autocast_enabled),
functools.partial(torch.set_autocast_enabled, "cuda"),
torch.is_autocast_enabled("cuda"),
)
global_state["autocast_cpu_enabled"] = (
autocast_specific_backend("cpu", torch.set_autocast_enabled),
functools.partial(torch.set_autocast_enabled, "cpu"),
torch.is_autocast_enabled("cpu"),
)
global_state["autocast_gpu_dtype"] = (
autocast_specific_backend("cuda", torch.set_autocast_dtype),
functools.partial(torch.set_autocast_dtype, "cuda"),
torch.get_autocast_dtype("cuda"),
)
global_state["autocast_cpu_dtype"] = (
autocast_specific_backend("cpu", torch.set_autocast_dtype),
functools.partial(torch.set_autocast_dtype, "cpu"),
torch.get_autocast_dtype("cpu"),
)
global_state["autocast_cache_enabled"] = (
Expand Down
2 changes: 2 additions & 0 deletions torch/amp/autocast_mode.py
Expand Up @@ -207,6 +207,8 @@ def __init__(
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
Comment on lines +213 to +214
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it covered by existing UTs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add two new UTs to cover this change in eager and jit path respectively.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update the doc to mention the new default value for this arg?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Expand Up @@ -815,6 +815,21 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, enabled);
},
aliasAnalysisConservative()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::get_autocast_dtype(str device_type) -> ScalarType"),
[](Stack& stack) {
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
// autocast is not supported.
at::ScalarType dtype = at::ScalarType::Undefined;
#else
at::DeviceType device_type =
at::Device(pop(stack).toStringRef()).type();
at::ScalarType dtype = at::autocast::get_autocast_dtype(device_type);
#endif
push(stack, dtype);
},
aliasAnalysisConservative()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"),
unInitialized,
Expand Down
14 changes: 6 additions & 8 deletions torch/utils/checkpoint.py
Expand Up @@ -288,11 +288,10 @@ def backward(ctx, *args):
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
detached_inputs = detach_variable(tuple(inputs))

device_autocast_ctx = device_module.amp.autocast(
**ctx.device_autocast_kwargs
device_autocast_ctx = torch.amp.autocast(
device_type=ctx.device, **ctx.device_autocast_kwargs
) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext()
with torch.enable_grad(), device_autocast_ctx, \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
Expand Down Expand Up @@ -1395,11 +1394,10 @@ def recompute_fn(*inputs):
if had_device_in_fwd:
set_device_states(fwd_devices, fwd_device_states)

device_autocast_ctx = device_module.amp.autocast(
**device_autocast_kwargs
device_autocast_ctx = torch.amp.autocast(
device_type=device, **device_autocast_kwargs
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho we gather and restore both the cpu context and another device's context here?
This makes this code a bit weird. But sounds fair. We definitely don't want to change the behavior here.

cc @soulitzer in case this is something you want to clean up for AC in general in a follow upnow that we have the nice API

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't change the behavior here, just use torch.amp.autocast to be more generic code and leave the logic as it is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep perfect!

fn(*args, **kwargs)

new_frame = _CheckpointFrame(
Expand Down