Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Implement constant folder for tensor.pad #92691

Merged
merged 1 commit into from
Jun 6, 2024

Conversation

sabauma
Copy link
Contributor

@sabauma sabauma commented May 19, 2024

Extend the folding ability of the RewriteAsConstant patterns to include tensor.pad operations on constants. The new pattern with constant fold tensor.pad operations which operate on tensor constants and have statically resolvable padding sizes/values.

%init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
%pad_value = arith.constant 0 : i32

%0 = tensor.pad %init low[1, 1] high[1, 1] {
  ^bb0(%arg1: index, %arg2: index):
    tensor.yield %pad_value : i32
} : tensor<2x2xi32> to tensor<4x4xi32>

becomes

%cst = arith.constant dense<[[0, 0, 0, 0],
                             [0, 6, 7, 0],
                             [0, 8, 9, 0],
                             [0, 0, 0, 0]]> : tensor<4x4xi32>

@llvmbot
Copy link
Collaborator

llvmbot commented May 19, 2024

@llvm/pr-subscribers-mlir-tensor

Author: Spenser Bauman (sabauma)

Changes

Extend the folding ability of the RewriteAsConstant patterns to include tensor.pad operations on constants. The new pattern with constant fold tensor.pad operations which operate on tensor constants and have statically resolvable padding sizes/values.

%init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
%pad_value = arith.constant 0 : i32

%0 = tensor.pad %init low[1, 1] high[1, 1] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : i32
} : tensor<2x2xi32> to tensor<4x4xi32>

becomes

%cst = arith.constant dense<[[0, 0, 0, 0],
[0, 6, 7, 0],
[0, 8, 9, 0],
[0, 0, 0, 0]]> : tensor<4x4xi32>


Full diff: https://github.com/llvm/llvm-project/pull/92691.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp (+150-1)
  • (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+1-1)
  • (modified) mlir/test/Dialect/Tensor/rewrite-as-constant.mlir (+83)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
index 11e1de543ac91..7928b206eded6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -8,9 +8,12 @@
 //
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+
 using namespace mlir;
 using namespace mlir::tensor;
 
@@ -45,9 +48,155 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
   }
 };
 
