Skip to content

Commit

Permalink
[mypyc] Optimize away some bool/bit registers (#17022)
Browse files Browse the repository at this point in the history
If a register is always used in a branch immediately after assignment,
and it isn't used for anything else, we can replace the assignment with
a branch op. This avoids some assignment ops and gotos.

This is not a very interesting optimization in general, but it will help
a lot with tagged integer operations once I refactor them to be
generated in the lowering pass (in follow-up PRs).
  • Loading branch information
JukkaL committed Mar 14, 2024
1 parent a00fcba commit a18a0db
Show file tree
Hide file tree
Showing 5 changed files with 445 additions and 14 deletions.
4 changes: 3 additions & 1 deletion mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from mypyc.options import CompilerOptions
from mypyc.transform.copy_propagation import do_copy_propagation
from mypyc.transform.exceptions import insert_exception_handling
from mypyc.transform.flag_elimination import do_flag_elimination
from mypyc.transform.refcount import insert_ref_count_opcodes
from mypyc.transform.uninit import insert_uninit_checks

Expand Down Expand Up @@ -234,8 +235,9 @@ def compile_scc_to_ir(
insert_exception_handling(fn)
# Insert refcount handling.
insert_ref_count_opcodes(fn)
# Perform copy propagation optimization.
# Perform optimizations.
do_copy_propagation(fn, compiler_options)
do_flag_elimination(fn, compiler_options)

return modules

Expand Down
300 changes: 300 additions & 0 deletions mypyc/test-data/opt-flag-elimination.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
-- Test cases for "flag elimination" optimization. Used to optimize away
-- registers that are always used immediately after assignment as branch conditions.

[case testFlagEliminationSimple]
def c() -> bool:
return True
def d() -> bool:
return True

def f(x: bool) -> int:
if x:
b = c()
else:
b = d()
if b:
return 1
else:
return 2
[out]
def c():
L0:
return 1
def d():
L0:
return 1
def f(x):
x, r0, r1 :: bool
L0:
if x goto L1 else goto L2 :: bool
L1:
r0 = c()
if r0 goto L4 else goto L5 :: bool
L2:
r1 = d()
if r1 goto L4 else goto L5 :: bool
L3:
unreachable
L4:
return 2
L5:
return 4

[case testFlagEliminationOneAssignment]
def c() -> bool:
return True

def f(x: bool) -> int:
# Not applied here
b = c()
if b:
return 1
else:
return 2
[out]
def c():
L0:
return 1
def f(x):
x, r0, b :: bool
L0:
r0 = c()
b = r0
if b goto L1 else goto L2 :: bool
L1:
return 2
L2:
return 4

[case testFlagEliminationThreeCases]
def c(x: int) -> bool:
return True

def f(x: bool, y: bool) -> int:
if x:
b = c(1)
elif y:
b = c(2)
else:
b = c(3)
if b:
return 1
else:
return 2
[out]
def c(x):
x :: int
L0:
return 1
def f(x, y):
x, y, r0, r1, r2 :: bool
L0:
if x goto L1 else goto L2 :: bool
L1:
r0 = c(2)
if r0 goto L6 else goto L7 :: bool
L2:
if y goto L3 else goto L4 :: bool
L3:
r1 = c(4)
if r1 goto L6 else goto L7 :: bool
L4:
r2 = c(6)
if r2 goto L6 else goto L7 :: bool
L5:
unreachable
L6:
return 2
L7:
return 4

[case testFlagEliminationAssignmentNotLastOp]
def f(x: bool) -> int:
y = 0
if x:
b = True
y = 1
else:
b = False
if b:
return 1
else:
return 2
[out]
def f(x):
x :: bool
y :: int
b :: bool
L0:
y = 0
if x goto L1 else goto L2 :: bool
L1:
b = 1
y = 2
goto L3
L2:
b = 0
L3:
if b goto L4 else goto L5 :: bool
L4:
return 2
L5:
return 4

[case testFlagEliminationAssignmentNoDirectGoto]
def f(x: bool) -> int:
if x:
b = True
else:
b = False
if x:
if b:
return 1
else:
return 2
return 4
[out]
def f(x):
x, b :: bool
L0:
if x goto L1 else goto L2 :: bool
L1:
b = 1
goto L3
L2:
b = 0
L3:
if x goto L4 else goto L7 :: bool
L4:
if b goto L5 else goto L6 :: bool
L5:
return 2
L6:
return 4
L7:
return 8

[case testFlagEliminationBranchNotNextOpAfterGoto]
def f(x: bool) -> int:
if x:
b = True
else:
b = False
y = 1 # Prevents the optimization
if b:
return 1
else:
return 2
[out]
def f(x):
x, b :: bool
y :: int
L0:
if x goto L1 else goto L2 :: bool
L1:
b = 1
goto L3
L2:
b = 0
L3:
y = 2
if b goto L4 else goto L5 :: bool
L4:
return 2
L5:
return 4

[case testFlagEliminationFlagReadTwice]
def f(x: bool) -> bool:
if x:
b = True
else:
b = False
if b:
return b # Prevents the optimization
else:
return False
[out]
def f(x):
x, b :: bool
L0:
if x goto L1 else goto L2 :: bool
L1:
b = 1
goto L3
L2:
b = 0
L3:
if b goto L4 else goto L5 :: bool
L4:
return b
L5:
return 0

[case testFlagEliminationArgumentNotEligible]
def f(x: bool, b: bool) -> bool:
if x:
b = True
else:
b = False
if b:
return True
else:
return False
[out]
def f(x, b):
x, b :: bool
L0:
if x goto L1 else goto L2 :: bool
L1:
b = 1
goto L3
L2:
b = 0
L3:
if b goto L4 else goto L5 :: bool
L4:
return 1
L5:
return 0

[case testFlagEliminationFlagNotAlwaysDefined]
def f(x: bool, y: bool) -> bool:
if x:
b = True
elif y:
b = False
else:
bb = False # b not assigned here -> can't optimize
if b:
return True
else:
return False
[out]
def f(x, y):
x, y, r0, b, bb, r1 :: bool
L0:
r0 = <error> :: bool
b = r0
if x goto L1 else goto L2 :: bool
L1:
b = 1
goto L5
L2:
if y goto L3 else goto L4 :: bool
L3:
b = 0
goto L5
L4:
bb = 0
L5:
if is_error(b) goto L6 else goto L7
L6:
r1 = raise UnboundLocalError('local variable "b" referenced before assignment')
unreachable
L7:
if b goto L8 else goto L9 :: bool
L8:
return 1
L9:
return 0
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Runner for copy propagation optimization tests."""
"""Runner for IR optimization tests."""

from __future__ import annotations

Expand All @@ -8,6 +8,7 @@
from mypy.test.config import test_temp_dir
from mypy.test.data import DataDrivenTestCase
from mypyc.common import TOP_LEVEL_NAME
from mypyc.ir.func_ir import FuncIR
from mypyc.ir.pprint import format_func
from mypyc.options import CompilerOptions
from mypyc.test.testutil import (
Expand All @@ -19,13 +20,16 @@
use_custom_builtins,
)
from mypyc.transform.copy_propagation import do_copy_propagation
from mypyc.transform.flag_elimination import do_flag_elimination
from mypyc.transform.uninit import insert_uninit_checks

files = ["opt-copy-propagation.test"]

class OptimizationSuite(MypycDataSuite):
"""Base class for IR optimization test suites.
To use this, add a base class and define "files" and "do_optimizations".
"""

class TestCopyPropagation(MypycDataSuite):
files = files
base_path = test_temp_dir

def run_case(self, testcase: DataDrivenTestCase) -> None:
Expand All @@ -41,7 +45,24 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"):
continue
insert_uninit_checks(fn)
do_copy_propagation(fn, CompilerOptions())
self.do_optimizations(fn)
actual.extend(format_func(fn))

assert_test_output(testcase, actual, "Invalid source code output", expected_output)

def do_optimizations(self, fn: FuncIR) -> None:
raise NotImplementedError


class TestCopyPropagation(OptimizationSuite):
files = ["opt-copy-propagation.test"]

def do_optimizations(self, fn: FuncIR) -> None:
do_copy_propagation(fn, CompilerOptions())


class TestFlagElimination(OptimizationSuite):
files = ["opt-flag-elimination.test"]

def do_optimizations(self, fn: FuncIR) -> None:
do_flag_elimination(fn, CompilerOptions())

0 comments on commit a18a0db

Please sign in to comment.