Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][emitc] Support scalar MemRef types #92684

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

aniragil
Copy link
Contributor

Extend MemRefToEmitC to map zero-ranked memrefs into scalar C variables
and replace the use of emitc.{variable, assign} in SCFToEmitC with
memref.{alloca, load, store}, leaving it to the memref dialect lowering
to handle memory allocation and accesses.

Extend MemRefToEmitC to map zero-ranked memrefs into scalar C variables
and replace the use of emitc.{variable, assign} in SCFToEmitC with
memref.{alloca, load, store}, leaving it to the memref dialect lowering
to handle memory allocation and accesses.
@llvmbot
Copy link
Collaborator

llvmbot commented May 19, 2024

@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Gil Rapaport (aniragil)

Changes

Extend MemRefToEmitC to map zero-ranked memrefs into scalar C variables
and replace the use of emitc.{variable, assign} in SCFToEmitC with
memref.{alloca, load, store}, leaving it to the memref dialect lowering
to handle memory allocation and accesses.


Patch is 27.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/92684.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+1-1)
  • (modified) mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp (+29-19)
  • (modified) mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp (+29-11)
  • (modified) mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir (-8)
  • (modified) mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir (+15-3)
  • (modified) mlir/test/Conversion/SCFToEmitC/for.mlir (+27-27)
  • (modified) mlir/test/Conversion/SCFToEmitC/if.mlir (+8-6)
  • (modified) mlir/test/Target/Cpp/for.mlir (+44-26)
  • (modified) mlir/test/Target/Cpp/if.mlir (+14)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index e6d678dc1b12b..18d7abe2d707c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -976,7 +976,7 @@ def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
 def SCFToEmitC : Pass<"convert-scf-to-emitc"> {
   let summary = "Convert SCF dialect to EmitC dialect, maintaining structured"
                 " control flow";
-  let dependentDialects = ["emitc::EmitCDialect"];
+  let dependentDialects = ["emitc::EmitCDialect", "memref::MemRefDialect"];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index e0c421741b305..2d2c8f988fdc5 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 
@@ -90,6 +91,12 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
     if (isa_and_present<UnitAttr>(initialValue))
       initialValue = {};
 
+    // If converted type is a scalar, extract the splatted initial value.
+    if (initialValue && !isa<emitc::ArrayType>(resultTy)) {
+      auto elementsAttr = llvm::cast<ElementsAttr>(initialValue);
+      initialValue = elementsAttr.getSplatValue<Attribute>();
+    }
+
     rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
         op, operands.getSymName(), resultTy, initialValue, externSpecifier,
         staticSpecifier, operands.getConstant());
@@ -116,6 +123,19 @@ struct ConvertGetGlobal final
   }
 };
 
+template <typename T>
+static Value getMemoryAccess(Value memref, Location loc,
+                             typename T::Adaptor operands,
+                             ConversionPatternRewriter &rewriter) {
+  // If MemRef is an array, access location using array subscripts.
+  if (auto arrayValue = dyn_cast<TypedValue<emitc::ArrayType>>(memref))
+    return rewriter.create<emitc::SubscriptOp>(loc, arrayValue,
+                                               operands.getIndices());
+
+  // MemRef is a scalar, access location using variable's name.
+  return memref;
+}
+
 struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -128,20 +148,13 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
       return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
     }
 
-    auto arrayValue =
-        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
-    if (!arrayValue) {
-      return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
-    }
-
-    auto subscript = rewriter.create<emitc::SubscriptOp>(
-        op.getLoc(), arrayValue, operands.getIndices());
-
+    Value lvalue = getMemoryAccess<memref::LoadOp>(
+        operands.getMemref(), op.getLoc(), operands, rewriter);
     auto noInit = emitc::OpaqueAttr::get(getContext(), "");
     auto var =
         rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);
 
