Skip to content

Commit

Permalink
delete unnecessary notes
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardWooSJTU committed Nov 8, 2022
1 parent 084c630 commit f0fdfea
Showing 1 changed file with 1 addition and 12 deletions.
13 changes: 1 addition & 12 deletions paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.cc
Expand Up @@ -37,7 +37,7 @@ MultiTransformerLayerPattern::operator()(bool enable_int8,
bool is_decoder) {
std::string fused_multi_transformer_name =
enable_int8 ? "fused_multi_transformer_int8" : "fused_multi_transformer";
// This map is used to store node_reprs

std::unordered_map<std::string, std::string> node_reprs;

// x0 and src_mask is unqiue input of subgraph
Expand All @@ -48,24 +48,20 @@ MultiTransformerLayerPattern::operator()(bool enable_int8,
->AsInput();

for (int i = 0; i < num_fused_op; ++i) {
// fused_multi_transformer op
auto fuse_op_repr =
PDNodeName(name_scope_, repr_, id_, "fuse_op_" + std::to_string(i));
node_reprs["fuse_op_" + std::to_string(i)] = fuse_op_repr;
auto* fused_multi_transformer =
pattern->NewNode(fuse_op_repr)
->assert_is_op(fused_multi_transformer_name);

// fused_multi_transformer output
auto out_repr =
PDNodeName(name_scope_, repr_, id_, "out_" + std::to_string(i));
node_reprs["out_" + std::to_string(i)] = out_repr;
auto* out = pattern->NewNode(out_repr)->assert_is_op_output(
fused_multi_transformer_name, "Out");

// Links
if (is_decoder) {
// shape and shape out
auto shape_repr =
PDNodeName(name_scope_, repr_, id_, "shape_" + std::to_string(i));
node_reprs["shape_" + std::to_string(i)] = shape_repr;
Expand All @@ -79,7 +75,6 @@ MultiTransformerLayerPattern::operator()(bool enable_int8,

shape->LinksFrom({src_mask}).LinksTo({shape_out});

// slice and slice out
auto slice_repr =
PDNodeName(name_scope_, repr_, id_, "slice_" + std::to_string(i));
node_reprs["slice_" + std::to_string(i)] = slice_repr;
Expand All @@ -96,15 +91,13 @@ MultiTransformerLayerPattern::operator()(bool enable_int8,
fused_multi_transformer->LinksFrom({x0, src_mask, slice_out})
.LinksTo({out});
} else {
// catch_kv
auto cache_kv_repr =
PDNodeName(name_scope_, repr_, id_, "cache_kv_" + std::to_string(i));
node_reprs["cache_kv_" + std::to_string(i)] = cache_kv_repr;
auto* cache_kv = pattern->NewNode(cache_kv_repr);
cache_kv->assert_is_op_input(fused_multi_transformer_name, "CacheKV");
cache_kv->AsInput();

// fill constant op is only valid in encoder
auto fill_const_op_repr =
PDNodeName(name_scope_, repr_, id_, "fill_op_" + std::to_string(i));
node_reprs["fill_op_" + std::to_string(i)] = fill_const_op_repr;
Expand Down Expand Up @@ -203,7 +196,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
std::vector<VariableNameMap> fuse_op_output_var_name_maps;

for (int i = 0; i < num_fuse_op; ++i) {
// fused_multi_transformer op
PDNode* fuse_op_pdnode =
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
node_reprs["fuse_op_" + std::to_string(i)]);
Expand All @@ -213,15 +205,13 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
fuse_op_input_var_name_maps.emplace_back(fuse_op_node->Op()->Inputs());
fuse_op_output_var_name_maps.emplace_back(fuse_op_node->Op()->Outputs());

// fused_multi_transformer output
PDNode* out_pdnode =
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
node_reprs["out_" + std::to_string(i)]);
out_nodes.push_back(subgraph.at(out_pdnode));

// fill_const op use x0 as input
if (!is_decoder && i != 0) {
// fill constant op
PDNode* fill_op_pdnode =
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
node_reprs["fill_op_" + std::to_string(i)]);
Expand Down Expand Up @@ -303,7 +293,6 @@ int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
std::unordered_set<const Node*> marked_fuse_op_nodes(
fuse_op_nodes.begin() + 1, fuse_op_nodes.end());

// Delete shape/slice op in decoder subgraph
if (is_decoder) {
marked_fuse_op_nodes.insert(unused_nodes.begin(), unused_nodes.end());
}
Expand Down

0 comments on commit f0fdfea

Please sign in to comment.