Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][bufferization] Allow mixed static/dynamic shapes in materialize_in_destination op #92681

Merged

Conversation

matthias-springer
Copy link
Member

This commit relaxes the verifier of bufferization.materialize_in_destination such that mixed static/dynamic dimensions are allowed for the source and destination operands. E.g., tensor<5xf32> and tensor<?xf32> are now compatible, but it is assumed that the dynamic dimension is 5 at runtime.

This commit fixes #91265.

…ze_in_destination` op

This commit relaxes the verifier of `bufferization.materialize_in_destination` such that mixed static/dynamic dimensions are allowed for the source and destination operands. E.g., `tensor<5xf32>` and `tensor<?xf32>` are now compatible, but it is assumed that the dynamic dimension is `5` at runtime.

This commit fixes #91265.
@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels May 19, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented May 19, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

Changes

This commit relaxes the verifier of bufferization.materialize_in_destination such that mixed static/dynamic dimensions are allowed for the source and destination operands. E.g., tensor&lt;5xf32&gt; and tensor&lt;?xf32&gt; are now compatible, but it is assumed that the dynamic dimension is 5 at runtime.

This commit fixes #91265.


Full diff: https://github.com/llvm/llvm-project/pull/92681.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+4-5)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+18)
  • (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+10-3)
  • (modified) mlir/test/Dialect/Bufferization/ops.mlir (+5-2)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..1c70a4b8df925 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -217,8 +217,7 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
 
 def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
-        [AllShapesMatch<["source", "dest"]>,
-         AllElementTypesMatch<["source", "dest"]>,
+        [AllElementTypesMatch<["source", "dest"]>,
          BufferizableOpInterface, DestinationStyleOpInterface,
          DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
          DeclareOpInterfaceMethods<SubsetOpInterface,
@@ -239,9 +238,9 @@ def Bufferization_MaterializeInDestinationOp
     memref, `source` materializes in `dest`, which is already a buffer. The op
     has no results in that case.
 
-    `source`, `dest` and `result` (if present) must have the same shape and
-    element type. If the op has a result, the types of `result` and `dest` must
-    match exactly (e.g., including any tensor encodings).
+    `source`, `dest` and `result` (if present) must have the same runtime shape
+    and element type. If the op has a result, the types of `result` and `dest`
+    must match exactly (e.g., including any tensor encodings).
 
     By default, this op bufferizes to a memcpy from the future buffer of the
     `source` tensor to the future buffer of the `dest` tensor or to the `dest`
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..3b7b412842bfb 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -686,6 +686,24 @@ LogicalResult MaterializeInDestinationOp::verify() {
   if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
     return emitOpError("'writable' must be specified if and only if the "
                        "destination is of memref type");
+  TensorType srcType = getSource().getType();
+  ShapedType destType = cast<ShapedType>(getDest().getType());
+  if (srcType.hasRank() != destType.hasRank())
+    return emitOpError("source/destination shapes are incompatible");
+  if (srcType.hasRank()) {
+    if (srcType.getRank() != destType.getRank())
+      return emitOpError("rank mismatch between source and destination shape");
+    for (auto [src, dest] :
+         llvm::zip(srcType.getShape(), destType.getShape())) {
+      if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
+        // Cannot verify dynamic dimension size. Assume that that they match at
+        // runtime.
+        continue;
+      }
+      if (src != dest)
+        return emitOpError("source/destination shapes are incompatible");
+    }
+  }
   return success();
 }
 
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 4ebdb0a8f0490..2c8807b66de74 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -43,9 +43,16 @@ func.func @invalid_writable_on_op() {
 
 // -----
 
-func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
-  // expected-error @below{{failed to verify that all of {source, dest} have same shape}}
-  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+func.func @invalid_materialize_in_destination(%arg0: tensor<4xf32>, %arg1: tensor<5xf32>) {
+  // expected-error @below{{source/destination shapes are incompatible}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<4xf32>, tensor<5xf32>) -> tensor<5xf32>
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>) {
+  // expected-error @below{{rank mismatch between source and destination shape}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5x5xf32>, tensor<5xf32>) -> tensor<5xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index d4bda0632189d..ad4a66c1b7978 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -59,12 +59,15 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
 }
 
 // CHECK-LABEL: func @test_materialize_in_destination_op
-func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>)
-    -> tensor<?xf32> {
+func.func @test_materialize_in_destination_op(
+    %arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>,
+    %arg4: tensor<5xf32>) -> tensor<?xf32> {
   // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   %1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, memref<?xf32, 3>) -> ()
   bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor<?xf32>, memref<?xf32, 3>) -> ()
+  // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+  %2 = bufferization.materialize_in_destination %arg0 in %arg4 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
   return %1 : tensor<?xf32>
 }
 

Copy link
Contributor

@christopherbate christopherbate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@matthias-springer matthias-springer merged commit 9d4b20a into main Jun 1, 2024
7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/materialize_in_dest_shape branch June 1, 2024 10:04
HendrikHuebner pushed a commit to HendrikHuebner/llvm-project that referenced this pull request Jun 2, 2024
…ze_in_destination` op (llvm#92681)

This commit relaxes the verifier of
`bufferization.materialize_in_destination` such that mixed
static/dynamic dimensions are allowed for the source and destination
operands. E.g., `tensor<5xf32>` and `tensor<?xf32>` are now compatible,
but it is assumed that the dynamic dimension is `5` at runtime.

This commit fixes llvm#91265.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Tensor Dialect canonicalizer FoldTensorCastProducerOp can produce invalid IR
3 participants