+/// Transform a linear index from one indexing space to another given:
+///
+/// - the shape of the source indexing space,
+/// - the strides of the target indexing space,
+/// - a linear index into the source indexing space.
+///
+/// This function is logically a sequence of linearize/delinearize over
+/// different bases but avoids allocating intermediate SmallVectors.
+int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
+                            ArrayRef<int64_t> outputStrides,
+                            int64_t srcLinearIndex) {
+  assert(inputShape.size() == outputStrides.size());
+
+  int64_t dstLinearIndex = 0;
+
+  for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
+    // Compute the index into the current dimension of the source tensor.
+    // `quotient` is the remaining linear index after accounting for the
+    // current dimension.
+    //
+    // `remainder` is the index into the source tensor for the current
+    // dimension.
+    auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
+
+    srcLinearIndex = quotient;
+
+    // Add the contribution of the current dimension to the output using the
+    // permutation map.
+    dstLinearIndex += outputStrides[dim] * remainder;
+  }
+
+  return dstLinearIndex;
+}
+
+template <typename ElemType, typename AttrType>
+Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
+                        DenseElementsAttr input, AttrType padValue,
+                        ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
+  auto inputValues = input.tryGetValues<ElemType>();
+  if (failed(inputValues))
+    return nullptr;
+
+  auto oldShape = input.getType().getShape();
+
+  // Compute the output shape of the new value.
+  auto newShape =
+      llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
+                          [](std::tuple<int64_t, int64_t, int64_t> pack) {
+                            auto [old, low, high] = pack;
+                            return old + low + high;
+                          });
+
+  int64_t outputSize = computeProduct(newShape);
+
+  // Fully initialize the vector with the padding value.
+  // The non-padded area will then be copied.
+  SmallVector<ElemType> values(outputSize, padValue.getValue());
+
+  // Strides for input and output are used to transform between the indexing
+  // space of the input and output tensors.
+  SmallVector<int64_t> outputStrides = computeStrides(newShape);
+
+  // The contribution of the low padding to the offset in the output tensor.
+  // This is the starting position of the source tensor within the padding
+  // tensor.
+  int64_t startingOffset = linearize(padLow, outputStrides);
+
+  // Copy values from the input tensor to the corresponding sub-region
+  // of the output tensor.
+  for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
+    auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
+    values[outputIndex + startingOffset] = inputValue;
+  }
+
+  // Create an attribute for the folded value.
+  auto newType = input.getType().clone(newShape);
+  auto newAttr = DenseElementsAttr::get(newType, values);
+
+  Operation *constantOp =
+      rewriter.getContext()
+          ->getLoadedDialect<TensorDialect>()
+          ->materializeConstant(rewriter, newAttr, newType, loc);
+
+  return constantOp ? constantOp->getResult(0) : nullptr;
+}
+
+struct PadOpToConstant final : public OpRewritePattern<PadOp> {
+  using OpRewritePattern<PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadOp padTensorOp,
+                                PatternRewriter &rewriter) const override {
+    if (padTensorOp.getNofold())
+      return rewriter.notifyMatchFailure(
+          padTensorOp, "refusing to fold nofold pad operation");
+
+    TypedValue<RankedTensorType> input = padTensorOp.getSource();
+    RankedTensorType resultType = padTensorOp.getResult().getType();
+
+    DenseElementsAttr inputAttr = nullptr;
+    if (!matchPattern(input, m_Constant(&inputAttr)))
+      return failure();
+
+    Value paddingValue = padTensorOp.getConstantPaddingValue();
+
+    // Extract the constant value used for padding or bail out.
+    Attribute paddingAttr = nullptr;
+    if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
+      return rewriter.notifyMatchFailure(padTensorOp,
+                                         "unable to get constant value");
+
+    // Try to extract the constant values of the low and high padding.
+    auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
+    auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
+
+    // If the padding cannot be extracted, bail out.
+    if (!lowPad || !highPad)
+      return rewriter.notifyMatchFailure(padTensorOp,
+                                         "unable to extract constant padding");
+
+    Location loc = padTensorOp.getLoc();
+
+    // Try constant folding the supported cases of integer and float values.
+    Value newOp =
+        llvm::TypeSwitch<Attribute, Value>(paddingAttr)
+            .Case([&](FloatAttr floatAttr) {
+              return constantFoldPadOp<llvm::APFloat>(
+                  rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
+            })
+            .Case([&](IntegerAttr integerAttr) {
+              return constantFoldPadOp<llvm::APInt>(
+                  rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
+            })
+            .Default(Value());
+
+    if (!newOp)
+      return rewriter.notifyMatchFailure(padTensorOp,
+                                         "tensor type not supported");
+
+    if (newOp.getType() != resultType)
+      newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp);
+
+    rewriter.replaceOp(padTensorOp, newOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tensor::populateRewriteAsConstantPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<GenerateToConstant>(patterns.getContext());
+  patterns.add<GenerateToConstant, PadOpToConstant>(patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 4c960659d80cb..aba225be720c3 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -92,7 +92,7 @@ int64_t mlir::computeProduct(ArrayRef<int64_t> basis) {
   assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
          "basis must be nonnegative");
   if (basis.empty())
-    return 0;
+    return 1;
   return std::accumulate(basis.begin(), basis.end(), 1,
                          std::multiplies<int64_t>());
 }
diff --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
index 1a1cf9e407d80..406065422cd7c 100644
--- a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
+++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
@@ -21,3 +21,86 @@ func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
   } : tensor<2x3x5xf32>
   return %0 : tensor<2x3x5xf32>
 }
+
+// CHECK-LABEL: func @pad_of_ints(
+//       CHECK: %[[cst:.*]] = arith.constant dense<[
+//       CHECK:     [0, 0, 0, 0],
+//       CHECK:     [0, 6, 7, 0],
+//       CHECK:     [0, 8, 9, 0],
+//       CHECK:     [0, 0, 0, 0]
+//       CHECK:     ]> : tensor<4x4xi32>
+//       CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
+//       CHECK: return %[[cast]]
+func.func @pad_of_ints() -> tensor<?x?xi32> {
+  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
+  %pad_value = arith.constant 0 : i32
+
+  %c1 = arith.constant 1 : index
+
+  %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : i32
+  } : tensor<2x2xi32> to tensor<?x?xi32>
+
+  return %0 : tensor<?x?xi32>
+}
+
+// CHECK-LABEL: func @pad_of_floats(
+//       CHECK: %[[cst:.*]] = arith.constant dense<[
+//       CHECK:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
+//       CHECK:     [0.000000e+00, 6.000000e+00, 7.000000e+00, 0.000000e+00],
+//       CHECK:     [0.000000e+00, 8.000000e+00, 9.000000e+00, 0.000000e+00],
+//       CHECK:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
+//       CHECK:     ]> : tensor<4x4xf32>
+//       CHECK: return %[[cst]]
+
+func.func @pad_of_floats() -> tensor<4x4xf32> {
+  %init = arith.constant dense<[[6.0, 7.0], [8.0, 9.0]]> : tensor<2x2xf32>
+  %pad_value = arith.constant 0.0 : f32
+
+  %0 = tensor.pad %init low[1, 1] high[1, 1] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+  } : tensor<2x2xf32> to tensor<4x4xf32>
+
+  return %0 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func @pad_of_ints_no_low_dims(
+//       CHECK: %[[cst:.*]] = arith.constant dense<[
+//       CHECK:     [6, 7, 0],
+//       CHECK:     [8, 9, 0],
+//       CHECK:     [0, 0, 0]
+//       CHECK:     ]> : tensor<3x3xi32>
+//       CHECK: return %[[cst]]
+func.func @pad_of_ints_no_low_dims() -> tensor<3x3xi32> {
+  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
+  %pad_value = arith.constant 0 : i32
+
+  %0 = tensor.pad %init low[0, 0] high[1, 1] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : i32
+  } : tensor<2x2xi32> to tensor<3x3xi32>
+
+  return %0 : tensor<3x3xi32>
+}
+
+// CHECK-LABEL: func @pad_of_ints_no_high_dims(
+//       CHECK: %[[cst:.*]] = arith.constant dense<[
+//       CHECK:     [0, 0, 0],
+//       CHECK:     [0, 6, 7],
+//       CHECK:     [0, 8, 9]
+//       CHECK:     ]> : tensor<3x3xi32>
+//       CHECK: return %[[cst]]
+func.func @pad_of_ints_no_high_dims() -> tensor<3x3xi32> {
+  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
+  %pad_value = arith.constant 0 : i32
+
+  %0 = tensor.pad %init low[1, 1] high[0, 0] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : i32
+  } : tensor<2x2xi32> to tensor<3x3xi32>
+
+  return %0 : tensor<3x3xi32>
+}
+

