Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms] Dialect conversion: Add option to disable folding #92683

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

This commit adds a new flag to ConversionConfig that deactivates op folding during a dialect conversion.

Op folding is problematic beause op folders may assume that the entire IR is in a valid state. (See #89770 for an example.) However, the dialect conversion driver does not guarantee that the IR is valid during a dialect conversion; it only guarantees that the IR is valid at the end of a dialect conversion. E.g., IR may be invalid after a conversion pattern application because some IR modifications (e.g., op/block replacements) are applied in a delayed fashion at the end of a dialect conversion. This makes op folders generally unsafe to use with a dialect conversion.

Note: For the same reason, it is also not safe to use non-conversion patterns with a dialect conversion. Regular rewrite patterns may assume that the entire IR is in a valid state, but conversion patterns cannot -- and developers of conversion patterns must take that into account.

This commit adds a new flag to `ConversionConfig` that deactivates op folding during a dialect conversion.

Op folding is problematic beause op folders may assume that the IR is in a valid state. (See #89770 for an example.) However, the dialect conversion driver does not guarantee that the IR is valid during a dialect conversion; it only guarantees that the IR is valid at the end of a dialect conversion. E.g., IR may be invalid after a conversion pattern application because some IR modifications (e.g., op/block replacements) are applied in a delayed fashion at the end of a dialect conversion. This makes op folders generally unsafe to use with a dialect conversion.

Note: For the same reason, it is also not safe to use non-conversion patterns with a dialect conversion. Conversion patterns can be used safely because they have an "adapter". (And conversion patterns cannot assume that the entire IR is valid in general.)
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels May 19, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented May 19, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new flag to ConversionConfig that deactivates op folding during a dialect conversion.

Op folding is problematic beause op folders may assume that the entire IR is in a valid state. (See #89770 for an example.) However, the dialect conversion driver does not guarantee that the IR is valid during a dialect conversion; it only guarantees that the IR is valid at the end of a dialect conversion. E.g., IR may be invalid after a conversion pattern application because some IR modifications (e.g., op/block replacements) are applied in a delayed fashion at the end of a dialect conversion. This makes op folders generally unsafe to use with a dialect conversion.

Note: For the same reason, it is also not safe to use non-conversion patterns with a dialect conversion. Regular rewrite patterns may assume that the entire IR is in a valid state, but conversion patterns cannot -- and developers of conversion patterns must take that into account.


Full diff: https://github.com/llvm/llvm-project/pull/92683.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+10-1)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+1-1)
  • (modified) mlir/test/Transforms/test-legalizer-analysis.mlir (+1-1)
  • (modified) mlir/test/Transforms/test-legalizer-full.mlir (+1-1)
  • (added) mlir/test/Transforms/test-legalizer-no-fold.mlir (+11)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+23-19)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9b0db54..ea41e7c5ba803 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -247,7 +247,8 @@ class TypeConverter {
   /// Attempts a 1-1 type conversion, expecting the result type to be
   /// `TargetType`. Returns the converted type cast to `TargetType` on success,
   /// and a null type on conversion or cast failure.
-  template <typename TargetType> TargetType convertType(Type t) const {
+  template <typename TargetType>
+  TargetType convertType(Type t) const {
     return dyn_cast_or_null<TargetType>(convertType(t));
   }
 
@@ -1118,6 +1119,14 @@ struct ConversionConfig {
   // already been modified) and iterators into past IR state cannot be
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
+
+  /// If set to "true", the dialect conversion driver attempts to fold
+  /// operations throughout the conversion. This is problematic because op
+  /// folders may assume that the IR is in a valid state at the beginning of
+  /// the folding process. However, the dialect conversion does not guarantee
+  /// that because some IR modifications are delayed until the end of the
+  /// conversion.
+  bool foldOps = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60334c70..3c684e9a208ac 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2030,7 +2030,7 @@ OperationLegalizer::legalize(Operation *op,
   // If the operation isn't legal, try to fold it in-place.
   // TODO: Should we always try to do this, even if the op is
   // already legal?
-  if (succeeded(legalizeWithFold(op, rewriter))) {
+  if (config.foldOps && succeeded(legalizeWithFold(op, rewriter))) {
     LLVM_DEBUG({
       logSuccess(logger, "operation was folded");
       logger.startLine() << logLineComment;
diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir
index 19a13100159a2..829415b9af414 100644
--- a/mlir/test/Transforms/test-legalizer-analysis.mlir
+++ b/mlir/test/Transforms/test-legalizer-analysis.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="legalize-mode=analysis" -verify-diagnostics %s | FileCheck %s
 // expected-remark@-2 {{op 'builtin.module' is legalizable}}
 
 // expected-remark@+1 {{op 'func.func' is legalizable}}
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index 5f1148cac6501..ea163a5767755 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -test-legalize-mode=full -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="legalize-mode=full" -split-input-file -verify-diagnostics %s | FileCheck %s
 
 // CHECK-LABEL: func @multi_level_mapping
 func.func @multi_level_mapping() {
diff --git a/mlir/test/Transforms/test-legalizer-no-fold.mlir b/mlir/test/Transforms/test-legalizer-no-fold.mlir
new file mode 100644
index 0000000000000..61afd72c934bc
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-no-fold.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="fold-ops=0" %s | FileCheck %s
+
+// CHECK-LABEL: @remove_foldable_op(
+func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
+  // Check that op was not folded.
+  // CHECK: "test.op_with_region_fold"
+  %0 = "test.op_with_region_fold"(%arg0) ({
+    "foo.op_with_region_terminator"() : () -> ()
+  }) : (i32) -> (i32)
+  "test.return"(%0) : (i32) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f9f7d4eacf948..97ef7a51aa203 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1102,7 +1102,9 @@ struct TestLegalizePatternDriver
   /// The mode of conversion to use with the driver.
   enum class ConversionMode { Analysis, Full, Partial };
 
-  TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
+  TestLegalizePatternDriver() = default;
+  TestLegalizePatternDriver(const TestLegalizePatternDriver &other)
+      : PassWrapper(other) {}
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<func::FuncDialect, test::TestDialect>();
@@ -1179,6 +1181,7 @@ struct TestLegalizePatternDriver
       DumpNotifications dumpNotifications;
       config.listener = &dumpNotifications;
       config.unlegalizedOps = &unlegalizedOps;
+      config.foldOps = foldOps;
       if (failed(applyPartialConversion(getOperation(), target,
                                         std::move(patterns), config))) {
         getOperation()->emitRemark() << "applyPartialConversion failed";
@@ -1197,6 +1200,7 @@ struct TestLegalizePatternDriver
       });
 
       ConversionConfig config;
+      config.foldOps = foldOps;
       DumpNotifications dumpNotifications;
       config.listener = &dumpNotifications;
       if (failed(applyFullConversion(getOperation(), target,
@@ -1212,6 +1216,7 @@ struct TestLegalizePatternDriver
     // Analyze the convertible operations.
     DenseSet<Operation *> legalizedOps;
     ConversionConfig config;
+    config.foldOps = foldOps;
     config.legalizableOps = &legalizedOps;
     if (failed(applyAnalysisConversion(getOperation(), target,
                                        std::move(patterns), config)))
@@ -1222,24 +1227,25 @@ struct TestLegalizePatternDriver
       op->emitRemark() << "op '" << op->getName() << "' is legalizable";
   }
 
-  /// The mode of conversion to use.
-  ConversionMode mode;
+  Option<bool> foldOps{
+      *this, "fold-ops",
+      llvm::cl::desc("Fold ops throughout the conversion process"),
+      llvm::cl::init(true)};
+
+  Option<TestLegalizePatternDriver::ConversionMode> mode{
+      *this, "legalize-mode",
+      llvm::cl::desc("The legalization mode to use with the test driver"),
+      llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
+      llvm::cl::values(
+          clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
+                     "analysis", "Perform an analysis conversion"),
+          clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
+                     "Perform a full conversion"),
+          clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
+                     "partial", "Perform a partial conversion"))};
 };
 } // namespace
 
-static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
-    legalizerConversionMode(
-        "test-legalize-mode",
-        llvm::cl::desc("The legalization mode to use with the test driver"),
-        llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
-        llvm::cl::values(
-            clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
-                       "analysis", "Perform an analysis conversion"),
-            clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
-                       "Perform a full conversion"),
-            clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
-                       "partial", "Perform a partial conversion")));
-
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriter::getRemappedValue testing. This method is used
 // to get the remapped value of an original value that was replaced using
@@ -1909,9 +1915,7 @@ void registerPatternsTestPass() {
   PassRegistration<TestPatternDriver>();
   PassRegistration<TestStrictPatternDriver>();
 
-  PassRegistration<TestLegalizePatternDriver>([] {
-    return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
-  });
+  PassRegistration<TestLegalizePatternDriver>();
 
   PassRegistration<TestRemappedValue>();
 

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably the best short term option. Removing the fold legalization entirely is more ambitious. I have no clue whether anyone actually relies on it.

Comment on lines +1124 to +1129
/// operations throughout the conversion. This is problematic because op
/// folders may assume that the IR is in a valid state at the beginning of
/// the folding process. However, the dialect conversion does not guarantee
/// that because some IR modifications are delayed until the end of the
/// conversion.
bool foldOps = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether there should be either a TODO such as "change to false in the future" (if we want to take that route) or whether the comment should note that it is true for historic reasons.

Looks funny that the majority of the paragraph discourages using the options but we default to it being true

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would go further: if we consider this unsafe, we should just deprecate this mode entirely.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do mean setting this to false by default or adding a comment that this is deprecated (or both)? We have at least one test case in the test dialect that tests the folding.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we should be setting this to false by default, folks who are broken can set it back to true, but we also document it as deprecated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are broken with foldOps = false:

Failed Tests (8):
  MLIR :: Conversion/AffineToStandard/lower-affine.mlir
  MLIR :: Conversion/FuncToLLVM/calling-convention.mlir
  MLIR :: Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir
  MLIR :: Conversion/ShapeToStandard/shape-to-standard.mlir
  MLIR :: Conversion/VectorToLLVM/vector-to-llvm.mlir
  MLIR :: Dialect/LLVM/lower-to-llvm-e2e-with-target-tag.mlir
  MLIR :: Dialect/LLVM/lower-to-llvm-e2e-with-top-level-named-sequence.mlir
  MLIR :: Dialect/Vector/linearize.mlir

They are mostly FileCheck failures, but vector-to-llvm.mlir is actually broken. I'm busy with other stuff right now, so it might take a while. (Or if someone else wants to take this over, feel free to.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants