Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validating ONNX model fails for GPT-J #607

Closed
2 of 4 tasks
Eichhof opened this issue Dec 18, 2022 · 40 comments · Fixed by #609
Closed
2 of 4 tasks

Validating ONNX model fails for GPT-J #607

Eichhof opened this issue Dec 18, 2022 · 40 comments · Fixed by #609
Labels
bug Something isn't working

Comments

@Eichhof
Copy link

Eichhof commented Dec 18, 2022

System Info

Optimum: 1.5.1
Python: 3.10.4
Platform: Windows 10

Who can help?

@lewtun @michaelbenayoun

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I Installed optimum with pip install optimum[onnxruntime-gpu]. Then I was running python -m optimum.exporters.onnx --task causal-lm-with-past --model EleutherAI/gpt-j-6B gptj_onnx/ to transform GPT-J to ONNX. The output of this call is then as follows:

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJModel: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing GPTJModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTJModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Framework not specified. Using pt to export to ONNX.
Using framework PyTorch: 1.12.1
Overriding 2 configuration item(s)
        - use_cache -> True
        - pad_token_id -> 0
C:\Users\myUsername\Anaconda3\envs\huggingface\lib\site-packages\transformers\models\gptj\modeling_gptj.py:597: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if batch_size <= 0:
C:\Users\myUsername\Anaconda3\envs\huggingface\lib\site-packages\transformers\models\gptj\modeling_gptj.py:177: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
Validating ONNX model...
        -[✓] ONNX model output names match reference model (present.22.value, present.15.key, present.15.value, present.25.value, present.9.value, present.26.value, present.8.value, present.13.key, present.27.key, present.6.value, present.7.value, present.12.value, present.24.key, present.1.value, present.4.key, logits, present.10.key, present.9.key, present.16.key, present.0.key, present.19.key, present.21.key, present.4.value, present.23.value, present.3.key, present.17.key, present.6.key, present.21.value, present.22.key, present.18.key, present.11.key, present.10.value, present.14.value, present.0.value, present.13.value, present.14.key, present.5.value, present.2.value, present.16.value, present.24.value, present.25.key, present.27.value, present.8.key, present.7.key, present.19.value, present.20.key, present.26.key, present.18.value, present.23.key, present.11.value, present.2.key, present.5.key, present.3.value, present.1.key, present.20.value, present.17.value, present.12.key)
        - Validating ONNX Model output "logits":
                -[✓] (2, 16, 50400) matches (2, 16, 50400)
                -[x] values not close enough, max diff: 3.2901763916015625e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.0.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.0.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.1.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.1.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.2.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.2.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.3.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.3.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.4.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.4.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.5.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.5.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.6.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.6.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.7.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 2.6702880859375e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.7.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.8.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.8.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.9.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 2.09808349609375e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.9.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 1.3589859008789062e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.10.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.10.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 1.2636184692382812e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.11.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 2.574920654296875e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.11.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.12.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.12.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.13.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 1.71661376953125e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.13.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.14.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.14.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.15.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.15.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.16.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.16.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.17.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 1.71661376953125e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.17.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.18.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.18.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.19.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.19.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.20.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.20.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.21.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.21.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.22.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.22.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.23.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.23.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.24.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.24.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.25.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.25.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[x] values not close enough, max diff: 1.1682510375976562e-05 (atol: 1e-05)
        - Validating ONNX Model output "present.26.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.26.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.27.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
        - Validating ONNX Model output "present.27.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 1e-05)
An error occured, but the model was saved at: gptj_onnx/model.onnx

Expected behavior

Validation of ONNX model should succeed.

@Eichhof Eichhof added the bug Something isn't working label Dec 18, 2022
@mht-sharma
Copy link
Contributor

Hi @Eichhof the model requires high atol value for validation. Could you try by setting --atol=1e-4 and export?

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 19, 2022

Fixed in #609

@Eichhof
Copy link
Author

Eichhof commented Dec 19, 2022

