diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index bc034301989b0..44b41a89700de 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -437,6 +437,20 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { vars_should_not_low_precision.insert(in_var_node->Var()->Name()); } } + + // when op_1 only support cpu kernel. if op_2's intput var is op_1's + // output var, then op_2 should not run half. + if (GetOpOriginalType(op_type) != "feed" && + !GpuKernelSupportPrecision(GetOpOriginalType(op_type), + phi::DataType::FLOAT32)) { + for (auto* out_var_node : op_node->outputs) { + CHECK_EQ(out_var_node->IsVar(), true); + if (out_var_node->Var()->Persistable()) continue; + if (!VarNodeHasDtype(out_var_node)) continue; + + vars_should_not_low_precision.insert(out_var_node->Var()->Name()); + } + } } } }; @@ -449,6 +463,25 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const { for (auto* op_node : nodes) { if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue; + for (auto* in_var_node : op_node->inputs) { + CHECK_EQ(in_var_node->IsVar(), true); + if (!VarNodeHasDtype(in_var_node)) continue; + + auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()]; + if (real_in_var_node->Var()->Persistable()) continue; + + if (vars_should_not_low_precision.count( + real_in_var_node->Var()->Name())) { + op_run_low_precision_.erase(op_node->Op()->Type()); + precision_updated = true; + VLOG(4) << op_node->Op()->Type() + << " should not run at low precision."; + break; + } + } + + if (op_run_low_precision_.count(op_node->Op()->Type()) == 0) continue; + for (auto* out_var_node : op_node->outputs) { CHECK_EQ(out_var_node->IsVar(), true); if (!VarNodeHasDtype(out_var_node)) continue;