Skip to content

Commit

Permalink
[XLA:GPU] Emit matrix-vector multiplication as GemmFusion
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10638 from jaro-sevcik:host-allocation-custom-call 0f78007475921031549c8555ad56ed8903efe6cb
PiperOrigin-RevId: 619001044
  • Loading branch information
anlunx authored and tensorflower-gardener committed Apr 25, 2024
1 parent 612cea9 commit 73c1c2e
Show file tree
Hide file tree
Showing 23 changed files with 375 additions and 134 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/tfrt/saved_model/saved_model_util.cc
Expand Up @@ -367,7 +367,7 @@ absl::Status DeserializeAoTMlirModule(
}

CallableOptions CombineSignatureDefs(
const proto2::Map<std::string, SignatureDef>& signature_defs) {
const google::protobuf::Map<std::string, SignatureDef>& signature_defs) {
CallableOptions callable_options;
for (const auto& sig_iter : signature_defs) {
const auto& signature_def = sig_iter.second;
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/tfrt/saved_model/saved_model_util.h
Expand Up @@ -142,7 +142,7 @@ absl::Status DeserializeAoTMlirModule(
mlir::OwningOpRef<mlir::ModuleOp>* mlir_module);

CallableOptions CombineSignatureDefs(
const proto2::Map<std::string, SignatureDef>& signature_defs);
const google::protobuf::Map<std::string, SignatureDef>& signature_defs);

void RegisterTfrtDialectsForAot(mlir::DialectRegistry& registry);

Expand Down
75 changes: 37 additions & 38 deletions third_party/stablehlo/temporary.patch
Expand Up @@ -446,7 +446,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/CMakeLists.txt b/stablehlo
diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp
--- stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp
+++ stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp
@@ -0,0 +1,506 @@
@@ -0,0 +1,505 @@
+/* Copyright 2023 The StableHLO Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -521,7 +521,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ // reduce_window_i1
+ SmallVector<ShapedType> inputTypes;
+ for (auto [index, input] : llvm::enumerate(inputs)) {
+ auto inputType = input.getType().dyn_cast<ShapedType>();
+ auto inputType = dyn_cast<ShapedType>(input.getType());
+ inputTypes.push_back(inputType);
+ if (!inputType)
+ return op_.emitError()
Expand All @@ -531,7 +531,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ // reduce_window_i2
+ SmallVector<ShapedType> initValueTypes;
+ for (auto [index, initValue] : llvm::enumerate(initValues)) {
+ auto initValueType = initValue.getType().dyn_cast<ShapedType>();
+ auto initValueType = dyn_cast<ShapedType>(initValue.getType());
+ initValueTypes.push_back(initValueType);
+ if (!initValueType || !initValueType.hasRank() ||
+ initValueType.getRank() != 0)
Expand All @@ -543,7 +543,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ // reduce_window_i3...reduce_window_i7
+ auto checkRank = [&](StringRef name, int64_t index, Value dynamicAttr,
+ int64_t expectedRank) -> LogicalResult {
+ auto type = dynamicAttr.getType().dyn_cast<ShapedType>();
+ auto type = dyn_cast<ShapedType>(dynamicAttr.getType());
+ if (!type || !type.hasRank() || type.getRank() != expectedRank ||
+ !type.getElementType().isIntOrIndex()) {
+ if (index < 0) index += op_->getNumOperands();
Expand All @@ -562,7 +562,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ return failure();
+
+ // reduce_window_i7
+ auto paddingType = getPadding().getType().dyn_cast<ShapedType>();
+ auto paddingType = dyn_cast<ShapedType>(getPadding().getType());
+ if (!paddingType || !paddingType.hasRank() || paddingType.getRank() != 2 ||
+ paddingType.getDimSize(1) != 2 ||
+ !paddingType.getElementType().isIntOrIndex())
Expand Down Expand Up @@ -598,7 +598,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ // verify them in that case, that seems like too much at this point.
+ auto checkShape = [&](StringRef name, int64_t index, Value dynamicAttr,
+ ArrayRef<int64_t> expectedShape) -> LogicalResult {
+ auto type = dynamicAttr.getType().cast<ShapedType>();
+ auto type = cast<ShapedType>(dynamicAttr.getType());
+ if (type.getShape() != expectedShape) {
+ if (index < 0) index += op_->getNumOperands();
+ return op_.emitError()
Expand All @@ -622,7 +622,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ // reduce_window_c13
+ if (op_.getCalledComputations().size() != 1)
+ return op_.emitError() << "expects called_computations to have 1 element";
+ auto bodyAttr = op_.getCalledComputations()[0].cast<FlatSymbolRefAttr>();
+ auto bodyAttr = cast<FlatSymbolRefAttr>(op_.getCalledComputations()[0]);
+ auto bodyFunc =
+ op_->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(bodyAttr);
+ if (!bodyFunc)
Expand All @@ -644,7 +644,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ SmallVector<ShapedType> resultTypes;
+ std::optional<ArrayRef<int64_t>> resultShape;
+ for (auto result : results) {
+ auto resultType = result.getType().dyn_cast<ShapedType>();
+ auto resultType = dyn_cast<ShapedType>(result.getType());
+ resultTypes.push_back(resultType);
+ if (!resultType) return op_.emitError() << "expects results to be tensors";
+
Expand Down Expand Up @@ -683,32 +683,32 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+}
+
+TypedValue<ShapedType> DynamicReduceWindowOpAdaptor::getWindowDimensions() {
+ return op_.getInputs()[op_.getInputs().size() - 5]
+ .cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(
+ op_.getInputs()[op_.getInputs().size() - 5]);
+}
+
+TypedValue<ShapedType> DynamicReduceWindowOpAdaptor::getWindowStrides() {
+ return op_.getInputs()[op_.getInputs().size() - 4]
+ .cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(
+ op_.getInputs()[op_.getInputs().size() - 4]);
+}
+
+TypedValue<ShapedType> DynamicReduceWindowOpAdaptor::getBaseDilations() {
+ return op_.getInputs()[op_.getInputs().size() - 3]
+ .cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(
+ op_.getInputs()[op_.getInputs().size() - 3]);
+}
+
+TypedValue<ShapedType> DynamicReduceWindowOpAdaptor::getWindowDilations() {
+ return op_.getInputs()[op_.getInputs().size() - 2]
+ .cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(
+ op_.getInputs()[op_.getInputs().size() - 2]);
+}
+
+TypedValue<ShapedType> DynamicReduceWindowOpAdaptor::getPadding() {
+ return op_.getInputs()[op_.getInputs().size() - 1]
+ .cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(
+ op_.getInputs()[op_.getInputs().size() - 1]);
+}
+
+Region& DynamicReduceWindowOpAdaptor::getBody() {
+ auto bodyAttr = op_.getCalledComputations()[0].cast<FlatSymbolRefAttr>();
+ auto bodyAttr = cast<FlatSymbolRefAttr>(op_.getCalledComputations()[0]);
+ auto bodyFunc =
+ op_->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(bodyAttr);
+ return bodyFunc.getBody();
Expand Down Expand Up @@ -758,20 +758,20 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ auto output = op_.getResults()[1];
+
+ // dynamic_rng_bit_generator_i1
+ if (!rngAlgorithmAttr.isa<RngAlgorithmAttr>())
+ if (!isa<RngAlgorithmAttr>(rngAlgorithmAttr))
+ return op_.emitError()
+ << "expects a #stablehlo<rng_algorithm ...> rng_algorithm";
+
+ // dynamic_rng_bit_generator_i2
+ // TODO(#643): Clarify supported types for RngBitGeneratorOp.
+ auto initialStateType = initialState.getType().dyn_cast<ShapedType>();
+ auto initialStateType = dyn_cast<ShapedType>(initialState.getType());
+ if (!initialStateType || !initialStateType.getElementType().isIntOrFloat())
+ return op_.emitError()
+ << "expects initial_state (operand #0) "
+ << "to be a tensor of integer or floating-point type";
+
+ // dynamic_rng_bit_generator_i3
+ auto outputShapeType = outputShape.getType().dyn_cast<ShapedType>();
+ auto outputShapeType = dyn_cast<ShapedType>(outputShape.getType());
+ if (!outputShapeType || !outputShapeType.hasRank() ||
+ outputShapeType.getRank() != 1 ||
+ !outputShapeType.getElementType().isIntOrIndex())
Expand All @@ -781,14 +781,14 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+
+ // dynamic_rng_bit_generator_o1
+ // TODO(#643): Clarify supported types for RngBitGeneratorOp.
+ auto outputStateType = outputState.getType().dyn_cast<ShapedType>();
+ auto outputStateType = dyn_cast<ShapedType>(outputState.getType());
+ if (!outputStateType || !outputStateType.getElementType().isIntOrFloat())
+ return op_.emitError()
+ << "expects output_state (result #0) "
+ << "to be a tensor of integer or floating-point type";
+
+ // dynamic_rng_bit_generator_o2
+ auto outputType = output.getType().dyn_cast<ShapedType>();
+ auto outputType = dyn_cast<ShapedType>(output.getType());
+ if (!outputType || !outputType.getElementType().isIntOrFloat())
+ return op_.emitError()
+ << "expects output (result #1) "
Expand All @@ -812,25 +812,24 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+}
+
+RngAlgorithm DynamicRngBitGeneratorOpAdaptor::getRngAlgorithm() {
+ return op_->getDiscardableAttr("rng_algorithm")
+ .cast<RngAlgorithmAttr>()
+ return cast<RngAlgorithmAttr>(op_->getDiscardableAttr("rng_algorithm"))
+ .getValue();
+}
+
+TypedValue<ShapedType> DynamicRngBitGeneratorOpAdaptor::getInitialState() {
+ return op_.getInputs()[0].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getInputs()[0]);
+}
+
+TypedValue<ShapedType> DynamicRngBitGeneratorOpAdaptor::getOutputShape() {
+ return op_.getInputs()[1].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getInputs()[1]);
+}
+
+TypedValue<ShapedType> DynamicRngBitGeneratorOpAdaptor::getOutputState() {
+ return op_.getResults()[0].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getResults()[0]);
+}
+
+TypedValue<ShapedType> DynamicRngBitGeneratorOpAdaptor::getOutput() {
+ return op_.getResults()[1].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getResults()[1]);
+}
+
+std::optional<DynamicRngBitGeneratorOpAdaptor> getDynamicRngBitGeneratorOp(
Expand Down Expand Up @@ -864,7 +863,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ auto indices = op_.getResults()[1];
+
+ // dynamic_top_k_i1
+ auto operandType = operand.getType().dyn_cast<ShapedType>();
+ auto operandType = dyn_cast<ShapedType>(operand.getType());
+ if (!operandType || !operandType.hasRank() || operandType.getRank() < 1 ||
+ !operandType.getElementType().isIntOrFloat())
+ return op_.emitError()
Expand All @@ -873,15 +872,15 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ << "of rank at least 1";
+
+ // dynamic_top_k_i2
+ auto kType = k.getType().dyn_cast<ShapedType>();
+ auto kType = dyn_cast<ShapedType>(k.getType());
+ if (!kType || !kType.hasRank() || kType.getRank() != 0 ||
+ !kType.getElementType().isIntOrIndex())
+ return op_.emitError()
+ << "expects k (operand #1) "
+ << "to be a 0-dimensional tensor of integer or index type";
+
+ // dynamic_top_k_o1
+ auto valuesType = values.getType().dyn_cast<ShapedType>();
+ auto valuesType = dyn_cast<ShapedType>(values.getType());
+ if (!valuesType || !valuesType.hasRank() || valuesType.getRank() < 1 ||
+ !valuesType.getElementType().isIntOrFloat())
+ return op_.emitError()
Expand All @@ -890,7 +889,7 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+ << "of rank at least 1";
+
+ // dynamic_top_k_o2
+ auto indicesType = indices.getType().dyn_cast<ShapedType>();
+ auto indicesType = dyn_cast<ShapedType>(indices.getType());
+ if (!indicesType || !indicesType.hasRank() || indicesType.getRank() < 1 ||
+ !indicesType.getElementType().isSignlessInteger(32))
+ return op_.emitError() << "expects indices (result #1) "
Expand Down Expand Up @@ -930,19 +929,19 @@ diff --ruN a/stablehlo/stablehlo/experimental/dialect/StablehloOps.cpp b/stableh
+}
+
+TypedValue<ShapedType> DynamicTopKOpAdaptor::getOperand() {
+ return op_.getInputs()[0].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getInputs()[0]);
+}
+
+TypedValue<ShapedType> DynamicTopKOpAdaptor::getK() {
+ return op_.getInputs()[1].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getInputs()[1]);
+}
+
+TypedValue<ShapedType> DynamicTopKOpAdaptor::getValues() {
+ return op_.getResults()[0].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getResults()[0]);
+}
+
+TypedValue<ShapedType> DynamicTopKOpAdaptor::getIndices() {
+ return op_.getResults()[1].cast<TypedValue<ShapedType>>();
+ return cast<TypedValue<ShapedType>>(op_.getResults()[1]);
+}
+
+std::optional<DynamicTopKOpAdaptor> getDynamicTopKOp(CustomCallOp op) {
Expand Down

0 comments on commit 73c1c2e

Please sign in to comment.