Skip to content

Commit

Permalink
Add rudimentary mypy type checking
Browse files Browse the repository at this point in the history
Add a very lax mypy configuration, add it to tox -e linting, and
fix/ignore the few errors that come up. The idea is to get it running
before diving in too much.

This enables:

- Progressively adding type annotations and enabling more strict
  options, which will improve the codebase (IMO).

- Annotating the public API in-line, and eventually exposing it to
  library users who use type checkers (with a py.typed file).

Though, none of this is done yet.

Refs pytest-dev#3342.
  • Loading branch information
bluetech committed Jul 8, 2019
1 parent 60a358f commit ee614dd
Show file tree
Hide file tree
Showing 31 changed files with 102 additions and 45 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -35,6 +35,7 @@ env/
.tox
.cache
.pytest_cache
.mypy_cache
.coverage
.coverage.*
coverage.xml
Expand Down
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Expand Up @@ -28,6 +28,7 @@ repos:
hooks:
- id: flake8
language_version: python3
additional_dependencies: [flake8-typing-imports]
- repo: https://github.com/asottile/reorder_python_imports
rev: v1.4.0
hooks:
Expand All @@ -42,6 +43,15 @@ repos:
rev: v1.4.0
hooks:
- id: rst-backticks
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.711
hooks:
- id: mypy
name: mypy (src)
files: ^src/
- id: mypy
name: mypy (testing)
files: ^testing/
- repo: local
hooks:
- id: rst
Expand Down
2 changes: 1 addition & 1 deletion bench/bench.py
Expand Up @@ -6,7 +6,7 @@
import pstats

script = sys.argv[1:] if len(sys.argv) > 1 else ["empty.py"]
stats = cProfile.run("pytest.cmdline.main(%r)" % script, "prof")
cProfile.run("pytest.cmdline.main(%r)" % script, "prof")
p = pstats.Stats("prof")
p.strip_dirs()
p.sort_stats("cumulative")
Expand Down
8 changes: 8 additions & 0 deletions setup.cfg
Expand Up @@ -61,3 +61,11 @@ ignore =

[devpi:upload]
formats = sdist.tgz,bdist_wheel

[mypy]
ignore_missing_imports = True
no_implicit_optional = True
strict_equality = True
warn_redundant_casts = True
warn_return_any = True
warn_unused_configs = True
3 changes: 2 additions & 1 deletion src/_pytest/_argcomplete.py
Expand Up @@ -56,6 +56,7 @@
import os
import sys
from glob import glob
from typing import Optional


class FastFilesCompleter:
Expand Down Expand Up @@ -91,7 +92,7 @@ def __call__(self, prefix, **kwargs):
import argcomplete.completers
except ImportError:
sys.exit(-1)
filescompleter = FastFilesCompleter()
filescompleter = FastFilesCompleter() # type: Optional[FastFilesCompleter]

def try_argcomplete(parser):
argcomplete.autocomplete(parser, always_complete_options=False)
Expand Down
13 changes: 7 additions & 6 deletions src/_pytest/_code/code.py
Expand Up @@ -33,7 +33,8 @@ def __init__(self, rawcode):
def __eq__(self, other):
return self.raw == other.raw

__hash__ = None
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore

def __ne__(self, other):
return not self == other
Expand Down Expand Up @@ -188,11 +189,11 @@ def path(self):
""" path to the source code """
return self.frame.code.path

def getlocals(self):
@property
def locals(self):
""" locals of underlaying frame """
return self.frame.f_locals

locals = property(getlocals, None, None, "locals of underlaying frame")

def getfirstlinesource(self):
return self.frame.code.firstlineno

Expand Down Expand Up @@ -255,11 +256,11 @@ def __str__(self):
line = "???"
return " File %r:%d in %s\n %s\n" % (fn, self.lineno + 1, name, line)

@property
def name(self):
""" co_name of underlaying code """
return self.frame.code.raw.co_name

name = property(name, None, None, "co_name of underlaying code")