-    rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
+    rewriter.create<emitc::AssignOp>(op.getLoc(), var, lvalue);
     rewriter.replaceOp(op, var);
     return success();
   }
@@ -153,15 +166,10 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
   LogicalResult
   matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
                   ConversionPatternRewriter &rewriter) const override {
-    auto arrayValue =
-        dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
-    if (!arrayValue) {
-      return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
-    }
+    Value lvalue = getMemoryAccess<memref::StoreOp>(
+        operands.getMemref(), op.getLoc(), operands, rewriter);
 
-    auto subscript = rewriter.create<emitc::SubscriptOp>(
-        op.getLoc(), arrayValue, operands.getIndices());
-    rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
+    rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, lvalue,
                                                  operands.getValue());
     return success();
   }
@@ -172,13 +180,15 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
   typeConverter.addConversion(
       [&](MemRefType memRefType) -> std::optional<Type> {
         if (!memRefType.hasStaticShape() ||
-            !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) {
+            !memRefType.getLayout().isIdentity()) {
           return {};
         }
         Type convertedElementType =
             typeConverter.convertType(memRefType.getElementType());
         if (!convertedElementType)
           return {};
+        if (memRefType.getRank() == 0)
+          return convertedElementType;
         return emitc::ArrayType::get(memRefType.getShape(),
                                      convertedElementType);
       });
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 367142a520742..e4aa10f4cf208 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -14,11 +14,13 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/Passes.h"
@@ -63,21 +65,31 @@ static SmallVector<Value> createVariablesForResults(T op,
 
   for (OpResult result : op.getResults()) {
     Type resultType = result.getType();
-    emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
-    emitc::VariableOp var =
-        rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
+    SmallVector<OpFoldResult> dimensions; // Zero rank for scalar memref.
+    memref::AllocaOp var =
+        rewriter.create<memref::AllocaOp>(loc, dimensions, resultType);
     resultVariables.push_back(var);
   }
 
   return resultVariables;
 }
 
-// Create a series of assign ops assigning given values to given variables at
+// Create a series of load ops reading the values of given variables at
+// the current insertion point of given rewriter.
+static SmallVector<Value> readValues(SmallVector<Value> &variables,
+                                     PatternRewriter &rewriter, Location loc) {
+  SmallVector<Value> values;
+  for (Value var : variables)
+    values.push_back(rewriter.create<memref::LoadOp>(loc, var).getResult());
+  return values;
+}
+
+// Create a series of store ops assigning given values to given variables at
 // the current insertion point of given rewriter.
 static void assignValues(ValueRange values, SmallVector<Value> &variables,
                          PatternRewriter &rewriter, Location loc) {
   for (auto [value, var] : llvm::zip(values, variables))
-    rewriter.create<emitc::AssignOp>(loc, var, value);
+    rewriter.create<memref::StoreOp>(loc, value, var);
 }
 
 static void lowerYield(SmallVector<Value> &resultVariables,
@@ -100,8 +112,6 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
 
   // Create an emitc::variable op for each result. These variables will be
   // assigned to by emitc::assign ops within the loop body.
-  SmallVector<Value> resultVariables =
-      createVariablesForResults(forOp, rewriter);
   SmallVector<Value> iterArgsVariables =
       createVariablesForResults(forOp, rewriter);
 
@@ -115,18 +125,25 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
   // Erase the auto-generated terminator for the lowered for op.
   rewriter.eraseOp(loweredBody->getTerminator());
 
+  IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
+  rewriter.setInsertionPointToEnd(loweredBody);
+  SmallVector<Value> iterArgsValues =
+      readValues(iterArgsVariables, rewriter, loc);
+  rewriter.restoreInsertionPoint(ip);
+
   SmallVector<Value> replacingValues;
   replacingValues.push_back(loweredFor.getInductionVar());
-  replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
+  replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
 
   rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
   lowerYield(iterArgsVariables, rewriter,
              cast<scf::YieldOp>(loweredBody->getTerminator()));
 
   // Copy iterArgs into results after the for loop.
-  assignValues(iterArgsVariables, resultVariables, rewriter, loc);
+  SmallVector<Value> resultValues =
+      readValues(iterArgsVariables, rewriter, loc);
 
-  rewriter.replaceOp(forOp, resultVariables);
+  rewriter.replaceOp(forOp, resultValues);
   return success();
 }
 
@@ -169,6 +186,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
 
   auto loweredIf =
       rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
+  SmallVector<Value> resultValues = readValues(resultVariables, rewriter, loc);
 
   Region &loweredThenRegion = loweredIf.getThenRegion();
   lowerRegion(thenRegion, loweredThenRegion);
@@ -178,7 +196,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
     lowerRegion(elseRegion, loweredElseRegion);
   }
 
-  rewriter.replaceOp(ifOp, resultVariables);
+  rewriter.replaceOp(ifOp, resultValues);
   return success();
 }
 
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index 89dafa7529ed5..0cb33034680d8 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -33,13 +33,5 @@ func.func @non_identity_layout() {
 
 // -----
 
-func.func @zero_rank() {
-  // expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
-  %0 = memref.alloca() : memref<f32>
-  return
-}
-
-// -----
-
 // expected-error@+1 {{failed to legalize operation 'memref.global'}}
 memref.global "nested" constant @nested_global : memref<3x7xf32>
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index bc40ef48268eb..aafaf9810711b 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -4,11 +4,15 @@
 // CHECK-SAME:  %[[v:.*]]: f32, %[[i:.*]]: index, %[[j:.*]]: index
 func.func @memref_store(%v : f32, %i: index, %j: index) {
   // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
+  // CHECK: %[[SCALAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
   %0 = memref.alloca() : memref<4x8xf32>
+  %s = memref.alloca() : memref<f32>
 
   // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
   // CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
   memref.store %v, %0[%i, %j] : memref<4x8xf32>
+  // CHECK: emitc.assign %[[v]] : f32 to %[[SCALAR]] : f32
+  memref.store %v, %s[] : memref<f32>
   return
 }
 
@@ -16,16 +20,21 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
 
 // CHECK-LABEL: memref_load
 // CHECK-SAME:  %[[i:.*]]: index, %[[j:.*]]: index
-func.func @memref_load(%i: index, %j: index) -> f32 {
+func.func @memref_load(%i: index, %j: index) -> (f32, f32) {
   // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
   %0 = memref.alloca() : memref<4x8xf32>
+  // CHECK: %[[SCALAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+  %s = memref.alloca() : memref<f32>
 
   // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32
   // CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
   // CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
   %1 = memref.load %0[%i, %j] : memref<4x8xf32>
-  // CHECK: return %[[VAR]] : f32
-  return %1 : f32
+  // CHECK: %[[VAR_S:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
+  // CHECK: emitc.assign %[[SCALAR]] : f32 to %[[VAR_S]] : f32
+  %sv = memref.load %s[] : memref<f32>
+  // CHECK: return %[[VAR]], %[[VAR_S]] : f32, f32
+  return %1, %sv : f32, f32
 }
 
 // -----
@@ -38,10 +47,13 @@ module @globals {
   // CHECK: emitc.global extern @public_global : !emitc.array<3x7xf32>
   memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
   // CHECK: emitc.global extern @uninitialized_global : !emitc.array<3x7xf32>
+  memref.global "private" constant @internal_global_scalar : memref<f32> = dense<4.0>
+  // CHECK: emitc.global static const @internal_global_scalar : f32 = 4.000000e+00
 
   func.func @use_global() {
     // CHECK: emitc.get_global @public_global : !emitc.array<3x7xf32>
     %0 = memref.get_global @public_global : memref<3x7xf32>
+    %1 = memref.get_global @internal_global_scalar : memref<f32>
     return
   }
 }
diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir
index 7f90310af2189..d4e211b7a5950 100644
--- a/mlir/test/Conversion/SCFToEmitC/for.mlir
+++ b/mlir/test/Conversion/SCFToEmitC/for.mlir
@@ -47,20 +47,20 @@ func.func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32)
 // CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> (f32, f32) {
 // CHECK-NEXT:    %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK-NEXT:    %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT:    %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT:    %[[VAL_6:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT:    %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT:    %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT:    emitc.assign %[[VAL_3]] : f32 to %[[VAL_7]] : f32
-// CHECK-NEXT:    emitc.assign %[[VAL_4]] : f32 to %[[VAL_8]] : f32
-// CHECK-NEXT:    emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
-// CHECK-NEXT:      %[[VAL_10:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32
-// CHECK-NEXT:      emitc.assign %[[VAL_10]] : f32 to %[[VAL_7]] : f32
-// CHECK-NEXT:      emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT:    %[[VAL_5:.*]] = memref.alloca() : memref<f32>
+// CHECK-NEXT:    %[[VAL_6:.*]] = memref.alloca() : memref<f32>
+// CHECK-NEXT:    memref.store %[[VAL_3]], %[[VAL_5]][] : memref<f32>
+// CHECK-NEXT:    memref.store %[[VAL_4]], %[[VAL_6]][] : memref<f32>
+// CHECK-NEXT:    emitc.for %[[VAL_7:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:      %[[VAL_8:.*]] = memref.load %[[VAL_5]][] : memref<f32>
+// CHECK-NEXT:      %[[VAL_9:.*]] = memref.load %[[VAL_6]][] : memref<f32>
+// CHECK-NEXT:      %[[VAL_10:.*]] = arith.addf %[[VAL_8]], %[[VAL_9]] : f32
+// CHECK-NEXT:      memref.store %[[VAL_10]], %[[VAL_5]][] : memref<f32>
+// CHECK-NEXT:      memref.store %[[VAL_10]], %[[VAL_6]][] : memref<f32>
 // CHECK-NEXT:    }
-// CHECK-NEXT:    emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
-// CHECK-NEXT:    emitc.assign %[[VAL_8]] : f32 to %[[VAL_6]] : f32
-// CHECK-NEXT:    return %[[VAL_5]], %[[VAL_6]] : f32, f32
+// CHECK-NEXT:    %[[VAL_11:.*]] = memref.load %[[VAL_5]][] : memref<f32>
+// CHECK-NEXT:    %[[VAL_12:.*]] = memref.load %[[VAL_6]][] : memref<f32>
+// CHECK-NEXT:    return %[[VAL_11]], %[[VAL_12]] : f32, f32
 // CHECK-NEXT:  }
 
 func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 {
@@ -77,20 +77,20 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32
 // CHECK-LABEL: func.func @nested_for_yield(
 // CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> f32 {
 // CHECK-NEXT:    %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT:    %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT:    %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT:    emitc.assign %[[VAL_3]] : f32 to %[[VAL_5]] : f32
-// CHECK-NEXT:    emitc.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
-// CHECK-NEXT:      %[[VAL_7:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT:      %[[VAL_8:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
-// CHECK-NEXT:      emitc.assign %[[VAL_5]] : f32 to %[[VAL_8]] : f32
-// CHECK-NEXT:      emitc.for %[[VAL_9:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
-// CHECK-NEXT:        %[[VAL_10:.*]] = arith.addf %[[VAL_8]], %[[VAL_8]] : f32
-// CHECK-NEXT:        emitc.assign %[[VAL_10]] : f32 to %[[VAL_8]] : f32
+// CHECK-NEXT:    %[[VAL_4:.*]] = memref.alloca() : memref<f32>
+// CHECK-NEXT:    memref.store %[[VAL_3]], %[[VAL_4]][] : memref<f32>
+// CHECK-NEXT:    emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:      %[[VAL_6:.*]] = memref.load %[[VAL_4]][] : memref<f32>
+// CHECK-NEXT:      %[[VAL_7:.*]] = memref.alloca() : memref<f32>
+// CHECK-NEXT:      memref.store %[[VAL_6]], %[[VAL_7]][] : memref<f32>
+// CHECK-NEXT:      emitc.for %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+// CHECK-NEXT:        %[[VAL_9:.*]] = memref.load %[[VAL_7]][] : memref<f32>
+// CHECK-NEXT:        %[[VAL_10:.*]] = arith.addf %[[VAL_9]], %[[VAL_9]] : f32
+// CHECK-NEXT:        memref.store %[[VAL_10]], %[[VAL_7]][] : memref<f32>
 // CHECK-NEXT:      }
-// CHECK-NEXT:      emitc.assign %[[VAL_8]] : f32 to %[[VAL_7]] : f32
-// CHECK-NEXT:      emitc.assign %[[VAL_7]] : f32 to %[[VAL_5]] : f32
+// CHECK-NEXT:      %[[VAL_11:.*]] = memref.load %[[VAL_7]][] : memref<f32>
+// CHECK-NEXT:      memref.store %[[VAL_11]], %[[VAL_4]][] : memref<f32>
 // CHECK-NEXT:    }
-// CHECK-NEXT:    emitc.assign %[[VAL_5]] : f32 to %[[VAL_4]] : f32
-// CHECK-NEXT:    return %[[VAL_4]] : f32
+// CHECK-NEXT:    %[[VAL_12:.*]] = memref.load %[[VAL_4]][] : memref<f32>
+// CHECK-NEXT:    return %[[VAL_12]] : f32
 // CHECK-NEXT:  }
diff --git a/mlir/test/Conversion/SCFToEmitC/if.mlir b/mlir/test/Conversion/SCFToEmitC/if.mlir
index afc9abc761eb4..0753d9eda3283 100644
--- a/mlir/test/Conversion/SCFToEmitC/if.mlir
+++ b/mlir/test/Conversion/SCFToEmitC/if.mlir
@@ -53,18 +53,20 @@ func.func @test_if_yield(%arg0: i1, %arg1: f32) {
 // CHECK-SAME:                           %[[VAL_0:.*]]: i1,
 // CHECK-SAME:                           %[[VAL_1:.*]]: f32) {
 // CHECK-NEXT:    %[[VAL_2:.*]] = arith.constant 0 : i8
-// CHECK-NEXT:    %[[VAL_3:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
-// CHECK-NEXT:    %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f64
+// CHECK-NEXT:    %[[VAL_3:.*]] = memref.alloca() : memref<i32>
+// CHECK-NEXT:    %[[VAL_4:.*]] = memref.alloca() : memref<f64>
 // CHECK-NEXT:    emitc.if %[[VAL_0]] {
 // CHECK-NEXT:      %[[VAL_5:.*]] = emitc.call_opaque "func_true_1"(%[[VAL_1]]) : (f32) -> i32
 // CHECK-NEXT:      %[[VAL_6:.*]] = emitc.call_opaque "func_true_2"(%[[VAL_1]]) : (f32) -> f64
-// CHECK-NEXT:      emitc.assign %[[VAL_5]] : i32 to %[[VAL_3]] : i32
-// CHECK-NEXT:      emitc.assign %[[VAL_6]] : f64 to %[[VAL_4]] : f64
+// CHECK-NEXT:      memref.store %[[VAL_5]], %[[VAL_3]][] : memref<i32>
+// CHECK-NEXT:      memref.store %[[VAL_6]], %[[VAL_4]][] : memref<f64>
 // CHECK-NEXT:    } else {
 // CHECK-NEXT:      %[[VAL_7:.*]] = emitc.call_opaque "func_false_1"(%[[VAL_1]]) : (f32) -> i32
 // CHECK-NEXT:      %[[VAL_8:.*]] = emitc.call_opaque "func_false_2"(%[[VAL_1]]) : (f32) -> f64
-// CHECK-NEXT:      emitc.assign %[[VAL_7]] : i32 to %[[VAL_3]] : i32
-// CHECK-NEXT:      emitc.assign %[[VAL_8]] : f64 to %[[VAL_4]] : f64
+// CHECK-NEXT:      memref.store %[[VAL_7]], %[[VAL_3]][] : memref<i32>
+// CHECK-NEXT: ...
[truncated]

@aniragil
Copy link
Contributor Author

This patch hopefully simplifies #91475 by consolidating scf's lowering memory semantics into the memref dialect.

@mgehre-amd
Copy link
Contributor

mgehre-amd commented May 21, 2024

I'm not convinced that this is the canonical way to handle rank-0 memrefs. For example, when you have a rank-0 memref as function argument and then the function body writes into it, you cannot turn the argument into a scalar.

There are alternative ways to lower this (e.g. promoting rank-0 memrefs to rank-1, or turning rank-0 memrefs into scalars to pointers), so this doesn't seem to be the canonical choice.

Converting rank-0 memrefs to scalars sounds like a transform that is not specific to emitc, so I'm more in favor of having this as a separate transform (in the memref dialect) instead if merging it into the memref-to-emitc conversion pass.

@aniragil
Copy link
Contributor Author

I'm not convinced that this is the canonical way to handle rank-0 memrefs. For example, when you have a rank-0 memref as function argument and then the function body writes into it, you cannot turn the argument into a scalar.

Good catch! Function arguments completely slipped my mind. We can perhaps temporarily exclude their support using the signature conversion mechanisms, but better have a complete design first. For now we can use rank-1 memrefs in the scf lowering pass.

While passing variables by reference in C requires conversion to pointer/array, C++ has native support for that. Both representations differ from the way C/C++ programmers define and use scalars within functions. It would be good to emit natural C/C++ code where possible. There are also other cases where different C-variants have distinct natural constructs, e.g. TupleType (lowered to C++'s std::tuple, but requires a struct in C), VectorType (OpenCL-C, CUDA, GCC vector extensions) and address spaces (OpenCL-C, CUDA).

There are alternative ways to lower this (e.g. promoting rank-0 memrefs to rank-1, or turning rank-0 memrefs into scalars to pointers)

Modeling variables consistently using !emitc.ptr was actually proposed as "Solution 1" in the variables RFC. This patch aimed to follow simons suggestion which suggests natural C variable syntax.

Converting rank-0 memrefs to scalars sounds like a transform that is not specific to emitc, so I'm more in favor of having this as a separate transform (in the memref dialect) instead if merging it into the memref-to-emitc conversion pass.

Doing Mem2Reg/SROA within the memref dialect where possible is indeed desired and not limited to rank-0 memrefs. Regardless, rank-0 memrefs model the concept of scalar variables which also lowers naturally and directly to llvm IR, e.g. the following

memref.global @scalar : memref<f32> = dense<1.0>
memref.global @array : memref<1xf32> = dense<1.0>

gets lowered to

  llvm.mlir.global external @scalar(1.000000e+00 : f32) {addr_space = 0 : i32} : f32
  llvm.mlir.global external @array(dense<1.000000e+00> : tensor<1xf32>) {addr_space = 0 : i32} : !llvm.array<1 x f32>

which translates to

@scalar = global float 1.000000e+00
@array = global [1 x float] [float 1.000000e+00]

@aniragil aniragil marked this pull request as draft May 26, 2024 08:11
@aniragil aniragil requested a review from marbre May 26, 2024 08:11
@aniragil
Copy link
Contributor Author

SCF if/for lowering using 1-D memrefs is up for review as #93371.
Turned this PR into a draft as a basis for discussions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants