Skip to content

Commit

Permalink
Apply --strict-equality special-casing for bytes and bytearray on Pyt…
Browse files Browse the repository at this point in the history
…hon 2 (#7493)

Fixes #7465

The previous fix only worked on Python 3.
  • Loading branch information
ilevkivskyi committed Sep 10, 2019
1 parent 41db9a0 commit 226a4f1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
14 changes: 9 additions & 5 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,8 +2136,9 @@ def dangerous_comparison(self, left: Type, right: Type,
if isinstance(left, UnionType) and isinstance(right, UnionType):
left = remove_optional(left)
right = remove_optional(right)
if (original_container and has_bytes_component(original_container) and
has_bytes_component(left)):
py2 = self.chk.options.python_version < (3, 0)
if (original_container and has_bytes_component(original_container, py2) and
has_bytes_component(left, py2)):
# We need to special case bytes and bytearray, because 97 in b'abc', b'a' in b'abc',
# b'a' in bytearray(b'abc') etc. all return True (and we want to show the error only
# if the check can _never_ be True).
Expand Down Expand Up @@ -4179,13 +4180,16 @@ def custom_equality_method(typ: Type) -> bool:
return False


def has_bytes_component(typ: Type) -> bool:
def has_bytes_component(typ: Type, py2: bool = False) -> bool:
"""Is this one of builtin byte types, or a union that contains it?"""
typ = get_proper_type(typ)
if py2:
byte_types = {'builtins.str', 'builtins.bytearray'}
else:
byte_types = {'builtins.bytes', 'builtins.bytearray'}
if isinstance(typ, UnionType):
return any(has_bytes_component(t) for t in typ.items)
if isinstance(typ, Instance) and typ.type.fullname() in {'builtins.bytes',
'builtins.bytearray'}:
if isinstance(typ, Instance) and typ.type.fullname() in byte_types:
return True
return False

Expand Down
6 changes: 6 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2066,6 +2066,12 @@ bytearray(b'abc') in b'abcde' # OK on Python 3
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-full.pyi]

[case testBytesVsByteArray_python2]
# flags: --strict-equality --py2
b'hi' in bytearray(b'hi')
[builtins_py2 fixtures/python2.pyi]
[typing fixtures/typing-full.pyi]

[case testStrictEqualityNoPromotePy3]
# flags: --strict-equality
'a' == b'a' # E: Non-overlapping equality check (left operand type: "Literal['a']", right operand type: "Literal[b'a']")
Expand Down
13 changes: 10 additions & 3 deletions test-data/unit/fixtures/python2.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, Iterable, TypeVar
from typing import Generic, Iterable, TypeVar, Sequence, Iterator

class object:
def __init__(self) -> None: pass
Expand All @@ -13,9 +13,16 @@ class function: pass
class int: pass
class str: pass
class unicode: pass
class bool: pass
class bool(int): pass
class bytearray(Sequence[int]):
def __init__(self, string: str) -> None: pass
def __contains__(self, item: object) -> bool: pass
def __iter__(self) -> Iterator[int]: pass
def __getitem__(self, item: int) -> int: pass

T = TypeVar('T')
class list(Iterable[T], Generic[T]): pass
class list(Iterable[T], Generic[T]):
def __iter__(self) -> Iterator[T]: pass
def __getitem__(self, item: int) -> T: pass

# Definition of None is implicit

0 comments on commit 226a4f1

Please sign in to comment.