Thank you very much @mht-sharma
With --atol=1e-4 the same error occurs. So is this then not a problem and can be ignored? For example, lower accuracy of the model or so.

In addition at the beginning of output (see above) it is written Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJModel: ['lm_head.weight', 'lm_head.bias']. Is this a problem?

Finally, why is --for-ort not recognized and what is the purpose of this flag?

Validating ONNX model...
        -[✓] ONNX model output names match reference model (present.10.key, present.16.value, present.8.key, present.10.value, present.27.value, present.12.key, present.7.key, present.12.value, present.11.key, present.25.value, present.5.value, present.21.value, present.17.key, present.4.value, present.15.value, present.0.value, present.6.key, present.20.value, present.2.value, present.9.value, present.13.value, present.22.key, present.3.key, present.27.key, present.2.key, present.20.key, present.5.key, present.8.value, present.6.value, present.17.value, logits, present.16.key, present.9.key, present.23.key, present.21.key, present.3.value, present.24.key, present.14.key, present.7.value, present.18.key, present.19.value, present.22.value, present.24.value, present.26.key, present.26.value, present.25.key, present.14.value, present.1.value, present.19.key, present.18.value, present.15.key, present.13.key, present.23.value, present.0.key, present.4.key, present.11.value, present.1.key)
        - Validating ONNX Model output "logits":
                -[✓] (2, 16, 50400) matches (2, 16, 50400)
                -[x] values not close enough, max diff: 0.0001621246337890625 (atol: 0.0001)
        - Validating ONNX Model output "present.0.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.0.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.1.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.1.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.2.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.2.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.3.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.3.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.4.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.4.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.5.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.5.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.6.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.6.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.7.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.7.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.8.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.8.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.9.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.9.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.10.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.10.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.11.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.11.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.12.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.12.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.13.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.13.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.14.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.14.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.15.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.15.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.16.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.16.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.17.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.17.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.18.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.18.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.19.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.19.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.20.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.20.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.21.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.21.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.22.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.22.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.23.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.23.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.24.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.24.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.25.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.25.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.26.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.26.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.27.key":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
        - Validating ONNX Model output "present.27.value":
                -[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
                -[✓] all values close (atol: 0.0001)
An error occured, but the model was saved at: gptj_onnx/model.onnx

@mht-sharma
Copy link
Contributor

Hi @Eichhof , the inputs are generated randomly for validation. Hence, sometimes the model might be sensitive to inputs which results in the error. You could run the command again and the model should run successfully.

The error of <=1e-4 is generally acceptable. However, if the error is high then it needs to be looked into.

The --for-ort flag is added for models like encoder-decoder where we need to export the model into multiple ONNX files for inference with onnxruntime (ort).

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 19, 2022

Note that --for-ort was added in #497 and is hence not yet in the stable release! We will do a release this week.

@Eichhof
Copy link
Author

Eichhof commented Dec 19, 2022

Thank you very much for the information @mht-sharma and @fxmarty .

Do you have a comment regarding the other warning Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJModel: ['lm_head.weight', 'lm_head.bias']? I don't know if that is critical for the model.

Regarding --for-ort, do I need this for inference of GPT-J?

@fxmarty fxmarty reopened this Dec 20, 2022
@Eichhof
Copy link
Author

Eichhof commented Dec 21, 2022

Any ideas regarding the not used weights and --for-ort?

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 22, 2022

Hi @Eichhof , sorry for the late reply. If you are working on language modeling, where you do text generation, I would advise using --for-ort. For now, it allows to export two models:

For now, this is the only way we have to be able to use past key values in the decoding. @JingyaHuang did a very nice PR to be able to merge the two models in one #587 , thus avoiding to duplicate the memory use. You may be interested in having a look, but it is not yet integrated with the ORTModelForXX.

We do a release today, so the --for-ort (and better documentation) will be included.

About the weights, not sure, I'll have a look asap

@Eichhof
Copy link
Author

Eichhof commented Dec 22, 2022

Thank you very much @fxmarty .

When will it be integrated in the ORTModelForXX? Right now, duplicating memory usage means that ORTModelForXX needs double GPU memory (i.e., 28 GB of VRAM)? Or do you refer to CPU memory? If you refer to GPU memory, that would be a problem because I only have a 24 GB GPU.

That would be great if you could look into the weights problem.

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 23, 2022

Hi @Eichhof , I agree it's a huge issue, I think it's high priority.

You may want to have a look at today's release notes, notably the section "Experimental support to merge ONNX decoder with/without past key values": https://github.com/huggingface/optimum/releases/tag/v1.6.0 . We'll gradually improve the documentation to reflect the new features (notably on the export side).

@Eichhof
Copy link
Author

Eichhof commented Dec 26, 2022

Thank you @fxmarty . I will have a look at the new release and the experimental support for the merging. I hope that the merging works for GPT-J.

@hivaze
Copy link
Contributor

hivaze commented Dec 28, 2022

Hey @Eichhof, I made a support for past key/value in decoder for my own. My version does not require 2 models to be loaded into memory, I think this is a terrible idea, since many decoders weigh a lot. Also, I got rid of many bugs in the implementation of ORTModelForCausalLM class, which I found while trying to use it.

You can check out my version here, unfortunately I'm not going to do a PR:
https://github.com/hivaze/optimum/blob/main/optimum/onnxruntime/modeling_decoder.py

@Eichhof
Copy link
Author

Eichhof commented Dec 29, 2022

Thank you very much @hivaze . That sounds very interesting. I will give it a try. How can I use your code? Do I only have to replace the original script with your script?

@Eichhof
Copy link
Author

Eichhof commented Dec 29, 2022

@fxmarty I was trying to merge the models with the PR from @JingyaHuang but I'm getting the error Detected 1 oom-kill event(s) in StepId=5333998.batch. Some of your processes may have been killed by the cgroup out-of-memory handler. although I was running it on 150 GB of memory on a cluster.

@hivaze
Copy link
Contributor

hivaze commented Dec 29, 2022

Thank you very much @hivaze . That sounds very interesting. I will give it a try. How can I use your code? Do I only have to replace the original script with your script?

Yes, you just need to copy the file and change the usual modeling_decoder.py to mine. Or you can redo the imports in my file so that you can use it as a plug-in script outside the library (there are relative imports, they just need to be made absolute). And also remove one more print statement from the forward() method, it's just there for debag.

If you want to use model with cache for generation, you can just call ORTModelForCausalLM.from_pretrained method where model_id is a path to a folder with file_name which is a file .onnx model (model.onnx by default). And folder should also contain a file config.json - HuggingFace pretrained config. If your model uses cache, you need also to pass cached_version=True here.

After this you will be comfortable using the generate method, since you usually use it without worrying about anything. use_cache parameter in generate() will be True by default while using cached_version=True

@Eichhof
Copy link
Author

Eichhof commented Dec 29, 2022

Thank you very much @hivaze .

Or you can redo the imports in my file so that you can use it as a plug-in script outside the library (there are relative imports, they just need to be made absolute).

I don't get it. What do you mean exactly?

If you want to use model with cache for generation, you can just call ORTModelForCausalLM.from_pretrained method where model_id is a path to a folder with file_name which is a file .onnx model (model.onnx by default). And folder should also contain a file config.json - HuggingFace pretrained config.

After running python -m optimum.exporters.onnx --task causal-lm-with-past --for-ort --model EleutherAI/gpt-j-6B gptj_onnx/, I have the files

config.json
decoder_model.onnx
decoder_model.onnx_data
decoder_with_past_model.onnx
decoder_with_past_model.onnx_data

I guess all these files are necessary?

@hivaze
Copy link
Contributor

hivaze commented Dec 30, 2022

After running python -m optimum.exporters.onnx --task causal-lm-with-past --for-ort --model EleutherAI/gpt-j-6B gptj_onnx/, I have the files

config.json
decoder_model.onnx
decoder_model.onnx_data
decoder_with_past_model.onnx
decoder_with_past_model.onnx_data

I guess all these files are necessary?

You need only these three files if you want to use cached keys/values in inference:

config.json
decoder_with_past_model.onnx
decoder_with_past_model.onnx_data

After copying the script you need to call ORTModelForCausalLM.from_pretrained(model_id='yourfolder', file_name='decoder_with_past_model.onnx', provider='CUDAExecutionProvider', cached_version=True), and you will get your ORTModelForCausalLM instance. Then you will be able to use an optimized generate().

@hivaze
Copy link
Contributor

hivaze commented Dec 30, 2022

Of course, I think this should all be discussed in a separate issue, or even a PR, but we will consider that you will do me a great favor if you help me test it, and then I can maybe do a PR.

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 30, 2022

@hivaze Thanks for working on this! I was wondering how you were handling the case where there aren't yet any past key values yet? The motive at first to introduce these two models was to be able to handle the special first-pass case, resulting in models with different inputs:

decoder_model.onnx
image

decoder_with_past_model.onnx
image

Looking back I think it was the easiest solution to implement back then, even if of course it's really to good memory-wise. So looking forward to support only a single ONNX for the decoder! cc @JingyaHuang

@hivaze
Copy link
Contributor

hivaze commented Dec 30, 2022

@hivaze Thanks for working on this! I was wondering how you were handling the case where there aren't yet any past key values yet? The motive at first to introduce these two models was to be able to handle the special first-pass case, resulting in models with different inputs:

Oh, it's a trick to generate a fake cache. We only need to generate (randn) a cache for keys and values for each layer for the past text of length 1. And then, we can safely use the attention mask, masking this fake cache with zero. As far as I have checked, this method really works and does not affect the output of the model when forwarding real tokens.

I do not consider this to be a correct solution to the problem, I just realized through experiments that it works. (tested only on gpt-j)

@fxmarty
Copy link
Collaborator

fxmarty commented Feb 16, 2023

@hivaze @Eichhof Hope you do well.

A PR has been merged to use a single ONNX without/with past key values: #647 . It should drastically reduce the memory usage at inference.

Following steps are listed in #784.

@fxmarty fxmarty closed this as completed Feb 16, 2023
@hivaze
Copy link
Contributor

hivaze commented Feb 16, 2023

@fxmarty Thank you! However, can you explain in few words what is the key idea of PR's changes? It's not clear for me if it is possible to use the only decoder with past_key_values as inputs now to generate the text?

@fxmarty
Copy link
Collaborator

fxmarty commented Feb 16, 2023

The strategy is to take the decoder_model.onnx and decoder_with_past_model.onnx, and merge them into a single ONNX with an If node that dispatches on the without or with branch depending on a flag passed as input, while weights are shared between both.

In the first pass, dummy past key values must be passed (they will simply not be used).

The decoder_model_merged.onnx (which will be the default in the ONNX export) now looks like:

image

To be honest, this is a bit of a hack, and there should be cleanier solution than this: https://discuss.huggingface.co/t/how-does-the-onnx-exporter-work-for-generationmodel-with-past-key-value/31316/8?u=fxmarty

@Eichhof
Copy link
Author

Eichhof commented Feb 16, 2023

Thanks for the update. Sounds great!

How can I use a single ONNX without/with past key values for GPT-J?

When loading the ONNX model exported from GPT-J it takes more than 10 minutes until the model is loaded (to load a Huggingface GPT-J model takes around 10s). In addition, when ONNX model is loaded it takes around 2 GB of GPU memory and 55 GB of CPU memory. In comparison, Huggingface GPT-J model is taking 14 GB of GPU memory and around 10 GB of CPU memory. Why is that?

@hivaze
Copy link
Contributor

hivaze commented Feb 17, 2023

In comparison, Huggingface GPT-J model is taking 14 GB of GPU

Hey Eichhof it seems that you're loading your GPT-J in ONNX with it's fp32 version. You need to convert your .onnx model in fp16 and load it then. Moreover, ONNX will always take up more space in the memory of the graphics card, because it has a static graph, unlike Pytorch.

Maybe of course the problem is related to the new mechanism, have not tried it yet. But I hope the information is still useful for you :)

