Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some type annotation & check_untyped_defs fixes #6311

Merged
merged 5 commits into from Jan 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/_pytest/_code/code.py
Expand Up @@ -67,7 +67,7 @@ def __ne__(self, other):
return not self == other

@property
def path(self):
def path(self) -> Union[py.path.local, str]:
""" return a path object pointing to source code (note that it
might not point to an actually existing file). """
try:
Expand Down Expand Up @@ -335,7 +335,7 @@ def cut(
(path is None or codepath == path)
and (
excludepath is None
or not hasattr(codepath, "relto")
or not isinstance(codepath, py.path.local)
or not codepath.relto(excludepath)
)
and (lineno is None or x.lineno == lineno)
Expand Down
82 changes: 70 additions & 12 deletions src/_pytest/_code/source.py
Expand Up @@ -5,8 +5,8 @@
import textwrap
import tokenize
import warnings
from ast import PyCF_ONLY_AST as _AST_FLAG
from bisect import bisect_right
from types import CodeType
from types import FrameType
from typing import Iterator
from typing import List
Expand All @@ -18,6 +18,10 @@
import py

from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING

if TYPE_CHECKING:
from typing_extensions import Literal
bluetech marked this conversation as resolved.
Show resolved Hide resolved


class Source:
Expand Down Expand Up @@ -121,7 +125,7 @@ def getstatement(self, lineno: int) -> "Source":
start, end = self.getstatementrange(lineno)
return self[start:end]

def getstatementrange(self, lineno: int):
def getstatementrange(self, lineno: int) -> Tuple[int, int]:
""" return (start, end) tuple which spans the minimal
statement region which containing the given lineno.
"""
Expand Down Expand Up @@ -159,14 +163,36 @@ def isparseable(self, deindent: bool = True) -> bool:
def __str__(self) -> str:
return "\n".join(self.lines)

@overload
def compile(
self,
filename=None,
mode="exec",
filename: Optional[str] = ...,
mode: str = ...,
flag: "Literal[0]" = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> CodeType:
raise NotImplementedError()

@overload # noqa: F811
def compile( # noqa: F811
self,
filename: Optional[str] = ...,
mode: str = ...,
flag: int = ...,
dont_inherit: int = ...,
_genframe: Optional[FrameType] = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()

def compile( # noqa: F811
self,
filename: Optional[str] = None,
mode: str = "exec",
flag: int = 0,
dont_inherit: int = 0,
_genframe: Optional[FrameType] = None,
):
) -> Union[CodeType, ast.AST]:
""" return compiled code object. if filename is None
invent an artificial filename which displays
the source/line position of the caller frame.
Expand Down Expand Up @@ -196,8 +222,10 @@ def compile(
newex.text = ex.text
raise newex
else:
if flag & _AST_FLAG:
if flag & ast.PyCF_ONLY_AST:
assert isinstance(co, ast.AST)
return co
assert isinstance(co, CodeType)
lines = [(x + "\n") for x in self.lines]
# Type ignored because linecache.cache is private.
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
Expand All @@ -209,22 +237,52 @@ def compile(
#


def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0):
@overload
def compile_(
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: "Literal[0]" = ...,
dont_inherit: int = ...,
) -> CodeType:
raise NotImplementedError()


@overload # noqa: F811
def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = ...,
mode: str = ...,
flags: int = ...,
dont_inherit: int = ...,
) -> Union[CodeType, ast.AST]:
raise NotImplementedError()


def compile_( # noqa: F811
source: Union[str, bytes, ast.mod, ast.AST],
filename: Optional[str] = None,
mode: str = "exec",
flags: int = 0,
dont_inherit: int = 0,
) -> Union[CodeType, ast.AST]:
""" compile the given source to a raw code object,
and maintain an internal cache which allows later
retrieval of the source code for the code object
and any recursively created code objects.
"""
if isinstance(source, ast.AST):
# XXX should Source support having AST?
return compile(source, filename, mode, flags, dont_inherit)
assert filename is not None
co = compile(source, filename, mode, flags, dont_inherit)
assert isinstance(co, (CodeType, ast.AST))
return co
_genframe = sys._getframe(1) # the caller
s = Source(source)
co = s.compile(filename, mode, flags, _genframe=_genframe)
return co
return s.compile(filename, mode, flags, _genframe=_genframe)


def getfslineno(obj):
def getfslineno(obj) -> Tuple[Union[str, py.path.local], int]:
""" Return source location (path, lineno) for the given object.
If the source cannot be determined return ("", -1).

Expand Down Expand Up @@ -321,7 +379,7 @@ def getstatementrange_ast(
# don't produce duplicate warnings when compiling source to find ast
with warnings.catch_warnings():
warnings.simplefilter("ignore")
astnode = compile(content, "source", "exec", _AST_FLAG)
astnode = ast.parse(content, "source", "exec")

start, end = get_statement_startend2(lineno, astnode)
# we need to correct the end:
Expand Down
6 changes: 3 additions & 3 deletions src/_pytest/recwarn.py
Expand Up @@ -57,7 +57,7 @@ def deprecated_call(func=None, *args, **kwargs):

@overload
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
*,
match: "Optional[Union[str, Pattern]]" = ...
) -> "WarningsChecker":
Expand All @@ -66,7 +66,7 @@ def warns(

@overload # noqa: F811
def warns( # noqa: F811
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
func: Callable,
*args: Any,
match: Optional[Union[str, "Pattern"]] = ...,
Expand All @@ -76,7 +76,7 @@ def warns( # noqa: F811


def warns( # noqa: F811
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
expected_warning: Optional[Union["Type[Warning]", Tuple["Type[Warning]", ...]]],
*args: Any,
match: Optional[Union[str, "Pattern"]] = None,
**kwargs: Any
Expand Down
9 changes: 8 additions & 1 deletion src/_pytest/reports.py
@@ -1,5 +1,6 @@
from io import StringIO
from pprint import pprint
from typing import Any
from typing import List
from typing import Optional
from typing import Tuple
Expand All @@ -17,6 +18,7 @@
from _pytest._code.code import ReprLocals
from _pytest._code.code import ReprTraceback
from _pytest._code.code import TerminalRepr
from _pytest.compat import TYPE_CHECKING
from _pytest.nodes import Node
from _pytest.outcomes import skip
from _pytest.pathlib import Path
Expand All @@ -41,9 +43,14 @@ class BaseReport:
sections = [] # type: List[Tuple[str, str]]
nodeid = None # type: str

def __init__(self, **kw):
def __init__(self, **kw: Any) -> None:
self.__dict__.update(kw)

if TYPE_CHECKING:
# Can have arbitrary fields given to __init__().
def __getattr__(self, key: str) -> Any:
raise NotImplementedError()

def toterminal(self, out) -> None:
if hasattr(self, "node"):
out.line(getslaveinfoline(self.node)) # type: ignore
Expand Down
16 changes: 16 additions & 0 deletions testing/code/test_source.py
Expand Up @@ -4,10 +4,13 @@
import ast
import inspect
import sys
from types import CodeType
from typing import Any
from typing import Dict
from typing import Optional

import py

import _pytest._code
import pytest
from _pytest._code import Source
Expand Down Expand Up @@ -147,6 +150,10 @@ def test_getrange(self) -> None:
assert len(x.lines) == 2
assert str(x) == "def f(x):\n pass"

def test_getrange_step_not_supported(self) -> None:
with pytest.raises(IndexError, match=r"step"):
self.source[::2]

def test_getline(self) -> None:
x = self.source[0]
assert x == "def f(x):"
Expand Down Expand Up @@ -449,6 +456,14 @@ def test_idem_compile_and_getsource() -> None:
assert src == expected


def test_compile_ast() -> None:
# We don't necessarily want to support this.
# This test was added just for coverage.
stmt = ast.parse("def x(): pass")
co = _pytest._code.compile(stmt, filename="foo.py")
assert isinstance(co, CodeType)


def test_findsource_fallback() -> None:
from _pytest._code.source import findsource

Expand Down Expand Up @@ -488,6 +503,7 @@ def f(x) -> None:

fspath, lineno = getfslineno(f)

assert isinstance(fspath, py.path.local)
assert fspath.basename == "test_source.py"
assert lineno == f.__code__.co_firstlineno - 1 # see findsource

Expand Down