diff --git a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc index 90f2b0fce1fea..4e2bca2ae2a97 100644 --- a/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc +++ b/paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc @@ -37,7 +37,7 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, bool is_decoder) { std::string fused_multi_transformer_name = enable_int8 ? "fused_multi_transformer_int8" : "fused_multi_transformer"; - // This map is used to store node_reprs + std::unordered_map node_reprs; // x0 and src_mask is unqiue input of subgraph @@ -48,7 +48,6 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, ->AsInput(); for (int i = 0; i < num_fused_op; ++i) { - // fused_multi_transformer op auto fuse_op_repr = PDNodeName(name_scope_, repr_, id_, "fuse_op_" + std::to_string(i)); node_reprs["fuse_op_" + std::to_string(i)] = fuse_op_repr; @@ -56,16 +55,13 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, pattern->NewNode(fuse_op_repr) ->assert_is_op(fused_multi_transformer_name); - // fused_multi_transformer output auto out_repr = PDNodeName(name_scope_, repr_, id_, "out_" + std::to_string(i)); node_reprs["out_" + std::to_string(i)] = out_repr; auto* out = pattern->NewNode(out_repr)->assert_is_op_output( fused_multi_transformer_name, "Out"); - // Links if (is_decoder) { - // shape and shape out auto shape_repr = PDNodeName(name_scope_, repr_, id_, "shape_" + std::to_string(i)); node_reprs["shape_" + std::to_string(i)] = shape_repr; @@ -79,7 +75,6 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, shape->LinksFrom({src_mask}).LinksTo({shape_out}); - // slice and slice out auto slice_repr = PDNodeName(name_scope_, repr_, id_, "slice_" + std::to_string(i)); node_reprs["slice_" + std::to_string(i)] = slice_repr; @@ -96,7 +91,6 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, fused_multi_transformer->LinksFrom({x0, src_mask, slice_out}) .LinksTo({out}); } else { - // catch_kv auto cache_kv_repr = PDNodeName(name_scope_, repr_, id_, "cache_kv_" + std::to_string(i)); node_reprs["cache_kv_" + std::to_string(i)] = cache_kv_repr; @@ -104,7 +98,6 @@ MultiTransformerLayerPattern::operator()(bool enable_int8, cache_kv->assert_is_op_input(fused_multi_transformer_name, "CacheKV"); cache_kv->AsInput(); - // fill constant op is only valid in encoder auto fill_const_op_repr = PDNodeName(name_scope_, repr_, id_, "fill_op_" + std::to_string(i)); node_reprs["fill_op_" + std::to_string(i)] = fill_const_op_repr; @@ -203,7 +196,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, std::vector fuse_op_output_var_name_maps; for (int i = 0; i < num_fuse_op; ++i) { - // fused_multi_transformer op PDNode* fuse_op_pdnode = multi_layer_pattern.PatternBase::pattern->RetrieveNode( node_reprs["fuse_op_" + std::to_string(i)]); @@ -213,7 +205,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, fuse_op_input_var_name_maps.emplace_back(fuse_op_node->Op()->Inputs()); fuse_op_output_var_name_maps.emplace_back(fuse_op_node->Op()->Outputs()); - // fused_multi_transformer output PDNode* out_pdnode = multi_layer_pattern.PatternBase::pattern->RetrieveNode( node_reprs["out_" + std::to_string(i)]); @@ -221,7 +212,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, // fill_const op use x0 as input if (!is_decoder && i != 0) { - // fill constant op PDNode* fill_op_pdnode = multi_layer_pattern.PatternBase::pattern->RetrieveNode( node_reprs["fill_op_" + std::to_string(i)]); @@ -303,7 +293,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, std::unordered_set marked_fuse_op_nodes( fuse_op_nodes.begin() + 1, fuse_op_nodes.end()); - // Delete shape/slice op in decoder subgraph if (is_decoder) { marked_fuse_op_nodes.insert(unused_nodes.begin(), unused_nodes.end()); }