@fxmarty
Copy link
Collaborator

fxmarty commented Feb 17, 2023

There is a PR open to export in fp16 with a --device cuda --fp16 option in the export: #749

Overall I've found ONNX Runtime to be a bit painful to use on GPU, with the TensorRT support limited (see this), but let's hope it gets better.

I still have to test #647 on large models on GPU to see the memory usage. I will keep you updated here!

@Eichhof
Copy link
Author

Eichhof commented Feb 17, 2023

@hivaze I'm using provider_options=dict(trt_fp16_enable=1). Do I also have to export to ONNX with fp16?

@fxmarty Is the fp16 export already available on the main branch? Should I rather use CUDAExecutionProvider instead of TensorrtExecutionProvider on GPU?

Can I also quantize my GPT-J ONNX model so that it used less memory? I have read here about ORTQuantizer to apply dynamic quantization.

@Eichhof
Copy link
Author

Eichhof commented Feb 21, 2023

@fxmarty Do you already have an estimate of when the PR will be ready for exporting ONNX with fp16?

@fxmarty
Copy link
Collaborator

fxmarty commented Feb 23, 2023

There are these two PRs #807 & #749 that will allow to use fp16 along with a single ONNX decoder (that handles both without/with past).

The first one should be merged soon and included in the next release.

