Skip to content

Commit

Permalink
Refine Bf16 test for deepspeed (#17734)
Browse files Browse the repository at this point in the history
* Refine BF16 check in CPU/GPU

* Fixes

* Renames
  • Loading branch information
sgugger committed Jun 16, 2022
1 parent f8c8f4d commit 90c8c01
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 15 deletions.
2 changes: 2 additions & 0 deletions src/transformers/utils/__init__.py
Expand Up @@ -125,6 +125,8 @@
is_tokenizers_available,
is_torch_available,
is_torch_bf16_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_fx_available,
is_torch_fx_proxy,
Expand Down
36 changes: 24 additions & 12 deletions src/transformers/utils/import_utils.py
Expand Up @@ -272,7 +272,7 @@ def is_torch_cuda_available():
return False


def is_torch_bf16_available():
def is_torch_bf16_gpu_available():
if not is_torch_available():
return False

Expand All @@ -288,30 +288,42 @@ def is_torch_bf16_available():
# 4. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)
is_torch_gpu_bf16_available = True
is_torch_cpu_bf16_available = True
if version.parse(torch.__version__) < version.parse("1.10"):
is_torch_gpu_bf16_available = False
is_torch_cpu_bf16_available = False
return False

if torch.cuda.is_available() and torch.version.cuda is not None:
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
is_torch_gpu_bf16_available = False
return False
if int(torch.version.cuda.split(".")[0]) < 11:
is_torch_gpu_bf16_available = False
return False
if not hasattr(torch.cuda.amp, "autocast"):
is_torch_gpu_bf16_available = False
return False
else:
is_torch_gpu_bf16_available = False
return False

return True


def is_torch_bf16_cpu_available():
if not is_torch_available():
return False

import torch

if version.parse(torch.__version__) < version.parse("1.10"):
return False

# checking CPU
try:
# multiple levels of AttributeError depending on the pytorch version so do them all in one check
_ = torch.cpu.amp.autocast
except AttributeError:
is_torch_cpu_bf16_available = False
return False

return True

return is_torch_cpu_bf16_available or is_torch_gpu_bf16_available

def is_torch_bf16_available():
return is_torch_bf16_cpu_available() or is_torch_bf16_gpu_available()


def is_torch_tf32_available():
Expand Down
6 changes: 3 additions & 3 deletions tests/deepspeed/test_deepspeed.py
Expand Up @@ -42,7 +42,7 @@
slow,
)
from transformers.trainer_utils import get_last_checkpoint, set_seed
from transformers.utils import WEIGHTS_NAME, is_torch_bf16_available
from transformers.utils import WEIGHTS_NAME, is_torch_bf16_gpu_available


if is_torch_available():
Expand Down Expand Up @@ -129,7 +129,7 @@ def get_launcher(distributed=False):
BF16 = "bf16"

stages = [ZERO2, ZERO3]
if is_torch_bf16_available():
if is_torch_bf16_gpu_available():
dtypes = [FP16, BF16]
else:
dtypes = [FP16]
Expand Down Expand Up @@ -920,7 +920,7 @@ def test_resume_train_not_from_ds_checkpoint(self, stage, dtype):
@require_torch_multi_gpu
@parameterized.expand(["bf16", "fp16", "fp32"])
def test_inference(self, dtype):
if dtype == "bf16" and not is_torch_bf16_available():
if dtype == "bf16" and not is_torch_bf16_gpu_available():
self.skipTest("test requires bfloat16 hardware support")

# this is just inference, so no optimizer should be loaded
Expand Down

0 comments on commit 90c8c01

Please sign in to comment.