diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 647d4b38c3ad2..99e136e8b6494 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -148,6 +148,7 @@ pass_library(delete_c_identity_op_pass inference) pass_library(preln_residual_bias_fuse_pass inference) pass_library(delete_fill_constant_op_pass inference) pass_library(constant_folding_pass inference) +pass_library(auto_mixed_precision_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base) diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc new file mode 100644 index 0000000000000..bc034301989b0 --- /dev/null +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -0,0 +1,745 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h" + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" + +namespace paddle { +namespace framework { +namespace ir { + +namespace { + +using VarType = AutoMixedPrecisionPass::VarType; + +bool PhiKernelSupportPrecision( + const std::string& op_type, + phi::Backend backend, + phi::DataType data_type, + phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { + const auto& kernels = phi::KernelFactory::Instance().kernels(); + if (kernels.count(op_type) == 0) { + return false; + } + phi::KernelKey kernel_key(backend, layout, data_type); + return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key); +} + +bool GpuKernelSupportPrecision( + const std::string& op_type, + phi::DataType precision, + phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { + auto phi_op_type = phi::TransToPhiKernelName(op_type); + bool support = PhiKernelSupportPrecision( + phi_op_type, phi::Backend::GPU, precision, layout); + support |= PhiKernelSupportPrecision( + phi_op_type, phi::Backend::GPUDNN, precision, layout); + + if (!support) { + const auto& all_kernels = framework::OperatorWithKernel::AllOpKernels(); + auto it = all_kernels.find(op_type); + if (it != all_kernels.end()) { + for (const auto& kern_pair : it->second) { + if (platform::is_gpu_place(kern_pair.first.place_) && + kern_pair.first.data_type_ == + framework::TransToProtoVarType(precision)) { + support = true; + break; + } + } + } + } + return support; +} + +inline bool VarNodeHasDtype(Node* var_node) { + auto type = var_node->Var()->GetType(); + return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) || + (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) || + (type == VarType::VOCAB); +} + +inline bool IsFloatType(VarType::Type type) { + return (type == VarType::FP64) || (type == VarType::FP32); +} + +inline bool IsHalfType(VarType::Type type) { + return (type == VarType::FP16) || (type == VarType::BF16); +} + +}; // namespace + +void DoInsertCastOp(Graph* graph, + Node* var_node, + Node* op_node, + VarType::Type from_type, + VarType::Type to_type, + framework::BlockDesc* block_desc, + int* suffix, + std::unordered_map* cache) { + if (from_type == to_type) return; + + auto update_cast_desc = [&](framework::OpDesc& desc, + const std::string& x_name, + const std::string& out_name, + const int in_dtype, + const int out_dtype) { + desc.SetType("cast"); + desc.SetInput("X", {x_name}); + desc.SetOutput("Out", {out_name}); + desc.SetAttr("in_dtype", in_dtype); + desc.SetAttr("out_dtype", out_dtype); + desc.SetAttr("use_mkldnn", false); + desc.SetAttr("with_quant_attr", false); + desc.Flush(); + }; + + if (cache->count(var_node) == 0) { + // insert cast op between var_node and op_node + std::string cast_input_name = var_node->Var()->Name(); + std::string cast_output_name = + var_node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++); + framework::OpDesc cast_op_desc(block_desc); + update_cast_desc(cast_op_desc, + cast_input_name, + cast_output_name, + static_cast(from_type), + static_cast(to_type)); + auto* cast_op_node = graph->CreateOpNode(&cast_op_desc); + auto* cast_output_vardesc = block_desc->Var(cast_output_name); + cast_output_vardesc->SetPersistable(false); + cast_output_vardesc->SetDataType(to_type); + cast_output_vardesc->SetShape(var_node->Var()->GetShape()); + auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc); + IR_NODE_LINK_TO(cast_op_node, cast_output_node); + (*cache)[var_node] = cast_output_node; + } + op_node->Op()->Rename(var_node->Name(), cache->at(var_node)->Name()); + IR_NODE_LINK_TO(var_node, cache->at(var_node)->inputs[0]); + IR_NODE_LINK_TO(cache->at(var_node), op_node); + + IR_NODE_UNLINK(var_node, op_node); +} + +bool OpSupportPrecision(const std::string& op_type, + phi::Backend backend, + phi::DataType precision, + const std::unordered_set& black_list) { + bool support = false; + if (black_list.count(op_type) == 0) { + if (backend == phi::Backend::GPU) { + support = GpuKernelSupportPrecision(op_type, precision); + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Now, only support backend of GPU.")); + } + } + return support; +} + +// The set of ops that support fp16 calculation and are considered +// numerically-dangerous, slower and whose effects may also be observed in +// downstream ops. +void AutoMixedPrecisionPass::SetDefaultBlacklist() const { + black_list_.insert({ + // numerically-dangerous + "acos", + "asin", + "cosh", + "tan", + "exp", + "expm1", + "square", + "log", + "log2", + "log10", + "log1p", + "logsumexp", + "mean", + "rsqrt", + "sum", + "cos_sim", + "softmax", + "softmax_with_cross_entropy", + "sigmoid_cross_entropy_with_logits", + "c_softmax_with_cross_entropy", + "cross_entropy", + "cross_entropy2", + // slower than fp32 + "conv2d_transpose", + // default fp32 can avoid return inf when the sum value large than 65504 + "reduce_sum", + }); +} + +void AutoMixedPrecisionPass::Init(Graph* graph) const { + bool enable_gpu_mixed = Get("enable_gpu_mixed"); + if (enable_gpu_mixed) { + backend_ = phi::Backend::GPU; + } + + skip_pass_ = !enable_gpu_mixed; + + low_precision_ = static_cast(Get("mixed_precision_mode")); + + black_list_ = Get>("mixed_black_list"); + SetDefaultBlacklist(); + VLOG(4) << "black_list has "; + for (const auto& name : black_list_) { + VLOG(4) << " - " << name; + } + + keep_io_types_ = true; + if (Has("keep_io_types")) { + keep_io_types_ = Get("keep_io_types"); + } + + auto graph_size = graph->SubGraphsSize(); + VLOG(4) << "graph size: " << graph_size; + subgraphes_.resize(graph_size); + all_op_nodes_.resize(graph_size); + + for (size_t i = 0; i < graph_size; i++) { + subgraphes_[i] = graph->GetSubGraph(i); + all_op_nodes_[i] = TopologySortOperations(*subgraphes_[i]); + VLOG(4) << "subgraph " << i << " has " << all_op_nodes_[i].size() + << "op nodes"; + for (auto* var_node : subgraphes_[i]->Nodes()) { + if (!var_node->IsVar()) continue; + + auto var_name = var_node->Var()->Name(); + if (real_vars_.count(var_name) == 0) { + real_vars_[var_name] = var_node; + VLOG(4) << var_name << " is in graph " << i; + } + } + } +} + +void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::PreconditionNotMet( + "During the auto_mixed_precision_pass, the graph " + "should not be nullptr.")); + PADDLE_ENFORCE_EQ(graph->IsMainGraph(), + true, + platform::errors::PreconditionNotMet( + "During the auto_mixed_precision_pass, the graph " + "should be main graph.")); + + FusePassBase::Init("auto_mixed_precision", graph); + + Init(graph); + VLOG(4) << "Init done"; + + if (skip_pass_) { + VLOG(3) << "Skip auto_mixed_precision_pass."; + return; + } + + SetOpUniqueType(); + VLOG(4) << "SetOpUniqueType done"; + GetOpPrecision(); + VLOG(4) << "GetOpPrecision done"; + UpdateOpPrecision(); + VLOG(4) << "UpdateOpPrecision done"; + SetVarPrecision(); + VLOG(4) << "SetVarPrecision done"; + ConvertWeightsData(); + VLOG(4) << "ConvertWeightsData done"; + ProcessOpWithDtypeAttr(); + VLOG(4) << "ProcessOpWithDtypeAttr done"; + InsertCastOp(); + VLOG(4) << "InsertCastOp done"; + RestoreOpOriginType(); + VLOG(4) << "RestoreOpOriginType done"; +} + +void AutoMixedPrecisionPass::SetOpUniqueType() const { + int suffix = 0; + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + auto op_type = op_node->Op()->Type(); + + if (op_type == "feed" || op_type == "fetch") continue; + + std::string unique_type = op_type + "_" + std::to_string(suffix++); + op_original_type_[unique_type] = op_type; + op_node->Op()->SetType(unique_type); + op_node->Op()->Flush(); + VLOG(4) << "change op type: " << op_type << " ---> " << unique_type; + } + } +} + +void AutoMixedPrecisionPass::RestoreOpOriginType() const { + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + auto op_type = op_node->Op()->Type(); + op_node->Op()->SetType(GetOpOriginalType(op_type)); + op_node->Op()->Flush(); + VLOG(4) << "restore op type: " << op_type << " ---> " + << op_node->Op()->Type(); + } + } +} + +inline std::string AutoMixedPrecisionPass::GetOpOriginalType( + const std::string& op_type) const { + if (op_original_type_.count(op_type)) { + return op_original_type_.at(op_type); + } + return op_type; +} + +void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const { + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + auto op_type = op_node->Op()->Type(); + if (op_run_low_precision_.count(op_type) == 0) continue; + + if (op_node->Op()->HasAttr("dtype")) { + auto dtype = op_node->Op()->GetAttrIfExists("dtype"); + if (IsFloatType(static_cast(dtype))) { + op_node->Op()->SetAttr( + "dtype", + static_cast(framework::TransToProtoVarType(low_precision_))); + op_node->Op()->Flush(); + VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype + << " --->" << static_cast(low_precision_) << " )"; + } + } + if (op_node->Op()->HasAttr("out_dtype")) { + auto out_dtype = op_node->Op()->GetAttrIfExists("out_dtype"); + if (IsFloatType(static_cast(out_dtype))) { + op_node->Op()->SetAttr( + "out_dtype", + static_cast(framework::TransToProtoVarType(low_precision_))); + op_node->Op()->Flush(); + VLOG(4) << "process op with out_dtype attr: " << op_type << " ( " + << out_dtype << " --->" << static_cast(low_precision_) + << " )"; + } + } + } + } +} + +void AutoMixedPrecisionPass::GetOpPrecision() const { + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + auto op_type = op_node->Op()->Type(); + bool support_low_precision = true; + if (GetOpOriginalType(op_type) == "feed" || + GetOpOriginalType(op_type) == "fetch") { + support_low_precision = !keep_io_types_; + } else { + support_low_precision = OpSupportPrecision( + GetOpOriginalType(op_type), backend_, low_precision_, black_list_); + } + + if (op_node->Op()->HasAttr("dtype")) { + auto dtype = op_node->Op()->GetAttrIfExists("dtype"); + support_low_precision = support_low_precision && + IsFloatType(static_cast(dtype)); + } else if (op_node->Op()->HasAttr("out_dtype")) { + auto out_dtype = op_node->Op()->GetAttrIfExists("out_dtype"); + support_low_precision = + support_low_precision && + IsFloatType(static_cast(out_dtype)); + } else { + // if op's input var and output var is not dense tensor, the op should + // not run at low precision. + for (auto* in_var_node : op_node->inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; + if (real_in_var_node->Var()->Persistable()) continue; + + support_low_precision = + support_low_precision && + (real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR); + } + + for (auto* out_var_node : op_node->outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; + if (real_out_var_node->Var()->Persistable()) continue; + + support_low_precision = + support_low_precision && + (real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR); + } + } + + if (support_low_precision) { + op_run_low_precision_.insert(op_type); + VLOG(4) << "support precision: " << op_type << " run at low precision"; + } else { + VLOG(4) << "support precision: " << op_type + << " not run at low precision"; + } + } + } +} + +void AutoMixedPrecisionPass::UpdateOpPrecision() const { + std::unordered_set vars_should_not_low_precision; + + // var -> the var's all input op + std::unordered_map> var_input_ops; + + auto GetVarInputOps = [&] { + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + auto op_type = op_node->Op()->Type(); + + if (GetOpOriginalType(op_type) == "fetch") continue; + if (op_node->Op()->HasAttr("sub_block")) continue; + + for (auto* var_node : op_node->outputs) { + CHECK_EQ(var_node->IsVar(), true); + if (var_node->Var()->Persistable()) continue; + if (!VarNodeHasDtype(var_node)) continue; + + var_input_ops[var_node->Var()->Name()].push_back(op_node); + VLOG(4) << "var input ops: " << var_node->Var()->Name() + << " is output of " << op_type; + } + + // the select_input op's input var should not convert to low precision. + // when op's output var is select_input op's input var, the op should + // not run at low precision. + if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") { + for (auto* in_var_node : op_node->inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + if (in_var_node->Var()->Persistable()) continue; + if (!VarNodeHasDtype(in_var_node)) continue; + + vars_should_not_low_precision.insert(in_var_node->Var()->Name()); + } + } + } + } + }; + GetVarInputOps(); + + bool precision_updated = false; + do { + precision_updated = false; + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue; + + for (auto* out_var_node : op_node->outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + if (!VarNodeHasDtype(out_var_node)) continue; + + auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; + if (real_out_var_node->Var()->Persistable()) continue; + + bool not_run_low_precision = false; + const auto& input_op_nodes = + var_input_ops[real_out_var_node->Var()->Name()]; + if (vars_should_not_low_precision.count( + real_out_var_node->Var()->Name())) { + not_run_low_precision = true; + } else { + for (auto* node : input_op_nodes) { + if (op_run_low_precision_.count(node->Op()->Type()) == 0) { + not_run_low_precision = true; + break; + } + } + } + if (not_run_low_precision) { + op_run_low_precision_.erase(op_node->Op()->Type()); + precision_updated = true; + VLOG(4) << op_node->Op()->Type() + << " should not run at low precision."; + break; + } + } + } + } + } while (precision_updated); +} + +// special ops, its weights should not be low precision. +bool AutoMixedPrecisionPass::InputVarsNotConvert( + Node* op_node, const std::string& var_name) const { + auto* op_desc = op_node->Op(); + if (GetOpOriginalType(op_desc->Type()) == "batch_norm") { + auto vecs = op_desc->Input("Bias"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("Mean"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("Scale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("Variance"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + } else if (GetOpOriginalType(op_desc->Type()) == "fused_multi_transformer") { + auto vecs = op_desc->Input("LnScale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("LnBias"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("FFNLnScale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("FFNLnBias"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + } + return false; +} + +bool AutoMixedPrecisionPass::OutputVarsNotConvert( + Node* op_node, const std::string& var_name) const { + auto* op_desc = op_node->Op(); + // batch_norm's input and output (variance and mean) are the same. + if (GetOpOriginalType(op_desc->Type()) == "batch_norm") { + auto vecs = op_desc->Output("MeanOut"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("VarianceOut"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("SavedMean"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("SavedVariance"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + } + return false; +} + +void AutoMixedPrecisionPass::SetVarPrecision() const { + for (const auto& nodes : all_op_nodes_) { + for (auto* op_node : nodes) { + if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) { + continue; + } + + if (GetOpOriginalType(op_node->Op()->Type()) != "feed") { + for (auto* in_var_node : op_node->inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + + auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; + auto in_var_name = real_in_var_node->Var()->Name(); + + if (!IsFloatType(real_in_var_node->Var()->GetDataType())) continue; + if (!VarNodeHasDtype(real_in_var_node)) continue; + if (InputVarsNotConvert(op_node, in_var_name)) continue; + + if (real_in_var_node->Var()->Persistable()) { + real_in_var_node->Var()->SetDataType( + framework::TransToProtoVarType(low_precision_)); + vars_convert_to_low_precision_.insert(in_var_name); + } + } + } + + if (GetOpOriginalType(op_node->Op()->Type()) != "fetch") { + for (auto* out_var_node : op_node->outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + + auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()]; + auto out_var_name = real_out_var_node->Var()->Name(); + + if (!IsFloatType(real_out_var_node->Var()->GetDataType())) continue; + if (!VarNodeHasDtype(real_out_var_node)) continue; + if (OutputVarsNotConvert(op_node, out_var_name)) continue; + + real_out_var_node->Var()->SetDataType( + framework::TransToProtoVarType(low_precision_)); + if (real_out_var_node->Var()->Persistable()) { + vars_convert_to_low_precision_.insert(out_var_name); + } + } + } + } + } + + // This code used to precess vars with the same name. Vars with the same + // name should have the same data type. + for (auto* subgraph : subgraphes_) { + for (auto* var_node : subgraph->Nodes()) { + if (!var_node->IsVar() || !var_node->Var()->Persistable()) continue; + if (!VarNodeHasDtype(var_node)) continue; + + auto var_name = var_node->Var()->Name(); + if (vars_convert_to_low_precision_.count(var_name)) { + var_node->Var()->SetDataType( + framework::TransToProtoVarType(low_precision_)); + } + } + } +} + +void AutoMixedPrecisionPass::ConvertWeightsData() const { + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL(scope, + platform::errors::PreconditionNotMet( + "During the auto_mixed_precision_pass, the scope " + "should not be null.")); + + auto var_names = scope->LocalVarNames(); + for (const auto& var_name : var_names) { + if (vars_convert_to_low_precision_.count(var_name)) { + VLOG(4) << var_name << "'s data type was convert to half"; + + auto* var = scope->FindLocalVar(var_name); + CHECK_EQ(var->IsType(), true); + + auto* origin_tensor = var->GetMutable(); + + phi::DenseTensor low_precision_tensor; + low_precision_tensor.Resize(origin_tensor->dims()); + low_precision_tensor.set_type(low_precision_); + + if (low_precision_ == phi::DataType::FLOAT16) { + auto* low_precision_data = + low_precision_tensor.mutable_data( + phi::CPUPlace{}); + for (int64_t i = 0; i < origin_tensor->numel(); i++) { + if (origin_tensor->dtype() == phi::DataType::FLOAT64) { + auto* origin_data = origin_tensor->data(); + low_precision_data[i] = + static_cast(origin_data[i]); + } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) { + auto* origin_data = origin_tensor->data(); + low_precision_data[i] = + static_cast(origin_data[i]); + } + } + } else if (low_precision_ == phi::DataType::BFLOAT16) { + auto* half_data = + low_precision_tensor.mutable_data( + phi::CPUPlace{}); + for (int64_t i = 0; i < origin_tensor->numel(); i++) { + if (origin_tensor->dtype() == phi::DataType::FLOAT64) { + auto* origin_data = origin_tensor->data(); + half_data[i] = static_cast(origin_data[i]); + } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) { + auto* origin_data = origin_tensor->data(); + half_data[i] = static_cast(origin_data[i]); + } + } + } + origin_tensor->clear(); + paddle::framework::TensorCopySync( + low_precision_tensor, phi::CPUPlace{}, origin_tensor); + } + } +} + +void AutoMixedPrecisionPass::InsertCastOp() const { + int suffix = 0; + std::unordered_map cache; + + for (size_t i = 0; i < all_op_nodes_.size(); i++) { + auto* block_desc = all_op_nodes_[i][0]->Op()->Block(); + CHECK_NOTNULL(block_desc); + for (auto* op_node : all_op_nodes_[i]) { + auto op_type = op_node->Op()->Type(); + + if (GetOpOriginalType(op_type) == "feed") continue; + if (op_node->Op()->HasAttr("sub_block")) continue; + + VLOG(4) << "process op: " << op_type + << " run low precision: " << op_run_low_precision_.count(op_type); + + auto inputs = op_node->inputs; + for (auto* in_var_node : inputs) { + if (!in_var_node->IsVar()) continue; + if (!VarNodeHasDtype(in_var_node)) continue; + if (in_var_node->Var()->Persistable()) continue; + + auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; + + auto in_var_type = real_in_var_node->Var()->GetDataType(); + + VLOG(4) << "process var: " << real_in_var_node->Var()->Name() + << " with type " << in_var_type; + + if (IsFloatType(in_var_type) && op_run_low_precision_.count(op_type)) { + DoInsertCastOp(subgraphes_[i], + in_var_node, + op_node, + in_var_type, + framework::TransToProtoVarType(low_precision_), + block_desc, + &suffix, + &cache); + } else if (IsHalfType(in_var_type) && + op_run_low_precision_.count(op_type) == 0) { + DoInsertCastOp(subgraphes_[i], + in_var_node, + op_node, + in_var_type, + VarType::FP32, + block_desc, + &suffix, + &cache); + } + } + + // Special op. + // fused_multi_transformer's input(CacheKV) and output(CacheKVOut) vars + // have same name. + if (GetOpOriginalType(op_type) == "fused_multi_transformer") { + auto cache_kv_inputs = op_node->Op()->Input("CacheKV"); + auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut"); + CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size()); + for (size_t i = 0; i < cache_kv_inputs.size(); ++i) { + op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]); + } + } + } + } + VLOG(4) << "insert number of cast op: " << cache.size(); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(auto_mixed_precision_pass, + paddle::framework::ir::AutoMixedPrecisionPass); diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.h b/paddle/fluid/framework/ir/auto_mixed_precision_pass.h new file mode 100644 index 0000000000000..578d47282b76d --- /dev/null +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.h @@ -0,0 +1,109 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/phi/common/backend.h" +#include "paddle/phi/common/data_type.h" + +namespace paddle { +namespace framework { +namespace ir { + +class AutoMixedPrecisionPass : public FusePassBase { + public: + using VarType = framework::proto::VarType; + + public: + AutoMixedPrecisionPass() = default; + ~AutoMixedPrecisionPass() = default; + + protected: + void ApplyImpl(Graph* graph) const override; + + private: + void Init(Graph* graph) const; + + void SetDefaultBlacklist() const; + + void SetOpUniqueType() const; + + void RestoreOpOriginType() const; + + inline std::string GetOpOriginalType(const std::string& op_type) const; + + void GetOpPrecision() const; + + void UpdateOpPrecision() const; + + void InsertCastOp() const; + + void ProcessOpWithDtypeAttr() const; + + bool InputVarsNotConvert(Node* op_node, const std::string& var_name) const; + + bool OutputVarsNotConvert(Node* op_node, const std::string& var_name) const; + + void SetVarPrecision() const; + + void ConvertWeightsData() const; + + private: + mutable bool skip_pass_{false}; + + mutable bool keep_io_types_{false}; + // float16 or bfloat16 now + mutable phi::DataType low_precision_{phi::DataType::FLOAT16}; + + mutable phi::Backend backend_{phi::Backend::GPU}; + + mutable std::unordered_set black_list_; + + // subgraph id -> pointer to subgraph + mutable std::vector subgraphes_; + // var name -> real var node + mutable std::unordered_map real_vars_; + // subgraph id -> all op nodes in subgraph + mutable std::vector> all_op_nodes_; + // op's unique type -> the op's origin type + mutable std::unordered_map op_original_type_; + // op's unique type -> whether the op run at low precision + mutable std::unordered_set op_run_low_precision_; + + mutable std::unordered_set vars_convert_to_low_precision_; +}; + +bool OpSupportPrecision(const std::string& op_type, + phi::Backend backend, + phi::DataType precision, + const std::unordered_set& black_list); + +void DoInsertCastOp(Graph* graph, + Node* var_node, + Node* op_node, + proto::VarType::Type from_type, + proto::VarType::Type to_type, + framework::BlockDesc* block_desc, + int* suffix, + std::unordered_map* cache); + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc index cd5cbf150b3a3..582b9389e0ffc 100644 --- a/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_fill_constant_op_pass.cc @@ -29,6 +29,11 @@ void FillConstData(LoDTensor* out_t, T value) { } void DeleteFillConstantOpPass::ApplyImpl(ir::Graph* graph) const { + bool with_dynamic_shape = Get("with_dynamic_shape"); + // Not support + if (with_dynamic_shape) { + return; + } FusePassBase::Init("delete_fill_constant_op_pass", graph); GraphPatternDetector detector; auto fill_constant_op = diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 6946fb6d7d9ee..5143ccfe4531c 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -75,7 +75,6 @@ Graph::Graph(const ProgramDesc &program, } } else { auto var_nodes = InitFromProgram(program_, start_op_index, end_op_index); - ResolveHazard(var_nodes); } } @@ -88,7 +87,6 @@ Graph::Graph(const BlockDesc &block, const int64_t end_op_index) : main_graph_(main_graph) { auto var_nodes = InitFromBlock(block, start_op_index, end_op_index); - ResolveHazard(var_nodes); } // TODO(levi): delete this interface after when we can convert all diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 84a9a3b74c0a2..b1d550c54b4e0 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -130,86 +130,6 @@ TEST(GraphTest, Basic) { ASSERT_EQ(nodes.size(), 5UL); } -TEST(GraphTest, WriteAfterRead) { - // void Test() { - ProgramDesc prog; - auto *op = prog.MutableBlock(0)->AppendOp(); - op->SetType("sum"); - op->SetInput("X", {"a"}); - op->SetOutput("Out", {"b"}); - op->SetAttr("op_role", 1); - - op = prog.MutableBlock(0)->AppendOp(); - op->SetType("dummy"); - op->SetInput("X", {"c"}); - op->SetOutput("Out", {"a"}); - op->SetAttr("op_role", 1); - - prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR); - prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR); - prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR); - - std::unique_ptr g(new ir::Graph(prog)); - ir::Node *control_dep1 = nullptr; - ir::Node *control_dep2 = nullptr; - for (ir::Node *n : g->Nodes()) { - if (n->Name() == "sum") { - ASSERT_EQ(n->outputs[0]->Name(), "b"); - ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); - control_dep1 = n->outputs[1]; - ASSERT_EQ(n->outputs.size(), 2UL); - } - if (n->Name() == "dummy") { - ASSERT_EQ(n->inputs[0]->Name(), "c"); - ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); - control_dep2 = n->inputs[1]; - ASSERT_EQ(n->inputs.size(), 2UL); - } - } - ASSERT_EQ(control_dep1, control_dep2); -} - -TEST(GraphTest, WriteAfterWrite) { - // void Test() { - ProgramDesc prog; - auto *op = prog.MutableBlock(0)->AppendOp(); - op->SetType("sum"); - op->SetInput("X", {"a"}); - op->SetOutput("Out", {"b"}); - op->SetAttr("op_role", 1); - - op = prog.MutableBlock(0)->AppendOp(); - op->SetType("dummy"); - op->SetInput("X", {"c"}); - op->SetOutput("Out", {"b"}); - op->SetAttr("op_role", 1); - - prog.MutableBlock(0)->Var("a")->SetType(proto::VarType::LOD_TENSOR); - prog.MutableBlock(0)->Var("b")->SetType(proto::VarType::LOD_TENSOR); - prog.MutableBlock(0)->Var("c")->SetType(proto::VarType::LOD_TENSOR); - - std::unique_ptr g(new ir::Graph(prog)); - ir::Node *control_dep1 = nullptr; - ir::Node *control_dep2 = nullptr; - for (ir::Node *n : g->Nodes()) { - if (n->Name() == "sum") { - ASSERT_EQ(n->outputs[0]->Name(), "b"); - ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); - ASSERT_EQ(n->outputs.size(), 2UL); - control_dep1 = n->outputs[1]; - } - if (n->Name() == "dummy") { - ASSERT_EQ(n->inputs[0]->Name(), "c"); - ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); - control_dep2 = n->inputs[1]; - ASSERT_EQ(n->inputs.size(), 2UL); - } - } - ASSERT_NE(control_dep1, nullptr); - ASSERT_NE(control_dep2, nullptr); - ASSERT_EQ(control_dep1, control_dep2); -} - TEST(GraphTest, TestException) { ProgramDesc prog; std::unique_ptr g(new ir::Graph(prog)); @@ -350,12 +270,13 @@ TEST(GraphTest, TestMultiBlock) { op = prog.MutableBlock(1)->AppendOp(); op->SetType("dummy"); op->SetInput("X", {"c"}); - op->SetOutput("Out", {"a"}); + op->SetOutput("Out", {"d"}); op->SetAttr("op_role", 1); prog.MutableBlock(1)->Var("a")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(1)->Var("b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(1)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR); // Set contents in block_2. op = prog.MutableBlock(2)->AppendOp(); @@ -367,12 +288,13 @@ TEST(GraphTest, TestMultiBlock) { op = prog.MutableBlock(2)->AppendOp(); op->SetType("dummy"); op->SetInput("X", {"c"}); - op->SetOutput("Out", {"b"}); + op->SetOutput("Out", {"d"}); op->SetAttr("op_role", 1); prog.MutableBlock(2)->Var("a")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(2)->Var("b")->SetType(proto::VarType::LOD_TENSOR); prog.MutableBlock(2)->Var("c")->SetType(proto::VarType::LOD_TENSOR); + prog.MutableBlock(1)->Var("d")->SetType(proto::VarType::LOD_TENSOR); // Step2: Convert program into graph, 3 blocks corresponding 3 sub_graphs. std::unique_ptr g(new ir::Graph(prog)); @@ -399,45 +321,29 @@ TEST(GraphTest, TestMultiBlock) { // Check contents in sub_graph_1. const ir::Graph *g1 = g->GetSubGraph(1); - ir::Node *control_dep1 = nullptr; - ir::Node *control_dep2 = nullptr; for (ir::Node *n : g1->Nodes()) { if (n->Name() == "sum") { ASSERT_EQ(n->outputs[0]->Name(), "b"); - ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); - control_dep1 = n->outputs[1]; - ASSERT_EQ(n->outputs.size(), 2UL); + ASSERT_EQ(n->outputs.size(), 1UL); } if (n->Name() == "dummy") { ASSERT_EQ(n->inputs[0]->Name(), "c"); - ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); - control_dep2 = n->inputs[1]; - ASSERT_EQ(n->inputs.size(), 2UL); + ASSERT_EQ(n->inputs.size(), 1UL); } } - ASSERT_EQ(control_dep1, control_dep2); // Check contents in sub_graph_2. const ir::Graph *g2 = g->GetSubGraph(2); - control_dep1 = nullptr; - control_dep2 = nullptr; for (ir::Node *n : g2->Nodes()) { if (n->Name() == "sum") { ASSERT_EQ(n->outputs[0]->Name(), "b"); - ASSERT_TRUE(ir::IsControlDepVar(*n->outputs[1])); - ASSERT_EQ(n->outputs.size(), 2UL); - control_dep1 = n->outputs[1]; + ASSERT_EQ(n->outputs.size(), 1UL); } if (n->Name() == "dummy") { ASSERT_EQ(n->inputs[0]->Name(), "c"); - ASSERT_TRUE(ir::IsControlDepVar(*n->inputs[1])); - control_dep2 = n->inputs[1]; - ASSERT_EQ(n->inputs.size(), 2UL); + ASSERT_EQ(n->inputs.size(), 1UL); } } - ASSERT_NE(control_dep1, nullptr); - ASSERT_NE(control_dep2, nullptr); - ASSERT_EQ(control_dep1, control_dep2); // Step3: Clone graph. std::shared_ptr clone_g = g->Clone(); diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc index 7b203125681c5..13465610e47fd 100644 --- a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc @@ -331,8 +331,6 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { copy_node(node); } } - - result.ResolveHazard(created); } } // namespace ir diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 651ac23e52fe1..3a85d30386cd5 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -183,5 +183,6 @@ void NaiveExecutor::ResetTrtOps(int num) { } #endif } + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 9615100f32ad3..4aadb34d7b354 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -38,8 +38,7 @@ void Analyzer::RunAnalysis(Argument *argument) { if (!disable_logs) { string::PrettyLogH1("--- Running analysis [%s]", pass); } - if (!argument->enable_analysis_optim() && pass == "ir_analysis_pass") - continue; + if (!argument->enable_ir_optim() && pass == "ir_analysis_pass") continue; auto *ptr = PassRegistry::Global().Retreive(pass); PADDLE_ENFORCE_NOT_NULL(ptr, diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index 1df8d06dd89ca..3f5be92f5a3e6 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -31,7 +31,7 @@ TEST(Analyzer, analysis_without_tensorrt) { Argument argument; argument.SetDisableLogs(false); argument.SetModelDir(FLAGS_inference_model_dir); - argument.SetEnableAnalysisOptim(false); + argument.SetEnableIrOptim(false); argument.SetUseGPU(false); argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass", @@ -44,7 +44,7 @@ TEST(Analyzer, analysis_without_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) { Argument argument; argument.SetDisableLogs(false); - argument.SetEnableAnalysisOptim(false); + argument.SetEnableIrOptim(false); argument.SetTensorRtMaxBatchSize(3); argument.SetTensorRtWorkspaceSize(1 << 20); argument.SetModelDir(FLAGS_inference_model_dir); diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 2f077b8631a23..2a4ce0d6492b0 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -42,8 +42,6 @@ namespace paddle { namespace inference { namespace analysis { -using framework::ir::Graph; - #ifdef PADDLE_WITH_MKLDNN using VarQuantScale = std::unordered_map>; @@ -148,7 +146,7 @@ struct Argument { DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool); DECL_ARGUMENT_FIELD(optim_cache_dir, OptimCacheDir, std::string); - DECL_ARGUMENT_FIELD(enable_analysis_optim, EnableAnalysisOptim, bool); + DECL_ARGUMENT_FIELD(enable_ir_optim, EnableIrOptim, bool); // For JITLayer DECL_ARGUMENT_FIELD(skip_load_params, SkipLoadParams, bool); @@ -362,6 +360,8 @@ struct Argument { DECL_ARGUMENT_FIELD(mixed_black_list, MixedBlackList, std::unordered_set); + DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool); + DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int); private: std::unordered_set valid_fields_; diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index e8d719ddb659d..e891da8e6d19f 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -153,25 +153,6 @@ T &GetFromScope(const framework::Scope &scope, const std::string &name) { return *var->GetMutable(); } -static framework::proto::ProgramDesc LoadProgramDesc( - const std::string &model_path) { - std::ifstream fin(model_path, std::ios::in | std::ios::binary); - PADDLE_ENFORCE_EQ( - fin.is_open(), - true, - platform::errors::NotFound( - "Cannot open file %s, please confirm whether the file exists", - model_path)); - fin.seekg(0, std::ios::end); - std::string buffer(fin.tellg(), ' '); - fin.seekg(0, std::ios::beg); - fin.read(&buffer[0], buffer.size()); - fin.close(); - framework::proto::ProgramDesc program_desc; - program_desc.ParseFromString(buffer); - return program_desc; -} - static bool FileExists(const std::string &filepath) { std::ifstream file(filepath); bool exists = file.is_open(); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index b275302741147..f994667df80bb 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -27,6 +27,7 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/analysis/argument.h" #include "paddle/fluid/string/pretty_log.h" +#include "paddle/phi/core/errors.h" namespace paddle { namespace inference { @@ -36,15 +37,6 @@ using string::PrettyLogEndl; using string::Style; IRPassManager::IRPassManager(Argument *argument) { - ARGUMENT_CHECK_FIELD(argument, main_program); - graph_ = std::unique_ptr(new Graph(argument->main_program())); - if (argument->Has("scope")) { - auto *scope_ptr = argument->scope_ptr(); - PADDLE_ENFORCE_NOT_NULL(scope_ptr, - platform::errors::PreconditionNotMet( - "The scope ptr should not be nullptr.")); - graph_->SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr); - } disable_logs_ = argument->disable_logs(); ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes); @@ -95,10 +87,14 @@ void IRPassManager::CreatePasses(Argument *argument, argument->tensorrt_tuned_dynamic_shape(); pass->Set("with_dynamic_shape", new bool(with_dynamic_shape)); + // mixed precision related pass->Set("model_precision", new int(argument->model_precision())); pass->Set( "mixed_black_list", new std::unordered_set(argument->mixed_black_list())); + pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed())); + pass->Set("mixed_precision_mode", + new int(argument->mixed_precision_mode())); if (pass_name == "graph_viz_pass") { std::string optim_cache_dir = argument->optim_cache_dir(); @@ -302,42 +298,18 @@ void IRPassManager::CreatePasses(Argument *argument, } std::unique_ptr IRPassManager::Apply(std::unique_ptr graph) { - if (passes_.empty()) { - return graph; - } PADDLE_ENFORCE_NOT_NULL( - graph.get(), - platform::errors::PreconditionNotMet("Graph cannot be NULL.")); + graph.get(), platform::errors::InvalidArgument("Graph cannot be null.")); // Apply all the passes for (const auto &pass : passes_) { if (pass->Type() != "graph_viz_pass" && !disable_logs_) { PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type()); } - // delete_fill_constant_op_pass is not apply under trt dynamic shape - if (pass->Type() == "delete_fill_constant_op_pass") { - bool use_dynamic = pass->Get("with_dynamic_shape"); - if (use_dynamic) continue; - } graph.reset(pass->Apply(graph.release())); } return graph; } -framework::proto::ProgramDesc IRPassManager::AcquireProgram( - std::unique_ptr *graph, ProgramDesc *program) const { - auto pass = - framework::ir::PassRegistry::Instance().Get("graph_to_program_pass"); - - // Direct using ProgramDesc desc(argument->main_program()) may cause - // incomplete copies of information. - ProgramDesc desc; - desc.CopyFrom(*program->Proto()); - pass->SetNotOwned("program", &desc); - auto *the_graph = graph->release(); - graph->reset(pass->Apply(the_graph)); - return *desc.Proto(); -} - } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.h b/paddle/fluid/inference/analysis/ir_pass_manager.h index 9f9a5fc347123..c56d3d40f54de 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.h +++ b/paddle/fluid/inference/analysis/ir_pass_manager.h @@ -48,15 +48,9 @@ class IRPassManager final { std::unique_ptr Apply(std::unique_ptr graph); - framework::proto::ProgramDesc AcquireProgram(std::unique_ptr *graph, - ProgramDesc *program) const; - - framework::ir::Graph &graph() const { return *graph_; } - private: void CreatePasses(Argument *argument, const std::vector &passes); - std::unique_ptr graph_; std::vector> passes_; bool disable_logs_{false}; }; diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 559d869c758aa..5021336df490d 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -94,14 +94,14 @@ void OutputProcess(framework::ir::Graph *graph, backend, precision, blacklist)) { - AddCastOp(graph, - var_node, - next_op, - framework::proto::VarType::FP32, - to_type, - &suffix, - block_desc, - &var_to_cast_op_map); + InsertCastOp(graph, + var_node, + next_op, + framework::proto::VarType::FP32, + to_type, + block_desc, + &suffix, + &var_to_cast_op_map); var_node->Var()->SetDataType(framework::proto::VarType::FP32); } } diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt index 126e2500c4890..96121601cb6fd 100644 --- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt @@ -13,7 +13,7 @@ cc_library( cc_library( convert_to_mixed_precision SRCS convert_to_mixed_precision.cc - DEPS analysis_pass ir_graph_build_pass) + DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass) cc_library( ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc @@ -30,17 +30,6 @@ cc_library( inference_op_replace_pass SRCS inference_op_replace_pass.cc DEPS analysis_pass graph_to_program_pass) -if(WITH_TESTING) - cc_library( - ir_graph_clean_pass - SRCS ir_graph_clean_pass.cc - DEPS analysis_pass gtest) -else() - cc_library( - ir_graph_clean_pass - SRCS ir_graph_clean_pass.cc - DEPS analysis_pass) -endif() cc_library( analysis_passes @@ -52,8 +41,7 @@ cc_library( memory_optim_pass convert_to_mixed_precision inference_op_replace_pass - ir_graph_to_program_pass - ir_graph_clean_pass) + ir_graph_to_program_pass) set(analysis_deps ${analysis_deps} analysis_passes subgraph_detector diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index 789865a52882f..f1939fc8b328b 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -14,807 +14,88 @@ #include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" -#include -#include -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/framework/framework.pb.h" -#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h" #include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/node.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/var_desc.h" -#include "paddle/fluid/inference/analysis/argument.h" -#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h" #include "paddle/fluid/inference/io.h" -#include "paddle/phi/common/bfloat16.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/float16.h" -#include "paddle/phi/common/layout.h" -#include "paddle/phi/common/place.h" -#include "paddle/phi/core/tensor_meta.h" - -using namespace paddle::framework; // NOLINT +#include "paddle/phi/common/backend.h" namespace paddle { namespace inference { namespace analysis { -namespace { -bool PhiKernelSupportPrecision( - const std::string& op_type, +ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass( + const std::string& model_file, + const std::string& params_file, + const std::string& mixed_model_file, + const std::string& mixed_params_file, + phi::DataType mixed_precision, phi::Backend backend, - phi::DataType data_type, - phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { - auto kernels = phi::KernelFactory::Instance().kernels(); - if (kernels.find(op_type) == kernels.end()) { - return false; - } - phi::KernelKey kernel_key(backend, layout, data_type); - return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key); -} - -bool GpuKernelSupportPrecision( - const std::string& op_type, - phi::DataType data_type, - phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { - auto phi_op_type = phi::TransToPhiKernelName(op_type); - bool res = PhiKernelSupportPrecision( - phi_op_type, phi::Backend::GPU, data_type, layout); - res |= PhiKernelSupportPrecision( - phi_op_type, phi::Backend::GPUDNN, data_type, layout); - - if (!res) { - auto& all_kernels = OperatorWithKernel::AllOpKernels(); - auto it = all_kernels.find(op_type); - if (it != all_kernels.end()) { - for (auto& kern_pair : it->second) { - if (platform::is_gpu_place(kern_pair.first.place_) && - kern_pair.first.data_type_ == framework::proto::VarType::FP16) { - res = true; - } - } - } - } - return res; -} - -class ConvertToMixedPrecisionPass { - public: - explicit ConvertToMixedPrecisionPass( - const std::string& model_file, - const std::string& params_file, - const std::string& mixed_model_file, - const std::string& mixed_params_file, - phi::DataType mixed_precision, - phi::Backend backend, - bool keep_io_types, - std::unordered_set black_list) - : model_file_(model_file), - params_file_(params_file), - mixed_model_file_(mixed_model_file), - mixed_params_file_(mixed_params_file), - mixed_precision_(mixed_precision), - backend_(backend), - keep_io_types_(keep_io_types), - black_list_(black_list), - place_(paddle::CPUPlace()), - executor_(place_) { - black_list_.insert("assign"); - black_list_.insert("fill_constant"); - black_list_.insert("assign_value"); - black_list_.insert("eye"); - black_list_.insert("fill_any_like"); - black_list_.insert("fill_constant_batch_size_like"); - } - void Run(); - - private: - void LoadAndPrepare(); - inline bool NodeVarHasDtype(framework::ir::Node* node); - void ConvertAllFp64ToFp32(framework::ir::Graph* graph); - void FixCastAttr(framework::ir::Graph* graph); - void SaveMixedModel(); - void ConvertTensorDtype(int block_idx); - void ProcessInputNode(bool support_precision, - ir::Node* in_node, - ir::Node* op_node, - int* suffix, - framework::BlockDesc* block_desc, - framework::proto::VarType::Type to_type, - int block_idx); - - void ProcessOutputNode(int block_idx, - ir::Node* var_node, - framework::proto::VarType::Type to_type); - inline bool IsFloatVarType(framework::proto::VarType::Type type); - - bool OutShouldNotConvert(ir::Node* var_node); - // Just process special cases for weights conversion. - bool WeightsShouldNotConvert(ir::Node* var_node); - - // To support multi block, we need to consider a lot of special cases. - // Return Node* which first appers in block. - framework::ir::Node* GetRealNode(int block_idx, framework::ir::Node* node); - void FindVarsInMultiBlock(); - inline bool VarIsMultiPrecisionOpsOut(int block_idx, - framework::ir::Node* op_node); - - private: - // A trick. Patch for strange op, which input name equal to output name, such - // as `fused_multi_transformer` - void PatchForStrangeOp(); - - private: - std::string model_file_; - std::string params_file_; - std::string mixed_model_file_; - std::string mixed_params_file_; - phi::DataType mixed_precision_; - phi::Backend backend_; - bool keep_io_types_; - std::unordered_set black_list_; - paddle::CPUPlace place_; - framework::Executor executor_; - framework::Scope scope_; - - std::unordered_map cast_map_; - std::unordered_map> - vars_in_multi_block_map_; - std::vector>> - vars_appear_multi_in_one_block_; - int suffix_{0}; - - std::unique_ptr program_desc_{nullptr}; - std::unique_ptr main_graph_{nullptr}; - std::vector graphes_; -}; - -framework::ir::Node* ConvertToMixedPrecisionPass::GetRealNode( - int block_idx, framework::ir::Node* node) { - if (vars_in_multi_block_map_.count(node->Name())) { - int var_origin_block_id = vars_in_multi_block_map_.at(node->Name()).second; - if (block_idx != var_origin_block_id) { - auto graph = graphes_[var_origin_block_id]; - for (auto nd : graph->Nodes()) { - if (nd->Name() == node->Name()) { - return nd; - } - } - } - } - - return node; -} - -inline bool ConvertToMixedPrecisionPass::NodeVarHasDtype( - framework::ir::Node* node) { - if (node->IsVar() && - (node->Var()->GetType() == - paddle::framework::proto::VarType::SELECTED_ROWS || - node->Var()->GetType() == - paddle::framework::proto::VarType::LOD_TENSOR || - node->Var()->GetType() == - paddle::framework::proto::VarType::LOD_TENSOR_ARRAY || - node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS || - node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB)) { - return true; - } - - return false; -} - -// op1(fp32) -> var1, op2(fp16) -> var1 -// if and only if op1 and op2 both support fp16, we convert op1 and op2's -// precision. -inline bool ConvertToMixedPrecisionPass::VarIsMultiPrecisionOpsOut( - int block_idx, framework::ir::Node* op_node) { - CHECK_EQ(op_node->IsOp(), true); - bool ret{false}; - - for (auto* out : op_node->outputs) { - auto* real_node = GetRealNode(block_idx, out); - if (!real_node->Var()->Persistable() && - vars_appear_multi_in_one_block_[block_idx].count(out->Name())) { - for (auto op_type : - vars_appear_multi_in_one_block_[block_idx].at(out->Name())) { - if (OpSupportPrecision( - op_type, backend_, mixed_precision_, black_list_)) { - ret = true; - VLOG(2) << out->Name() - << " is multi precision op's out, so we skip convert to fp16"; - break; - } - } - } - if (ret) break; - } - return ret; -} - -void ConvertToMixedPrecisionPass::ProcessInputNode( - bool support_precision, - ir::Node* in_node, - ir::Node* op_node, - int* suffix, - framework::BlockDesc* block_desc, - framework::proto::VarType::Type to_type, - int block_idx) { - auto* real_node = GetRealNode(block_idx, in_node); - if (!NodeVarHasDtype(real_node)) return; - auto graph = graphes_[block_idx]; - bool is_main_block = block_idx == 0; - auto* in_var = real_node->Var(); - auto in_var_type = in_var->GetDataType(); - auto prev_type = in_var_type; - bool is_in_multi_block = vars_in_multi_block_map_.count(in_var->Name()); - - if (!is_main_block && is_in_multi_block) { - in_var_type = vars_in_multi_block_map_.at(in_var->Name()).first; - } - if (support_precision) { - if (in_var->Persistable() && - in_var_type == framework::proto::VarType::FP32) { - if (WeightsShouldNotConvert(in_node)) return; - in_var->SetDataType(to_type); - in_var_type = to_type; - VLOG(3) << " in_node name " << in_var->Name() << " from " << prev_type - << " to " << to_type; - } else if (!in_var->Persistable() && IsFloatVarType(in_var_type) && - in_var_type != to_type) { - AddCastOp(graph, - in_node, - op_node, - in_var_type, - to_type, - suffix, - block_desc, - &cast_map_); - VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type - << ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")"; - } - } else { - if (!in_var->Persistable() && IsFloatVarType(in_var_type) && - in_var_type != to_type) { - AddCastOp(graph, - in_node, - op_node, - in_var_type, - to_type, - suffix, - block_desc, - &cast_map_); - VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type - << ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")"; - } + bool keep_io_types, + const std::unordered_set& black_list) + : model_file_(model_file), + params_file_(params_file), + mixed_model_file_(mixed_model_file), + mixed_params_file_(mixed_params_file), + mixed_precision_(mixed_precision), + backend_(backend), + keep_io_types_(keep_io_types), + black_list_(black_list) { + if (mixed_precision_ != phi::DataType::FLOAT16 && + mixed_precision_ != phi::DataType::BFLOAT16) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "mixed_precision currently not supported dtype %d, we now only " + "support fp16 and bf16.", + static_cast(mixed_precision_))); } -} - -void ConvertToMixedPrecisionPass::ProcessOutputNode( - int block_idx, - ir::Node* var_node, - framework::proto::VarType::Type to_type) { - auto* real_node = GetRealNode(block_idx, var_node); - if (!NodeVarHasDtype(real_node)) return; - auto* out_var = real_node->Var(); - auto prev_type = out_var->GetDataType(); - if (out_var->GetDataType() == framework::proto::VarType::FP32) { - if (OutShouldNotConvert(var_node)) return; - out_var->SetDataType(to_type); + if (backend_ != phi::Backend::GPU) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "mixed_precision currently not supported place %d, we now only " + "support gpu.", + static_cast(backend_))); } - VLOG(3) << " out_node name " << var_node->Name() << " from dtype " - << prev_type << " to " << out_var->GetDataType(); } -// Just process special cases. -bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) { - auto op_node = var_node->inputs[0]; - auto* op_desc = op_node->Op(); - - // batch_norm's input and output (variance and mean) are the same. - if (op_desc->Type() == "batch_norm") { - auto vecs = op_desc->Output("MeanOut"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Output("VarianceOut"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Output("SavedMean"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Output("SavedVariance"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - } - - return false; -} +void ConvertToMixedPrecisionPass::LoadModel() { + framework::Executor exe{platform::CPUPlace{}}; -bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) { - auto op_nodes = var_node->outputs; - for (auto* op_node : op_nodes) { - auto* op_desc = op_node->Op(); - // batch_norm op's bias, mean, scale and variance just be float32, so we can - // not convert the dtype. - if (op_desc->Type() == "batch_norm") { - auto vecs = op_desc->Input("Bias"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Input("Mean"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Input("Scale"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - vecs = op_desc->Input("Variance"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - } else if (op_desc->Type() == "fused_multi_transformer") { - auto vecs = op_desc->Input("LnScale"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - - vecs = op_desc->Input("LnBias"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - - vecs = op_desc->Input("FFNLnScale"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - - vecs = op_desc->Input("FFNLnBias"); - if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { - return true; - } - } - } - - return false; -} -inline bool ConvertToMixedPrecisionPass::IsFloatVarType( - framework::proto::VarType::Type type) { - if (type == framework::proto::VarType::FP16 || - type == framework::proto::VarType::FP32 || - type == framework::proto::VarType::BF16) - return true; - return false; -} - -void ConvertToMixedPrecisionPass::LoadAndPrepare() { - program_desc_ = - inference::Load(&executor_, &scope_, model_file_, params_file_); + auto program_desc = inference::Load(&exe, &scope_, model_file_, params_file_); main_graph_ = std::unique_ptr( - new framework::ir::Graph(*program_desc_)); - - // Remove all control var - IrInferCleanGraphPass pass; - Argument arg; - arg.SetMainGraphNotOwned(main_graph_.get()); - pass.Run(&arg); - - vars_appear_multi_in_one_block_.resize(program_desc_->Size()); - FindVarsInMultiBlock(); -} - -void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() { - std::vector> block_var_names_set(program_desc_->Size()); - for (size_t i = 0; i < program_desc_->Size(); ++i) { - for (auto op : program_desc_->Block(i).AllOps()) { - auto in_names = op->InputArgumentNames(); - block_var_names_set[i].insert(in_names.begin(), in_names.end()); - auto out_names = op->OutputArgumentNames(); - if (op->HasAttr("sub_block") == false) { - for (auto& n : out_names) { - if (block_var_names_set[i].count(n)) { - vars_appear_multi_in_one_block_[i][n].push_back(op->Type()); - } - } - } - block_var_names_set[i].insert(out_names.begin(), out_names.end()); - } - } - - for (size_t i = 0; i < program_desc_->Size() - 1; ++i) { - for (size_t j = i + 1; j < program_desc_->Size(); ++j) { - std::set vars_in_multi_block; - std::set_intersection( - block_var_names_set[i].begin(), - block_var_names_set[i].end(), - block_var_names_set[j].begin(), - block_var_names_set[j].end(), - std::inserter(vars_in_multi_block, vars_in_multi_block.begin())); - - for (auto name : vars_in_multi_block) { - vars_in_multi_block_map_.emplace( - name, std::make_pair(framework::proto::VarType::FP32, i)); - } - } - } -} - -void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32( - framework::ir::Graph* graph) { - auto op_nodes = framework::ir::TopologySortOperations(*graph); - for (auto* op_node : op_nodes) { - if (!op_node->IsOp()) continue; - auto op_type = op_node->Op()->Type(); - if (op_type == "feed" || op_type == "fetch") continue; - - if (op_type == "fill_constant") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "dtype", static_cast(framework::proto::VarType::FP32)); - } else if (op_type == "assign_value") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "dtype", static_cast(framework::proto::VarType::FP32)); - } else if (op_type == "eye") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "dtype", static_cast(framework::proto::VarType::FP32)); - } else if (op_type == "fill_any_like") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "dtype", static_cast(framework::proto::VarType::FP32)); - } else if (op_type == "cast") { - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "in_dtype", static_cast(framework::proto::VarType::FP32)); - if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) == - static_cast(framework::proto::VarType::FP64)) - op_node->Op()->SetAttr( - "out_dtype", static_cast(framework::proto::VarType::FP32)); - } - - auto inputs = op_node->inputs; - for (auto* in_node : inputs) { - auto* in_var = in_node->Var(); - if (!in_var->Persistable() && - in_var->GetDataType() == framework::proto::VarType::FP64) { - in_var->SetDataType(framework::proto::VarType::FP32); - } - } - } + new framework::ir::Graph(*program_desc)); + main_graph_->SetNotOwned(framework::ir::kParamScopeAttr, &scope_); } void ConvertToMixedPrecisionPass::Run() { - LoadAndPrepare(); + LoadModel(); - for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) { - auto graph = main_graph_->GetSubGraph(i); - graphes_.push_back(graph); - VLOG(2) << " -------- handle subgraph " << i << ", has " - << graph->Nodes().size() << " nodes --------"; + framework::ir::AutoMixedPrecisionPass pass; + pass.Set("mixed_precision_mode", new int{static_cast(mixed_precision_)}); + pass.Set("mixed_black_list", + new std::unordered_set{black_list_}); + pass.Set("enable_gpu_mixed", new bool{true}); + pass.Set("keep_io_types", new bool{keep_io_types_}); - ConvertAllFp64ToFp32(graph); - ConvertTensorDtype(i); - FixCastAttr(graph); - - // A trick - PatchForStrangeOp(); - - CHECK_EQ(ir::VarDescIsConsistency(*graph), true); - } + pass.Apply(main_graph_.get()); SaveMixedModel(); } -void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { - auto graph = graphes_[block_idx]; - framework::proto::VarType::Type to_type; - if (mixed_precision_ == phi::DataType::FLOAT16) { - to_type = framework::proto::VarType::FP16; - } else if (mixed_precision_ == phi::DataType::BFLOAT16) { - to_type = framework::proto::VarType::BF16; - } else { - PADDLE_THROW(paddle::platform::errors::InvalidArgument( - "mixed_precision currently not supported dtype %d, we now only " - "support fp16 and bf16.", - static_cast(mixed_precision_))); - } - - auto op_nodes = framework::ir::TopologySortOperations(*graph); - auto* block_desc = op_nodes[0]->Op()->Block(); - int num_low_precision = 0; - std::vector output_nodes; - - for (auto* op_node : op_nodes) { - if (!op_node->IsOp()) continue; - auto op_type = op_node->Op()->Type(); - VLOG(3) << "-------------------- op_type " << op_type << ", phi_type " - << phi::TransToPhiKernelName(op_type); - // 1. set input dtype. - if (op_type == "feed") { - auto feed_var = op_node->outputs[0]->Var(); - if (!keep_io_types_ && - feed_var->GetDataType() == framework::proto::VarType::FP32) { - feed_var->SetDataType(to_type); - } - } else if (op_type == "fetch") { - auto* fetch_var = op_node->inputs[0]; - output_nodes.push_back(fetch_var); - continue; - } else if (op_type == "cast") { - continue; - } - - else if (op_node->Op()->HasAttr("sub_block")) { // NOLINT - // sub_block op's output dtype should be same as input dtype, if have the - // same name. - std::unordered_map in_name_to_node; - for (auto* in : op_node->inputs) { - auto* real_node = GetRealNode(block_idx, in); - if (NodeVarHasDtype(real_node)) { - in_name_to_node[in->Name()] = in; - } - } - - for (auto out : op_node->outputs) { - auto* real_node = GetRealNode(block_idx, out); - if (NodeVarHasDtype(real_node)) { - if (in_name_to_node.count(out->Name())) - real_node->Var()->SetDataType( - in_name_to_node[out->Name()]->Var()->GetDataType()); - } - } - - continue; - } - - // 2. if op support fp16/bf16 and not in blacklist. - // - cast weight to fp16/bf16. - // - add cast op if the input dtype is not fp16/bf16. - // - set output dtype. - // - // If a var(op's out var) appears multiple times in a block, we should not - // convert to fp16. - else if (black_list_.count(op_type) == 0 && // NOLINT - !VarIsMultiPrecisionOpsOut(block_idx, op_node)) { - bool support_precision = - OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_); - - // if op not has float input, we will not choose the low precision kernel. - { - bool has_float_input{false}; - for (auto in_node : op_node->inputs) { - auto* real_node = GetRealNode(block_idx, in_node); - if (real_node->Var()->GetDataType() == proto::VarType::FP16 || - real_node->Var()->GetDataType() == proto::VarType::FP32 || - real_node->Var()->GetDataType() == proto::VarType::FP64 || - real_node->Var()->GetDataType() == proto::VarType::BF16) { - has_float_input = true; - break; - } - } - if (!has_float_input) { - support_precision = false; - VLOG(2) << " op doesn't has float input, just skip."; - } - } - VLOG(2) << " support low precision " << support_precision; - - if (support_precision) { - VLOG(2) << " process input nodes:"; - ++num_low_precision; - auto inputs = op_node->inputs; - - // Just for paddle's terriable case: op's input and output has the same - // name. - std::unordered_map names_map; - for (auto out_node : op_node->outputs) { - for (auto in_node : op_node->inputs) { - if (out_node->Name() == in_node->Name()) { - names_map[out_node->Name()] = in_node->Name(); - } - } - } - - // Process inputs. - for (auto* in_node : inputs) { - ProcessInputNode( - true, in_node, op_node, &suffix_, block_desc, to_type, block_idx); - if (names_map.count(in_node->Name()) && cast_map_.count(in_node)) { - names_map[in_node->Name()] = cast_map_[in_node]->Name(); - } - } - VLOG(2) << " process output nodes:"; - // Process outputs. - for (auto* out_node : op_node->outputs) { - ProcessOutputNode(block_idx, out_node, to_type); - } - } else { - auto inputs = op_node->inputs; - for (auto* in_node : inputs) { - ProcessInputNode(false, - in_node, - op_node, - &suffix_, - block_desc, - framework::proto::VarType::FP32, - block_idx); - } - } - } - - // 3. check op not support fp16/bf16 or in blacklist. - // - add cast op if the input dtype is not fp32. - else { // NOLINT - VLOG(3) << "not to run fp16 op_type: " << op_type; - auto ins = op_node->inputs; - for (auto* in_node : ins) { - auto* in_var = in_node->Var(); - if (in_var->GetDataType() == to_type) { - AddCastOp(graph, - in_node, - op_node, - to_type, - framework::proto::VarType::FP32, - &suffix_, - block_desc, - &cast_map_); - VLOG(3) << "-- " << in_node->Name() << "(" << to_type << ") to " - << cast_map_[in_node]->Name() << "(" - << framework::proto::VarType::FP32 << ")"; - } - } - } - } - - // 4. if output_op's dtype is not compatible to output dtype, then just - // insert cast. - for (auto* node : output_nodes) { - ir::Node* fetch_op{nullptr}; - for (auto* op_node : node->outputs) { - if (op_node->IsOp() && op_node->Op()->Type() == "fetch") { - fetch_op = op_node; - } - } - CHECK_NOTNULL(fetch_op); - auto var = node->Var(); - if (keep_io_types_ && var->GetDataType() == to_type) { - // fp16/bf16 -> fp32. - AddCastOp(graph, - node, - fetch_op, - to_type, - framework::proto::VarType::FP32, - &suffix_, - block_desc, - &cast_map_); - } else if (!keep_io_types_ && - var->GetDataType() == framework::proto::VarType::FP32) { - // fp32 -> fp16/bf16 - AddCastOp(graph, - node, - fetch_op, - framework::proto::VarType::FP32, - to_type, - &suffix_, - block_desc, - &cast_map_); - } - } - - for (auto node : graph->Nodes()) { - auto* real_node = GetRealNode(block_idx, node); - if (!NodeVarHasDtype(real_node)) continue; - - if (vars_in_multi_block_map_.count(real_node->Name()) && - vars_in_multi_block_map_.at(real_node->Name()).second == block_idx) { - vars_in_multi_block_map_.at(real_node->Name()).first = - real_node->Var()->GetDataType(); - } - } - - if (num_low_precision) - LOG(INFO) << "--- detected " << num_low_precision - << " low precision ops in " << block_idx << " subgraph"; -} - -// We modify op's input output precision, and we need to fix cast op in_dtype -// and out_dtype attribute. -void ConvertToMixedPrecisionPass::FixCastAttr(framework::ir::Graph* graph) { - auto op_nodes = framework::ir::TopologySortOperations(*graph); - for (auto* op_node : op_nodes) { - if (!op_node->IsOp()) continue; - auto op_type = op_node->Op()->Type(); - if (op_type != "cast") continue; - auto input = op_node->inputs[0]; - auto output = op_node->outputs[0]; - op_node->Op()->SetAttr("in_dtype", - static_cast(input->Var()->GetDataType())); - op_node->Op()->SetAttr("out_dtype", - static_cast(output->Var()->GetDataType())); - } -} - void ConvertToMixedPrecisionPass::SaveMixedModel() { framework::ProgramDesc mixed_program_desc; framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc); - paddle::CPUPlace place; auto parameters = scope_.LocalVarNames(); std::sort(parameters.begin(), parameters.end()); - std::unordered_set weights_should_be_fp32; - for (auto* node : main_graph_->Nodes()) { - if (!(node->IsVar())) continue; - if (NodeVarHasDtype(node)) { - if (node->Var()->Persistable() && - node->Var()->GetDataType() == - paddle::framework::proto::VarType::FP32) { - VLOG(2) << "weights keep to fp32: " << node->Name(); - weights_should_be_fp32.insert(node->Name()); - } - } - } - -#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \ - mixed_tensor.set_type(DTYPE); \ - auto* mixed_data = mixed_tensor.mutable_data(platform::CPUPlace()); \ - for (int i = 0; i < t->numel(); i++) { \ - mixed_data[i] = static_cast(data[i]); \ - } \ - t->clear(); \ - paddle::framework::TensorCopySync(mixed_tensor, place, t) - - for (const auto& param_name : parameters) { - auto* var = scope_.FindLocalVar(param_name); - if (var->IsType()) { - auto* t = var->GetMutable(); - if (t->dtype() != phi::DataType::FLOAT32) continue; - phi::DenseTensor mixed_tensor; - mixed_tensor.Resize(t->dims()); - auto* data = t->mutable_data(platform::CPUPlace()); - if (mixed_precision_ == phi::DataType::FLOAT16 && - !weights_should_be_fp32.count(param_name)) { - CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16, - phi::dtype::float16); - } else if (mixed_precision_ == phi::DataType::BFLOAT16 && - !weights_should_be_fp32.count(param_name)) { - CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16, - phi::dtype::bfloat16); - } - } - } - -#undef CONVERT_TENSOR_DTYPE - auto SerializeParams = [&]() -> std::string { std::ostringstream os; phi::CPUContext ctx; for (const auto& param : parameters) { - VLOG(3) << "Serialize param: " << param; PADDLE_ENFORCE_NOT_NULL( scope_.FindVar(param), platform::errors::NotFound( "Block should already have a '%s' variable", param)); - auto* tensor = scope_.FindVar(param)->GetMutable(); + auto* tensor = scope_.FindVar(param)->GetMutable(); framework::SerializeToStream(os, *tensor, ctx); } return os.str(); @@ -831,96 +112,42 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { StrToBinary(mixed_params_file_, SerializeParams()); } -void ConvertToMixedPrecisionPass::PatchForStrangeOp() { - for (auto* graph : graphes_) { - for (auto op_node : framework::ir::TopologySortOperations(*graph)) { - if (op_node->Name() == "fused_multi_transformer") { - auto cache_kv_inputs = op_node->Op()->Input("CacheKV"); - auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut"); - CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size()); - for (size_t i = 0; i < cache_kv_inputs.size(); ++i) { - op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]); - } - } - } - } +bool OpSupportPrecision(const std::string& op_type, + phi::Backend backend, + phi::DataType precision, + const std::unordered_set& black_list) { + return framework::ir::OpSupportPrecision( + op_type, backend, precision, black_list); } -} // namespace -void AddCastOp( +void InsertCastOp( framework::ir::Graph* graph, - framework::ir::Node* node, - framework::ir::Node* next_op, + framework::ir::Node* var_node, + framework::ir::Node* op_node, framework::proto::VarType::Type from_type, framework::proto::VarType::Type to_type, - int* suffix, framework::BlockDesc* block_desc, - std::unordered_map* map) { - auto update_cast_desc = [&](framework::OpDesc& desc, - const std::string& x_name, - const std::string& out_name, - const int in_dtype, - const int out_dtype) { - desc.SetType("cast"); - desc.SetInput("X", {x_name}); - desc.SetOutput("Out", {out_name}); - desc.SetAttr("in_dtype", in_dtype); - desc.SetAttr("out_dtype", out_dtype); - desc.SetAttr("use_mkldnn", false); - desc.SetAttr("with_quant_attr", false); - desc.Flush(); - }; - - if (map->count(node) == 0) { - // insert cast op before node. - std::string cast_input_name = node->Var()->Name(); - std::string cast_output_name = - node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++); - CHECK_NOTNULL(block_desc); - framework::OpDesc cast_op_desc(block_desc); - update_cast_desc(cast_op_desc, - cast_input_name, - cast_output_name, - static_cast(from_type), - static_cast(to_type)); - auto* cast_op_node = graph->CreateOpNode(&cast_op_desc); - auto* cast_output_vardesc = block_desc->Var(cast_output_name); - cast_output_vardesc->SetPersistable(false); - cast_output_vardesc->SetDataType(to_type); - cast_output_vardesc->SetShape(node->Var()->GetShape()); - auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc); - IR_NODE_LINK_TO(cast_op_node, cast_output_node); - (*map)[node] = cast_output_node; - } - next_op->Op()->Rename(node->Name(), map->at(node)->Name()); - IR_NODE_LINK_TO(node, map->at(node)->inputs[0]); - IR_NODE_LINK_TO(map->at(node), next_op); -} - -bool OpSupportPrecision(const std::string& op_type, - phi::Backend backend, - phi::DataType precision, - const std::unordered_set& blacklist) { - auto phi_op_type = phi::TransToPhiKernelName(op_type); - bool support_precision = false; - if (blacklist.count(op_type) == 0) { - if (backend == phi::Backend::GPU) - support_precision = GpuKernelSupportPrecision(op_type, precision); - else - support_precision = - PhiKernelSupportPrecision(phi_op_type, backend, precision); - } - return support_precision; -} - -void ConvertToMixedPrecision(const std::string& model_file, - const std::string& params_file, - const std::string& mixed_model_file, - const std::string& mixed_params_file, - phi::DataType mixed_precision, - phi::Backend backend, - bool keep_io_types, - std::unordered_set black_list) { + int* suffix, + std::unordered_map* visited) { + framework::ir::DoInsertCastOp(graph, + var_node, + op_node, + from_type, + to_type, + block_desc, + suffix, + visited); +} + +void ConvertToMixedPrecision( + const std::string& model_file, + const std::string& params_file, + const std::string& mixed_model_file, + const std::string& mixed_params_file, + phi::DataType mixed_precision, + phi::Backend backend, + bool keep_io_types, + const std::unordered_set& black_list) { ConvertToMixedPrecisionPass pass(model_file, params_file, mixed_model_file, diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h index 3b763a4420ed0..3a1e5fbb30a21 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h @@ -15,14 +15,12 @@ #pragma once #include -#include #include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" #include "paddle/phi/common/backend.h" #include "paddle/phi/common/data_type.h" @@ -30,20 +28,52 @@ namespace paddle { namespace inference { namespace analysis { +class ConvertToMixedPrecisionPass { + public: + explicit ConvertToMixedPrecisionPass( + const std::string& model_file, + const std::string& params_file, + const std::string& mixed_model_file, + const std::string& mixed_params_file, + phi::DataType mixed_precision, + phi::Backend backend, + bool keep_io_types, + const std::unordered_set& black_list); + + void Run(); + + private: + void LoadModel(); + void SaveMixedModel(); + + private: + std::string model_file_; + std::string params_file_; + std::string mixed_model_file_; + std::string mixed_params_file_; + phi::DataType mixed_precision_; + phi::Backend backend_; + bool keep_io_types_; + std::unordered_set black_list_; + + framework::Scope scope_; + std::unique_ptr main_graph_{nullptr}; +}; + bool OpSupportPrecision(const std::string& op_type, phi::Backend backend, phi::DataType precision, - const std::unordered_set& blacklist); + const std::unordered_set& black_list); -void AddCastOp( +void InsertCastOp( framework::ir::Graph* graph, - framework::ir::Node* node, - framework::ir::Node* next_op, + framework::ir::Node* var_node, + framework::ir::Node* op_node, framework::proto::VarType::Type from_type, framework::proto::VarType::Type to_type, - int* suffix, framework::BlockDesc* block_desc, - std::unordered_map* map); + int* suffix, + std::unordered_map* visited); void ConvertToMixedPrecision(const std::string& model_file, const std::string& params_file, @@ -51,8 +81,8 @@ void ConvertToMixedPrecision(const std::string& model_file, const std::string& mixed_params_file, phi::DataType mixed_precision, phi::Backend backend, - bool keep_io_types = true, - std::unordered_set black_list = {}); + bool keep_io_types, + const std::unordered_set& black_list); } // namespace analysis } // namespace inference diff --git a/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc index ed45ec3301d1d..126d16933fd82 100644 --- a/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc +++ b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc @@ -40,7 +40,7 @@ void InferenceOpReplacePass::RunImpl(Argument* argument) { } std::string InferenceOpReplacePass::repr() const { - return "inference-op-replace-pass"; + return "inference_op_replace_pass"; } } // namespace analysis diff --git a/paddle/fluid/inference/analysis/passes/ir_analysis_pass.cc b/paddle/fluid/inference/analysis/passes/ir_analysis_pass.cc index 53398a69536b9..12b18ac53e368 100644 --- a/paddle/fluid/inference/analysis/passes/ir_analysis_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_analysis_pass.cc @@ -105,7 +105,7 @@ void IrAnalysisPass::CollectFusionStatis(Argument* argument) { framework::ir::kFuseStatisAttr)); } -std::string IrAnalysisPass::repr() const { return "ir-analysis-pass"; } +std::string IrAnalysisPass::repr() const { return "ir_analysis_pass"; } } // namespace analysis } // namespace inference diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc index e07eaa64615c8..df0ffc534b71c 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc @@ -64,7 +64,8 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { "set.")); } - auto graph = std::unique_ptr(new Graph(argument->main_program())); + auto graph = std::unique_ptr( + new framework::ir::Graph(argument->main_program())); argument->SetMainGraph(graph.release()); auto *scope_ptr = argument->scope_ptr(); PADDLE_ENFORCE_NOT_NULL(scope_ptr, @@ -125,7 +126,7 @@ std::unique_ptr IrGraphBuildPass::LoadModel( } } -std::string IrGraphBuildPass::repr() const { return "ir-graph-build-pass"; } +std::string IrGraphBuildPass::repr() const { return "ir_graph_build_pass"; } } // namespace analysis } // namespace inference diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.cc deleted file mode 100644 index 6c18c62563716..0000000000000 --- a/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.cc +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h" - -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/node.h" - -namespace paddle { -namespace inference { -namespace analysis { - -void IrInferCleanGraphPass::RunImpl(Argument* argument) { - auto& graph = argument->main_graph(); - auto is_valid_node = [](framework::ir::Node* x) { - return x && IsControlDepVar(*x) && x->IsVar() && !x->Var(); - }; - - std::unordered_set invalid_nodes; - int valid_op = 0; - for (auto* node : graph.Nodes()) { - PADDLE_ENFORCE_NOT_NULL(node, - platform::errors::PreconditionNotMet( - "The node should not be nullptr.")); - if (is_valid_node(node)) { - invalid_nodes.insert(node); - } else if (node->IsOp()) { - ++valid_op; - } - } - - GraphSafeRemoveNodes(&graph, invalid_nodes); -} - -} // namespace analysis -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h b/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h deleted file mode 100644 index a4d60e91e8455..0000000000000 --- a/paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "paddle/fluid/inference/analysis/analysis_pass.h" - -namespace paddle { -namespace inference { -namespace analysis { - -struct Argument; - -class IrInferCleanGraphPass : public AnalysisPass { - public: - void RunImpl(Argument *argument) override; - - std::string repr() const override { return "ir_graph_clean_pass"; } -}; - -} // namespace analysis -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc index 999fb4ad8d764..3d86f7bf399a9 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc @@ -31,7 +31,7 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) { new int(argument->memory_optim_sort_kind())); } - std::unique_ptr graph(argument->main_graph_ptr()); + std::unique_ptr graph(argument->main_graph_ptr()); // Direct using ProgramDesc desc(argument->main_program()) may cause // incomplete copies of information. diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h b/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h index 5b20667d62ab6..8e90eb0e20d57 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h +++ b/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h @@ -28,7 +28,7 @@ class IrGraphToProgramPass : public AnalysisPass { public: void RunImpl(Argument *argument) override; - std::string repr() const override { return "ir-graph-to-param-pass"; } + std::string repr() const override { return "ir_graph_to_param_pass"; } }; } // namespace analysis diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index 5ec9ca03fafc3..8961cbb5b6e47 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -169,7 +169,7 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { } std::string IrParamsSyncAmongDevicesPass::repr() const { - return "ir-params-sync-among-devices-pass"; + return "ir_params_sync_among_devices_pass"; } } // namespace analysis diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc index 775b61e9494ee..63aaa7d97967a 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.cc @@ -295,7 +295,7 @@ void UpdateOpDescsByReuse( } } -std::string MemoryOptimizePass::repr() const { return "memory optimize pass"; } +std::string MemoryOptimizePass::repr() const { return "memory_optimize_pass"; } void MemoryOptimizePass::RunImpl(Argument* argument) { // Memory optimization. diff --git a/paddle/fluid/inference/analysis/passes/passes.cc b/paddle/fluid/inference/analysis/passes/passes.cc index 19aab1a948dd2..cd65757d08f3f 100644 --- a/paddle/fluid/inference/analysis/passes/passes.cc +++ b/paddle/fluid/inference/analysis/passes/passes.cc @@ -18,7 +18,6 @@ #include "paddle/fluid/inference/analysis/passes/inference_op_replace_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h" -#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" @@ -34,8 +33,6 @@ PassRegistry::PassRegistry() { std::unique_ptr(new IrAnalysisPass)); passes_.emplace("ir_graph_build_pass", std::unique_ptr(new IrGraphBuildPass)); - passes_.emplace("ir_graph_clean_pass", - std::unique_ptr(new IrInferCleanGraphPass)); passes_.emplace("memory_optimize_pass", std::unique_ptr(new MemoryOptimizePass)); passes_.emplace( diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 58573db22af45..87c622cf50905 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -85,15 +85,29 @@ void AnalysisConfig::SetModel(const std::string &prog_file_path, Update(); } + void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, - int device_id) { + int device_id, + Precision precision_mode) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) use_gpu_ = true; memory_pool_init_size_mb_ = memory_pool_init_size_mb; FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_; gpu_device_id_ = device_id; + mixed_precision_mode_ = precision_mode; + if (precision_mode == Precision::kFloat32) { + // default + } else if (precision_mode == Precision::kHalf || + precision_mode == Precision::kBf16) { + enable_gpu_mixed_ = true; + } else { + LOG(ERROR) + << "The Paddle-GPU inference currently only supports " + "float32/float16/bfloat16 precision. Please check the parameters " + "you specified in EnableUseGpu or enable_use_gpu function."; + } #else - LOG(ERROR) << "Please compile with gpu to EnableGpu()"; + LOG(ERROR) << "Please use PaddlePaddle with GPU version."; use_gpu_ = false; #endif @@ -279,7 +293,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) { if (ipu_config_mapper_.find(key) == ipu_config_mapper_.end()) { PADDLE_THROW(platform::errors::InvalidArgument( - "invalid key {} in IPU config", key)); + "invalid key %s in IPU config: ", key)); } switch (ipu_config_mapper_.at(key)) { case ipu_config_code::ipu_device_num: @@ -315,7 +329,7 @@ void AnalysisConfig::LoadIpuConfig(const std::string &config_path) { default: PADDLE_THROW(platform::errors::InvalidArgument( - "invalid key {} in IPU config", key)); + "invalid key %s in IPU config", key)); break; } } @@ -372,8 +386,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(gpu_device_id_); CP_MEMBER(memory_pool_init_size_mb_); - // Mixed related. + // Mixed precision related. CP_MEMBER(mixed_black_list_); + CP_MEMBER(enable_gpu_mixed_); + CP_MEMBER(mixed_precision_mode_); CP_MEMBER(enable_memory_optim_); // TensorRT related. @@ -740,13 +756,7 @@ void AnalysisConfig::Update() { ((use_custom_device() ^ pass_builder_->use_custom_device()))) { if (use_gpu()) { pass_builder_.reset(new GpuPassStrategy); - - if (use_tensorrt_) { - // Append after the Affine_channel_conv_fuse pass. - pass_builder()->InsertPass(3, "tensorrt_subgraph_pass"); - } } else if (use_ipu()) { - VLOG(1) << "IpuPassStrategy has been used for new."; pass_builder_.reset(new IpuPassStrategy); } else if (use_xpu()) { PADDLE_ENFORCE_EQ( @@ -946,9 +956,6 @@ void AnalysisConfig::Update() { "but did not have the option -DWITH_CUSTOM_DEVICE compiled.")); #endif } - if (ir_debug_) { - pass_builder()->TurnOnDebug(); - } } std::string AnalysisConfig::SerializeInfoCache() { @@ -960,6 +967,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << calibration_file_path_; ss << use_gpu_; + ss << enable_gpu_mixed_; ss << use_external_stream_; ss << exec_stream_; ss << use_fc_padding_; @@ -1167,6 +1175,7 @@ std::string AnalysisConfig::Summary() { os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"}); if (use_gpu_) { os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)}); + os.InsertRow({"enable_gpu_mixed", std::to_string(enable_gpu_mixed_)}); os.InsertRow({"memory_pool_init_size", std::to_string(memory_pool_init_size_mb_) + "MB"}); os.InsertRow( @@ -1360,7 +1369,7 @@ bool AnalysisConfig::trt_allow_build_at_runtime() { return trt_allow_build_at_runtime_; } -void AnalysisConfig::Exp_SetBlackListOpsForMixedModel( +void AnalysisConfig::Exp_DisableMixedPrecisionOps( const std::unordered_set &black_list) { mixed_black_list_ = black_list; } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 8ec16dad3c19d..7bd14ca05ecdd 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1065,7 +1065,7 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetUseGPU(config_.use_gpu()); argument_.SetUseFcPadding(config_.use_fc_padding()); argument_.SetGPUDeviceId(config_.gpu_device_id()); - argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_); + argument_.SetEnableIrOptim(config_.enable_ir_optim_); argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); argument_.SetModelFromMemory(config_.model_from_memory_); // Analyze inference_program @@ -1210,53 +1210,57 @@ void AnalysisPredictor::PrepareArgument() { } #endif - auto passes = config_.pass_builder()->AllPasses(); + auto *pass_builder = config_.pass_builder(); if (model_precision_ != phi::DataType::FLOAT32) { LOG(INFO) << "Model is mixed precision type with " << model_precision_ << ", we will use a new PassStrategy. Note that only the GPU " "backend is supported for now."; - passes.clear(); + pass_builder->ClearPasses(); + const auto &deleted_passes = pass_builder->GetAllDeletedPasses(); if (config_.tensorrt_engine_enabled()) { for (const auto &pass : kTrtLowerPrecisionPasses) { - passes.push_back(pass); + if (deleted_passes.count(pass)) continue; + pass_builder->AppendPass(pass); } } else if (config_.use_gpu()) { for (const auto &pass : kGpuLowerPrecisionPasses) { - passes.push_back(pass); + if (deleted_passes.count(pass)) continue; + pass_builder->AppendPass(pass); } } + } - const auto &deleted_passes = config_.pass_builder()->GetAllDeletedPasses(); - for (const auto &it : deleted_passes) { - auto iterator = std::find(passes.begin(), passes.end(), it); - if (iterator != passes.end()) { - passes.erase(iterator); - } + if (!config_.ir_optim()) { + argument_.SetEnableIrOptim(false); + if (config_.enable_gpu_mixed_) { + argument_.SetEnableIrOptim(true); + pass_builder->ClearPasses(); + pass_builder->AppendPass("auto_mixed_precision_pass"); + LOG(INFO) + << "This model run in Paddle-GPU mixed precision mode with no ir " + "optimization."; + } else { + LOG(INFO) << "ir_optim is turned off, no IR pass will be executed."; } - + } else { if (config_.ir_debug_) { - auto it = std::begin(passes); - while (it != std::end(passes)) { - if (*it != "graph_viz_pass") { - it = passes.insert(it + 1, "graph_viz_pass"); - } else { - ++it; - } - } + pass_builder->TurnOnDebug(); + } + if (config_.enable_gpu_mixed_) { + LOG(INFO) << "This model run in Paddle-GPU mixed precision mode."; } - } - if (!config_.ir_optim()) { - passes.clear(); - LOG(INFO) << "ir_optim is turned off, no IR pass will be executed"; } argument_.SetDisableLogs(config_.glog_info_disabled()); - argument_.SetIrAnalysisPasses(passes); - argument_.SetAnalysisPasses(config_.pass_builder()->AnalysisPasses()); + argument_.SetIrAnalysisPasses(pass_builder->AllPasses()); + argument_.SetAnalysisPasses(pass_builder->AnalysisPasses()); argument_.SetScopeNotOwned(scope_.get()); // mixed precison. argument_.SetModelPrecision(static_cast(model_precision_)); argument_.SetMixedBlackList(config_.mixed_black_list_); + argument_.SetEnableGPUMixed(config_.enable_gpu_mixed_); + argument_.SetMixedPrecisionMode(static_cast( + paddle::ConvertPrecision(config_.mixed_precision_mode_))); } // NOTE All the members in AnalysisConfig should be copied to Argument. @@ -2107,7 +2111,9 @@ std::unique_ptr AnalysisPredictor::Clone(void *stream) { } x->predictor_stream_ = stream; x->Init(scope_, inference_program_); +#ifdef PADDLE_WITH_TENSORRT x->executor_->ResetTrtOps(++AnalysisPredictor::clone_num_); +#endif return std::unique_ptr(x); } diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index bca2cde0fc2c6..293236b111630 100755 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -604,10 +604,8 @@ void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { if (predictor_.config_.ir_debug_) builder->TurnOnDebug(); auto passes = builder->AllPasses(); predictor_.argument_.SetIrAnalysisPasses(passes); - predictor_.argument_.SetAnalysisPasses({"ir_graph_clean_pass", - "ir_analysis_pass", - "memory_optimize_pass", - "ir_graph_to_program_pass"}); + predictor_.argument_.SetAnalysisPasses( + {"ir_analysis_pass", "memory_optimize_pass", "ir_graph_to_program_pass"}); predictor_.argument_.SetQuantVarScales(scales_); } diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 4a71d0966256e..7dfc8d1df41de 100755 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -247,8 +247,12 @@ struct PD_INFER_DECL AnalysisConfig { /// /// \param memory_pool_init_size_mb initial size of the GPU memory pool in MB. /// \param device_id device_id the GPU card to use (default is 0). + /// \param precision the precision used in Paddle-GPU inference. /// - void EnableUseGpu(uint64_t memory_pool_init_size_mb, int device_id = 0); + void EnableUseGpu(uint64_t memory_pool_init_size_mb, + int device_id = 0, + Precision precision_mode = Precision::kFloat32); + /// /// \brief Turn off GPU. /// @@ -967,7 +971,7 @@ struct PD_INFER_DECL AnalysisConfig { /// interface is in the experimental stage and may change in the future. Note /// that the blacklist must be the same as the model conversion blacklist. /// - void Exp_SetBlackListOpsForMixedModel( + void Exp_DisableMixedPrecisionOps( const std::unordered_set& black_list); void SetApplyOptim(bool value) { apply_optim_ = value; } @@ -987,13 +991,15 @@ struct PD_INFER_DECL AnalysisConfig { mutable std::string params_file_; mutable std::string calibration_file_path_; - // Mixed precision. + // Mixed precision related. + Precision mixed_precision_mode_{Precision::kFloat32}; std::unordered_set mixed_black_list_; // GPU related. bool use_gpu_{false}; int gpu_device_id_{0}; uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB. + bool enable_gpu_mixed_{false}; bool thread_local_stream_{false}; bool use_cudnn_{false}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 30cf4273e11ba..0a478a2d2c8ae 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -227,9 +227,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_elementwise_add_fuse_pass", // #endif // "transpose_flatten_concat_fuse_pass", // - "constant_folding_pass", + "constant_folding_pass", // // following pass should be located in the last, since it will // work on all fused ops. + "auto_mixed_precision_pass", // "runtime_context_cache_pass" }); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index cd97382785395..c8083e87dd8f0 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -115,7 +115,6 @@ class PD_INFER_DECL PaddlePassBuilder { /// \cond Protected std::vector analysis_passes_{ {"ir_graph_build_pass", - "ir_graph_clean_pass", "ir_analysis_pass", "ir_params_sync_among_devices_pass", "adjust_cudnn_workspace_size_pass", diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index ffbd593d304df..17d8bb35b29d0 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -294,15 +294,6 @@ class TensorRTEngine { nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } nvinfer1::IExecutionContext* context() { -#ifndef PADDLE_WITH_TESTING - PADDLE_ENFORCE_GT( - predictor_id_per_thread, - -1, - platform::errors::InvalidArgument( - "thread local var predictor_id_per_thread must be " - "initialized to >= 0, but now predictor_id_per_thread = %d", - predictor_id_per_thread)); -#endif std::unique_lock lock(mutex_); if (infer_context_.find(predictor_id_per_thread) == infer_context_.end()) { PADDLE_ENFORCE_NOT_NULL( @@ -329,15 +320,6 @@ class TensorRTEngine { int GetProfileIndex() { if (max_profile_num_ > 1) { -#ifndef PADDLE_WITH_TESTING - PADDLE_ENFORCE_GT( - predictor_id_per_thread, - -1, - platform::errors::InvalidArgument( - "thread local var predictor_id_per_thread must be " - "initialized to >= 0, but now predictor_id_per_thread = %d", - predictor_id_per_thread)); -#endif std::unique_lock lock(mutex_); return profile_index_[predictor_id_per_thread]; } else { @@ -356,15 +338,6 @@ class TensorRTEngine { infer_engine_, platform::errors::InvalidArgument( "You should build engine first and then set the context.")); -#ifndef PADDLE_WITH_TESTING - PADDLE_ENFORCE_GT( - predictor_id_per_thread, - -1, - platform::errors::InvalidArgument( - "thread local var predictor_id_per_thread must be " - "initialized to >= 0, but now predictor_id_per_thread = %d", - predictor_id_per_thread)); -#endif std::unique_lock lock(mutex_); infer_context_[predictor_id_per_thread].reset(nullptr); infer_context_.erase(predictor_id_per_thread); diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 77831167ddd5d..f8650ef366e15 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -416,6 +416,9 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz" if(WITH_GPU) inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} analyzer_ernie_tester.cc) + inference_analysis_api_test(gpu_ernie_half_test ${ERNIE_INSTALL_DIR} + gpu_ernie_half_test.cc) + set_tests_properties(gpu_ernie_half_test PROPERTIES TIMEOUT 60) endif() inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR} analyzer_ernie_int8_tester.cc) diff --git a/paddle/fluid/inference/tests/api/gpu_ernie_half_test.cc b/paddle/fluid/inference/tests/api/gpu_ernie_half_test.cc new file mode 100644 index 0000000000000..6354ee47a18f6 --- /dev/null +++ b/paddle/fluid/inference/tests/api/gpu_ernie_half_test.cc @@ -0,0 +1,294 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { + +using paddle::PaddleTensor; + +template +void GetValueFromStream(std::stringstream *ss, T *t) { + (*ss) >> (*t); +} + +template <> +void GetValueFromStream(std::stringstream *ss, std::string *t) { + *t = ss->str(); +} + +// Split string to vector +template +void Split(const std::string &line, char sep, std::vector *v) { + std::stringstream ss; + T t; + for (auto c : line) { + if (c != sep) { + ss << c; + } else { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } + } + + if (!ss.str().empty()) { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } +} + +// Parse tensor from string +template +bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) { + std::vector data; + Split(field, ':', &data); + if (data.size() < 2) return false; + + std::string shape_str = data[0]; + + std::vector shape; + Split(shape_str, ' ', &shape); + + std::string mat_str = data[1]; + + std::vector mat; + Split(mat_str, ' ', &mat); + + tensor->shape = shape; + auto size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) * + sizeof(T); + tensor->data.Resize(size); + std::copy(mat.begin(), mat.end(), static_cast(tensor->data.data())); + tensor->dtype = GetPaddleDType(); + + return true; +} + +// Parse input tensors from string +bool ParseLine(const std::string &line, + std::vector *tensors) { + std::vector fields; + Split(line, ';', &fields); + + tensors->clear(); + tensors->reserve(4); + + int i = 0; + auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_"; + for (; i < 3; i++) { + paddle::PaddleTensor temp; + ParseTensor(fields[i], &temp); + temp.name = input_name + std::to_string(i); + tensors->push_back(temp); + } + + // input_mask + paddle::PaddleTensor input_mask; + ParseTensor(fields[i], &input_mask); + input_mask.name = input_name + std::to_string(i); + tensors->push_back(input_mask); + + return true; +} + +bool LoadInputData(std::vector> *inputs, + int batch_size = 1) { + if (FLAGS_infer_data.empty()) { + LOG(ERROR) << "please set input data path"; + return false; + } + + std::ifstream fin(FLAGS_infer_data); + std::string line; + int sample = 0; + + // The unit-test dataset only have 10 samples, each sample have 5 feeds. + while (std::getline(fin, line)) { + std::vector feed_data; + ParseLine(line, &feed_data); + inputs->push_back(std::move(feed_data)); + sample++; + if (!FLAGS_test_all_data && sample == batch_size) break; + } + LOG(INFO) << "number of samples: " << sample; + return true; +} + +// Compare results +TEST(Ernie_gpu_fp16_no_ir, compare_results) { + AnalysisConfig config; + config.SetModel(FLAGS_infer_model); + config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf); + config.SwitchIrOptim(false); + + auto predictor = CreatePaddlePredictor(config); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + + auto output = outputs.front(); + size_t outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float *result = reinterpret_cast(output.data.data()); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2); + } + } +} + +// Compare results +TEST(Ernie_gpu_fp16_with_ir, compare_results) { + AnalysisConfig config; + config.SetModel(FLAGS_infer_model); + config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf); + config.SwitchIrOptim(true); + // The fc_fuse_pass has diff, which will be repaired later. + config.pass_builder()->DeletePass("fc_fuse_pass"); + // There is a problem with the model itself, which has nothing to do with + // constant_folding_pass. + config.pass_builder()->DeletePass("constant_folding_pass"); + + auto predictor = CreatePaddlePredictor(config); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + + auto output = outputs.front(); + size_t outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float *result = reinterpret_cast(output.data.data()); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2); + } + } +} + +// Compare results +TEST(Ernie_gpu_bf16_no_ir, compare_results) { + AnalysisConfig config; + config.SetModel(FLAGS_infer_model); + config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16); + config.SwitchIrOptim(false); + + auto predictor = CreatePaddlePredictor(config); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + + auto output = outputs.front(); + size_t outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float *result = reinterpret_cast(output.data.data()); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2); + } + } +} + +// Compare results +TEST(Ernie_gpu_bf16_with_ir, compare_results) { + AnalysisConfig config; + config.SetModel(FLAGS_infer_model); + config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16); + config.SwitchIrOptim(true); + // The fc_fuse_pass has diff, which will be repaired later. + config.pass_builder()->DeletePass("fc_fuse_pass"); + // There is a problem with the model itself, which has nothing to do with + // constant_folding_pass. + config.pass_builder()->DeletePass("constant_folding_pass"); + + auto predictor = CreatePaddlePredictor(config); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + + auto output = outputs.front(); + size_t outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float *result = reinterpret_cast(output.data.data()); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2); + } + } +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/paddle_infer_api_test.cc b/paddle/fluid/inference/tests/api/paddle_infer_api_test.cc index 8cff649b97092..9029cefc9a424 100644 --- a/paddle/fluid/inference/tests/api/paddle_infer_api_test.cc +++ b/paddle/fluid/inference/tests/api/paddle_infer_api_test.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,15 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include #include -#include -#include - #include "gflags/gflags.h" -#include "paddle/fluid/inference/tests/api/trt_test_helper.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" namespace paddle_infer { diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index a77659ba99d47..0a59caae2bbe8 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -262,10 +262,6 @@ if(WITH_PYTHON) list(APPEND OP_FUNCTION_GENERETOR_DEPS cncl_context) endif() - if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) - list(APPEND OP_FUNCTION_GENERETOR_DEPS ${PYTHON_LIBRARIES}) - endif() - add_executable(op_function_generator op_function_generator.cc) target_link_libraries(op_function_generator ${OP_FUNCTION_GENERETOR_DEPS}) add_executable(eager_legacy_op_function_generator @@ -605,13 +601,4 @@ if(WITH_PYTHON) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) target_link_libraries(${SHARD_LIB_NAME} ${os_dependency_modules}) add_dependencies(${SHARD_LIB_NAME} op_function_generator_cmd) - - if(APPLE) - string(REGEX REPLACE ".+/(.+)" "\\1" PYTHON_LIBRARY_NAME - ${PYTHON_LIBRARIES}) - # target_link_libraries(${SHARD_LIB_NAME} "-Wl,-rpath,${PYTHON_LIBRARY_NAME}") - else() - target_link_libraries(${SHARD_LIB_NAME} ${PYTHON_LIBRARIES}) - endif() - endif() diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 75cdf86aaf211..d7b28aab2301a 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -642,7 +642,8 @@ void BindAnalysisConfig(py::module *m) { .def("enable_use_gpu", &AnalysisConfig::EnableUseGpu, py::arg("memory_pool_init_size_mb"), - py::arg("device_id") = 0) + py::arg("device_id") = 0, + py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) .def("set_exec_stream", [](AnalysisConfig &self, phi::CUDAStream &stream) {