Skip to content

Commit

Permalink
make torch.amp.autocast more generic
Browse files Browse the repository at this point in the history
ghstack-source-id: 62a264b2a9033119fee7ebd49a93b6252ba2c89f
Pull Request resolved: #125103
  • Loading branch information
guangyey committed May 7, 2024
1 parent 5033d3b commit eb8c451
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 52 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -227,6 +227,7 @@ namespace c10 {
_(aten, is_autocast_enabled) \
_(aten, is_autocast_cpu_enabled) \
_(aten, is_autocast_xla_enabled) \
_(aten, get_autocast_dtype) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \
Expand Down
49 changes: 24 additions & 25 deletions benchmarks/dynamo/common.py
Expand Up @@ -2052,31 +2052,30 @@ def setup_amp(self, current_device=None):

devices = [current_device] if current_device else self.args.devices
if self.args.amp:
if devices == ["cuda"]:
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = torch.cuda.amp.autocast
if devices == ["cpu"]:
self.autocast = torch.cpu.amp.autocast
# AMP training can lead to small loss values which can undeflow
# gradient values returning in zero gradients. To solve this
# problem, PyTorch introduces GradScaler. GradScaler is a stateful
# structure, that scales the loss values to prevent underflow. Loss
# values are big at the beginning of training (therefore not
# requiring scaling), while loss value tends to be small as network
# starts getting better (requiring scaling). GradScaler manages all
# of this fine tuning, checking the gradients are turning to inf,
# discarding such batches.

# Since we are not running a long iteration, default value of
# init_scale 65536 is going to turn all gradients to inf. Therefore,
# we just use a init_scale of 2.0 for benchmarking purpose.

# Disabling Gradscaler because
# 1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
# 2) Current setup shares grad_scaler for eager and dynamo model,
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
self.autocast = functools.partial(
torch.amp.autocast, device_type=devices[0]
)
if self.args.amp_dtype:
amp_dtype = (
torch.float16
Expand Down
9 changes: 9 additions & 0 deletions test/test_autocast.py
Expand Up @@ -244,6 +244,15 @@ def test_autocast_disabled_with_fp32_dtype(self):
with torch.autocast(device_type="cpu", dtype=torch.float32, enabled=False):
_ = torch.ones(10)

def test_generic_autocast(self):
for op_with_args in self.autocast_lists.torch_16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
with torch.amp.autocast(device_type="cpu"):
generic_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
with torch.cpu.amp.autocast():
cpu_autocast_output = getattr(torch, op)(*args, **maybe_kwargs)
self.assertEqual(generic_autocast_output, cpu_autocast_output)


class CustomLinear(torch.autograd.Function):
@staticmethod
Expand Down
17 changes: 17 additions & 0 deletions test/test_jit_autocast.py
Expand Up @@ -33,6 +33,23 @@ def tearDown(self):
torch._C._jit_set_autocast_mode(self.old_value)
super().tearDown()

@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_generic_jit_autocast(self):
@torch.jit.script
def fn_cuda(a, b):
with autocast():
x = torch.mm(a, b)
y = torch.sum(x)
return x, y

@torch.jit.script
def fn_generic(a, b):
with torch.amp.autocast(device_type='cpu'):
x = torch.mm(a, b)
y = torch.sum(x)
return x, y
self.assertEqual(fn_cuda(self.a_fp32, self.b_fp32), fn_generic(self.a_fp32, self.b_fp32))

@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_minimal(self):
@torch.jit.script
Expand Down
16 changes: 4 additions & 12 deletions torch/_dynamo/output_graph.py
Expand Up @@ -599,28 +599,20 @@ def save_global_state(self, out=None):
)
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())

def autocast_specific_backend(
device_type: str, func: Callable[[str, Any], None]
):
def decorator(value):
return func(device_type, value)

return decorator

