From 3b909b8644dcc443ef3bea98ec240d855bc794cd Mon Sep 17 00:00:00 2001 From: wufeisheng Date: Tue, 1 Nov 2022 07:14:40 +0000 Subject: [PATCH 1/6] add fuse_multi_transformer_layer_pass --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../ir/fuse_multi_transformer_layer_pass.cc | 322 ++++++++++++++++++ .../ir/fuse_multi_transformer_layer_pass.h | 60 ++++ .../fused_multi_transformer_decoder_pass.cc | 3 + .../fused_multi_transformer_encoder_pass.cc | 3 + paddle/fluid/framework/ir/pass.h | 4 + .../inference/api/paddle_pass_builder.cc | 1 + 7 files changed, 394 insertions(+) create mode 100644 paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc create mode 100644 paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 279ab07ff31b0..ed1ea67dc4105 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -107,6 +107,7 @@ pass_library(skip_layernorm_fuse_pass base) pass_library(multihead_matmul_fuse_pass inference) pass_library(fused_multi_transformer_encoder_pass inference) pass_library(fused_multi_transformer_decoder_pass inference) +pass_library(fuse_multi_transformer_layer_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(yolo_box_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc new file mode 100644 index 0000000000000..5209d286ec3d7 --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc @@ -0,0 +1,322 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h" + +#include + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +std::unordered_map +MultiTransformerLayerPattern::operator()(bool enable_int8, + int num_fused_op, + bool is_decoder) { + std::string fused_multi_transformer_name = + enable_int8 ? "fused_multi_transformer_int8" : "fused_multi_transformer"; + // This map is used to store node_reprs, 3 * i names will be inserted + // cache_kv0_{i}, cache_kv1_{i}, fill_constant_batch_size_like_{i} + std::unordered_map node_reprs; + + VLOG(0) << "num in pattern = " << num_fused_op; + // x0 and src_mask is unqiue input of subgraph + auto* x0 = pattern->NewNode(x0_repr()); + x0->assert_is_op_input(fused_multi_transformer_name, "X")->AsInput(); + + auto* src_mask = pattern->NewNode(src_mask_repr()); + src_mask->assert_is_op_input(fused_multi_transformer_name, "SrcMask") + ->AsInput(); + + for (int i = 0; i < num_fused_op; ++i) { + // fused_multi_transformer op + auto fuse_op_repr = + PDNodeName(name_scope_, repr_, id_, "fuse_op_" + std::to_string(i)); + node_reprs["fuse_op_" + std::to_string(i)] = fuse_op_repr; + auto* fused_multi_transformer = + pattern->NewNode(fuse_op_repr) + ->assert_is_op(fused_multi_transformer_name); + + // fused_multi_transformer output + auto out_repr = + PDNodeName(name_scope_, repr_, id_, "out_" + std::to_string(i)); + node_reprs["out_" + std::to_string(i)] = out_repr; + auto* out = pattern->NewNode(out_repr)->assert_is_op_output( + fused_multi_transformer_name, "Out"); + + // Links + if (is_decoder) { + fused_multi_transformer->LinksFrom({x0, src_mask}).LinksTo({out}); + } else { + // catch_kv + auto cache_kv_repr = + PDNodeName(name_scope_, repr_, id_, "cache_kv_" + std::to_string(i)); + node_reprs["cache_kv_" + std::to_string(i)] = cache_kv_repr; + auto* cache_kv = pattern->NewNode(cache_kv_repr); + cache_kv->assert_is_op_input(fused_multi_transformer_name, "CacheKV"); + cache_kv->AsInput(); + + // fill constant op is only valid in encoder + auto fill_const_op_repr = + PDNodeName(name_scope_, repr_, id_, "fill_op_" + std::to_string(i)); + node_reprs["fill_op_" + std::to_string(i)] = fill_const_op_repr; + auto fill_const_op = pattern->NewNode(fill_const_op_repr) + ->assert_is_op("fill_constant_batch_size_like"); + + fused_multi_transformer->LinksFrom({x0, src_mask, cache_kv}) + .LinksTo({out}); + fill_const_op->LinksFrom({x0}).LinksTo({cache_kv}); + } + x0 = out; + } + x0->AsOutput(); + return node_reprs; +} +} // namespace patterns + +inline void MergeInput(OpDesc* op, + const std::vector& input_name_maps, + const std::string& input_name) { + std::vector tmp = input_name_maps[0].at(input_name); + for (int i = 1; i < input_name_maps.size(); ++i) { + tmp.insert(tmp.end(), + input_name_maps[i].at(input_name).begin(), + input_name_maps[i].at(input_name).end()); + } + op->SetInput(input_name, tmp); +} + +template +inline void MergeAttrs(const std::vector& ops, + const std::string& attr_name) { + std::vector res; + for (int i = 0; i < ops.size(); ++i) { + auto scale_vec = + PADDLE_GET_CONST(std::vector, ops[i]->GetAttr(attr_name)); + res.insert(res.end(), scale_vec.begin(), scale_vec.end()); + } + ops[0]->SetAttr(attr_name, res); +} + +int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + VLOG(0) << "In build fusion"; + + bool enable_int8 = false; + if (graph->Has("enable_int8")) { + enable_int8 = graph->Get("enable_int8"); + } + if (!enable_int8) { + VLOG(4) + << "fuse_multi_transformer_layer_pass will match float transformer op " + "cause enable_int8 is not been set or set to false"; + } + int num_fuse_op = 0; + bool is_decoder = false; + if (graph->Has("enable_int8")) { + num_fuse_op = graph->Get("num_fused_multi_transformer_op"); + } + + if (graph->Has(kFusedMultiTransformerEncoderFusionCount)) { + num_fuse_op = graph->Get(kFusedMultiTransformerEncoderFusionCount); + is_decoder = false; + } else if (graph->Has(kFusedMultiTransformerDecoderFusionCount)) { + num_fuse_op = graph->Get(kFusedMultiTransformerDecoderFusionCount); + is_decoder = true; + } + if (num_fuse_op == 0) { + VLOG(4) << "fuse_multi_transformer_layer_pass will be skipped " + "cause num_fuse_op is not been set or set to 0"; + return 0; + } + if (!is_decoder) { + VLOG(4) << "fuse_multi_transformer_layer_pass will match encoder pattern " + "cause is_decoder is not been set or set to false"; + } + + patterns::MultiTransformerLayerPattern multi_layer_pattern(pattern, + name_scope); + auto node_reprs = multi_layer_pattern(enable_int8, num_fuse_op, is_decoder); + for (auto p : node_reprs) { + VLOG(0) << "key: " << p.first << " value: " << p.second; + } + + VLOG(0) << "Finish build pattern"; + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(0) << "handle FuseMultiTransformerLayerPass"; + VLOG(0) << "subgraph.size()" << subgraph.size(); + + GET_IR_NODE_FROM_SUBGRAPH(src_mask, src_mask, multi_layer_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(x0, x0, multi_layer_pattern); + + VLOG(0) << "Get input node"; + + std::vector fuse_op_nodes; + std::vector out_nodes; + + std::vector fuse_op_descs; + std::vector fuse_op_input_var_name_maps; + std::vector fuse_op_output_var_name_maps; + + for (int i = 0; i < num_fuse_op; ++i) { + // fused_multi_transformer op + PDNode* fuse_op_pdnode = + multi_layer_pattern.PatternBase::pattern->RetrieveNode( + node_reprs["fuse_op_" + std::to_string(i)]); + Node* fuse_op_node = subgraph.at(fuse_op_pdnode); + fuse_op_nodes.push_back(fuse_op_node); + fuse_op_descs.push_back(fuse_op_node->Op()); + fuse_op_input_var_name_maps.emplace_back(fuse_op_node->Op()->Inputs()); + fuse_op_output_var_name_maps.emplace_back(fuse_op_node->Op()->Outputs()); + + // fused_multi_transformer output + PDNode* out_pdnode = + multi_layer_pattern.PatternBase::pattern->RetrieveNode( + node_reprs["out_" + std::to_string(i)]); + out_nodes.push_back(subgraph.at(out_pdnode)); + + // fill_const op use x0 as input + if (!is_decoder && i != 0) { + // fill constant op + PDNode* fill_op_pdnode = + multi_layer_pattern.PatternBase::pattern->RetrieveNode( + node_reprs["fill_op_" + std::to_string(i)]); + Node* fill_op_node = subgraph.at(fill_op_pdnode); + fill_op_node->Op()->SetInput("Input", {x0->Name()}); + IR_NODE_UNLINK(out_nodes[i - 1], fill_op_node); + IR_NODE_LINK_TO(x0, fill_op_node); + } + } + + // Merge inputs + std::vector inputs_names = {"CacheKV", + "FFN1Bias", + "FFN1Weight", + "FFN2Bias", + "FFN2Weight", + "FFNLnBias", + "FFNLnScale", + "LnBias", + "LnScale", + "OutLinearBias", + "OutLinearW", + "QKVBias", + "QKVW"}; + if (enable_int8) { + std::vector inputs_names_int8_supp = { + "FFN1OutScale", "FFN2OutScale", "OutLinearOutScale", "QKVOutScale"}; + inputs_names.insert(inputs_names.end(), + inputs_names_int8_supp.begin(), + inputs_names_int8_supp.end()); + } + for (const auto& input_name : inputs_names) { + MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name); + } + VLOG(0) << "Finsh Merge input"; + + // Merge outputs + fuse_op_descs[0]->SetOutput( + "Out", fuse_op_output_var_name_maps[num_fuse_op - 1]["Out"]); + auto& merged_cache_kv_out_names = + fuse_op_output_var_name_maps[0]["CacheKVOut"]; + for (int i = 1; i < num_fuse_op; ++i) { + const auto& out_var_names = fuse_op_output_var_name_maps[i]["CacheKVOut"]; + merged_cache_kv_out_names.insert(merged_cache_kv_out_names.end(), + out_var_names.begin(), + out_var_names.end()); + } + // for (auto out_name : output_names0["CacheKVOut"]) { + // VLOG(0) << "out_name " << out_name; + // } + fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names); + + if (enable_int8) { + // Merge inputs scale + std::vector attr_names = {"qkv_in_scale", + "out_linear_in_scale", + "ffn1_in_scale", + "ffn2_in_scale"}; + for (const auto& name : attr_names) { + MergeAttrs(fuse_op_descs, name); + } + VLOG(0) << "Finsh Merge attrs"; + } + + // ReLink + // before relink, out nodes (0 -> num_layer-1) should be removed + std::unordered_set marked_out_nodes(out_nodes.begin(), + out_nodes.end() - 1); + GraphSafeRemoveNodes(graph, marked_out_nodes); + + auto& merged_inputs = fuse_op_nodes[0]->inputs; + for (int i = 1; i < num_fuse_op; ++i) { + merged_inputs.insert(merged_inputs.end(), + fuse_op_nodes[i]->inputs.begin(), + fuse_op_nodes[i]->inputs.end()); + } + + // Relink fuse op -> out + IR_NODE_UNLINK(fuse_op_nodes[num_fuse_op - 1], out_nodes[num_fuse_op - 1]); + IR_NODE_LINK_TO(fuse_op_nodes[0], out_nodes[num_fuse_op - 1]); + VLOG(0) << "Finsh relinks"; + + std::unordered_set marked_fuse_op_nodes( + fuse_op_nodes.begin() + 1, fuse_op_nodes.end()); + GraphSafeRemoveNodes(graph, marked_fuse_op_nodes); + VLOG(0) << "Finsh remove"; + ++fusion_count; + }; + + gpd(graph, handler); + return fusion_count; +} + +void FuseMultiTransformerLayerPass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal("During the fuse_multi_transformer_layer pass, " + "The scope should not be null.")); + int fusion_count = BuildFusion(graph, name_scope_, scope); + VLOG(0) << "fusion_count is " << fusion_count; + + // PD_THROW("IMULTILAYER"); + + AddStatis(fusion_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fuse_multi_transformer_layer_pass, + paddle::framework::ir::FuseMultiTransformerLayerPass); diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h new file mode 100644 index 0000000000000..339cc6815e223 --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct MultiTransformerLayerPattern : public PatternBase { + MultiTransformerLayerPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, "fuse_multi_transformer_layer") {} + + std::unordered_map operator()( + bool enable_int8, int num_fused_op = 1, bool is_decoder = false); + + PATTERN_DECL_NODE(src_mask); + PATTERN_DECL_NODE(x0); +}; + +} // namespace patterns + +class FuseMultiTransformerLayerPass : public FusePassBase { + public: + FuseMultiTransformerLayerPass() {} + virtual ~FuseMultiTransformerLayerPass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"fuse_multi_transformer_layer"}; + + private: + int BuildFusion(Graph* graph, + const std::string& name_scope, + Scope* scope) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc index 5559499e0b4b2..e2bec03af78d4 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc @@ -1671,6 +1671,7 @@ void FusedMultiTransformerDecoderPass::ApplyImpl(Graph* graph) const { int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { graph->Set(kFusedMultiTransformerDecoderPass, new bool(true)); + graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } @@ -2319,6 +2320,7 @@ void FusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(Graph* graph) const { int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true)); + graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } @@ -3009,6 +3011,7 @@ void MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::ApplyImpl( int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true)); + graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc index d93c29765649a..89d0e70c02c47 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc @@ -1835,6 +1835,7 @@ void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const { int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { graph->Set(kFusedMultiTransformerEncoderPass, new bool(true)); + graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } @@ -2517,6 +2518,7 @@ void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const { int fusion_count = BuildFusion(graph, name_scope_, scope); if (fusion_count > 0) { graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true)); + graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } @@ -3243,6 +3245,7 @@ void MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::ApplyImpl( if (fusion_count > 0) { graph->Set(kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass, new bool(true)); + graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count)); } AddStatis(fusion_count); } diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 2ed753cdeb717..e0315f0b5b741 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -59,6 +59,10 @@ constexpr char kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass[] = "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass_flag"; constexpr char kMultiDevicesFusedMultiTransformerDecoderFuseQKVPass[] = "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass_flag"; +constexpr char kFusedMultiTransformerEncoderFusionCount[] = + "fused_multi_transformer_encoder_fusion_count"; +constexpr char kFusedMultiTransformerDecoderFusionCount[] = + "fused_multi_transformer_decoder_fusion_count"; constexpr char kPrelnEmbEltwiseLayernormPass[] = "preln_embedding_eltwise_layernorm_fuse_pass_flag"; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 59ebbb5764a56..9a7a646bc0513 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -206,6 +206,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "fused_multi_transformer_decoder_fuse_qkv_pass", // "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", // "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", // + "fuse_multi_transformer_layer_pass", // "gpu_cpu_squeeze2_matmul_fuse_pass", // "gpu_cpu_reshape2_matmul_fuse_pass", // "gpu_cpu_flatten2_matmul_fuse_pass", // From cc5e3fcce3fcccac8f9d202308084208c7afaf5b Mon Sep 17 00:00:00 2001 From: wufeisheng Date: Tue, 1 Nov 2022 09:16:36 +0000 Subject: [PATCH 2/6] fuse_multi_transformer_layer_pass need to support sub-graph --- paddle/fluid/framework/ir/pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 74ad71f37da69..4ad93183996fa 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -46,7 +46,7 @@ static const std::vector support_subgraph_passes = { "fused_multi_transformer_decoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", -}; + "fuse_multi_transformer_layer_pass"}; Graph *Pass::Apply(Graph *graph) const { VLOG(10) << "start to apply pass " << Type() << " to graph"; From e152480a9f41c23cb2e6315fd142659451318ed9 Mon Sep 17 00:00:00 2001 From: wufeisheng Date: Tue, 1 Nov 2022 11:30:28 +0000 Subject: [PATCH 3/6] delete slice/shape op in decoder --- .../ir/fuse_multi_transformer_layer_pass.cc | 74 +++++++++++++++++-- 1 file changed, 68 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc index 5209d286ec3d7..09436cd5668f1 100644 --- a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc @@ -68,7 +68,36 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, // Links if (is_decoder) { - fused_multi_transformer->LinksFrom({x0, src_mask}).LinksTo({out}); + // shape and shape out + auto shape_repr = + PDNodeName(name_scope_, repr_, id_, "shape_" + std::to_string(i)); + node_reprs["shape_" + std::to_string(i)] = shape_repr; + auto* shape = pattern->NewNode(shape_repr)->assert_is_op("shape"); + + auto shape_out_repr = + PDNodeName(name_scope_, repr_, id_, "shape_out_" + std::to_string(i)); + node_reprs["shape_out_" + std::to_string(i)] = shape_out_repr; + auto* shape_out = + pattern->NewNode(shape_out_repr)->assert_is_op_output("shape", "Out"); + + shape->LinksFrom({src_mask}).LinksTo({shape_out}); + + // slice and slice out + auto slice_repr = + PDNodeName(name_scope_, repr_, id_, "slice_" + std::to_string(i)); + node_reprs["slice_" + std::to_string(i)] = slice_repr; + auto* slice = pattern->NewNode(slice_repr)->assert_is_op("slice"); + + auto slice_out_repr = + PDNodeName(name_scope_, repr_, id_, "slice_out_" + std::to_string(i)); + node_reprs["slice_out_" + std::to_string(i)] = slice_out_repr; + auto* slice_out = + pattern->NewNode(slice_out_repr)->assert_is_op_output("slice", "Out"); + + slice->LinksFrom({shape_out}).LinksTo({slice_out}); + + fused_multi_transformer->LinksFrom({x0, src_mask, slice_out}) + .LinksTo({out}); } else { // catch_kv auto cache_kv_repr = @@ -155,8 +184,9 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, return 0; } if (!is_decoder) { - VLOG(4) << "fuse_multi_transformer_layer_pass will match encoder pattern " - "cause is_decoder is not been set or set to false"; + VLOG(4) << "fuse_multi_transformer_layer_pass will match encoder pattern"; + } else { + VLOG(4) << "fuse_multi_transformer_layer_pass will match decoder pattern"; } patterns::MultiTransformerLayerPattern multi_layer_pattern(pattern, @@ -173,6 +203,10 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, VLOG(0) << "handle FuseMultiTransformerLayerPass"; VLOG(0) << "subgraph.size()" << subgraph.size(); + /////////////////// + //// Get nodes //// + /////////////////// + GET_IR_NODE_FROM_SUBGRAPH(src_mask, src_mask, multi_layer_pattern); GET_IR_NODE_FROM_SUBGRAPH(x0, x0, multi_layer_pattern); @@ -182,6 +216,10 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, std::vector fuse_op_nodes; std::vector out_nodes; + std::vector unused_node_prefixes = { + "shape_", "shape_out_", "slice_", "slice_out_"}; + std::vector unused_nodes; + std::vector fuse_op_descs; std::vector fuse_op_input_var_name_maps; std::vector fuse_op_output_var_name_maps; @@ -213,9 +251,21 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, fill_op_node->Op()->SetInput("Input", {x0->Name()}); IR_NODE_UNLINK(out_nodes[i - 1], fill_op_node); IR_NODE_LINK_TO(x0, fill_op_node); + } else if (is_decoder && i != 0) { + for (const auto& unused_node_prefix : unused_node_prefixes) { + PDNode* unused_pdnode = + multi_layer_pattern.PatternBase::pattern->RetrieveNode( + node_reprs[unused_node_prefix + std::to_string(i)]); + Node* unused_node = subgraph.at(unused_pdnode); + unused_nodes.push_back(unused_node); + } } } + /////////////// + //// Merge //// + /////////////// + // Merge inputs std::vector inputs_names = {"CacheKV", "FFN1Bias", @@ -269,13 +319,15 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, } VLOG(0) << "Finsh Merge attrs"; } - - // ReLink - // before relink, out nodes (0 -> num_layer-1) should be removed + //////////////// + //// ReLink //// + //////////////// + // Before relink, out nodes (0 -> num_layer-1) should be removed std::unordered_set marked_out_nodes(out_nodes.begin(), out_nodes.end() - 1); GraphSafeRemoveNodes(graph, marked_out_nodes); + // Relink all input nodes of fused_multi_transformer ops to the first op auto& merged_inputs = fuse_op_nodes[0]->inputs; for (int i = 1; i < num_fuse_op; ++i) { merged_inputs.insert(merged_inputs.end(), @@ -288,8 +340,18 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, IR_NODE_LINK_TO(fuse_op_nodes[0], out_nodes[num_fuse_op - 1]); VLOG(0) << "Finsh relinks"; + ///////////////////////////// + //// Delete unused nodes //// + ///////////////////////////// + // Delete fused_multi_transformer op expect for the first one std::unordered_set marked_fuse_op_nodes( fuse_op_nodes.begin() + 1, fuse_op_nodes.end()); + + // Delete shape/slice op in decoder subgraph + if (is_decoder) { + marked_fuse_op_nodes.insert(unused_nodes.begin(), unused_nodes.end()); + } + GraphSafeRemoveNodes(graph, marked_fuse_op_nodes); VLOG(0) << "Finsh remove"; ++fusion_count; From 15ea2d70de355fcbfff5eeaee6a1f3d64482f929 Mon Sep 17 00:00:00 2001 From: wufeisheng Date: Wed, 2 Nov 2022 04:06:07 +0000 Subject: [PATCH 4/6] add unit test --- paddle/fluid/framework/ir/CMakeLists.txt | 4 + .../ir/fuse_multi_transformer_layer_pass.cc | 39 +--- ...use_multi_transformer_layer_pass_tester.cc | 199 ++++++++++++++++++ .../fluid/framework/ir/pass_tester_helper.h | 113 ++++++++++ 4 files changed, 324 insertions(+), 31 deletions(-) create mode 100644 paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ed1ea67dc4105..952bd44377b4d 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -324,6 +324,10 @@ cc_test( test_fused_multi_transformer_decoder_pass SRCS fused_multi_transformer_decoder_pass_tester.cc DEPS fused_multi_transformer_decoder_pass) +cc_test( + test_fuse_multi_transformer_layer_pass + SRCS fuse_multi_transformer_layer_pass_tester.cc + DEPS fuse_multi_transformer_layer_pass) cc_test( test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc index 09436cd5668f1..9def7f0a22743 100644 --- a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc @@ -129,7 +129,7 @@ inline void MergeInput(OpDesc* op, const std::vector& input_name_maps, const std::string& input_name) { std::vector tmp = input_name_maps[0].at(input_name); - for (int i = 1; i < input_name_maps.size(); ++i) { + for (size_t i = 1; i < input_name_maps.size(); ++i) { tmp.insert(tmp.end(), input_name_maps[i].at(input_name).begin(), input_name_maps[i].at(input_name).end()); @@ -141,7 +141,7 @@ template inline void MergeAttrs(const std::vector& ops, const std::string& attr_name) { std::vector res; - for (int i = 0; i < ops.size(); ++i) { + for (size_t i = 0; i < ops.size(); ++i) { auto scale_vec = PADDLE_GET_CONST(std::vector, ops[i]->GetAttr(attr_name)); res.insert(res.end(), scale_vec.begin(), scale_vec.end()); @@ -156,25 +156,19 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, auto* pattern = gpd.mutable_pattern(); VLOG(0) << "In build fusion"; + // TODO(wufeisheng): Get enable_int8 attr from graph after + // fused_multi_transformer pass with int8 merged bool enable_int8 = false; - if (graph->Has("enable_int8")) { - enable_int8 = graph->Get("enable_int8"); - } - if (!enable_int8) { - VLOG(4) - << "fuse_multi_transformer_layer_pass will match float transformer op " - "cause enable_int8 is not been set or set to false"; - } + int num_fuse_op = 0; bool is_decoder = false; - if (graph->Has("enable_int8")) { - num_fuse_op = graph->Get("num_fused_multi_transformer_op"); - } if (graph->Has(kFusedMultiTransformerEncoderFusionCount)) { + VLOG(0) << "encoder fusion count"; num_fuse_op = graph->Get(kFusedMultiTransformerEncoderFusionCount); is_decoder = false; } else if (graph->Has(kFusedMultiTransformerDecoderFusionCount)) { + VLOG(0) << "decoder fusion count"; num_fuse_op = graph->Get(kFusedMultiTransformerDecoderFusionCount); is_decoder = true; } @@ -280,13 +274,7 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, "OutLinearW", "QKVBias", "QKVW"}; - if (enable_int8) { - std::vector inputs_names_int8_supp = { - "FFN1OutScale", "FFN2OutScale", "OutLinearOutScale", "QKVOutScale"}; - inputs_names.insert(inputs_names.end(), - inputs_names_int8_supp.begin(), - inputs_names_int8_supp.end()); - } + for (const auto& input_name : inputs_names) { MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name); } @@ -308,17 +296,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, // } fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names); - if (enable_int8) { - // Merge inputs scale - std::vector attr_names = {"qkv_in_scale", - "out_linear_in_scale", - "ffn1_in_scale", - "ffn2_in_scale"}; - for (const auto& name : attr_names) { - MergeAttrs(fuse_op_descs, name); - } - VLOG(0) << "Finsh Merge attrs"; - } //////////////// //// ReLink //// //////////////// diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc new file mode 100644 index 0000000000000..5ee30bee8d874 --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc @@ -0,0 +1,199 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" + +#define DEF_INPUT_DATA \ + Layers layers; \ + int num_layers = 3; \ + auto* x = layers.data("x", {1, 128, 1024}); \ + auto* src_mask = layers.data("src_mask", {1, 16, 128, 128}); \ + auto* ln_scale = layers.data("ln_scale", {1024}, true); \ + auto* ln_bias = layers.data("ln_bias", {1024}, true); \ + auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); \ + auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); \ + auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true); \ + auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true); \ + auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true); \ + auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true); \ + auto* qkv_bias = layers.data("qkv_bias", {3072}, true); \ + auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true); \ + auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true); \ + auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true); + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, + const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "ln_scale", {1024}); + AddVarToScope(param_scope, "ln_bias", {1024}); + AddVarToScope(param_scope, "ffn_ln_scale", {1024}); + AddVarToScope(param_scope, "ffn_ln_bias", {1024}); + + AddVarToScope(param_scope, "qkv_w", {3, 16, 64, 1024}); + AddVarToScope(param_scope, "out_linear_w", {1024, 1024}); + AddVarToScope(param_scope, "ffn1_w", {1024, 4096}); + AddVarToScope(param_scope, "ffn2_w", {4096, 1024}); + AddVarToScope(param_scope, "qkv_bias", {3072}); + AddVarToScope(param_scope, "out_linear_bias", {1024}); + AddVarToScope(param_scope, "ffn1_bias", {4096}); + AddVarToScope(param_scope, "ffn2_bias", {1024}); + + return param_scope; +} +TEST(FuseMultiTransformerLayerPass, encoder_fp) { + // Layers layers; + // int num_layers = 3; + // // Vars + // auto* x = layers.data("x", {1, 128, 1024}); + // auto* src_mask = layers.data("src_mask", {1, 16, 128, 128}); + + // auto* ln_scale = layers.data("ln_scale", {1024}, true); + // auto* ln_bias = layers.data("ln_bias", {1024}, true); + // auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); + // auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); + // auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true); + // auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true); + // auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true); + // auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true); + // auto* qkv_bias = layers.data("qkv_bias", {3072}, true); + // auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true); + // auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true); + // auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true); + + DEF_INPUT_DATA + + // Layers + for (int i = 0; i < num_layers; ++i) { + std::cout << "begin to add fill const layer " << i << std::endl; + auto* cache_kv = layers.fill_constant_batch_size_like( + x, + static_cast(proto::VarType::FP32), + 0, + 1, + {2, -1, 16, 1024, 64}, + 0); + std::cout << "begin to add fused_multi_transformer layer " << i + << std::endl; + auto* out = layers.fused_multi_transformer(x, + cache_kv, + src_mask, + qkv_w, + qkv_bias, + out_linear_w, + out_linear_bias, + ffn1_w, + ffn1_bias, + ffn2_w, + ffn2_bias, + ln_scale, + ln_bias, + ffn_ln_scale, + ffn_ln_bias, + 0.1, + 1e-12); + + x = out; + } + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(num_layers)); + + auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass"); + if (pass.get() == nullptr) + LOG(INFO) << "get fuse_multi_transformer_layer_pass failed"; + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); + + PADDLE_ENFORCE_EQ( + num_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fuse_multi_transformer_layer_pass, " + "The node num in graph should be 1, but the result is %d", + num_nodes_after)); +} +TEST(FuseMultiTransformerLayerPass, decoder_fp) { + DEF_INPUT_DATA + + x = layers.data("x", {1, 1, 1024}); + auto* cache_kv = layers.data("cache_kv", {2, 1, 16, 1024, 64}, true); + src_mask = layers.data("src_mask", {1, 16, 1, 129}); + + // Layers + for (int i = 0; i < num_layers; ++i) { + auto* shape_out = layers.shape(src_mask); + auto* time_stamp = layers.slice(shape_out, {0}, {3}, {4}); + std::cout << "begin to add fused_multi_transformer layer " << i + << std::endl; + auto* out = layers.fused_multi_transformer(x, + cache_kv, + src_mask, + qkv_w, + qkv_bias, + out_linear_w, + out_linear_bias, + ffn1_w, + ffn1_bias, + ffn2_w, + ffn2_bias, + ln_scale, + ln_bias, + ffn_ln_scale, + ffn_ln_bias, + 0.1, + 1e-12, + time_stamp); + + x = out; + } + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto param_scope = CreateParamScope(); + AddVarToScope(param_scope, "cache_kv", {2, 1, 16, 1024, 64}); + graph->Set("__param_scope__", param_scope); + + graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(num_layers)); + + auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass"); + if (pass.get() == nullptr) + LOG(INFO) << "get fuse_multi_transformer_layer_pass failed"; + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer"); + + PADDLE_ENFORCE_EQ( + num_nodes_after, + 1, + platform::errors::InvalidArgument( + "After the fuse_multi_transformer_layer_pass, " + "The node num in graph should be 1, but the result is %d", + num_nodes_after)); +} +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fuse_multi_transformer_layer_pass); diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 3cce19e10c682..48f8cb37d60a9 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -528,6 +528,119 @@ struct Layers { return out; } + VarDesc* shape(VarDesc* input) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("shape"); + op->SetInput("Input", {input->Name()}); + op->SetOutput("Out", {out->Name()}); + return out; + } + + VarDesc* slice(VarDesc* input, + std::vector axes, + std::vector starts, + std::vector ends) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("slice"); + op->SetInput("Input", {input->Name()}); + op->SetOutput("Out", {out->Name()}); + op->SetAttr("axes", axes); + op->SetAttr("starts", starts); + op->SetAttr("ends", ends); + return out; + } + + VarDesc* fill_constant_batch_size_like(VarDesc* x, + int dtype, + int input_dim_idx, + int output_dim_idx, + std::vector shape, + float value) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("fill_constant_batch_size_like"); + op->SetInput("Input", {x->Name()}); + op->SetAttr("dtype", dtype); + op->SetAttr("input_dim_idx", input_dim_idx); + op->SetAttr("output_dim_idx", output_dim_idx); + op->SetAttr("shape", shape); + op->SetAttr("value", value); + op->SetOutput("Out", {out->Name()}); + return out; + } + + VarDesc* fused_multi_transformer(VarDesc* x, + VarDesc* cache_kv, + VarDesc* src_mask, + VarDesc* qkv_w, + VarDesc* qkv_bias, + VarDesc* out_linear_w, + VarDesc* out_linear_bias, + VarDesc* ffn1_w, + VarDesc* ffn1_bias, + VarDesc* ffn2_w, + VarDesc* ffn2_bias, + VarDesc* ln_scale, + VarDesc* ln_bias, + VarDesc* ffn_ln_scale, + VarDesc* ffn_ln_bias, + float epsilon, + float dropout_rate, + VarDesc* time_stamp = nullptr, + VarDesc* qkv_out_scale = nullptr, + VarDesc* out_linear_out_scale = nullptr, + VarDesc* ffn1_out_scale = nullptr, + VarDesc* ffn2_out_scale = nullptr, + std::vector qkv_in_scale = {}, + std::vector out_linear_in_scale = {}, + std::vector ffn1_in_scale = {}, + std::vector ffn2_in_scale = {}) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + std::string op_type = qkv_out_scale ? "fused_multi_transformer_int8" + : "fused_multi_transformer"; + op->SetType(op_type); + op->SetInput("X", {x->Name()}); + op->SetInput("CacheKV", {cache_kv->Name()}); + op->SetInput("SrcMask", {src_mask->Name()}); + op->SetInput("QKVW", {qkv_w->Name()}); + op->SetInput("QKVBias", {qkv_bias->Name()}); + op->SetInput("OutLinearW", {out_linear_w->Name()}); + op->SetInput("OutLinearBias", {out_linear_bias->Name()}); + op->SetInput("FFN1Weight", {ffn1_w->Name()}); + op->SetInput("FFN1Bias", {ffn1_bias->Name()}); + op->SetInput("FFN2Weight", {ffn2_w->Name()}); + op->SetInput("FFN2Bias", {ffn2_bias->Name()}); + op->SetInput("LnScale", {ln_scale->Name()}); + op->SetInput("LnBias", {ln_bias->Name()}); + op->SetInput("FFNLnScale", {ffn_ln_scale->Name()}); + op->SetInput("FFNLnBias", {ffn_ln_bias->Name()}); + op->SetAttr("pre_layer_norm", true); + op->SetAttr("is_test", true); + op->SetAttr("dropout_implementation", "upscale_in_train"); + op->SetAttr("dropout_rate", dropout_rate); + op->SetAttr("epsilon", epsilon); + op->SetOutput("Out", {out->Name()}); + + if (time_stamp) { + op->SetInput("TimeStep", {time_stamp->Name()}); + } + + if (qkv_out_scale) { + op->SetInput("QKVOutScale", {qkv_out_scale->Name()}); + op->SetInput("OutLinearOutScale", {out_linear_out_scale->Name()}); + op->SetInput("FFN1OutScale", {ffn1_out_scale->Name()}); + op->SetInput("FFN2OutScale", {ffn2_out_scale->Name()}); + op->SetAttr("qkv_in_scale", qkv_in_scale); + op->SetAttr("out_linear_in_scale", out_linear_in_scale); + op->SetAttr("ffn1_in_scale", ffn1_in_scale); + op->SetAttr("ffn2_in_scale", ffn2_in_scale); + } + return out; + } + void backward(std::vector targets) { // This function is designed to simulate the structure of training program, // but is constructed differently as the actual program. From 084c6308e725ae15e1ef4ca5cc2e2a762cbdeb03 Mon Sep 17 00:00:00 2001 From: wufeisheng Date: Thu, 3 Nov 2022 02:38:41 +0000 Subject: [PATCH 5/6] delete debug codes --- .../ir/fuse_multi_transformer_layer_pass.cc | 27 +------------------ ...use_multi_transformer_layer_pass_tester.cc | 24 ----------------- 2 files changed, 1 insertion(+), 50 deletions(-) diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc index 9def7f0a22743..90f2b0fce1fea 100644 --- a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc @@ -37,15 +37,12 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, bool is_decoder) { std::string fused_multi_transformer_name = enable_int8 ? "fused_multi_transformer_int8" : "fused_multi_transformer"; - // This map is used to store node_reprs, 3 * i names will be inserted - // cache_kv0_{i}, cache_kv1_{i}, fill_constant_batch_size_like_{i} + // This map is used to store node_reprs std::unordered_map node_reprs; - VLOG(0) << "num in pattern = " << num_fused_op; // x0 and src_mask is unqiue input of subgraph auto* x0 = pattern->NewNode(x0_repr()); x0->assert_is_op_input(fused_multi_transformer_name, "X")->AsInput(); - auto* src_mask = pattern->NewNode(src_mask_repr()); src_mask->assert_is_op_input(fused_multi_transformer_name, "SrcMask") ->AsInput(); @@ -154,7 +151,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, Scope* scope) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); - VLOG(0) << "In build fusion"; // TODO(wufeisheng): Get enable_int8 attr from graph after // fused_multi_transformer pass with int8 merged @@ -164,11 +160,9 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, bool is_decoder = false; if (graph->Has(kFusedMultiTransformerEncoderFusionCount)) { - VLOG(0) << "encoder fusion count"; num_fuse_op = graph->Get(kFusedMultiTransformerEncoderFusionCount); is_decoder = false; } else if (graph->Has(kFusedMultiTransformerDecoderFusionCount)) { - VLOG(0) << "decoder fusion count"; num_fuse_op = graph->Get(kFusedMultiTransformerDecoderFusionCount); is_decoder = true; } @@ -186,27 +180,17 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, patterns::MultiTransformerLayerPattern multi_layer_pattern(pattern, name_scope); auto node_reprs = multi_layer_pattern(enable_int8, num_fuse_op, is_decoder); - for (auto p : node_reprs) { - VLOG(0) << "key: " << p.first << " value: " << p.second; - } - VLOG(0) << "Finish build pattern"; int fusion_count{0}; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { - VLOG(0) << "handle FuseMultiTransformerLayerPass"; - VLOG(0) << "subgraph.size()" << subgraph.size(); - /////////////////// //// Get nodes //// /////////////////// GET_IR_NODE_FROM_SUBGRAPH(src_mask, src_mask, multi_layer_pattern); - GET_IR_NODE_FROM_SUBGRAPH(x0, x0, multi_layer_pattern); - VLOG(0) << "Get input node"; - std::vector fuse_op_nodes; std::vector out_nodes; @@ -278,7 +262,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, for (const auto& input_name : inputs_names) { MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name); } - VLOG(0) << "Finsh Merge input"; // Merge outputs fuse_op_descs[0]->SetOutput( @@ -291,9 +274,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, out_var_names.begin(), out_var_names.end()); } - // for (auto out_name : output_names0["CacheKVOut"]) { - // VLOG(0) << "out_name " << out_name; - // } fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names); //////////////// @@ -315,7 +295,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, // Relink fuse op -> out IR_NODE_UNLINK(fuse_op_nodes[num_fuse_op - 1], out_nodes[num_fuse_op - 1]); IR_NODE_LINK_TO(fuse_op_nodes[0], out_nodes[num_fuse_op - 1]); - VLOG(0) << "Finsh relinks"; ///////////////////////////// //// Delete unused nodes //// @@ -330,7 +309,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, } GraphSafeRemoveNodes(graph, marked_fuse_op_nodes); - VLOG(0) << "Finsh remove"; ++fusion_count; }; @@ -346,9 +324,6 @@ void FuseMultiTransformerLayerPass::ApplyImpl(Graph* graph) const { platform::errors::Fatal("During the fuse_multi_transformer_layer pass, " "The scope should not be null.")); int fusion_count = BuildFusion(graph, name_scope_, scope); - VLOG(0) << "fusion_count is " << fusion_count; - - // PD_THROW("IMULTILAYER"); AddStatis(fusion_count); } diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc index 5ee30bee8d874..72635d1c95855 100644 --- a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc @@ -64,30 +64,10 @@ Scope* CreateParamScope() { return param_scope; } TEST(FuseMultiTransformerLayerPass, encoder_fp) { - // Layers layers; - // int num_layers = 3; - // // Vars - // auto* x = layers.data("x", {1, 128, 1024}); - // auto* src_mask = layers.data("src_mask", {1, 16, 128, 128}); - - // auto* ln_scale = layers.data("ln_scale", {1024}, true); - // auto* ln_bias = layers.data("ln_bias", {1024}, true); - // auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); - // auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); - // auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true); - // auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true); - // auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true); - // auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true); - // auto* qkv_bias = layers.data("qkv_bias", {3072}, true); - // auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true); - // auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true); - // auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true); - DEF_INPUT_DATA // Layers for (int i = 0; i < num_layers; ++i) { - std::cout << "begin to add fill const layer " << i << std::endl; auto* cache_kv = layers.fill_constant_batch_size_like( x, static_cast(proto::VarType::FP32), @@ -95,8 +75,6 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) { 1, {2, -1, 16, 1024, 64}, 0); - std::cout << "begin to add fused_multi_transformer layer " << i - << std::endl; auto* out = layers.fused_multi_transformer(x, cache_kv, src_mask, @@ -147,8 +125,6 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) { for (int i = 0; i < num_layers; ++i) { auto* shape_out = layers.shape(src_mask); auto* time_stamp = layers.slice(shape_out, {0}, {3}, {4}); - std::cout << "begin to add fused_multi_transformer layer " << i - << std::endl; auto* out = layers.fused_multi_transformer(x, cache_kv, src_mask, From f0fdfeae092aaface6f8ce559e28cac9fe568f26 Mon Sep 17 00:00:00 2001 From: wufeisheng Date: Tue, 8 Nov 2022 10:35:20 +0800 Subject: [PATCH 6/6] delete unnecessary notes --- .../ir/fuse_multi_transformer_layer_pass.cc | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc index 90f2b0fce1fea..4e2bca2ae2a97 100644 --- a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc @@ -37,7 +37,7 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, bool is_decoder) { std::string fused_multi_transformer_name = enable_int8 ? "fused_multi_transformer_int8" : "fused_multi_transformer"; - // This map is used to store node_reprs + std::unordered_map node_reprs; // x0 and src_mask is unqiue input of subgraph @@ -48,7 +48,6 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, ->AsInput(); for (int i = 0; i < num_fused_op; ++i) { - // fused_multi_transformer op auto fuse_op_repr = PDNodeName(name_scope_, repr_, id_, "fuse_op_" + std::to_string(i)); node_reprs["fuse_op_" + std::to_string(i)] = fuse_op_repr; @@ -56,16 +55,13 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, pattern->NewNode(fuse_op_repr) ->assert_is_op(fused_multi_transformer_name); - // fused_multi_transformer output auto out_repr = PDNodeName(name_scope_, repr_, id_, "out_" + std::to_string(i)); node_reprs["out_" + std::to_string(i)] = out_repr; auto* out = pattern->NewNode(out_repr)->assert_is_op_output( fused_multi_transformer_name, "Out"); - // Links if (is_decoder) { - // shape and shape out auto shape_repr = PDNodeName(name_scope_, repr_, id_, "shape_" + std::to_string(i)); node_reprs["shape_" + std::to_string(i)] = shape_repr; @@ -79,7 +75,6 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, shape->LinksFrom({src_mask}).LinksTo({shape_out}); - // slice and slice out auto slice_repr = PDNodeName(name_scope_, repr_, id_, "slice_" + std::to_string(i)); node_reprs["slice_" + std::to_string(i)] = slice_repr; @@ -96,7 +91,6 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, fused_multi_transformer->LinksFrom({x0, src_mask, slice_out}) .LinksTo({out}); } else { - // catch_kv auto cache_kv_repr = PDNodeName(name_scope_, repr_, id_, "cache_kv_" + std::to_string(i)); node_reprs["cache_kv_" + std::to_string(i)] = cache_kv_repr; @@ -104,7 +98,6 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, cache_kv->assert_is_op_input(fused_multi_transformer_name, "CacheKV"); cache_kv->AsInput(); - // fill constant op is only valid in encoder auto fill_const_op_repr = PDNodeName(name_scope_, repr_, id_, "fill_op_" + std::to_string(i)); node_reprs["fill_op_" + std::to_string(i)] = fill_const_op_repr; @@ -203,7 +196,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, std::vector fuse_op_output_var_name_maps; for (int i = 0; i < num_fuse_op; ++i) { - // fused_multi_transformer op PDNode* fuse_op_pdnode = multi_layer_pattern.PatternBase::pattern->RetrieveNode( node_reprs["fuse_op_" + std::to_string(i)]); @@ -213,7 +205,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, fuse_op_input_var_name_maps.emplace_back(fuse_op_node->Op()->Inputs()); fuse_op_output_var_name_maps.emplace_back(fuse_op_node->Op()->Outputs()); - // fused_multi_transformer output PDNode* out_pdnode = multi_layer_pattern.PatternBase::pattern->RetrieveNode( node_reprs["out_" + std::to_string(i)]); @@ -221,7 +212,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, // fill_const op use x0 as input if (!is_decoder && i != 0) { - // fill constant op PDNode* fill_op_pdnode = multi_layer_pattern.PatternBase::pattern->RetrieveNode( node_reprs["fill_op_" + std::to_string(i)]); @@ -303,7 +293,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, std::unordered_set marked_fuse_op_nodes( fuse_op_nodes.begin() + 1, fuse_op_nodes.end()); - // Delete shape/slice op in decoder subgraph if (is_decoder) { marked_fuse_op_nodes.insert(unused_nodes.begin(), unused_nodes.end()); }