Skip to content

Commit

Permalink
Add some warning for Dynamo and enable TF32 when it's set (huggingfac…
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored and miyu386 committed Feb 9, 2023
1 parent 6f36392 commit 23a59ed
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/transformers/training_args.py
Expand Up @@ -1148,6 +1148,15 @@ def __post_init__(self):
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
)

if self.framework == "pt" and is_torch_available() and self.torchdynamo is not None:
if is_torch_tf32_available():
if self.tf32 is None and not self.fp16 or self.bf16:
logger.info("Setting TF32 in CUDA backends to speedup torchdynamo.")
torch.backends.cuda.matmul.allow_tf32 = True
else:
logger.warning(
"The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here."
)
if self.framework == "pt" and is_torch_available() and self.tf32 is not None:
if self.tf32:
if is_torch_tf32_available():
Expand Down

0 comments on commit 23a59ed

Please sign in to comment.