Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MatMulNBits + Add fusion #20587

Merged
merged 30 commits into from May 16, 2024
Merged

MatMulNBits + Add fusion #20587

merged 30 commits into from May 16, 2024

Conversation

edgchen1
Copy link
Contributor

@edgchen1 edgchen1 commented May 7, 2024

Description

  • Add MatMulNBits Bias input
  • Add graph transformer to fuse MatMulNBits + Add

Motivation and Context

Improve performance.

Measurements

Phi-2 int4 model using onnxruntime_perf_test on x64 machine.

With ORT profiling enabled:

Baseline:

------ Top CPU Kernel Times ------
                  name  duration   pct  count  cumulative_pct  cumulative_dur
           MatMulNBits   3458639 74.28  19300           74.28         3458639
                   Add    479686 10.30  19600           84.59         3938325
    MultiHeadAttention    207680  4.46   3200           89.05         4146005
       RotaryEmbedding    190082  4.08   6400           93.13         4336087
              FastGelu    125999  2.71   3200           95.84         4462086
SkipLayerNormalization    105812  2.27   3200           98.11         4567898
             Unsqueeze     17475  0.38    900           98.48         4585373
                Gather     13065  0.28    600           98.76         4598438
                Concat     10644  0.23    400           98.99         4609082
                 Where      9325  0.20    400           99.19         4618407
                 Shape      7823  0.17    400           99.36         4626230
                Expand      4330  0.09    200           99.45         4630560
                  Cast      4079  0.09    200           99.54         4634639
                 Equal      3962  0.09    200           99.63         4638601
    LayerNormalization      3344  0.07    100           99.70         4641945
                 Slice      2482  0.05    100           99.75         4644427
                   Sub      2253  0.05    100           99.80         4646680
                 Range      1914  0.04    100           99.84         4648594
               Reshape      1892  0.04    100           99.88         4650486
                  Less      1851  0.04    100           99.92         4652337
               Squeeze      1836  0.04    100           99.96         4654173
       ConstantOfShape      1817  0.04    100          100.00         4655990

Updated:

------ Top CPU Kernel Times ------
                  name  duration   pct  count  cumulative_pct  cumulative_dur
           MatMulNBits   3446419 81.21  19300           81.21         3446419
    MultiHeadAttention    203500  4.80   3200           86.00         3649919
       RotaryEmbedding    194088  4.57   6400           90.58         3844007
              FastGelu    123723  2.92   3200           93.49         3967730
SkipLayerNormalization    104530  2.46   3200           95.96         4072260
                   Add     83163  1.96   3500           97.92         4155423
             Unsqueeze     17609  0.41    900           98.33         4173032
                Gather     13474  0.32    600           98.65         4186506
                Concat     10648  0.25    400           98.90         4197154
                 Where      9162  0.22    400           99.12         4206316
                 Shape      7687  0.18    400           99.30         4214003
                Expand      4265  0.10    200           99.40         4218268
                  Cast      4213  0.10    200           99.50         4222481
                 Equal      4040  0.10    200           99.59         4226521
    LayerNormalization      3324  0.08    100           99.67         4229845
                 Slice      2497  0.06    100           99.73         4232342
                   Sub      2228  0.05    100           99.78         4234570
               Reshape      1909  0.04    100           99.83         4236479
                 Range      1880  0.04    100           99.87         4238359
                  Less      1855  0.04    100           99.91         4240214
               Squeeze      1844  0.04    100           99.96         4242058
       ConstantOfShape      1804  0.04    100          100.00         4243862

Average inferences/sec without profiling enabled:
Baseline: 29.6533
Updated: 30.521

@edgchen1 edgchen1 marked this pull request as ready for review May 10, 2024 00:56
@edgchen1 edgchen1 requested a review from a team as a code owner May 10, 2024 00:56
@edgchen1 edgchen1 requested a review from skottmckay May 10, 2024 01:03
import sys

import numpy as np
import onnx

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'onnx' is imported with both 'import' and 'import from'.
Module 'onnxruntime.test.onnx' is imported with both 'import' and 'import from'.
yufenglee
yufenglee previously approved these changes May 15, 2024
Copy link
Contributor

@skottmckay skottmckay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@edgchen1 edgchen1 merged commit e81c867 into main May 16, 2024
93 of 96 checks passed
@edgchen1 edgchen1 deleted the edgchen1/matmul_nbits_bias branch May 16, 2024 18:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants