Skip to content

Commit

Permalink
Add double buffer removal pass
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10638 from jaro-sevcik:host-allocation-custom-call 0f78007475921031549c8555ad56ed8903efe6cb
PiperOrigin-RevId: 617588813
  • Loading branch information
farzinhoushmand authored and tensorflower-gardener committed Apr 25, 2024
1 parent ea52231 commit b3eabf3
Show file tree
Hide file tree
Showing 27 changed files with 1,408 additions and 134 deletions.
11 changes: 11 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/optimize.mlir
Expand Up @@ -2652,6 +2652,17 @@ func.func @FuseAddWithFullyConnectedWithQuantizedWeight(%arg: tensor<2x512xf32>)
// CHECK: tfl.add
}

// CHECK-LABEL: @FuseBatchMatMulAndTransposeWithQuantizedWeight
func.func @FuseBatchMatMulAndTransposeWithQuantizedWeight(%arg: tensor<1x2xf32>) -> tensor<1x3xf32> {
%cst_3 = arith.constant dense<[1, 0]> : tensor<2xi32>
%79 = "tfl.pseudo_qconst"() {qtype = tensor<3x2x!quant.uniform<i8<-127:127>:f32:0, {2.378620e-03,2.848260e-03,2.545190e-03}>>, value = dense<10> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform<i8<-127:127>:f32:0, {2.378620e-03,2.848260e-03,2.545190e-03}>>
%80 = "tfl.transpose"(%79, %cst_3) : (tensor<3x2x!quant.uniform<i8<-127:127>:f32:0, {2.378620e-03,2.848260e-03,2.545190e-03}>>, tensor<2xi32>) -> tensor<2x3x!quant.uniform<i8<-127:127>:f32:1, {2.378620e-03,2.848260e-03,2.545190e-03}>>
%81 = "tfl.batch_matmul"(%arg, %80) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform<i8<-127:127>:f32:1, {2.378620e-03,2.848260e-03,2.545190e-03}>>) -> tensor<1x3xf32>
func.return %81 : tensor<1x3xf32>

// CHECK: tfl.fully_connected
}

// CHECK-LABEL: @FuseAddWithFullyConnectedNoBias
// Note: Currently not fused.
func.func @FuseAddWithFullyConnectedNoBias(%arg: tensor<2x512xf32>) -> tensor<2x1024xf32> {
Expand Down
3 changes: 0 additions & 3 deletions tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
Expand Up @@ -73,9 +73,6 @@ def CreateTFCastToInt32Op : NativeCodeCall<
def CreateInt32ConstOrCast : NativeCodeCall<
"CreateInt32ConstOrCast($0, $_loc, $_builder)">;

def CreateNoneValue : NativeCodeCall<
"$_builder.create<TFL::NoValueOp>($0.getLoc(), $_builder.getUnitAttr())">;

// Creates an int32 constant op from an integer attribute $0.
def CreateInt32ConstOpFromIntAttr
: NativeCodeCall<"$_builder.create<TF::ConstOp>($_loc, DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getI32Type()), {static_cast<int32_t>($0.cast<IntegerAttr>().getInt())}))">;
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
Expand Up @@ -1562,6 +1562,20 @@ def FuseTransposeAfterBatchMatmul : Pat<
),
[(AreLastTwoDimsTransposed $perm_value)]>;

// Fuse redundant RHS TFL_TransposeOp into TFL_BatchMatMulOp if rhs is constant
// tensor of rank-2.
def FuseTransposeIntoBatchMatMulRHS: Pat<
(TFL_BatchMatMulOp $lhs,
(TFL_TransposeOp (TFL_QConstOp:$input $_, $_), (Arith_ConstantOp:$perm_value $p0)),
$adj_x, $adj_y, $asymmetric_quantize_inputs),
(TFL_FullyConnectedOp
$lhs,
$input, (CreateNoneValue $lhs), TFL_AF_None, TFL_FCWO_Default,
ConstBoolAttrTrue, $asymmetric_quantize_inputs),
[(HasRank<2> $input),
(AreLastTwoDimsTransposed $perm_value),
(IsBoolAttrEqual<"false"> $adj_y)]>;

// Replace conv-->transpose-->add with conv-->add-->transpose
// The bias needs only reshape (i.e. ReshapeNCHWBiasToNHWC) and not transpose
// because the bias's shape simply changes from NxCx1x1 to Nx1x1xC.
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/mlir/lite/utils/utils.td
Expand Up @@ -19,6 +19,9 @@ include "mlir/IR/OpBase.td"
include "mlir/Dialect/Func/IR/FuncOps.td"
include "mlir/IR/PatternBase.td"

def CreateNoneValue : NativeCodeCall<
"$_builder.create<TFL::NoValueOp>($0.getLoc(), $_builder.getUnitAttr())">;

// Returns shape of a ranked tensor.
// if called without a ranked tensor it will fail.
def GetShape: NativeCodeCall<"GetShape($0)">;
Expand Down
75 changes: 37 additions & 38 deletions third_party/stablehlo/temporary.patch
Expand Up @@ -446,7 +446,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/CMakeLists.txt b/stablehlo
diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp
--- stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp
+++ stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp
@@ -0,0 +1,506 @@
@@ -0,0 +1,505 @@
+/* Copyright 2023 The StableHLO Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -521,7 +521,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ // reduce_window_i1
+ SmallVector<ShapedType> inputTypes;
+ for (auto [index, input] : llvm::enumerate(inputs)) {
+ auto inputType = input.getType().dyn_cast<ShapedType>();
+ auto inputType = dyn_cast<ShapedType>(input.getType());
+ inputTypes.push_back(inputType);
+ if (!inputType)
+ return op_.emitError()
Expand All @@ -531,7 +531,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ // reduce_window_i2
+ SmallVector<ShapedType> initValueTypes;
+ for (auto [index, initValue] : llvm::enumerate(initValues)) {
+ auto initValueType = initValue.getType().dyn_cast<ShapedType>();
+ auto initValueType = dyn_cast<ShapedType>(initValue.getType());
+ initValueTypes.push_back(initValueType);
+ if (!initValueType || !initValueType.hasRank() ||
+ initValueType.getRank() != 0)
Expand All @@ -543,7 +543,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ // reduce_window_i3...reduce_window_i7
+ auto checkRank = [&](StringRef name, int64_t index, Value dynamicAttr,
+ int64_t expectedRank) -> LogicalResult {
+ auto type = dynamicAttr.getType().dyn_cast<ShapedType>();
+ auto type = dyn_cast<ShapedType>(dynamicAttr.getType());
+ if (!type || !type.hasRank() || type.getRank() != expectedRank ||
+ !type.getElementType().isIntOrIndex()) {
+ if (index < 0) index += op_->getNumOperands();
Expand All @@ -562,7 +562,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ return failure();
+
+ // reduce_window_i7
+ auto paddingType = getPadding().getType().dyn_cast<ShapedType>();
+ auto paddingType = dyn_cast<ShapedType>(getPadding().getType());
+ if (!paddingType || !paddingType.hasRank() || paddingType.getRank() != 2 ||
+ paddingType.getDimSize(1) != 2 ||
+ !paddingType.getElementType().isIntOrIndex())
Expand Down Expand Up @@ -598,7 +598,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ // verify them in that case, that seems like too much at this point.
+ auto checkShape = [&](StringRef name, int64_t index, Value dynamicAttr,
+ ArrayRef<int64_t> expectedShape) -> LogicalResult {
+ auto type = dynamicAttr.getType().cast<ShapedType>();
+ auto type = cast<ShapedType>(dynamicAttr.getType());
+ if (type.getShape() != expectedShape) {
+ if (index < 0) index += op_->getNumOperands();
+ return op_.emitError()
Expand All @@ -622,7 +622,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ // reduce_window_c13
+ if (op_.getCalledComputations().size() != 1)
+ return op_.emitError() << "expects called_computations to have 1 element";
+ auto bodyAttr = op_.getCalledComputations()[0].cast<FlatSymbolRefAttr>();
+ auto bodyAttr = cast<FlatSymbolRefAttr>(op_.getCalledComputations()[0]);
+ auto bodyFunc =
+ op_->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(bodyAttr);
+ if (!bodyFunc)
Expand All @@ -644,7 +644,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ SmallVector<ShapedType> resultTypes;
+ std::optional<ArrayRef<int64_t>> resultShape;
+ for (auto result : results) {
+ auto resultType = result.getType().dyn_cast<ShapedType>();
+ auto resultType = dyn_cast<ShapedType>(result.getType());
+ resultTypes.push_back(resultType);
+ if (!resultType) return op_.emitError() << "expects results to be tensors";
+
Expand Down Expand Up @@ -683,32 +683,32 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+}
+
+TypedValue<ShapedType> DynamicReduceWindowOpAdaptor::getWindowDimensions() {
+ return op_.getInputs()[op_.getInputs().size() - 5]
+ .cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(
+ op_.getInputs()[op_.getInputs().size() - 5]);
+}
+
+TypedValue<ShapedType> DynamicReduceWindowOpAdaptor::getWindowStrides() {
+ return op_.getInputs()[op_.getInputs().size() - 4]
+ .cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(
+ op_.getInputs()[op_.getInputs().size() - 4]);
+}
+
+TypedValue<ShapedType> DynamicReduceWindowOpAdaptor::getBaseDilations() {
+ return op_.getInputs()[op_.getInputs().size() - 3]
+ .cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(
+ op_.getInputs()[op_.getInputs().size() - 3]);
+}
+
+TypedValue<ShapedType> DynamicReduceWindowOpAdaptor::getWindowDilations() {
+ return op_.getInputs()[op_.getInputs().size() - 2]
+ .cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(
+ op_.getInputs()[op_.getInputs().size() - 2]);
+}
+
+TypedValue<ShapedType> DynamicReduceWindowOpAdaptor::getPadding() {
+ return op_.getInputs()[op_.getInputs().size() - 1]
+ .cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(
+ op_.getInputs()[op_.getInputs().size() - 1]);
+}
+
+Region& DynamicReduceWindowOpAdaptor::getBody() {
+ auto bodyAttr = op_.getCalledComputations()[0].cast<FlatSymbolRefAttr>();
+ auto bodyAttr = cast<FlatSymbolRefAttr>(op_.getCalledComputations()[0]);
+ auto bodyFunc =
+ op_->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(bodyAttr);
+ return bodyFunc.getBody();
Expand Down Expand Up @@ -758,20 +758,20 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ auto output = op_.getResults()[1];
+
+ // dynamic_rng_bit_generator_i1
+ if (!rngAlgorithmAttr.isa<RngAlgorithmAttr>())
+ if (!isa<RngAlgorithmAttr>(rngAlgorithmAttr))
+ return op_.emitError()
+ << "expects a #stablehlo<rng_algorithm ...> rng_algorithm";
+
+ // dynamic_rng_bit_generator_i2
+ // TODO(#643): Clarify supported types for RngBitGeneratorOp.
+ auto initialStateType = initialState.getType().dyn_cast<ShapedType>();
+ auto initialStateType = dyn_cast<ShapedType>(initialState.getType());
+ if (!initialStateType || !initialStateType.getElementType().isIntOrFloat())
+ return op_.emitError()
+ << "expects initial_state (operand #0) "
+ << "to be a tensor of integer or floating-point type";
+
+ // dynamic_rng_bit_generator_i3
+ auto outputShapeType = outputShape.getType().dyn_cast<ShapedType>();
+ auto outputShapeType = dyn_cast<ShapedType>(outputShape.getType());
+ if (!outputShapeType || !outputShapeType.hasRank() ||
+ outputShapeType.getRank() != 1 ||
+ !outputShapeType.getElementType().isIntOrIndex())
Expand All @@ -781,14 +781,14 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+
+ // dynamic_rng_bit_generator_o1
+ // TODO(#643): Clarify supported types for RngBitGeneratorOp.
+ auto outputStateType = outputState.getType().dyn_cast<ShapedType>();
+ auto outputStateType = dyn_cast<ShapedType>(outputState.getType());
+ if (!outputStateType || !outputStateType.getElementType().isIntOrFloat())
+ return op_.emitError()
+ << "expects output_state (result #0) "
+ << "to be a tensor of integer or floating-point type";
+
+ // dynamic_rng_bit_generator_o2
+ auto outputType = output.getType().dyn_cast<ShapedType>();
+ auto outputType = dyn_cast<ShapedType>(output.getType());
+ if (!outputType || !outputType.getElementType().isIntOrFloat())
+ return op_.emitError()
+ << "expects output (result #1) "
Expand All @@ -812,25 +812,24 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+}
+
+RngAlgorithm DynamicRngBitGeneratorOpAdaptor::getRngAlgorithm() {
+ return op_->getDiscardableAttr("rng_algorithm")
+ .cast<RngAlgorithmAttr>()
+ return cast<RngAlgorithmAttr>(op_->getDiscardableAttr("rng_algorithm"))
+ .getValue();
+}
+
+TypedValue<ShapedType> DynamicRngBitGeneratorOpAdaptor::getInitialState() {
+ return op_.getInputs()[0].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getInputs()[0]);
+}
+
+TypedValue<ShapedType> DynamicRngBitGeneratorOpAdaptor::getOutputShape() {
+ return op_.getInputs()[1].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getInputs()[1]);
+}
+
+TypedValue<ShapedType> DynamicRngBitGeneratorOpAdaptor::getOutputState() {
+ return op_.getResults()[0].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getResults()[0]);
+}
+
+TypedValue<ShapedType> DynamicRngBitGeneratorOpAdaptor::getOutput() {
+ return op_.getResults()[1].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getResults()[1]);
+}
+
+std::optional<DynamicRngBitGeneratorOpAdaptor> getDynamicRngBitGeneratorOp(
Expand Down Expand Up @@ -864,7 +863,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ auto indices = op_.getResults()[1];
+
+ // dynamic_top_k_i1
+ auto operandType = operand.getType().dyn_cast<ShapedType>();
+ auto operandType = dyn_cast<ShapedType>(operand.getType());
+ if (!operandType || !operandType.hasRank() || operandType.getRank() < 1 ||
+ !operandType.getElementType().isIntOrFloat())
+ return op_.emitError()
Expand All @@ -873,15 +872,15 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ << "of rank at least 1";
+
+ // dynamic_top_k_i2
+ auto kType = k.getType().dyn_cast<ShapedType>();
+ auto kType = dyn_cast<ShapedType>(k.getType());
+ if (!kType || !kType.hasRank() || kType.getRank() != 0 ||
+ !kType.getElementType().isIntOrIndex())
+ return op_.emitError()
+ << "expects k (operand #1) "
+ << "to be a 0-dimensional tensor of integer or index type";
+
+ // dynamic_top_k_o1
+ auto valuesType = values.getType().dyn_cast<ShapedType>();
+ auto valuesType = dyn_cast<ShapedType>(values.getType());
+ if (!valuesType || !valuesType.hasRank() || valuesType.getRank() < 1 ||
+ !valuesType.getElementType().isIntOrFloat())
+ return op_.emitError()
Expand All @@ -890,7 +889,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ << "of rank at least 1";
+
+ // dynamic_top_k_o2
+ auto indicesType = indices.getType().dyn_cast<ShapedType>();
+ auto indicesType = dyn_cast<ShapedType>(indices.getType());
+ if (!indicesType || !indicesType.hasRank() || indicesType.getRank() < 1 ||
+ !indicesType.getElementType().isSignlessInteger(32))
+ return op_.emitError() << "expects indices (result #1) "
Expand Down Expand Up @@ -930,19 +929,19 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+}
+
+TypedValue<ShapedType> DynamicTopKOpAdaptor::getOperand() {
+ return op_.getInputs()[0].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getInputs()[0]);
+}
+
+TypedValue<ShapedType> DynamicTopKOpAdaptor::getK() {
+ return op_.getInputs()[1].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getInputs()[1]);
+}
+
+TypedValue<ShapedType> DynamicTopKOpAdaptor::getValues() {
+ return op_.getResults()[0].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getResults()[0]);
+}
+
+TypedValue<ShapedType> DynamicTopKOpAdaptor::getIndices() {
+ return op_.getResults()[1].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getResults()[1]);
+}
+
+std::optional<DynamicTopKOpAdaptor> getDynamicTopKOp(CustomCallOp op) {
Expand Down

0 comments on commit b3eabf3

Please sign in to comment.