Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama Model throwing "RuntimeError: expected scalar type BFloat16 but found Float" when using torch.compile and AMP together #30945

Open
2 of 4 tasks
JackCai1206 opened this issue May 21, 2024 · 8 comments
Labels
Compilation Issues related to torchdynamo and torchinductor

Comments

@JackCai1206
Copy link

System Info

transformers 4.41.0
torch 2.3.0
GPU: NVIDIA GeForce RTX 4090, CUDA version 12.3

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import LlamaConfig, LlamaForCausalLM, AdamW, AutoModelForCausalLM, GPT2Config
from torch.cuda.amp import autocast, GradScaler

# Configure the model
config = LlamaConfig(
    num_attention_heads=6,
    num_hidden_layers=6,
    hidden_size=384,
    intermediate_size=1536,  # Typically 4 * hidden_size
    vocab_size=30522,        # Standard vocabulary size
    max_position_embeddings=1024,
)

# config = GPT2Config(
#     n_embd=384,
#     n_head=6,
#     n_layer=6,
#     n_positions=1024,
#     n_ctx=1024,
#     n_vocab=30522,
# )

# Initialize the model
model = AutoModelForCausalLM.from_config(config, attn_implementation="eager").to('cuda')

# Compile the model (Torch 2.0 and above)
model = torch.compile(model)


# Create dummy data
batch_size = 8
sequence_length = 1024
dummy_input_ids = torch.randint(0, config.vocab_size, (batch_size, sequence_length)).to('cuda')
dummy_labels = torch.randint(0, config.vocab_size, (batch_size, sequence_length)).to('cuda')

# Set up the optimizer
optimizer = AdamW(model.parameters(), lr=1e-4)

scaler = GradScaler()

# Set the model to training mode
model.train()

# Training loop
num_epochs = 10000
for epoch in range(num_epochs):
    with autocast(dtype=torch.bfloat16, enabled=True):
        # Forward pass
        outputs = model(input_ids=dummy_input_ids, labels=dummy_labels)
        loss = outputs.loss

    # Backward pass
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

print("Training complete.")

Expected behavior

Running the code snippet above gives me the following error

{
	"name": "RuntimeError",
	"message": "expected scalar type BFloat16 but found Float",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 54
     51     loss = outputs.loss
     53 # Backward pass
---> 54 scaler.scale(loss).backward()
     55 scaler.step(optimizer)
     56 scaler.update()

File ~/anaconda3/lib/python3.11/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    515 if has_torch_function_unary(self):
    516     return handle_torch_function(
    517         Tensor.backward,
    518         (self,),
   (...)
    523         inputs=inputs,
    524     )
--> 525 torch.autograd.backward(
    526     self, gradient, retain_graph, create_graph, inputs=inputs
    527 )

File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/__init__.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    262     retain_graph = create_graph
    264 # The reason we repeat the same comment below is that
    265 # some Python versions print out the first line of a multi-line function
    266 # calls in the traceback and some print out the last line
--> 267 _engine_run_backward(
    268     tensors,
    269     grad_tensors_,
    270     retain_graph,
    271     create_graph,
    272     inputs,
    273     allow_unreachable=True,
    274     accumulate_grad=True,
    275 )

File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, **kwargs)
    742     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    743 try:
--> 744     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    745         t_outputs, *args, **kwargs
    746     )  # Calls into the C++ engine to run the backward pass
    747 finally:
    748     if attach_logging_hooks:

File ~/anaconda3/lib/python3.11/site-packages/torch/autograd/function.py:301, in BackwardCFunction.apply(self, *args)
    295     raise RuntimeError(
    296         \"Implementing both 'backward' and 'vjp' for a custom \"
    297         \"Function is not allowed. You should only implement one \"
    298         \"of them.\"
    299     )
    300 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 301 return user_fn(self, *args)

File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:882, in aot_dispatch_autograd.<locals>.CompiledFunction.backward(ctx, *flat_args)
    880     out = CompiledFunctionBackward.apply(*all_args)
    881 else:
--> 882     out = call_compiled_backward()
    884 # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here.
    885 if CompiledFunction.maybe_subclass_metadata is not None:

File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:831, in aot_dispatch_autograd.<locals>.CompiledFunction.backward.<locals>.call_compiled_backward()
    824     with tracing(saved_context), context(), track_graph_compiling(
    825         aot_config, \"backward\"
    826     ):
    827         CompiledFunction.compiled_bw = aot_config.bw_compiler(
    828             bw_module, placeholder_list
    829         )
--> 831 out = call_func_at_runtime_with_args(
    832     CompiledFunction.compiled_bw,
    833     all_args,
    834     steal_args=True,
    835     disable_amp=disable_amp,
    836 )
    838 out = functionalized_rng_runtime_epilogue(
    839     CompiledFunction.metadata, out
    840 )
    841 return tuple(out)

