Skip to content

Commit

Permalink
Use i32 indices when possible.
Browse files Browse the repository at this point in the history
Also enable MLIR loop fusions pre-ampere. Looks like something very low in the stack doesn't like 64 bit indices. It
works fine with 32 bit.

PiperOrigin-RevId: 609363660
  • Loading branch information
jreiffers authored and tensorflower-gardener committed Feb 22, 2024
1 parent 4d1da7b commit 5a79a18
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 30 deletions.
Expand Up @@ -811,10 +811,6 @@ bool IsHloConversionSupported(const HloFusionAdaptor& fusion,
}
auto cuda_compute_capability =
std::get<se::CudaComputeCapability>(compute_capability);
if (!cuda_compute_capability.IsAtLeastAmpere()) {
// Not all lowerings work with pre-ampere yet.
return false;
}

if (fusion.GetRoots().size() > 1) {
auto first_shape = fusion.GetRoots()[0].instruction().shape();
Expand Down
10 changes: 7 additions & 3 deletions third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc
Expand Up @@ -11,6 +11,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <limits>
#include <memory>
#include <utility>

Expand Down Expand Up @@ -121,9 +122,12 @@ mlir::Value CreateGep(mlir::Operation* op,
rewriter.setInsertionPoint(op);
mlir::Value index = rewriter.create<mlir::affine::AffineApplyOp>(
tensor.getLoc(), linearize_map, indices);
// TODO(jreiffers): Use i32 if the index is sufficiently small.
index = rewriter.create<mlir::arith::IndexCastUIOp>(
tensor.getLoc(), rewriter.getI64Type(), index);
auto index_ty =
ShapeUtil::ElementsIn(byte_shape) < std::numeric_limits<int32_t>::max()
? rewriter.getI32Type()
: rewriter.getI64Type();
index = rewriter.create<mlir::arith::IndexCastUIOp>(tensor.getLoc(), index_ty,
index);

auto tensor_ptr = rewriter
.create<mlir::UnrealizedConversionCastOp>(
Expand Down
Expand Up @@ -165,9 +165,10 @@ TEST_F(MlirFusionEmitterTest, CreateLLVMModule) {
// CHECK: define void @fusion(ptr noalias %[[IN:.*]], ptr noalias %[[OUT:.*]])
// CHECK: %[[TID:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
// CHECK: %[[EXT:.*]] = sext i32 %[[TID]] to i64
// CHECK: %[[IN_PTR:.*]] = getelementptr inbounds float, ptr %[[IN]], i64 %[[EXT]]
// CHECK: %[[TRUNC:.*]] = trunc i64 %[[EXT]] to i32
// CHECK: %[[IN_PTR:.*]] = getelementptr inbounds float, ptr %[[IN]], i32 %[[TRUNC]]
// CHECK: %[[VAL:.*]] = load float, ptr %[[IN_PTR]], align 4
// CHECK: %[[OUT_PTR:.*]] = getelementptr inbounds float, ptr %[[OUT]], i64 %[[EXT]]
// CHECK: %[[OUT_PTR:.*]] = getelementptr inbounds float, ptr %[[OUT]], i32 %[[TRUNC]]
// CHECK: store float %[[VAL]], ptr %[[OUT_PTR]], align 4
// CHECK: ret void
)"));
Expand Down
27 changes: 20 additions & 7 deletions third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc
Expand Up @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <functional>
#include <limits>
#include <memory>
#include <optional>
#include <utility>
Expand Down Expand Up @@ -118,7 +120,11 @@ struct RewriteAffineApply

RangeEvaluator range_evaluator(dim_ranges, symbol_ranges, op->getContext());
std::function<bool(mlir::AffineExpr)> can_be_lowered;
bool fits_32_bits = true;
can_be_lowered = [&](mlir::AffineExpr expr) {
auto range = range_evaluator.ComputeExpressionRange(expr);
fits_32_bits &= range.upper_bound < std::numeric_limits<int32_t>::max();

auto bin_op = llvm::dyn_cast<mlir::AffineBinaryOpExpr>(expr);
if (!bin_op) {
return true;
Expand All @@ -145,9 +151,11 @@ struct RewriteAffineApply
return rewriter.notifyMatchFailure(op,
"unable to lower the affine apply");
}

std::function<mlir::Value(mlir::AffineExpr)> lower;

mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
auto int_ty = fits_32_bits ? b.getI32Type() : b.getI64Type();
b.setInsertionPoint(op);
lower = [&](mlir::AffineExpr expr) -> mlir::Value {
if (auto bin_op = mlir::dyn_cast<mlir::AffineBinaryOpExpr>(expr)) {
Expand All @@ -169,20 +177,25 @@ struct RewriteAffineApply

switch (expr.getKind()) {
case mlir::AffineExprKind::Constant:
return b.create<mlir::arith::ConstantIndexOp>(
mlir::cast<mlir::AffineConstantExpr>(expr).getValue());
return b.create<mlir::arith::ConstantIntOp>(
mlir::cast<mlir::AffineConstantExpr>(expr).getValue(), int_ty);
case mlir::AffineExprKind::DimId:
return op.getDimOperands()[mlir::cast<mlir::AffineDimExpr>(expr)
.getPosition()];
return b.create<mlir::arith::IndexCastUIOp>(
int_ty, op.getDimOperands()[mlir::cast<mlir::AffineDimExpr>(expr)
.getPosition()]);
case mlir::AffineExprKind::SymbolId:
return op.getSymbolOperands()[mlir::cast<mlir::AffineSymbolExpr>(expr)
.getPosition()];
return b.create<mlir::arith::IndexCastUIOp>(
int_ty,
op.getSymbolOperands()[mlir::cast<mlir::AffineSymbolExpr>(expr)
.getPosition()]);
default:
ABSL_UNREACHABLE();
}
};

rewriter.replaceOp(op, lower(map.GetAffineMap().getResult(0)));
auto result = lower(map.GetAffineMap().getResult(0));
rewriter.replaceOp(
op, b.create<mlir::arith::IndexCastUIOp>(b.getIndexType(), result));
return mlir::success();
}
};
Expand Down
Expand Up @@ -35,7 +35,7 @@ module {
// CHECK: func.func @tensorarg(%[[ARG0:.*]]: !llvm.ptr
// CHECK-SAME: {xla.invariant, xla.slice_index = 0 : i64}, %[[ARG1:.*]]: index) -> f32 {
// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00
// CHECK-DAG: %[[IDX:.*]] = arith.index_castui %[[ARG1]] : index to i64
// CHECK-DAG: %[[IDX:.*]] = arith.index_castui %[[ARG1]] : index to i32
// CHECK-DAG: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX]]]
// CHECK-DAG: %[[V2:.*]] = llvm.load %[[PTR]] invariant
// CHECK: %[[RET:.*]] = call @add(%[[C2]], %[[V2]])
Expand Down Expand Up @@ -72,7 +72,7 @@ module {
// CHECK: @layout(%[[ARG0:.*]]: !llvm.ptr,
// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index
// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]](%[[X]], %[[Y]])
// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]]
// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i32
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]]
// CHECK: llvm.load %[[PTR]]

Expand Down Expand Up @@ -110,11 +110,25 @@ module {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
// CHECK: %[[CAST:.*]] = arith.index_castui %[[I]]
// CHECK: %[[CAST:.*]] = arith.index_castui %[[I]] : index to i32
// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[CAST]]]
// CHECK: llvm.store {{.*}}, %[[PTR]]
// CHECK: %[[INBOUNDS:.*]] = arith.cmpi
// CHECK: scf.if %[[INBOUNDS]] {
// CHECK: llvm.store
// CHECK-NEXT: }
// CHECK-NEXT: return
// CHECK-NEXT: return

// -----

module {
func.func @large_tensor(
%arg0: tensor<1024x1024x1024x6xf32>,
%arg1: index) -> f32 {
%v = tensor.extract %arg0[%arg1, %arg1, %arg1, %arg1] : tensor<1024x1024x1024x6xf32>
func.return %v : f32
}
}

// CHECK: @large_tensor
// CHECK: arith.index_castui {{.*}} : index to i64
Expand Up @@ -28,34 +28,55 @@ module {
// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x
// CHECK-DAG: %[[BID_X:.*]] = gpu.block_id x
// CHECK: scf.for %[[I:.*]] =
// CHECK: %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_X]], %[[C512]]
// CHECK: %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_X]], %[[C4]]
// CHECK: %[[BID_32:.*]] = arith.index_castui %[[BID_X]] : index to i32
// CHECK: %[[BLOCK_OFFSET:.*]] = arith.muli %[[BID_32]], %[[C512]]
// CHECK: %[[TID_32:.*]] = arith.index_castui %[[TID_X]] : index to i32
// CHECK: %[[THREAD_OFFSET:.*]] = arith.muli %[[TID_32]], %[[C4]]
// CHECK: %[[OFFSET:.*]] = arith.addi %[[BLOCK_OFFSET]], %[[THREAD_OFFSET]]
// CHECK: %[[UNROLL_OFFSET:.*]] = arith.addi %[[OFFSET]], %[[I]]
// CHECK: arith.index_castui %[[UNROLL_OFFSET]] : index to i64
// CHECK: %[[I_32:.*]] = arith.index_castui %[[I]] : index to i32
// CHECK: %[[UNROLL_OFFSET:.*]] = arith.addi %[[OFFSET]], %[[I_32]]
// CHECK: %[[UNROLL_INDEX:.*]] = arith.index_castui %[[UNROLL_OFFSET]] : i32 to index
// CHECK: arith.index_castui %[[UNROLL_INDEX]] : index to i64

// -----

module {
func.func @arg_ranges(%arg0: index {xla.range = [0 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index {
%0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1]
return %0 : index
}
}

// CHECK: @arg_ranges
// CHECK-NEXT: %[[C100:.*]] = arith.constant 100
// CHECK-NEXT: %[[RET:.*]] = arith.divui %{{.*}}, %[[C100]]
// CHECK-NEXT: %[[ARG0_32:.*]] = arith.index_castui {{.*}} : index to i32
// CHECK-NEXT: %[[RET_32:.*]] = arith.divui %[[ARG0_32]], %[[C100]]
// CHECK-NEXT: %[[RET:.*]] = arith.index_castui %[[RET_32]] : i32 to index
// CHECK-NEXT: return %[[RET]]


// -----

module {
func.func @arg_ranges(%arg0: index {xla.range = [0 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index {
%0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1]
func.func @needs_i64(%arg0: index {xla.range = [0 : index, 1000000000000 : index]}, %arg1: index {xla.range = [0 : index, 10 : index]}) -> index {
%0 = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%arg0, %arg1]
return %0 : index
}
}

// -----
// CHECK: @needs_i64
// CHECK: arith.index_castui {{.*}} : index to i64
// CHECK: arith.index_castui {{.*}} : index to i64
// CHECK: arith.index_castui {{.*}} : i64 to index

// CHECK: @cant_lower
// CHECK: affine.apply
// -----

module {
func.func @cant_lower(%arg0: index {xla.range = [-10 : index, 42 : index]}, %arg1: index {xla.range = [0 : index, 1000 : index]}) -> index {
%0 = affine.apply affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)>()[%arg0, %arg1]
return %0 : index
}
}

// CHECK: @cant_lower
// CHECK: affine.apply

0 comments on commit 5a79a18

Please sign in to comment.