Skip to content

Commit

Permalink
[mypyc] Speed up in operations for list/tuple (#9004)
Browse files Browse the repository at this point in the history
When right hand side of a in/not in operation is a literal
list/tuple, simplify it into simpler direct equality comparison
expressions and use binary and/or to join them.

Yields speedup of up to 46% in micro benchmarks.

Co-authored-by: Johan Dahlin <johan.dahlin@textual.se>
Co-authored-by: Tomer Chachamu <tomer.chachamu@gmail.com>
Co-authored-by: Xuanda Yang <th3charlie@gmail.com>
  • Loading branch information
4 people committed Sep 28, 2020
1 parent 4fb5a21 commit 8bf770d
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 3 deletions.
54 changes: 51 additions & 3 deletions mypyc/irbuild/expression.py
Expand Up @@ -4,7 +4,7 @@
and mypyc.irbuild.builder.
"""

from typing import List, Optional, Union, Callable
from typing import List, Optional, Union, Callable, cast

from mypy.nodes import (
Expression, NameExpr, MemberExpr, SuperExpr, CallExpr, UnaryExpr, OpExpr, IndexExpr,
Expand All @@ -13,7 +13,7 @@
SetComprehension, DictionaryComprehension, SliceExpr, GeneratorExpr, CastExpr, StarExpr,
Var, RefExpr, MypyFile, TypeInfo, TypeApplication, LDEF, ARG_POS
)
from mypy.types import TupleType, get_proper_type
from mypy.types import TupleType, get_proper_type, Instance

from mypyc.common import MAX_LITERAL_SHORT_INT
from mypyc.ir.ops import (
Expand Down Expand Up @@ -406,8 +406,56 @@ def transform_conditional_expr(builder: IRBuilder, expr: ConditionalExpr) -> Val


def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
# TODO: Don't produce an expression when used in conditional context
# x in (...)/[...]
# x not in (...)/[...]
if (e.operators[0] in ['in', 'not in']
and len(e.operators) == 1
and isinstance(e.operands[1], (TupleExpr, ListExpr))):
items = e.operands[1].items
n_items = len(items)
# x in y -> x == y[0] or ... or x == y[n]
# x not in y -> x != y[0] and ... and x != y[n]
# 16 is arbitrarily chosen to limit code size
if 1 < n_items < 16:
if e.operators[0] == 'in':
bin_op = 'or'
cmp_op = '=='
else:
bin_op = 'and'
cmp_op = '!='
lhs = e.operands[0]
mypy_file = builder.graph['builtins'].tree
assert mypy_file is not None
bool_type = Instance(cast(TypeInfo, mypy_file.names['bool'].node), [])
exprs = []
for item in items:
expr = ComparisonExpr([cmp_op], [lhs, item])
builder.types[expr] = bool_type
exprs.append(expr)

or_expr = exprs.pop(0) # type: Expression
for expr in exprs:
or_expr = OpExpr(bin_op, or_expr, expr)
builder.types[or_expr] = bool_type
return builder.accept(or_expr)
# x in [y]/(y) -> x == y
# x not in [y]/(y) -> x != y
elif n_items == 1:
if e.operators[0] == 'in':
cmp_op = '=='
else:
cmp_op = '!='
e.operators = [cmp_op]
e.operands[1] = items[0]
# x in []/() -> False
# x not in []/() -> True
elif n_items == 0:
if e.operators[0] == 'in':
return builder.false()
else:
return builder.true()

# TODO: Don't produce an expression when used in conditional context
# All of the trickiness here is due to support for chained conditionals
# (`e1 < e2 > e3`, etc). `e1 < e2 > e3` is approximately equivalent to
# `e1 < e2 and e2 > e3` except that `e2` is only evaluated once.
Expand Down
64 changes: 64 additions & 0 deletions mypyc/test-data/irbuild-tuple.test
Expand Up @@ -181,3 +181,67 @@ L2:
r2 = CPySequenceTuple_GetItem(nt, 2)
r3 = unbox(int, r2)
return r3


[case testTupleOperatorIn]
def f(i: int) -> bool:
return i in [1, 2, 3]
[out]
def f(i):
i :: int
r0, r1, r2 :: bool
r3 :: native_int
r4, r5, r6, r7 :: bool
r8 :: native_int
r9, r10, r11, r12 :: bool
r13 :: native_int
r14, r15, r16 :: bool
L0:
r3 = i & 1
r4 = r3 == 0
if r4 goto L1 else goto L2 :: bool
L1:
r5 = i == 2
r2 = r5
goto L3
L2:
r6 = CPyTagged_IsEq_(i, 2)
r2 = r6
L3:
if r2 goto L4 else goto L5 :: bool
L4:
r1 = r2
goto L9
L5:
r8 = i & 1
r9 = r8 == 0
if r9 goto L6 else goto L7 :: bool
L6:
r10 = i == 4
r7 = r10
goto L8
L7:
r11 = CPyTagged_IsEq_(i, 4)
r7 = r11
L8:
r1 = r7
L9:
if r1 goto L10 else goto L11 :: bool
L10:
r0 = r1
goto L15
L11:
r13 = i & 1
r14 = r13 == 0
if r14 goto L12 else goto L13 :: bool
L12:
r15 = i == 6
r12 = r15
goto L14
L13:
r16 = CPyTagged_IsEq_(i, 6)
r12 = r16
L14:
r0 = r12
L15:
return r0
119 changes: 119 additions & 0 deletions mypyc/test-data/run-lists.test
Expand Up @@ -149,3 +149,122 @@ def test_slicing() -> None:
assert s[1:long_int] == ["o", "o", "b", "a", "r"]
assert s[long_int:] == []
assert s[-long_int:-1] == ["f", "o", "o", "b", "a"]

[case testOperatorInExpression]

def tuple_in_int0(i: int) -> bool:
return i in []

def tuple_in_int1(i: int) -> bool:
return i in (1,)

def tuple_in_int3(i: int) -> bool:
return i in (1, 2, 3)

def tuple_not_in_int0(i: int) -> bool:
return i not in []

def tuple_not_in_int1(i: int) -> bool:
return i not in (1,)

def tuple_not_in_int3(i: int) -> bool:
return i not in (1, 2, 3)

def tuple_in_str(s: "str") -> bool:
return s in ("foo", "bar", "baz")

def tuple_not_in_str(s: "str") -> bool:
return s not in ("foo", "bar", "baz")

def list_in_int0(i: int) -> bool:
return i in []

def list_in_int1(i: int) -> bool:
return i in (1,)

def list_in_int3(i: int) -> bool:
return i in (1, 2, 3)

def list_not_in_int0(i: int) -> bool:
return i not in []

def list_not_in_int1(i: int) -> bool:
return i not in (1,)

def list_not_in_int3(i: int) -> bool:
return i not in (1, 2, 3)

def list_in_str(s: "str") -> bool:
return s in ("foo", "bar", "baz")

def list_not_in_str(s: "str") -> bool:
return s not in ("foo", "bar", "baz")

def list_in_mixed(i: object):
return i in [[], (), "", 0, 0.0, False, 0j, {}, set(), type]

[file driver.py]

from native import *

assert not tuple_in_int0(0)
assert not tuple_in_int1(0)
assert tuple_in_int1(1)
assert not tuple_in_int3(0)
assert tuple_in_int3(1)
assert tuple_in_int3(2)
assert tuple_in_int3(3)
assert not tuple_in_int3(4)

assert tuple_not_in_int0(0)
assert tuple_not_in_int1(0)
assert not tuple_not_in_int1(1)
assert tuple_not_in_int3(0)
assert not tuple_not_in_int3(1)
assert not tuple_not_in_int3(2)
assert not tuple_not_in_int3(3)
assert tuple_not_in_int3(4)

assert tuple_in_str("foo")
assert tuple_in_str("bar")
assert tuple_in_str("baz")
assert not tuple_in_str("apple")
assert not tuple_in_str("pie")
assert not tuple_in_str("\0")
assert not tuple_in_str("")

assert not list_in_int0(0)
assert not list_in_int1(0)
assert list_in_int1(1)
assert not list_in_int3(0)
assert list_in_int3(1)
assert list_in_int3(2)
assert list_in_int3(3)
assert not list_in_int3(4)

assert list_not_in_int0(0)
assert list_not_in_int1(0)
assert not list_not_in_int1(1)
assert list_not_in_int3(0)
assert not list_not_in_int3(1)
assert not list_not_in_int3(2)
assert not list_not_in_int3(3)
assert list_not_in_int3(4)

assert list_in_str("foo")
assert list_in_str("bar")
assert list_in_str("baz")
assert not list_in_str("apple")
assert not list_in_str("pie")
assert not list_in_str("\0")
assert not list_in_str("")

assert list_in_mixed(0)
assert list_in_mixed([])
assert list_in_mixed({})
assert list_in_mixed(())
assert list_in_mixed(False)
assert list_in_mixed(0.0)
assert not list_in_mixed([1])
assert not list_in_mixed(object)
assert list_in_mixed(type)

0 comments on commit 8bf770d

Please sign in to comment.