Skip to content

Commit

Permalink
[GmlSt] Implement common interface to get loop destination for 3 gml_…
Browse files Browse the repository at this point in the history
…st loop ops

PiperOrigin-RevId: 484774823
  • Loading branch information
tensorflower-gardener committed Oct 29, 2022
1 parent 345477a commit 269e683
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
Expand Up @@ -300,13 +300,17 @@ def GMLST_LoopOp : GMLST_Op<"loop", [
return "iterator_types";
}


/// Return whether the loop dimension is parallel or not.
bool isParallelDimension(unsigned dim) {
IteratorTypeAttr attr =
this->getIteratorTypes()[dim].cast<IteratorTypeAttr>();
return attr.getValue() == utils::IteratorType::parallel;
}

/// Return the destinations for a gml_st.loop op.
ValueRange getLoopLikeOpInits() {
return getOutputs();
}
}];

let hasCanonicalizer = 1;
Expand Down
Expand Up @@ -204,7 +204,10 @@ def GMLST_ParallelOp : GMLST_LoopLikeOp<"parallel", []> {
"nullptr">:$bodyBuilderFn)>,
];

let extraClassDeclaration = extraBaseClassDeclaration;
let extraClassDeclaration = extraBaseClassDeclaration # [{
/// Return the destinations for a gml_st.parallel op.
ValueRange getLoopLikeOpInits();
}];
}

def GMLST_ForOp : GMLST_LoopLikeOp<"for", []> {
Expand Down Expand Up @@ -304,6 +307,11 @@ def GMLST_ForOp : GMLST_LoopLikeOp<"for", []> {
return getOperation()->getOpOperand(
getNumControlOperands() + opResult.getResultNumber());
}

/// Return the destinations for a gml_st.for op.
ValueRange getLoopLikeOpInits() {
return getOutputs();
}
}];

let hasCanonicalizer = 1;
Expand Down
Expand Up @@ -697,6 +697,10 @@ ParseResult ParallelOp::parse(OpAsmParser &parser, OperationState &result) {
return parseLoopLikeOp<ParallelOp>(parser, result);
}

ValueRange ParallelOp::getLoopLikeOpInits() {
return getTerminator().getDsts();
}

//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
Expand Down
Expand Up @@ -409,14 +409,6 @@ ForOp vectorizeLoopLikeOp(ForOp op, BlockAndValueMapping &bvm,
});
}

// Returns the destinations for a gml_st.parallel op.
ValueRange getLoopLikeOpInits(ParallelOp op) {
return op.getTerminator().getDsts();
}

// Returns the destinations for a gml_st.for op.
ValueRange getLoopLikeOpInits(ForOp op) { return op.getOutputs(); }

template <typename LoopLikeOp>
struct LoopLikeOpVectorizationPattern : public OpRewritePattern<LoopLikeOp> {
LoopLikeOpVectorizationPattern(MLIRContext *context,
Expand Down Expand Up @@ -465,7 +457,7 @@ struct LoopLikeOpVectorizationPattern : public OpRewritePattern<LoopLikeOp> {
auto vectorLoopLikeOp = vectorizeLoopLikeOp(op, bvm, rewriter);
bvm.map(op.getResults(), vectorLoopLikeOp.getResults());

convertVectorResultsToTensor(op->getResults(), getLoopLikeOpInits(op), bvm,
convertVectorResultsToTensor(op->getResults(), op.getLoopLikeOpInits(), bvm,
rewriter);
SmallVector<Value, 1> mappedResults = llvm::to_vector<1>(llvm::map_range(
op.getResults(), [&](Value v) { return bvm.lookupOrDefault(v); }));
Expand Down

0 comments on commit 269e683

Please sign in to comment.