@Eichhof
Copy link
Author

Eichhof commented Feb 24, 2023

@fxmarty I'm a bit confused now. #749 is for exporting to ONNX with fp16 but for what is #807? Do I need both?

@fxmarty
Copy link
Collaborator

fxmarty commented Feb 24, 2023

The two take a different path:

@Eichhof
Copy link
Author

Eichhof commented Feb 25, 2023

@fxmarty Thank you for the detailed explanations. So only using one of them but not both of them is necessary, is this correct understanding? How can I use #807 (once it is released)?

@fxmarty
Copy link
Collaborator

fxmarty commented Feb 27, 2023

Hi @Eichhof ,

Yes, the two methods are exclusive.

Once released (tomorrow or on Wednesday), #807 usage will be in the optimum-cli export onnx CLI as --optimize O1, up to --optimize O4 which runs ONNX Runtime conversion to fp16.

@Eichhof
Copy link
Author

Eichhof commented Feb 28, 2023

@fxmarty I just saw that both are merged now. Is there any difference of using optimum-cli export onnx --optimize O4 versus python -m optimum.exporters.onnx --atol=1e-4 --for-ort --task causal-lm-with-past --fp16 --model EleutherAI/gpt-j-6B gptj_onnx/ in terms of the exported model and runtime? Is a model with optimum-cli export onnx --optimize O4 even more optimized?

@fxmarty
Copy link
Collaborator

fxmarty commented Feb 28, 2023

The argument --for-ort is not necessary anymore - it is the default now.

Exporting with optimum-cli export onnx --task causal-lm-with-past --fp16 --model EleutherAI/gpt-j-6B gptj_onnx/, you may be able to use the resulting model with TensorRT in float16.

Exporting with optimum-cli export onnx --task causal-lm-with-past --optimize O4 --model EleutherAI/gpt-j-6B gptj_onnx/, you will only be able to use ONNX Runtime CUDAExecutionProvider.

By the way, with ONNX Runtime float16 conversion, there is an issue I haven't been able to solve yet specifically with GPT-J, so for now this architecture is not tested:

@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY))
@require_torch_gpu
@require_vision
@slow
@pytest.mark.gpu_test
@pytest.mark.run_slow
def test_exporters_cli_pytorch_with_O4_optimization(
self, test_name: str, model_type: str, model_name: str, task: str, monolith: bool, no_post_process: bool
):
# TODO: optimization for codegen not supported until https://github.com/microsoft/onnxruntime/pull/14751 is released
if model_type == "codegen":
self.skipTest("codegen not supported")
# TODO: investigate why gptj with past is the only failing one (as in ORTOptimizer)
if model_type == "gptj" and (task is None or "-with-past" in task):
self.skipTest("Test failing with Shape mismatch attempting to re-use buffer")
# TODO: disable due to a bug in PyTorch: https://github.com/pytorch/pytorch/issues/95377
if model_type == "yolos":
self.skipTest("Export on cuda device fails for yolos due to a bug in PyTorch")
try:
self._onnx_export(model_name, task, monolith, no_post_process, optimization_level="O4", device="cuda")
except subprocess.CalledProcessError as e:
if (
"Tried to use ORTOptimizer for the model type" in e.stderr
or "doesn't support the graph optimization" in e.stderr
):
self.skipTest("unsupported model type in ORTOptimizer")
else:
raise e