class Traceback(list):
""" Traceback objects encapsulate and offer higher level
Expand Down
3 changes: 2 additions & 1 deletion src/_pytest/_code/source.py
Expand Up @@ -44,7 +44,8 @@ def __eq__(self, other):
return str(self) == other
return False

__hash__ = None
# Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore

def __getitem__(self, key):
if isinstance(key, int):
Expand Down
25 changes: 15 additions & 10 deletions src/_pytest/assertion/rewrite.py
Expand Up @@ -12,6 +12,10 @@
import sys
import tokenize
import types
from typing import Dict
from typing import List
from typing import Optional
from typing import Set

import atomicwrites

Expand Down Expand Up @@ -459,39 +463,40 @@ def _fix(node, lineno, col_offset):
return node


def _get_assertion_exprs(src: bytes): # -> Dict[int, str]
def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
"""Returns a mapping from {lineno: "assertion test expression"}"""
ret = {}
ret = {} # type: Dict[int, str]

depth = 0
lines = []
assert_lineno = None
seen_lines = set()
lines = [] # type: List[str]
assert_lineno = None # type: Optional[int]
seen_lines = set() # type: Set[int]

def _write_and_reset() -> None:
nonlocal depth, lines, assert_lineno, seen_lines
assert assert_lineno is not None
ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
depth = 0
lines = []
assert_lineno = None
seen_lines = set()

tokens = tokenize.tokenize(io.BytesIO(src).readline)
for tp, src, (lineno, offset), _, line in tokens:
if tp == tokenize.NAME and src == "assert":
for tp, source, (lineno, offset), _, line in tokens:
if tp == tokenize.NAME and source == "assert":
assert_lineno = lineno
elif assert_lineno is not None:
# keep track of depth for the assert-message `,` lookup
if tp == tokenize.OP and src in "([{":
if tp == tokenize.OP and source in "([{":
depth += 1
elif tp == tokenize.OP and src in ")]}":
elif tp == tokenize.OP and source in ")]}":
depth -= 1

if not lines:
lines.append(line[offset:])
seen_lines.add(lineno)
# a non-nested comma separates the expression from the message
elif depth == 0 and tp == tokenize.OP and src == ",":
elif depth == 0 and tp == tokenize.OP and source == ",":
# one line assert with message
if lineno in seen_lines and len(lines) == 1:
offset_in_trimmed = offset + len(lines[-1]) - len(line)
Expand Down
12 changes: 8 additions & 4 deletions src/_pytest/capture.py
Expand Up @@ -547,6 +547,8 @@ def __init__(self, targetfd, tmpfile=None):
self.start = lambda: None
self.done = lambda: None
else:
self.start = self._start
self.done = self._done
if targetfd == 0:
assert not tmpfile, "cannot set tmpfile with stdin"
tmpfile = open(os.devnull, "r")
Expand All @@ -568,7 +570,7 @@ def __repr__(self):
self.targetfd, getattr(self, "targetfd_save", None), self._state
)

def start(self):
def _start(self):
""" Start capturing on targetfd using memorized tmpfile. """
try:
os.fstat(self.targetfd_save)
Expand All @@ -585,7 +587,7 @@ def snap(self):
self.tmpfile.truncate()
return res

def done(self):
def _done(self):
""" stop capturing, restore streams, return original capture file,
seeked to position zero. """
targetfd_save = self.__dict__.pop("targetfd_save")
Expand Down Expand Up @@ -618,7 +620,8 @@ class FDCapture(FDCaptureBinary):
snap() produces text
"""

EMPTY_BUFFER = str()
# Ignore type because it doesn't match the type in the superclass (bytes).
EMPTY_BUFFER = str() # type: ignore

def snap(self):
res = super().snap()
Expand Down Expand Up @@ -679,7 +682,8 @@ def writeorg(self, data):


class SysCaptureBinary(SysCapture):
EMPTY_BUFFER = b""
# Ignore type because it doesn't match the type in the superclass (str).
EMPTY_BUFFER = b"" # type: ignore

def snap(self):
res = self.tmpfile.buffer.getvalue()
Expand Down
2 changes: 1 addition & 1 deletion src/_pytest/debugging.py
Expand Up @@ -74,7 +74,7 @@ class pytestPDB:

_pluginmanager = None
_config = None
_saved = []
_saved = [] # type: list
_recursive_debug = 0
_wrapped_pdb_cls = None

Expand Down
12 changes: 9 additions & 3 deletions src/_pytest/fixtures.py
Expand Up @@ -6,6 +6,8 @@
from collections import defaultdict
from collections import deque
from collections import OrderedDict
from typing import Dict
from typing import Tuple

import attr
import py
Expand All @@ -31,6 +33,9 @@
from _pytest.outcomes import fail
from _pytest.outcomes import TEST_OUTCOME

if False: # TYPE_CHECKING
from typing import Type


@attr.s(frozen=True)
class PseudoFixtureDef:
Expand All @@ -54,10 +59,10 @@ def pytest_sessionstart(session):
session._fixturemanager = FixtureManager(session)


scopename2class = {}
scopename2class = {} # type: Dict[str, Type[nodes.Node]]


scope2props = dict(session=())
scope2props = dict(session=()) # type: Dict[str, Tuple[str, ...]]
scope2props["package"] = ("fspath",)
scope2props["module"] = ("fspath", "module")
scope2props["class"] = scope2props["module"] + ("cls",)
Expand Down Expand Up @@ -960,7 +965,8 @@ class FixtureFunctionMarker:
scope = attr.ib()
params = attr.ib(converter=attr.converters.optional(tuple))
autouse = attr.ib(default=False)
ids = attr.ib(default=None, converter=_ensure_immutable_ids)
# Ignore type because of https://github.com/python/mypy/issues/6172.
ids = attr.ib(default=None, converter=_ensure_immutable_ids) # type: ignore
name = attr.ib(default=None)

def __call__(self, function):
Expand Down
3 changes: 2 additions & 1 deletion src/_pytest/mark/__init__.py
Expand Up @@ -91,7 +91,8 @@ def pytest_cmdline_main(config):
return 0


pytest_cmdline_main.tryfirst = True
# Ignore type because of https://github.com/python/mypy/issues/2087.
pytest_cmdline_main.tryfirst = True # type: ignore


def deselect_by_keyword(items, config):
Expand Down
3 changes: 2 additions & 1 deletion src/_pytest/mark/structures.py
Expand Up @@ -3,6 +3,7 @@
from collections import namedtuple
from collections.abc import MutableMapping
from operator import attrgetter
from typing import Set

import attr

Expand Down Expand Up @@ -298,7 +299,7 @@ def test_function():
on the ``test_function`` object. """

_config = None
_markers = set()
_markers = set() # type: Set[str]

def __getattr__(self, name):
if name[0] == "_":
Expand Down
3 changes: 2 additions & 1 deletion src/_pytest/nodes.py
Expand Up @@ -280,7 +280,8 @@ def _repr_failure_py(self, excinfo, style=None):
truncate_locals=truncate_locals,
)

repr_failure = _repr_failure_py
def repr_failure(self, excinfo, style=None):
return self._repr_failure_py(excinfo, style)


def get_fslocation_from_item(item):
Expand Down
12 changes: 8 additions & 4 deletions src/_pytest/outcomes.py
Expand Up @@ -70,7 +70,8 @@ def exit(msg, returncode=None):
raise Exit(msg, returncode)


exit.Exception = Exit
# Ignore type because of https://github.com/python/mypy/issues/2087.
exit.Exception = Exit # type: ignore


def skip(msg="", *, allow_module_level=False):
Expand All @@ -96,7 +97,8 @@ def skip(msg="", *, allow_module_level=False):
raise Skipped(msg=msg, allow_module_level=allow_module_level)


skip.Exception = Skipped
# Ignore type because of https://github.com/python/mypy/issues/2087.
skip.Exception = Skipped # type: ignore


def fail(msg="", pytrace=True):
Expand All @@ -111,7 +113,8 @@ def fail(msg="", pytrace=True):
raise Failed(msg=msg, pytrace=pytrace)


fail.Exception = Failed
# Ignore type because of https://github.com/python/mypy/issues/2087.
fail.Exception = Failed # type: ignore


class XFailed(Failed):
Expand All @@ -132,7 +135,8 @@ def xfail(reason=""):
raise XFailed(reason)


xfail.Exception = XFailed
# Ignore type because of https://github.com/python/mypy/issues/2087.
xfail.Exception = XFailed # type: ignore


def importorskip(modname, minversion=None, reason=None):
Expand Down

0 comments on commit ee614dd

Please sign in to comment.