Skip to content

Commit

Permalink
[TIR] Allow IndexMap applied to arguments with different dtypes (#13085)
Browse files Browse the repository at this point in the history
* [TIR] Allow IndexMap applied to arguments with different dtypes

* address comments

* Add SubstituteWithDataTypeLegalization
  • Loading branch information
vinx13 committed Oct 19, 2022
1 parent 687ef5b commit 458ca81
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 4 deletions.
26 changes: 26 additions & 0 deletions include/tvm/tir/stmt_functor.h
Expand Up @@ -409,6 +409,32 @@ inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>&
return Substitute(std::move(input), vmap);
}

/*!
* \brief Substitute the var specified by vmap and legalize data types after substitution.
* \param stmt The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
*
* Unlike `Substitute`, this allows the substitution to change the data type of the expression.
*
* \sa Substitute
* \return The result.
*/
TVM_DLL Stmt SubstituteWithDataTypeLegalization(Stmt stmt,
std::function<Optional<PrimExpr>(const Var&)> vmap);

/*!
* \brief Substitute the var specified by vmap and legalize data types after substitution.
* \param expr The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
*
* Unlike `Substitute`, this allows the substitution to change the data type of the expression.
*
* \sa Substitute
* \return The result.
*/
TVM_DLL PrimExpr SubstituteWithDataTypeLegalization(
PrimExpr expr, std::function<Optional<PrimExpr>(const Var&)> vmap);

/*!
* \brief Recursively visit the IR in pre DFS order node, apply fvisit.
* If fvisit returns false, it won't visit the children of the node.
Expand Down
25 changes: 21 additions & 4 deletions src/tir/ir/index_map.cc
Expand Up @@ -162,9 +162,11 @@ Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices,
analyzer = &local_analyzer;
}

Array<PrimExpr> output = final_indices.Map(
[&](PrimExpr index) { return analyzer->Simplify(Substitute(std::move(index), vmap)); });

Array<PrimExpr> output = final_indices.Map([&](PrimExpr index) {
PrimExpr result = SubstituteWithDataTypeLegalization(
std::move(index), [&](const Var& var) { return vmap.Get(var); });
return analyzer->Simplify(result);
});
return output;
}

Expand Down Expand Up @@ -218,6 +220,21 @@ Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges, arith::Analyzer
analyzer->Simplify(int_set.max() - int_set.min() + 1)));
}
}
auto output_dtype = [&]() {
int max_bits = 0;
for (const auto& range : ranges) {
max_bits = std::max(max_bits, range->extent.dtype().bits());
}
return DataType::Int(max_bits);
}();
output.MutateByApply([&](const Range& range) {
if (range->min.dtype() != output_dtype || range->extent.dtype() != output_dtype) {
return Range::FromMinExtent(cast(output_dtype, range->min),
cast(output_dtype, range->extent));
} else {
return range;
}
});
return output;
}

Expand All @@ -227,7 +244,7 @@ Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,

Array<Range> ranges;
for (auto& dim : shape) {
ranges.push_back(Range(0, dim));
ranges.push_back(Range(make_zero(dim.dtype()), dim));
}
Array<Range> mapped = MapRanges(std::move(ranges), analyzer);

Expand Down
89 changes: 89 additions & 0 deletions src/tir/ir/stmt_functor.cc
Expand Up @@ -809,6 +809,95 @@ void PreOrderVisit(const ObjectRef& stmt_or_expr,
}
}

class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer {
public:
explicit IRSubstituteWithDataTypeLegalization(std::function<Optional<PrimExpr>(const Var&)> vmap)
: vmap_(vmap) {}

PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
auto ret = vmap_(var);
if (ret.defined()) {
return ret.value();
}
return std::move(var);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return VisitBufferAccess(std::move(node));
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
return VisitBufferAccess(std::move(node));
}

template <typename Node>
Node VisitBufferAccess(Node node) {
Buffer new_buf = GetRemappedBuffer(node->buffer);

if (!new_buf.same_as(node->buffer)) {
auto writer = node.CopyOnWrite();
writer->buffer = new_buf;
}

return node;
}

Buffer GetRemappedBuffer(Buffer buf) {
auto key = buf.get();
auto it = buf_remap_.find(key);
if (it != buf_remap_.end()) {
return it->second;
}

auto new_buffer_var = vmap_(buf->data);
if (new_buffer_var.defined() && !new_buffer_var.value().same_as(buf->data)) {
auto writer = buf.CopyOnWrite();
writer->data = Downcast<Var>(new_buffer_var);
}

buf_remap_[key] = buf;
return buf;
}

Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
// remap var node in attr
if (const auto* var_node = op->node.as<VarNode>()) {
if (auto mapped_var = vmap_(GetRef<Var>(var_node))) {
return AttrStmt(mapped_var, op->attr_key, op->value, op->body);
}
}
return ret;
}

private:
// Caller provided function that defines the variables to be remapped.
std::function<Optional<PrimExpr>(const Var&)> vmap_;

/* \brief Generated map to track buffers being remapped.
*
* If a `Var BufferNode::data` is remapped, then all buffers
* containing that data pointer should also be remapped. This map
* is used to track buffer modifications, and ensure all instances
* of a buffer are replaced by the same modified buffer object.
*/
std::unordered_map<const BufferNode*, Buffer> buf_remap_;
};

Stmt SubstituteWithDataTypeLegalization(Stmt stmt,
std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstituteWithDataTypeLegalization(vmap)(std::move(stmt));
}

PrimExpr SubstituteWithDataTypeLegalization(PrimExpr expr,
std::function<Optional<PrimExpr>(const Var&)> vmap) {
return IRSubstituteWithDataTypeLegalization(vmap)(std::move(expr));
}

TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform);

TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) {
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_index_map.py
Expand Up @@ -21,6 +21,7 @@
import tvm.testing
from tvm.ir import assert_structural_equal
from tvm.tir import IndexMap, IntImm, floordiv, floormod
from tvm.runtime import const


def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None:
Expand All @@ -41,6 +42,9 @@ def test_index_mapping():
assert_structural_equal(index_map.map_indices([3]), [0, 3])
assert_structural_equal(index_map.map_indices([4]), [1, 0])
assert_structural_equal(index_map.map_indices([42]), [10, 2])
assert_structural_equal(
index_map.map_indices([const(42, "int64")]), [const(10, "int64"), const(2, "int64")]
)


def test_shape_mapping():
Expand All @@ -50,6 +54,12 @@ def test_shape_mapping():
assert_structural_equal(index_map.map_shape([16]), [4, 4])

assert_structural_equal(index_map.map_shape([14]), [4, 4])
assert_structural_equal(
index_map.map_shape([const(16, "int64")]), [const(4, "int64"), const(4, "int64")]
)
assert_structural_equal(
index_map.map_shape([const(14, "int64")]), [const(4, "int64"), const(4, "int64")]
)


def test_inverse():
Expand Down
35 changes: 35 additions & 0 deletions tests/python/unittest/test_tir_schedule_transform_layout.py
Expand Up @@ -376,6 +376,41 @@ def test_transform_block_layout_fail_mixed_iter_type(use_block_name):
)


def test_transform_block_layout_int64_extent(use_block_name):
@T.prim_func
def elementwise_int64_extent(
A: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
B: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
) -> None:
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

@T.prim_func
def elementwise_int64_extent_transformed(
A: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
B: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
) -> None:
for i in range(T.int64(16384)):
with T.block("B"):
vi = T.axis.remap("S", [i])
B[vi // T.int64(128), vi % T.int64(128)] = (
A[vi // T.int64(128), vi % T.int64(128)] * 2.0
)

sch = tir.Schedule(elementwise_int64_extent, debug_mask="all")
block = "B" if use_block_name else sch.get_block("B")
sch.transform_block_layout(block, lambda i, j: (i * 128 + j,))
print(
tvm.ir.base.get_first_structural_mismatch(
elementwise_int64_extent_transformed, sch.mod["main"]
)
)
tvm.ir.assert_structural_equal(elementwise_int64_extent_transformed, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise_int64_extent)


class BasePaddingCompare(tvm.testing.CompareBeforeAfter):
pad_value = tvm.testing.parameter(None)

Expand Down

0 comments on commit 458ca81

Please sign in to comment.