@llvmbot
Copy link
Collaborator

llvmbot commented May 19, 2024

@llvm/pr-subscribers-mlir

Author: Spenser Bauman (sabauma)

Changes

Extend the folding ability of the RewriteAsConstant patterns to include tensor.pad operations on constants. The new pattern with constant fold tensor.pad operations which operate on tensor constants and have statically resolvable padding sizes/values.

%init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
%pad_value = arith.constant 0 : i32

%0 = tensor.pad %init low[1, 1] high[1, 1] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : i32
} : tensor<2x2xi32> to tensor<4x4xi32>

becomes

%cst = arith.constant dense<[[0, 0, 0, 0],
[0, 6, 7, 0],
[0, 8, 9, 0],
[0, 0, 0, 0]]> : tensor<4x4xi32>


Full diff: https://github.com/llvm/llvm-project/pull/92691.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp (+150-1)
  • (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+1-1)
  • (modified) mlir/test/Dialect/Tensor/rewrite-as-constant.mlir (+83)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
index 11e1de543ac91..7928b206eded6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
@@ -8,9 +8,12 @@
 //
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 
+#include "llvm/ADT/TypeSwitch.h"
+
 using namespace mlir;
 using namespace mlir::tensor;
 
@@ -45,9 +48,155 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
   }
 };
 
