Skip to content

Commit

Permalink
[mypyc] Fix compilation of unreachable comprehensions (#15721)
Browse files Browse the repository at this point in the history
Fixes mypyc/mypyc#816. Admittedly hacky.
  • Loading branch information
ichard26 committed Feb 28, 2024
1 parent 9f1c90a commit f19b5d3
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
4 changes: 4 additions & 0 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from mypyc.irbuild.constant_fold import constant_fold_expr
from mypyc.irbuild.for_helpers import (
comprehension_helper,
raise_error_if_contains_unreachable_names,
translate_list_comprehension,
translate_set_comprehension,
)
Expand Down Expand Up @@ -1020,6 +1021,9 @@ def transform_set_comprehension(builder: IRBuilder, o: SetComprehension) -> Valu


def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehension) -> Value:
if raise_error_if_contains_unreachable_names(builder, o):
return builder.none()

d = builder.maybe_spill(builder.call_c(dict_new_op, [], o.line))
loop_params = list(zip(o.indices, o.sequences, o.condlists, o.is_async))

Expand Down
29 changes: 29 additions & 0 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from mypy.nodes import (
ARG_POS,
CallExpr,
DictionaryComprehension,
Expression,
GeneratorExpr,
Lvalue,
MemberExpr,
NameExpr,
RefExpr,
SetExpr,
TupleExpr,
Expand All @@ -28,6 +30,7 @@
IntOp,
LoadAddress,
LoadMem,
RaiseStandardError,
Register,
TupleGet,
TupleSet,
Expand Down Expand Up @@ -229,6 +232,9 @@ def set_item(item_index: Value) -> None:


def translate_list_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Value:
if raise_error_if_contains_unreachable_names(builder, gen):
return builder.none()

# Try simplest list comprehension, otherwise fall back to general one
val = sequence_from_generator_preallocate_helper(
builder,
Expand All @@ -251,7 +257,30 @@ def gen_inner_stmts() -> None:
return builder.read(list_ops)


def raise_error_if_contains_unreachable_names(
builder: IRBuilder, gen: GeneratorExpr | DictionaryComprehension
) -> bool:
"""Raise a runtime error and return True if generator contains unreachable names.
False is returned if the generator can be safely transformed without crashing.
(It may still be unreachable!)
"""
if any(isinstance(s, NameExpr) and s.node is None for s in gen.indices):
error = RaiseStandardError(
RaiseStandardError.RUNTIME_ERROR,
"mypyc internal error: should be unreachable",
gen.line,
)
builder.add(error)
return True

return False


def translate_set_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Value:
if raise_error_if_contains_unreachable_names(builder, gen):
return builder.none()

set_ops = builder.maybe_spill(builder.new_set_op([], gen.line))
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))

Expand Down
6 changes: 4 additions & 2 deletions mypyc/test-data/run-misc.test
Original file line number Diff line number Diff line change
Expand Up @@ -1097,8 +1097,10 @@ B = sys.platform == 'x' and sys.foobar
C = sys.platform == 'x' and f(a, -b, 'y') > [c + e, g(y=2)]
C = sys.platform == 'x' and cast(a, b[c])
C = sys.platform == 'x' and (lambda x: y + x)
# TODO: This still doesn't work
# C = sys.platform == 'x' and (x for y in z)
C = sys.platform == 'x' and (x for y in z)
C = sys.platform == 'x' and [x for y in z]
C = sys.platform == 'x' and {x: x for y in z}
C = sys.platform == 'x' and {x for y in z}

assert not A
assert not B
Expand Down

0 comments on commit f19b5d3

Please sign in to comment.