Skip to content

Commit

Permalink
[AMP] Use generic autocast in example, specify dtype (#79579) (#79579)
Browse files Browse the repository at this point in the history
Summary:
CC mruberry ptrblck

Pull Request resolved: #79579
Approved by: https://github.com/mruberry, https://github.com/ngimel

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/eff74ed7bdb97693535c202a669bc24dbd124641

Reviewed By: malfet

Differential Revision: D37278895

fbshipit-source-id: d8e0fbd76bd29e8d3e3624cd769c7866d4566f29
  • Loading branch information
eqy authored and facebook-github-bot committed Jun 20, 2022
1 parent c7a596e commit f7e72d0
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions docs/source/notes/amp_examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ CUDA Automatic Mixed Precision examples
.. currentmodule:: torch.cuda.amp

Ordinarily, "automatic mixed precision training" means training with
:class:`torch.cuda.amp.autocast` and :class:`torch.cuda.amp.GradScaler` together.
:class:`torch.autocast` and :class:`torch.cuda.amp.GradScaler` together.

Instances of :class:`torch.cuda.amp.autocast` enable autocasting for chosen regions.
Instances of :class:`torch.autocast` enable autocasting for chosen regions.
Autocasting automatically chooses the precision for GPU operations to improve performance
while maintaining accuracy.

Instances of :class:`torch.cuda.amp.GradScaler` help perform the steps of
gradient scaling conveniently. Gradient scaling improves convergence for networks with ``float16``
gradients by minimizing gradient underflow, as explained :ref:`here<gradient-scaling>`.

:class:`torch.cuda.amp.autocast` and :class:`torch.cuda.amp.GradScaler` are modular.
:class:`torch.autocast` and :class:`torch.cuda.amp.GradScaler` are modular.
In the samples below, each is used as its individual documentation suggests.

(Samples here are illustrative. See the
Expand All @@ -42,7 +42,7 @@ Typical Mixed Precision Training
optimizer.zero_grad()

# Runs the forward pass with autocasting.
with autocast():
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)

Expand Down Expand Up @@ -87,7 +87,7 @@ Calling ``scaler.unscale_(optimizer)`` before clipping enables you to clip unsca
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast():
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
Expand Down Expand Up @@ -140,7 +140,7 @@ where you called :meth:`step<step>` for a full effective batch::

for epoch in epochs:
for i, (input, target) in enumerate(data):
with autocast():
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)
loss = loss / iters_to_accumulate
Expand Down Expand Up @@ -205,7 +205,7 @@ Here's how that looks for the same L2 penalty::
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
with autocast():
with autocast(device_type='cuda', dtype=torch.float16):
output = model(input)
loss = loss_fn(output, target)

Expand All @@ -220,7 +220,7 @@ Here's how that looks for the same L2 penalty::
grad_params = [p * inv_scale for p in scaled_grad_params]

# Computes the penalty term and adds it to the loss
with autocast():
with autocast(device_type='cuda', dtype=torch.float16):
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
Expand Down Expand Up @@ -256,7 +256,7 @@ after all optimizers used this iteration have been stepped::
for input, target in data:
optimizer0.zero_grad()
optimizer1.zero_grad()
with autocast():
with autocast(device_type='cuda', dtype=torch.float16):
output0 = model0(input)
output1 = model1(input)
loss0 = loss_fn(2 * output0 + 3 * output1, target)
Expand Down Expand Up @@ -303,7 +303,7 @@ The autocast state is propagated in each one and the following will work::
dp_model = nn.DataParallel(model)

# Sets autocast in the main thread
with autocast():
with autocast(device_type='cuda', dtype=torch.float16):
# dp_model's internal threads will autocast.
output = dp_model(input)
# loss_fn also autocast
Expand Down Expand Up @@ -341,9 +341,9 @@ autocast compatibility if any function
In all cases, if you're importing the function and can't alter its definition, a safe fallback
is to disable autocast and force execution in ``float32`` ( or ``dtype``) at any points of use where errors occur::

with autocast():
with autocast(device_type='cuda', dtype=torch.float16):
...
with autocast(enabled=False):
with autocast(device_type='cuda', dtype=torch.float16, enabled=False):
output = imported_function(input1.float(), input2.float())

If you're the function's author (or can alter its definition) a better solution is to use the
Expand Down Expand Up @@ -373,7 +373,7 @@ Now ``MyMM`` can be invoked anywhere, without disabling autocast or manually cas

mymm = MyMM.apply

with autocast():
with autocast(device_type='cuda', dtype=torch.float16):
output = mymm(input1, input2)

Functions that need a particular ``dtype``
Expand Down Expand Up @@ -401,6 +401,6 @@ Now ``MyFloat32Func`` can be invoked anywhere, without manually disabling autoca

func = MyFloat32Func.apply

with autocast():
with autocast(device_type='cuda', dtype=torch.float16):
# func will run in float32, regardless of the surrounding autocast state
output = func(input)

0 comments on commit f7e72d0

Please sign in to comment.