+/// Transform a linear index from one indexing space to another given:
+///
+/// - the shape of the source indexing space,
+/// - the strides of the target indexing space,
+/// - a linear index into the source indexing space.
+///
+/// This function is logically a sequence of linearize/delinearize over
+/// different bases but avoids allocating intermediate SmallVectors.
+int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
+                            ArrayRef<int64_t> outputStrides,
+                            int64_t srcLinearIndex) {
+  assert(inputShape.size() == outputStrides.size());
+
+  int64_t dstLinearIndex = 0;
+
+  for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
+    // Compute the index into the current dimension of the source tensor.
+    // `quotient` is the remaining linear index after accounting for the
+    // current dimension.
+    //
+    // `remainder` is the index into the source tensor for the current
+    // dimension.
+    auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
+
+    srcLinearIndex = quotient;
+
+    // Add the contribution of the current dimension to the output using the
+    // permutation map.
+    dstLinearIndex += outputStrides[dim] * remainder;
+  }
+
+  return dstLinearIndex;
+}
+
+template <typename ElemType, typename AttrType>
+Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
+                        DenseElementsAttr input, AttrType padValue,
+                        ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
+  auto inputValues = input.tryGetValues<ElemType>();
+  if (failed(inputValues))
+    return nullptr;
+
+  auto oldShape = input.getType().getShape();
+
+  // Compute the output shape of the new value.
+  auto newShape =
+      llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
+                          [](std::tuple<int64_t, int64_t, int64_t> pack) {
+                            auto [old, low, high] = pack;
+                            return old + low + high;
+                          });
+
+  int64_t outputSize = computeProduct(newShape);
+
+  // Fully initialize the vector with the padding value.
+  // The non-padded area will then be copied.
+  SmallVector<ElemType> values(outputSize, padValue.getValue());
+
+  // Strides for input and output are used to transform between the indexing
+  // space of the input and output tensors.
+  SmallVector<int64_t> outputStrides = computeStrides(newShape);
+
+  // The contribution of the low padding to the offset in the output tensor.
+  // This is the starting position of the source tensor within the padding
+  // tensor.
+  int64_t startingOffset = linearize(padLow, outputStrides);
+
+  // Copy values from the input tensor to the corresponding sub-region
+  // of the output tensor.
+  for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
+    auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
+    values[outputIndex + startingOffset] = inputValue;
+  }
+
+  // Create an attribute for the folded value.
+  auto newType = input.getType().clone(newShape);
+  auto newAttr = DenseElementsAttr::get(newType, values);
+
+  Operation *constantOp =
+      rewriter.getContext()
+          ->getLoadedDialect<TensorDialect>()
+          ->materializeConstant(rewriter, newAttr, newType, loc);
+
+  return constantOp ? constantOp->getResult(0) : nullptr;
+}
+
+struct PadOpToConstant final : public OpRewritePattern<PadOp> {
+  using OpRewritePattern<PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadOp padTensorOp,
+                                PatternRewriter &rewriter) const override {
+    if (padTensorOp.getNofold())
+      return rewriter.notifyMatchFailure(
+          padTensorOp, "refusing to fold nofold pad operation");
+
+    TypedValue<RankedTensorType> input = padTensorOp.getSource();
+    RankedTensorType resultType = padTensorOp.getResult().getType();
+
+    DenseElementsAttr inputAttr = nullptr;
+    if (!matchPattern(input, m_Constant(&inputAttr)))
+      return failure();
+
+    Value paddingValue = padTensorOp.getConstantPaddingValue();
+
+    // Extract the constant value used for padding or bail out.
+    Attribute paddingAttr = nullptr;
+    if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
+      return rewriter.notifyMatchFailure(padTensorOp,
+                                         "unable to get constant value");
+
+    // Try to extract the constant values of the low and high padding.
+    auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
+    auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
+
+    // If the padding cannot be extracted, bail out.
+    if (!lowPad || !highPad)
+      return rewriter.notifyMatchFailure(padTensorOp,
+                                         "unable to extract constant padding");
+
+    Location loc = padTensorOp.getLoc();
+
+    // Try constant folding the supported cases of integer and float values.
+    Value newOp =
+        llvm::TypeSwitch<Attribute, Value>(paddingAttr)
+            .Case([&](FloatAttr floatAttr) {
+              return constantFoldPadOp<llvm::APFloat>(
+                  rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
+            })
+            .Case([&](IntegerAttr integerAttr) {
+              return constantFoldPadOp<llvm::APInt>(
+                  rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
+            })
+            .Default(Value());
+
+    if (!newOp)
+      return rewriter.notifyMatchFailure(padTensorOp,
+                                         "tensor type not supported");
+
+    if (newOp.getType() != resultType)
+      newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp);
+
+    rewriter.replaceOp(padTensorOp, newOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tensor::populateRewriteAsConstantPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<GenerateToConstant>(patterns.getContext());
+  patterns.add<GenerateToConstant, PadOpToConstant>(patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 4c960659d80cb..aba225be720c3 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -92,7 +92,7 @@ int64_t mlir::computeProduct(ArrayRef<int64_t> basis) {
   assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
          "basis must be nonnegative");
   if (basis.empty())
-    return 0;
+    return 1;
   return std::accumulate(basis.begin(), basis.end(), 1,
                          std::multiplies<int64_t>());
 }
diff --git a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
index 1a1cf9e407d80..406065422cd7c 100644
--- a/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
+++ b/mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
@@ -21,3 +21,86 @@ func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
   } : tensor<2x3x5xf32>
   return %0 : tensor<2x3x5xf32>
 }
+
+// CHECK-LABEL: func @pad_of_ints(
+//       CHECK: %[[cst:.*]] = arith.constant dense<[
+//       CHECK:     [0, 0, 0, 0],
+//       CHECK:     [0, 6, 7, 0],
+//       CHECK:     [0, 8, 9, 0],
+//       CHECK:     [0, 0, 0, 0]
+//       CHECK:     ]> : tensor<4x4xi32>
+//       CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
+//       CHECK: return %[[cast]]
+func.func @pad_of_ints() -> tensor<?x?xi32> {
+  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
+  %pad_value = arith.constant 0 : i32
+
+  %c1 = arith.constant 1 : index
+
+  %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : i32
+  } : tensor<2x2xi32> to tensor<?x?xi32>
+
+  return %0 : tensor<?x?xi32>
+}
+
+// CHECK-LABEL: func @pad_of_floats(
+//       CHECK: %[[cst:.*]] = arith.constant dense<[
+//       CHECK:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
+//       CHECK:     [0.000000e+00, 6.000000e+00, 7.000000e+00, 0.000000e+00],
+//       CHECK:     [0.000000e+00, 8.000000e+00, 9.000000e+00, 0.000000e+00],
+//       CHECK:     [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
+//       CHECK:     ]> : tensor<4x4xf32>
+//       CHECK: return %[[cst]]
+
+func.func @pad_of_floats() -> tensor<4x4xf32> {
+  %init = arith.constant dense<[[6.0, 7.0], [8.0, 9.0]]> : tensor<2x2xf32>
+  %pad_value = arith.constant 0.0 : f32
+
+  %0 = tensor.pad %init low[1, 1] high[1, 1] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+  } : tensor<2x2xf32> to tensor<4x4xf32>
+
+  return %0 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func @pad_of_ints_no_low_dims(
+//       CHECK: %[[cst:.*]] = arith.constant dense<[
+//       CHECK:     [6, 7, 0],
+//       CHECK:     [8, 9, 0],
+//       CHECK:     [0, 0, 0]
+//       CHECK:     ]> : tensor<3x3xi32>
+//       CHECK: return %[[cst]]
+func.func @pad_of_ints_no_low_dims() -> tensor<3x3xi32> {
+  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
+  %pad_value = arith.constant 0 : i32
+
+  %0 = tensor.pad %init low[0, 0] high[1, 1] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : i32
+  } : tensor<2x2xi32> to tensor<3x3xi32>
+
+  return %0 : tensor<3x3xi32>
+}
+
+// CHECK-LABEL: func @pad_of_ints_no_high_dims(
+//       CHECK: %[[cst:.*]] = arith.constant dense<[
+//       CHECK:     [0, 0, 0],
+//       CHECK:     [0, 6, 7],
+//       CHECK:     [0, 8, 9]
+//       CHECK:     ]> : tensor<3x3xi32>
+//       CHECK: return %[[cst]]
+func.func @pad_of_ints_no_high_dims() -> tensor<3x3xi32> {
+  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
+  %pad_value = arith.constant 0 : i32
+
+  %0 = tensor.pad %init low[1, 1] high[0, 0] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : i32
+  } : tensor<2x2xi32> to tensor<3x3xi32>
+
+  return %0 : tensor<3x3xi32>
+}
+

