Skip to content

Commit

Permalink
Disable the rewriting of Sqrt into Exp and Log (#2088)
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <tung@jp.ibm.com>
Co-authored-by: Charles Volzka <42243335+cjvolzka@users.noreply.github.com>
  • Loading branch information
tungld and cjvolzka committed Mar 24, 2023
1 parent df58b4d commit 4e1c970
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ void ZHighStickOp::getCanonicalizationPatterns(
results.insert<StickUnstickSameLayoutRemovalPattern>(context);
results.insert<StickUnstickDiffLayoutRemovalPattern>(context);
results.insert<ReplaceONNXLeakyReluPattern>(context);
results.insert<ReplaceONNXSqrtPattern>(context);
results.insert<ReplaceONNXReciprocalSqrtPattern>(context);
results.insert<ReshapeTransposeReshape2DTo3DSPattern>(context);
results.insert<ReshapeTransposeReshape3DSTo2DPattern>(context);
Expand Down
11 changes: 0 additions & 11 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,6 @@ def ReplaceONNXLeakyReluPattern: Pat<
(SameLayout $X, $stickout)]
>;

// Since zDNN does not support Sqrt(X), we calculate it by using zDNN-supported
// operations, i.e. Exp and Log.
// Formulas: `sqrt(X) = exp(log(x)/2) = exp(0.5 * log(x))`
def ReplaceONNXSqrtPattern: Pat<
(ZHighStickOp:$stick (ONNXSqrtOp (ZHighUnstickOp $X)), $layout),
(ZHighExpOp (ZHighMulOp (ZHighLogOp $X, (returnType $X)),
(ZHighStickOp (GetConstantOfType<"0.5"> $X), $layout),
(returnType $X))),
[(IsStaticShapeTensor $X), (SameLayout $X, $stick)]
>;

// Calulation of `1/sqrt(X)` or reciprocal square root is often found in
// deep learning models, but zDNN does not support it. Thus, we rewrite it into
// zDNN-supported operations.
Expand Down
35 changes: 0 additions & 35 deletions test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -130,41 +130,6 @@ func.func @donot_replace_leakyrelu(%arg0 : tensor<1x104x104x128xf32, #zhigh.layo

// -----

func.func @replace_sqrt(%arg0 : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) {
%0 = "zhigh.Unstick"(%arg0) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32>
%1 = "onnx.Sqrt"(%0) : (tensor<4x256x1xf32>) -> tensor<4x256x1xf32>
%2 = "zhigh.Stick"(%1) {layout = "3D"} : (tensor<4x256x1xf32>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
return %2 : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>

// CHECK-LABEL: func.func @replace_sqrt
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>> {
// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Log"([[PARAM_0_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<5.000000e-01> : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
// CHECK: [[VAR_2_:%.+]] = "zhigh.Stick"([[VAR_1_]]) {layout = "3D"} : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
// CHECK: [[VAR_3_:%.+]] = "zhigh.Mul"([[VAR_0_]], [[VAR_2_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>, tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
// CHECK: [[VAR_4_:%.+]] = "zhigh.Exp"([[VAR_3_]]) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
// CHECK: return [[VAR_4_]] : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
// CHECK: }
}

// -----

// Do not replace square root because of unknown dimension.
// In this case, there is no static shape to create a constant of 2.
func.func @donot_replace_sqrt(%arg0 : tensor<?x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> (tensor<?x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) {
%0 = "zhigh.Unstick"(%arg0) : (tensor<?x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<?x256x1xf32>
%1 = "onnx.Sqrt"(%0) : (tensor<?x256x1xf32>) -> tensor<?x256x1xf32>
%2 = "zhigh.Stick"(%1) {layout = "3D"} : (tensor<?x256x1xf32>) -> tensor<?x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>
return %2 : tensor<?x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>

// CHECK-LABEL: func.func @donot_replace_sqrt
// CHECK: zhigh.Unstick
// CHECK: onnx.Sqrt
// CHECK: zhigh.Stick
}

// -----

func.func @replace_reciprocal_sqrt(%arg0 : tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) {
%0 = "zhigh.Unstick"(%arg0) : (tensor<4x256x1xf32, #zhigh.layout<{dataLayout = "3D"}>>) -> tensor<4x256x1xf32>
%1 = "onnx.Sqrt"(%0) : (tensor<4x256x1xf32>) -> tensor<4x256x1xf32>
Expand Down

0 comments on commit 4e1c970

Please sign in to comment.