From 6f1a2ea6efc47b64a832d323642cacec5d2d800f Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Fri, 7 May 2021 12:27:33 -0700 Subject: [PATCH] Replaces all the use of the argument before it is replaced The bug will be manifested only when the model has non-quantized inputs. If the input can be quantized, the input should have only one user and the user must be the quantize op. PiperOrigin-RevId: 372606255 Change-Id: I946da53cc431d8f873238f9454983d8a9c9393b6 --- .../mlir/lite/tests/modify_io_nodes.mlir | 52 +++++++++++++++++-- .../mlir/lite/transforms/modify_io_nodes.cc | 4 ++ 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/modify_io_nodes.mlir b/tensorflow/compiler/mlir/lite/tests/modify_io_nodes.mlir index 32713012ad4ef6..144ab70baca6e4 100644 --- a/tensorflow/compiler/mlir/lite/tests/modify_io_nodes.mlir +++ b/tensorflow/compiler/mlir/lite/tests/modify_io_nodes.mlir @@ -2,7 +2,7 @@ // RUN: tf-opt %s -tfl-modify-io-nodes -tfl-test-io-types="int8,int8" | FileCheck --check-prefix=INT8 %s // RUN: tf-opt %s -tfl-modify-io-nodes -tfl-test-io-types="uint8,uint8" | FileCheck --check-prefix=UINT8 %s -func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> { +func @modified(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> { %cst = constant dense<[1, 401408]> : tensor<2xi32> %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> @@ -13,7 +13,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> { %6 = "tfl.dequantize"(%5) : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408xf32> return %6 : tensor<1x401408xf32> -// CHECK-LABEL: func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> { +// CHECK-LABEL: func @modified(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> { // CHECK-NEXT: %[[shape:.*]] = constant dense<[1, 401408]> : tensor<2xi32> // CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> // CHECK-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> @@ -24,7 +24,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> { // CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[softmax]]) : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408xf32> // CHECK-NEXT: return %[[dq]] : tensor<1x401408xf32> -// INT8-LABEL: @main(%arg0: tensor<1x224x224x3x!quant.uniform>) -> tensor<1x401408x!quant.uniform> { +// INT8-LABEL: @modified(%arg0: tensor<1x224x224x3x!quant.uniform>) -> tensor<1x401408x!quant.uniform> { // INT8-NEXT: %[[shape:.*]] = constant dense<[1, 401408]> : tensor<2xi32> // INT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> // INT8-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> @@ -33,7 +33,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> { // INT8-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> // INT8-NEXT: return %[[softmax]] : tensor<1x401408x!quant.uniform> -// UINT8-LABEL: func @main(%arg0: tensor<1x224x224x3x!quant.uniform>) -> tensor<1x401408x!quant.uniform> { +// UINT8-LABEL: func @modified(%arg0: tensor<1x224x224x3x!quant.uniform>) -> tensor<1x401408x!quant.uniform> { // UINT8-NEXT: %[[shape:.*]] = constant dense<[1, 401408]> : tensor<2xi32> // UINT8-NEXT: %[[q:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3x!quant.uniform> // UINT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> @@ -44,3 +44,47 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x401408xf32> { // UINT8-NEXT: %[[dq:.*]] = "tfl.quantize"(%[[softmax]]) {qtype = tensor<1x401408x!quant.uniform>} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> // UINT8-NEXT: return %[[dq]] : tensor<1x401408x!quant.uniform> } + +func @not_modified(%arg0: tensor, %arg1: tensor<1x224x224x3xf32>) -> (tensor<1x401408xf32>, tensor<1x224x224x3xf32>) { + %cst = constant dense<[1, 401408]> : tensor<2xi32> + %0 = "tfl.quantize"(%arg1) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> + %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> + %2 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> + %3 = "tfl.conv_2d"(%0, %1, %2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> + %4 = "tfl.reshape"(%3, %cst) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> + %5 = "tfl.softmax"(%4) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> + %6 = "tfl.dequantize"(%5) : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408xf32> + return %6, %arg1 : tensor<1x401408xf32>, tensor<1x224x224x3xf32> + +// CHECK-LABEL: func @not_modified(%arg0: tensor, %arg1: tensor<1x224x224x3xf32>) -> (tensor<1x401408xf32>, tensor<1x224x224x3xf32>) { +// CHECK-NEXT: %[[shape:.*]] = constant dense<[1, 401408]> : tensor<2xi32> +// CHECK-NEXT: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> +// CHECK-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> +// CHECK-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> +// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> +// CHECK-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> +// CHECK-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// CHECK-NEXT: %[[dq:.*]] = "tfl.dequantize"(%[[softmax]]) : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408xf32> +// CHECK-NEXT: return %[[dq]], %arg1 : tensor<1x401408xf32>, tensor<1x224x224x3xf32> + +// INT8-LABEL: @not_modified(%arg0: tensor, %arg1: tensor<1x224x224x3xf32>) -> (tensor<1x401408x!quant.uniform>, tensor<1x224x224x3xf32>) { +// INT8-NEXT: %[[shape:.*]] = constant dense<[1, 401408]> : tensor<2xi32> +// INT8-NEXT: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> +// INT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> +// INT8-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> +// INT8-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> +// INT8-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> +// INT8-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// INT8-NEXT: return %[[softmax]], %arg1 : tensor<1x401408x!quant.uniform>, tensor<1x224x224x3xf32> + +// UINT8-LABEL: func @not_modified(%arg0: tensor, %arg1: tensor<1x224x224x3xf32>) -> (tensor<1x401408x!quant.uniform>, tensor<1x224x224x3xf32>) { +// UINT8-NEXT: %[[shape:.*]] = constant dense<[1, 401408]> : tensor<2xi32> +// UINT8-NEXT: %[[q:.*]] = "tfl.quantize"(%arg1) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> +// UINT8-NEXT: %[[cst1:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>> +// UINT8-NEXT: %[[cst2:.*]] = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> +// UINT8-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[q]], %[[cst1]], %[[cst2]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> +// UINT8-NEXT: %[[reshape:.*]] = "tfl.reshape"(%[[conv]], %[[shape]]) : (tensor<1x112x112x32x!quant.uniform>, tensor<2xi32>) -> tensor<1x401408x!quant.uniform> +// UINT8-NEXT: %[[softmax:.*]] = "tfl.softmax"(%[[reshape]]) {beta = 1.000000e+00 : f32} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// UINT8-NEXT: %[[dq:.*]] = "tfl.quantize"(%[[softmax]]) {qtype = tensor<1x401408x!quant.uniform>} : (tensor<1x401408x!quant.uniform>) -> tensor<1x401408x!quant.uniform> +// UINT8-NEXT: return %[[dq]], %arg1 : tensor<1x401408x!quant.uniform>, tensor<1x224x224x3xf32> +} diff --git a/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc b/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc index 53ac0b051e1490..bcfca0690ec1af 100644 --- a/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc +++ b/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc @@ -135,7 +135,11 @@ LogicalResult ModifyIONodesPass::ModifyInputNodes( quantize_op.erase(); } } else { + // `arg` has multiple uses or the user isn't a quantiz op (so we couldn't + // rewrite it to a different type. Make a copy of the `arg` and replace + // its use. new_arg = block.addArgument(arg_type); + arg.replaceAllUsesWith(new_arg); } block.eraseArgument(0); }