More read: #785 (comment) (and the following answers). It could be a bug in ONNX Runtime.

Reference: https://huggingface.co/docs/optimum/onnxruntime/package_reference/configuration#optimum.onnxruntime.AutoOptimizationConfig.with_optimization_level

@fxmarty
Copy link
Collaborator

fxmarty commented Feb 28, 2023

Tracking the issue in #800

So for now I would recommend you using --fp16.

@Eichhof
Copy link
Author

Eichhof commented Mar 1, 2023

So I will try to use optimum-cli export onnx --task causal-lm-with-past --fp16 --model EleutherAI/gpt-j-6B gptj_onnx/ with TensorRT in float16. Do you have an estimate of decrease in inference time when using this? By how much will the GPU memory consumption go up in comparison to not using ONNX? The current memory consumption is 14 GB on fp16.

@fxmarty
Copy link
Collaborator

fxmarty commented Mar 1, 2023

@Eichhof I'll try and get back to you.

@Eichhof
Copy link
Author

Eichhof commented Mar 7, 2023

@fxmarty Do you have any news regarding the decrease in response time and memory?

@fxmarty
Copy link
Collaborator

fxmarty commented Mar 21, 2023

Hi @Eichhof , I had a short test with CUDAExecutionProvider.

The model is exported with: optimum-cli export onnx --model EleutherAI/gpt-j-6B --optimize O3 --device cuda --fp16 gptj_onnx

Here's the result:

  • PyTorch model load time: 134.29 s
  • ORT model load time: 153.83 s
  • PyTorch latency: 1.350 s
  • ORT latency: 1.320 s

As for memory, it appears ONNX Runtime CUDAExecutionProvider is still very bad: microsoft/onnxruntime#14526 (comment)

Scripts:

from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer
import time

model_id = "gptj_onnx"
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

print("loading model")
start = time.time()
model = ORTModelForCausalLM.from_pretrained(model_id, provider="CUDAExecutionProvider")
print(f"Loading took: {time.time() - start:.2f} s")

prompt = "ORT fast or slow"
inp = tokenizer(prompt, return_tensors="pt").to("cuda")

# warmup
res = model.generate(**inp, num_beams=1, min_length=50, max_length=50)

n_batch = 20
start = time.time()
for i in range(n_batch):
    res = model.generate(**inp, num_beams=1, min_length=50, max_length=50)
end = time.time()
ort_time = end - start
print(f"ORT: {ort_time / n_batch:.3f} s")
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time

model_id = "EleutherAI/gpt-j-6B"
tokenizer = AutoTokenizer.from_pretrained(model_id)

print("loading model")
start = time.time()
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
print(f"Loading took: {time.time() - start:.2f} s")

prompt = "ORT fast or slow"
inp = tokenizer(prompt, return_tensors="pt").to("cuda")

# warmup
res = model.generate(**inp, num_beams=1, min_length=50, max_length=50)

n_batch = 20
with torch.inference_mode():
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for i in range(n_batch):
        res = model.generate(**inp, num_beams=1, min_length=50, max_length=50)
    end_event.record()
    torch.cuda.synchronize()
    pt_time = start_event.elapsed_time(end_event) * 1e-3
    print(f"PT: {pt_time / n_batch:.3f} s")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants