Skip to content

Commit

Permalink
Add int8 support in fused_multi_transformer_pass and fuse_multi_trans…
Browse files Browse the repository at this point in the history
…former_layer_pass (#48209)

* delete unnecessary shape and slice op

Co-authored-by: Your Name <you@example.com>
  • Loading branch information
RichardWooSJTU and Your Name committed Nov 30, 2022
1 parent 9ff99e9 commit 1248671
Show file tree
Hide file tree
Showing 26 changed files with 2,115 additions and 295 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Expand Up @@ -96,6 +96,8 @@ pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_weight_dequant_linear_op_encoder_pass inference)
pass_library(delete_weight_dequant_linear_op_decoder_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference)
pass_library(delete_c_identity_op_pass inference)
Expand Down
17 changes: 15 additions & 2 deletions paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
Expand Up @@ -121,14 +121,27 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
true,
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0];

float input_scale;
if (input_scale_tensor.dtype() == paddle::experimental::DataType::FLOAT32) {
const float* input_scale_data = input_scale_tensor.data<float>();
input_scale = input_scale_data[0];
} else if (input_scale_tensor.dtype() ==
paddle::experimental::DataType::FLOAT16) {
const phi::dtype::float16* input_scale_data =
input_scale_tensor.data<phi::dtype::float16>();
input_scale = static_cast<float>(input_scale_data[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented("%d is not supported.",
input_scale_tensor.dtype()));
}

int nums_any_ops = dequantize_linear_op_out->outputs.size();
for (int i = 0; i < nums_any_ops; ++i) {
auto* any_op_desc = dequantize_linear_op_out->outputs[i]->Op();
any_op_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(),
input_scale);

// link x to any_op2
any_op_desc->RenameInput(dequantize_linear_op_out->Var()->Name(),
quantize_linear_op_x->Var()->Name());
Expand Down

0 comments on commit 1248671

Please sign in to comment.