File ~/anaconda3/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py:113, in call_func_at_runtime_with_args(f, args, steal_args, disable_amp)
    111 with context():
    112     if hasattr(f, \"_boxed_call\"):
--> 113         out = normalize_as_list(f(args))
    114     else:
    115         # TODO: Please remove soon
    116         # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
    117         warnings.warn(
    118             \"Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. \"
    119             \"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. \"
    120             \"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.\"
    121         )

File ~/anaconda3/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:451, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    449 prior = set_eval_frame(callback)
    450 try:
--> 451     return fn(*args, **kwargs)
    452 finally:
    453     set_eval_frame(prior)

File ~/anaconda3/lib/python3.11/site-packages/torch/_dynamo/external_utils.py:36, in wrap_inline.<locals>.inner(*args, **kwargs)
     34 @functools.wraps(fn)
     35 def inner(*args, **kwargs):
---> 36     return fn(*args, **kwargs)

File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/codecache.py:906, in CompiledFxGraph.__call__(self, inputs)
    905 def __call__(self, inputs: List[Any]) -> Any:
--> 906     return self.get_current_callable()(inputs)

File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/compile_fx.py:784, in align_inputs_from_check_idxs.<locals>.run(new_inputs)
    782 def run(new_inputs):
    783     copy_misaligned_inputs(new_inputs, inputs_to_check)
--> 784     return model(new_inputs)

File ~/anaconda3/lib/python3.11/site-packages/torch/_inductor/codecache.py:934, in _run_from_cache(compiled_graph, inputs)
    926     assert compiled_graph.artifact_path
    927     compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path(
    928         compiled_graph.cache_key,
    929         compiled_graph.artifact_path,
    930         compiled_graph.cache_linemap,
    931         compiled_graph.constants,
    932     ).call
--> 934 return compiled_graph.compiled_artifact(inputs)

File /tmp/torchinductor_zcai75/wq/cwqm67koqia7gthn65wgmhppfzrfyheocl4px7fecurpkfigigfs.py:1751, in call(args)
   1749 buf39 = reinterpret_tensor(buf34, (48, 64, 1024), (65536, 1024, 1), 0); del buf34  # reuse
   1750 # Source Nodes: [], Original ATen: [aten.bmm]
-> 1751 extern_kernels.bmm(permute_103, reinterpret_tensor(buf38, (48, 1024, 1024), (1048576, 1024, 1), 0), out=buf39)
   1752 del permute_103
   1753 buf41 = empty_strided_cuda((8, 6, 1024, 64), (393216, 65536, 64, 1), torch.bfloat16)

RuntimeError: expected scalar type BFloat16 but found Float"
}

This problem does not seem to happen for a GPT2 model. If I initialize the GPT2Config instead of LlamaConfig in the commented code in the script, there is no such error.

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker

@RUFFY-369
Copy link

Hi @JackCai1206 I ran your script but didn't encounter the error that you mentioned for LlamaConfig and ran smoothly for both. Can you check your pytorch cuda compatibility as I have a version 12.2 with pytorch 2.3 (PyTorch version (GPU?): 2.3.0+cu121 (True), Cuda compilation tools, release 12.2, V12.2.140)?

@JackCai1206
Copy link
Author

When I run nvidia-smi I get | NVIDIA-SMI 545.23.08 Driver Version: 545.23.08 CUDA Version: 12.3 |
and i have installed torch 2.3.0 without "cu" suffixes, which I assume is compatible with cuda 12?

@RUFFY-369
Copy link

RUFFY-369 commented May 23, 2024

@JackCai1206 There are two main APIs of CUDA, the runtime and the driver. The nvidia CUDA version you have posted is the driver API version and what we have with pytorch is the runtime API one which we get after cuda toolkit gets installed automatically with pip3 install torch
Just for confirmation can you check the output of pip list | grep torch and torch.version.cuda. If the outputs does show no cuda dependencies and None respectively then we have to reinstall pytorch with cuda dependencies.

@ArthurZucker ArthurZucker added the Compilation Issues related to torchdynamo and torchinductor label May 23, 2024
@JackCai1206
Copy link
Author

Hi, thanks for the explanation! This is the output of pip list

torch                             2.3.0
torchaudio                        2.3.0
torchvision                       0.18.0

and torch cuda version

>>> import torch
>>> torch.version.cuda
'12.1'

@ArthurZucker
Copy link
Collaborator

also cc @gante

@RUFFY-369
Copy link

Hi, thanks for the explanation! This is the output of pip list

torch                             2.3.0
torchaudio                        2.3.0
torchvision                       0.18.0

and torch cuda version

>>> import torch
>>> torch.version.cuda
'12.1'

@JackCai1206 Oh! I see. What i found could be the reason for the error is this line in modeling_llama as your model has (rotary_emb): LlamaRotaryEmbedding(). It forces float32 as bfloat16 loses precision on long context.

If you want to use autocast then an alternative trial could be to use Trainer class of transformers and activate autocast through bf16=True argument in TrainingArguments

@JackCai1206
Copy link
Author

Sounds good. Yeah i think a warning message there could be useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Compilation Issues related to torchdynamo and torchinductor
Projects
None yet
Development

No branches or pull requests

4 participants