global_state["autocast_enabled"] = (
autocast_specific_backend("cuda", torch.set_autocast_enabled),
functools.partial(torch.set_autocast_enabled, "cuda"),
torch.is_autocast_enabled("cuda"),
)
global_state["autocast_cpu_enabled"] = (
autocast_specific_backend("cpu", torch.set_autocast_enabled),
functools.partial(torch.set_autocast_enabled, "cpu"),
torch.is_autocast_enabled("cpu"),
)
global_state["autocast_gpu_dtype"] = (
autocast_specific_backend("cuda", torch.set_autocast_dtype),
functools.partial(torch.set_autocast_dtype, "cuda"),
torch.get_autocast_dtype("cuda"),
)
global_state["autocast_cpu_dtype"] = (
autocast_specific_backend("cpu", torch.set_autocast_dtype),
functools.partial(torch.set_autocast_dtype, "cpu"),
torch.get_autocast_dtype("cpu"),
)
global_state["autocast_cache_enabled"] = (
Expand Down
3 changes: 2 additions & 1 deletion torch/amp/autocast_mode.py
Expand Up @@ -207,11 +207,12 @@ def __init__(
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
self.fast_dtype = dtype
# TODO: support get_autocast_gpu/cpu_dtype
assert dtype is not None
return
self.device = device_type
Expand Down
14 changes: 8 additions & 6 deletions torch/csrc/jit/passes/autocast.cpp
Expand Up @@ -96,17 +96,19 @@ c10::optional<AutocastScope> parseAutocast(
use.user->s(attr::name) == "fast_dtype") {
// Search for `prim::SetAttr[name="fast_dtype"]`
auto ret = constant_as<c10::ScalarType>(use.user->input(1));
TORCH_CHECK(
ret.has_value() && ret.value() != c10::ScalarType::Undefined,
"Autocast dtype argument must be a constant and defined");
dtype = ret.value();
if (ret.has_value()) {
dtype = ret.value();
}
}
}
TORCH_CHECK(enabled.has_value(), "Autocast missing _enabled attribute");
TORCH_CHECK(!device.empty(), "Autocast missing device attribute");
if (dtype == c10::ScalarType::Undefined) {
dtype = at::autocast::get_autocast_dtype(c10::Device(device).type());
}
TORCH_CHECK(
dtype != c10::ScalarType::Undefined,
"Autocast missing fast_dtype attribute");
TORCH_CHECK(!device.empty(), "Autocast missing device attribute");
"Autocast has invalid fast_dtype attribute");
if (device == "cuda") {
scope.context.gpu_enabled = enabled.value();
scope.context.gpu_scalar_type = dtype;
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Expand Up @@ -815,6 +815,21 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
push(stack, enabled);
},
aliasAnalysisConservative()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::get_autocast_dtype(str device_type) -> ScalarType"),
[](Stack& stack) {
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
// autocast is not supported.
at::ScalarType dtype = at::ScalarType::Undefined;
#else
at::DeviceType device_type =
at::Device(pop(stack).toStringRef()).type();
at::ScalarType dtype = at::autocast::get_autocast_dtype(device_type);
#endif
push(stack, dtype);
},
aliasAnalysisConservative()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"),
unInitialized,
Expand Down
14 changes: 6 additions & 8 deletions torch/utils/checkpoint.py
Expand Up @@ -288,11 +288,10 @@ def backward(ctx, *args):
set_device_states(ctx.fwd_devices, ctx.fwd_device_states)
detached_inputs = detach_variable(tuple(inputs))

device_autocast_ctx = device_module.amp.autocast(
**ctx.device_autocast_kwargs
device_autocast_ctx = torch.amp.autocast(
device_type=ctx.device, **ctx.device_autocast_kwargs
) if torch.amp.is_autocast_available(ctx.device) else contextlib.nullcontext()
with torch.enable_grad(), device_autocast_ctx, \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
Expand Down Expand Up @@ -1395,11 +1394,10 @@ def recompute_fn(*inputs):
if had_device_in_fwd:
set_device_states(fwd_devices, fwd_device_states)

device_autocast_ctx = device_module.amp.autocast(
**device_autocast_kwargs
device_autocast_ctx = torch.amp.autocast(
device_type=device, **device_autocast_kwargs
) if torch.amp.is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
fn(*args, **kwargs)

new_frame = _CheckpointFrame(
Expand Down

0 comments on commit eb8c451

Please sign in to comment.