Skip to content

Commit

Permalink
Merge pull request #6311 from bluetech/type-annotations-10
Browse files Browse the repository at this point in the history
Some type annotation & check_untyped_defs fixes
  • Loading branch information
bluetech committed Jan 19, 2020
2 parents 4fb9cc3 + 3392be3 commit 44eb1f5
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 214 deletions.
4 changes: 2 additions & 2 deletions src/_pytest/_code/code.py
Expand Up @@ -67,7 +67,7 @@ def __ne__(self, other):
return not self == other

@property
def path(self):
def path(self) -> Union[py.path.local, str]:
""" return a path object pointing to source code (note that it
might not point to an actually existing file). """
try:
Expand Down Expand Up @@ -335,7 +335,7 @@ def cut(
(path is None or codepath == path)
and (
excludepath is None
or not hasattr(codepath, "relto")
or not isinstance(codepath, py.path.local)
or not codepath.relto(excludepath)
)
and (lineno is None or x.lineno == lineno)
Expand Down
82 changes: 70 additions & 12 deletions src/_pytest/_code/source.py
Expand Up @@ -5,8 +5,8 @@
import textwrap
import tokenize
import warnings
from ast import PyCF_ONLY_AST as _AST_FLAG
from bisect import bisect_right
from types import CodeType
from types import FrameType
from typing import Iterator
from typing import List
Expand All @@ -18,6 +18,10 @@
import py

from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING

if TYPE_CHECKING:
from typing_extensions import Literal


class Source:
Expand Down Expand Up @@ -121,7 +125,7 @@ def getstatement(self, lineno: int) -> "Source":
start, end = self.getstatementrange(lineno)
return self[start:end]

def getstatementrange(self, lineno: int):
def getstatementrange(self, lineno: int) -> Tuple[int, int]:
""" return (start, end) tuple which spans the minimal
statement region which containing the given lineno.
"""
Expand Down Expand Up @@ -159,14 +163,36 @@ def isparseable(self, deindent: bool = True) -> bool:
def __str__(self) -> str:
return "\n".join(self.lines)

@overload
def compile(
self,
filename=None,
mode="exec",
filename: Optional[str] = ...,
mode: str = ...,
flag: "Literal[0]" = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> CodeType:
raise NotImplementedError()

@overload # noqa: F811
def compile( # noqa: F811
self,
filename: Optional[str] = ...,
mode: str = ...,
flag: int = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()

def compile( # noqa: F811
self,
filename: Optional[str] = None,
mode: str = "exec",
flag: int = 0,
dont_inherit: int = 0,
_genframe: Optional[FrameType] = None,
):
) -> Union[CodeType, ast.AST]:
""" return compiled code object. if filename is None
invent an artificial filename which displays
the source/line position of the caller frame.
Expand Down Expand Up @@ -196,8 +222,10 @@ def compile(
newex.text = ex.text
raise newex
else:
if flag & _AST_FLAG:
if flag & ast.PyCF_ONLY_AST:
assert isinstance(co, ast.AST)
return co
assert isinstance(co, CodeType)
lines = [(x + "\n") for x in self.lines]
# Type ignored because linecache.cache is private.
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
Expand All @@ -209,22 +237,52 @@ def compile(
#


def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0):
@overload
def compile_(
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: "Literal[0]" = ...,
dont_inherit: int = ...,
) -> CodeType:
raise NotImplementedError()


@overload # noqa: F811
def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: int = ...,
dont_inherit: int = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()


def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = None,
mode: str = "exec",
flags: int = 0,
dont_inherit: int = 0,
) -> Union[CodeType, ast.AST]:
""" compile the given source to a raw code object,
and maintain an internal cache which allows later
retrieval of the source code for the code object
and any recursively created code objects.
"""
if isinstance(source, ast.AST):
# XXX should Source support having AST?
return compile(source, filename, mode, flags, dont_inherit)
assert filename is not None
co = compile(source, filename, mode, flags, dont_inherit)
assert isinstance(co, (CodeType, ast.AST))
return co
_genframe = sys._getframe(1) # the caller
s = Source(source)
co = s.compile(filename, mode, flags, _genframe=_genframe)
return co
return s.compile(filename, mode, flags, _genframe=_genframe)


def getfslineno(obj):
def getfslineno(obj) -> Tuple[Union[str, py.path.local], int]:
""" Return source location (path, lineno) for the given object.
If the source cannot be determined return ("", -1).
Expand Down Expand Up @@ -321,7 +379,7 @@ def getstatementrange_ast(
# don't produce duplicate warnings when compiling source to find ast
with warnings.catch_warnings():
warnings.simplefilter("ignore")
astnode = compile(content, "source", "exec", _AST_FLAG)
astnode = ast.parse(content, "source", "exec")

start, end = get_statement_startend2(lineno, astnode)
# we need to correct the end:
Expand Down
6 changes: 3 additions & 3 deletions src/_pytest/recwarn.py
Expand Up @@ -57,7 +57,7 @@ def deprecated_call(func=None, *args, **kwargs):

@overload
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
*,
match: "Optional[Union[str, Pattern]]" = ...
) -> "WarningsChecker":
Expand All @@ -66,7 +66,7 @@ def warns(

@overload # noqa: F811
def warns( # noqa: F811
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
func: Callable,
*args: Any,
match: Optional[Union[str, "Pattern"]] = ...,
Expand All @@ -76,7 +76,7 @@ def warns( # noqa: F811


def warns( # noqa: F811
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
*args: Any,
match: Optional[Union[str, "Pattern"]] = None,
**kwargs: Any
Expand Down
9 changes: 8 additions & 1 deletion src/_pytest/reports.py
@@ -1,5 +1,6 @@
from io import StringIO
from pprint import pprint
from typing import Any
from typing import List
from typing import Optional
from typing import Tuple
Expand All @@ -17,6 +18,7 @@
from _pytest._code.code import ReprLocals
from _pytest._code.code import ReprTraceback
from _pytest._code.code import TerminalRepr
from _pytest.compat import TYPE_CHECKING
from _pytest.nodes import Node
from _pytest.outcomes import skip
from _pytest.pathlib import Path
Expand All @@ -41,9 +43,14 @@ class BaseReport:
sections = [] # type: List[Tuple[str, str]]
nodeid = None # type: str

def __init__(self, **kw):
def __init__(self, **kw: Any) -> None:
self.__dict__.update(kw)

if TYPE_CHECKING:
# Can have arbitrary fields given to __init__().
def __getattr__(self, key: str) -> Any:
raise NotImplementedError()

def toterminal(self, out) -> None:
if hasattr(self, "node"):
out.line(getslaveinfoline(self.node)) # type: ignore
Expand Down
16 changes: 16 additions & 0 deletions testing/code/test_source.py
Expand Up @@ -4,10 +4,13 @@
import ast
import inspect
import sys
from types import CodeType
from typing import Any
from typing import Dict
from typing import Optional

import py

import _pytest._code
import pytest
from _pytest._code import Source
Expand Down Expand Up @@ -147,6 +150,10 @@ def test_getrange(self) -> None:
assert len(x.lines) == 2
assert str(x) == "def f(x):\n pass"

def test_getrange_step_not_supported(self) -> None:
with pytest.raises(IndexError, match=r"step"):
self.source[::2]

def test_getline(self) -> None:
x = self.source[0]
assert x == "def f(x):"
Expand Down Expand Up @@ -449,6 +456,14 @@ def test_idem_compile_and_getsource() -> None:
assert src == expected


def test_compile_ast() -> None:
# We don't necessarily want to support this.
# This test was added just for coverage.
stmt = ast.parse("def x(): pass")
co = _pytest._code.compile(stmt, filename="foo.py")
assert isinstance(co, CodeType)


def test_findsource_fallback() -> None:
from _pytest._code.source import findsource

Expand Down Expand Up @@ -488,6 +503,7 @@ def f(x) -> None:

fspath, lineno = getfslineno(f)

assert isinstance(fspath, py.path.local)
assert fspath.basename == "test_source.py"
assert lineno == f.__code__.co_firstlineno - 1 # see findsource

Expand Down

0 comments on commit 44eb1f5

Please sign in to comment.