@sabauma
Copy link
Contributor Author

sabauma commented May 19, 2024

@rafaelubalmw @ryan-holt-1

@joker-eph
Copy link
Collaborator

Isn't this the kind of folder that can blow up the IR?
I find constant folding on full tensors to be the kind of thing that is difficult to get right, and hardly possible without context.

@MaheshRavishankar
Copy link
Contributor

+1 to what Mehdi said. This only works in context of very small constants.

@sabauma
Copy link
Contributor Author

sabauma commented May 19, 2024

@joker-eph @MaheshRavishankar Fair enough. This definitely falls into that particular category of constant folding operation. I was modeling this on some of the folders from TOSA, but I know there are concerns with how they can bloat the IR/context.

Is the preference to not have patterns like this at all? This change does not add the new pattern to the tensor.pad canonicalizers, so it is opt-in and could be parameterized by a cost function.

@joker-eph
Copy link
Collaborator

This change does not add the new pattern to the tensor.pad canonicalizers, so it is opt-in and could be parameterized by a cost function.

Yeah it is fine!

@MaheshRavishankar
Copy link
Contributor

@joker-eph @MaheshRavishankar Fair enough. This definitely falls into that particular category of constant folding operation. I was modeling this on some of the folders from TOSA, but I know there are concerns with how they can bloat the IR/context.

Is the preference to not have patterns like this at all? This change does not add the new pattern to the tensor.pad canonicalizers, so it is opt-in and could be parameterized by a cost function.

Is it parameterized by a cost function? I dont see that in the code.

Extend the folding ability of the RewriteAsConstant patterns to include
tensor.pad operations on constants. The new pattern with constant fold
tensor.pad operations which operate on tensor constants and have
statically resolvable padding sizes/values.

  %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
  %pad_value = arith.constant 0 : i32

  %0 = tensor.pad %init low[1, 1] high[1, 1] {
    ^bb0(%arg1: index, %arg2: index):
      tensor.yield %pad_value : i32
  } : tensor<2x2xi32> to tensor<4x4xi32>

becomes

  %cst = arith.constant dense<[[0, 0, 0, 0],
                               [0, 6, 7, 0],
                               [0, 8, 9, 0],
                               [0, 0, 0, 0]]> : tensor<4x4xi32>
@sabauma
Copy link
Contributor Author

sabauma commented Jun 5, 2024

@MaheshRavishankar I've parameterized the pattern set with a cost function, similar to how the equivalent Linalg patterns are parameterized.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks.

@sabauma sabauma merged commit a9205c5 into llvm:main Jun 6, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants