Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor][cpp] GEMM template (infra and fp32) #124021

Closed
wants to merge 60 commits into from

Conversation

jgong5
Copy link
Collaborator

@jgong5 jgong5 commented Apr 14, 2024

Stack from ghstack (oldest at bottom):

This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info.

  1. Cpp template infrastructure
    Similar template abstractions as the CUTLASS template, i.e., CppTemplate, CppTemplateKernel, CppTemplateBuffer. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates.
  2. Initial FP32 gemm template
    This involves a GEMM template implementation CppPackedGemmTemplate that supports GEMM with constant weight (B) requiring N to be a multiple of register blocking while allows the static or dynamic sizes for the M (batch dim) of A. The B matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via thread_blocking) and cache blocking (via cache_blocking). Then it invokes CppMicroGemm which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A CppMicroGemmFP32Vec micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction.
  3. Correctness and performance
    The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details.

Static shapes

Benchmark torchbench huggingface timm_models
Multi-threaded (baseline) 1.47x 1.36x 1.91x
Multi-threaded (max-autotune) 1.47x 1.36x 1.92x
Single-threaded (baseline) 1.56x 1.19x 1.51x
Single-threaded (max-autotune) 1.56x 1.19x 1.52x

Key models being sped up:
drq: 1.14x
soft_act: 1.12
cait_m36_384: 1.18x

Dynamic shapes

Benchmark torchbench huggingface timm_models
Multi-threaded (baseline) 1.43x 1.28x 1.85x
Multi-threaded (max-autotune) 1.47x 1.28x 1.85x
Single-threaded (baseline) 1.55x 1.20x 1.51x
Single-threaded (max-autotune) 1.56x 1.19x 1.53x

Key models being sped up:
BERT_pytorch: 1.22x
pyhpc_turbulent: 1.13x
soft_actor_critic: 1.77x
BlenderbotForCausalLM: 1.09x
cait_m36_384: 1.17x

cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang

Differential Revision: D57585365

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124021

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5f07582 with merge base 3f5b59e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jgong5 added a commit that referenced this pull request Apr 14, 2024
ghstack-source-id: 16b145bc95cd7a18b29dab08a9dca8379ef8c2f0
Pull Request resolved: #124021
@jgong5 jgong5 marked this pull request as draft April 14, 2024 14:21
[ghstack-poisoned]
jgong5 added a commit that referenced this pull request Apr 16, 2024
ghstack-source-id: 78e95234d720874d2de4d57f523821bec8b90461
Pull Request resolved: #124021
[ghstack-poisoned]
jgong5 added a commit that referenced this pull request Apr 16, 2024
ghstack-source-id: 63199103660ddb479f91c869685263c62c5c6cb2
Pull Request resolved: #124021
[ghstack-poisoned]
jgong5 added a commit that referenced this pull request Apr 17, 2024
ghstack-source-id: b06afe934b7414b32952285659a659f09cd7ae48
Pull Request resolved: #124021
[ghstack-poisoned]
jgong5 added a commit that referenced this pull request Apr 17, 2024
ghstack-source-id: 30af30579007f953540b54d0d2811a9ae6868234
Pull Request resolved: #124021
[ghstack-poisoned]
jgong5 added a commit that referenced this pull request Apr 17, 2024
ghstack-source-id: af09b73875d34785c1de8c3641671d1af81f0870
Pull Request resolved: #124021
[ghstack-poisoned]
jgong5 added a commit that referenced this pull request Apr 17, 2024
ghstack-source-id: fa01d00bdc5f4ff2d2262c846cc59520fc9a6399
Pull Request resolved: #124021
[ghstack-poisoned]
jgong5 added a commit that referenced this pull request Apr 17, 2024
ghstack-source-id: b49c92e0425f902201d8c1af11ff76b2f5007e38
Pull Request resolved: #124021
[ghstack-poisoned]
jgong5 added a commit that referenced this pull request Apr 18, 2024
ghstack-source-id: e226d12bd8a3e74c011aa684483ee00ddab56013
Pull Request resolved: #124021
[ghstack-poisoned]
jgong5 added a commit that referenced this pull request Apr 18, 2024
ghstack-source-id: f918cd8829709fa4a183555bc3c0a8c9d51604e0
Pull Request resolved: #124021
[ghstack-poisoned]
jgong5 added a commit that referenced this pull request Apr 26, 2024
ghstack-source-id: 4a203e5590639ab2c921ffaad59b86210cfe0c77
Pull Request resolved: #124021
[ghstack-poisoned]
jgong5 added a commit that referenced this pull request Apr 27, 2024
ghstack-source-id: f867714acbbdce8454c4882ac37c6cc35758c53e
Pull Request resolved: #124021
[ghstack-poisoned]
leslie-fang-intel pushed a commit to leslie-fang-intel/pytorch that referenced this pull request May 23, 2024
ghstack-source-id: 0295b702999021f6fc156909f69db661e9900637
Pull Request resolved: pytorch#124021
pytorchmergebot pushed a commit that referenced this pull request May 23, 2024
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result.

Pull Request resolved: #126019
Approved by: https://github.com/jansel
ghstack dependencies: #124021
pytorchmergebot pushed a commit that referenced this pull request May 23, 2024
…ue fusion (#126068)

As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR.

Pull Request resolved: #126068
Approved by: https://github.com/jansel
ghstack dependencies: #124021, #126019
pytorchmergebot pushed a commit that referenced this pull request May 23, 2024
)

As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows:
1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling.
2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops.
3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen.
4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality.

Pull Request resolved: #126545
Approved by: https://github.com/jansel
ghstack dependencies: #124021, #126019, #126068
@DanilBaibak
Copy link
Contributor

@pytorchbot revert -m "Broken trunk" -c ignoredsignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request May 23, 2024
pytorchmergebot added a commit that referenced this pull request May 23, 2024
…o epilogue fusion (#126068)"

This reverts commit 31412cb.

Reverted #126068 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](#124021 (comment)))
pytorchmergebot added a commit that referenced this pull request May 23, 2024
pytorchmergebot added a commit that referenced this pull request May 23, 2024
@pytorchmergebot
Copy link
Collaborator

@jgong5 your PR has been successfully reverted.

@DanilBaibak
Copy link
Contributor

Hi @jgong5! Sorry, I need to revert your PR because it broke the linter and a bunch of linux jobs. Here you can find more information.

@jgong5
Copy link
Collaborator Author

jgong5 commented May 23, 2024

Hi @jgong5! Sorry, I need to revert your PR because it broke the linter and a bunch of linux jobs. Here you can find more information.

Hi @DanilBaibak The error message you were referring to is related to another PR (#126545) on top of the stack. Is there any issue for this one so that you reverted here?

--Edit--
Ah, I can repro the problem. Thanks. Weird, it was not shown up previously.

This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info.
1. Cpp template infrastructure
Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates.
2. Initial FP32 gemm template
This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction.
3. Correctness and performance
The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details.

Static shapes
| Benchmark | torchbench | huggingface | timm_models |
|------------|-------------|--------------|--------------|
| Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x |
| Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x |
| Single-threaded (baseline) | 1.56x | 1.19x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x |

Key models being sped up:
drq: 1.14x
soft_act: 1.12
cait_m36_384: 1.18x

Dynamic shapes
| Benchmark | torchbench | huggingface | timm_models |
| --- | --- | --- | --- |
| Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x |
| Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x |
| Single-threaded (baseline) | 1.55x | 1.20x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x |

Key models being sped up:
BERT_pytorch: 1.22x
pyhpc_turbulent: 1.13x
soft_actor_critic: 1.77x
BlenderbotForCausalLM: 1.09x
cait_m36_384: 1.17x

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

Differential Revision: [D57585365](https://our.internmc.facebook.com/intern/diff/D57585365)

[ghstack-poisoned]
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info.
1. Cpp template infrastructure
Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates.
2. Initial FP32 gemm template
This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction.
3. Correctness and performance
The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details.

Static shapes
| Benchmark | torchbench | huggingface | timm_models |
|------------|-------------|--------------|--------------|
| Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x |
| Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x |
| Single-threaded (baseline) | 1.56x | 1.19x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x |

Key models being sped up:
drq: 1.14x
soft_act: 1.12
cait_m36_384: 1.18x

Dynamic shapes
| Benchmark | torchbench | huggingface | timm_models |
| --- | --- | --- | --- |
| Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x |
| Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x |
| Single-threaded (baseline) | 1.55x | 1.20x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x |

Key models being sped up:
BERT_pytorch: 1.22x
pyhpc_turbulent: 1.13x
soft_actor_critic: 1.77x
BlenderbotForCausalLM: 1.09x
cait_m36_384: 1.17x

cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

Differential Revision: [D57585365](https://our.internmc.facebook.com/intern/diff/D57585365)

[ghstack-poisoned]
@jgong5
Copy link
Collaborator Author

jgong5 commented May 24, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request May 24, 2024
ghstack-source-id: c04bd58b262d1b71d3a9c1b8fc097c7171db3784
Pull Request resolved: #124021
pytorchmergebot pushed a commit that referenced this pull request May 24, 2024
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result.

Pull Request resolved: #126019
Approved by: https://github.com/jansel
ghstack dependencies: #124021
pytorchmergebot pushed a commit that referenced this pull request May 24, 2024
…ue fusion (#126068)

As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR.

Pull Request resolved: #126068
Approved by: https://github.com/jansel
ghstack dependencies: #124021, #126019
pytorchmergebot pushed a commit that referenced this pull request May 24, 2024
)

As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows:
1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling.
2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops.
3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen.
4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality.

Pull Request resolved: #126545
Approved by: https://github.com/jansel
ghstack dependencies: #124021, #126019, #126068
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants