Skip to content

Commit

Permalink
[Arith] Handle bitwise_and with power of 2 in modular set (#12272)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Aug 2, 2022
1 parent 7b72c4e commit bc91978
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
13 changes: 13 additions & 0 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
// used for index calculation.
if (op->op.same_as(tir::builtin::shift_right())) {
return VisitRightShift(op);
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
return VisitBitwiseAnd(op);
} else {
return Everything();
}
Expand All @@ -274,6 +276,17 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
return Everything();
}

Entry VisitBitwiseAnd(const CallNode* op) {
Entry b = VisitExpr(op->args[1]);
if (b.is_const()) {
int shift;
if (is_const_power_of_two_integer(Integer(b.base + 1), &shift)) {
return ModByConst(op->args[0], static_cast<int64_t>(1) << shift, true);
}
}
return Everything();
}

private:
/*! \brief pointer to parent. */
Analyzer* parent_{nullptr};
Expand Down
32 changes: 20 additions & 12 deletions tests/python/unittest/test_arith_modular_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import te


Expand Down Expand Up @@ -200,18 +201,25 @@ def test_let():
x = te.var("x")
y = te.var("y")
m = analyzer.modular_set(tvm.tir.Let(x, y * 10, x + 1))
m.coeff = 10
m.base = 1
assert m.coeff == 10
assert m.base == 1


def test_bitwise_and():
analyzer = tvm.arith.Analyzer()
x = te.var("x")
y = te.var("y")

# RHS of bitwise_and is 2^p - 1
m = analyzer.modular_set((x * 16 + y * 4) & 31)
assert m.coeff == 4
assert m.base == 0

# arbitrary RHS
m = analyzer.modular_set((x * 16 + y * 4) & 17)
assert m.coeff == 1
assert m.base == 0


if __name__ == "__main__":
test_let()
test_cast()
test_add_sub()
test_mul()
test_div_shift()
test_floormod()
test_min_max_select()
test_mix_index()
test_constraint_scope()
test_intersect()
tvm.testing.main()

0 comments on commit bc91978

Please sign in to comment.