From 23a59ed006c40e70b98b86d7d77fb9c9a197d646 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 30 Nov 2022 15:42:17 -0500 Subject: [PATCH] Add some warning for Dynamo and enable TF32 when it's set (#20515) --- src/transformers/training_args.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 1de1e85337f071..c8c0a4588888c6 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1148,6 +1148,15 @@ def __post_init__(self): " (`--bf16_full_eval`) can only be used on CUDA or CPU devices." ) + if self.framework == "pt" and is_torch_available() and self.torchdynamo is not None: + if is_torch_tf32_available(): + if self.tf32 is None and not self.fp16 or self.bf16: + logger.info("Setting TF32 in CUDA backends to speedup torchdynamo.") + torch.backends.cuda.matmul.allow_tf32 = True + else: + logger.warning( + "The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here." + ) if self.framework == "pt" and is_torch_available() and self.tf32 is not None: if self.tf32: if is_torch_tf32_available():