Skip to content

Commit

Permalink
Allow arbitrary attributes in BaseReport, and fix check_untyped_defs …
Browse files Browse the repository at this point in the history
…in test_runner
  • Loading branch information
bluetech committed Jan 1, 2020
1 parent a10b03e commit 6175738
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
8 changes: 7 additions & 1 deletion src/_pytest/reports.py
@@ -1,5 +1,6 @@
from io import StringIO
from pprint import pprint
from typing import Any
from typing import List
from typing import Optional
from typing import Tuple
Expand Down Expand Up @@ -41,9 +42,14 @@ class BaseReport:
sections = [] # type: List[Tuple[str, str]]
nodeid = None # type: str

def __init__(self, **kw):
def __init__(self, **kw: Any) -> None:
self.__dict__.update(kw)

if False: # TYPE_CHECKING
# Can have arbitrary fields given to __init__().
def __getattr__(self, key: str) -> Any:
raise NotImplementedError()

def toterminal(self, out) -> None:
if hasattr(self, "node"):
out.line(getslaveinfoline(self.node)) # type: ignore
Expand Down
25 changes: 18 additions & 7 deletions testing/test_runner.py
Expand Up @@ -2,6 +2,9 @@
import os
import sys
import types
from typing import Dict
from typing import List
from typing import Tuple

import py

Expand All @@ -16,6 +19,9 @@
from _pytest.outcomes import OutcomeException
from _pytest.outcomes import Skipped

if False: # TYPE_CHECKING
from typing import Type


class TestSetupState:
def test_setup(self, testdir) -> None:
Expand Down Expand Up @@ -464,18 +470,22 @@ class TestClass(object):
assert res[1].name == "TestClass"


reporttypes = [reports.BaseReport, reports.TestReport, reports.CollectReport]
reporttypes = [
reports.BaseReport,
reports.TestReport,
reports.CollectReport,
] # type: List[Type[reports.BaseReport]]


@pytest.mark.parametrize(
"reporttype", reporttypes, ids=[x.__name__ for x in reporttypes]
)
def test_report_extra_parameters(reporttype) -> None:
def test_report_extra_parameters(reporttype: "Type[reports.BaseReport]") -> None:
if hasattr(inspect, "signature"):
args = list(inspect.signature(reporttype.__init__).parameters.keys())[1:]
else:
args = inspect.getargspec(reporttype.__init__)[0][1:]
basekw = dict.fromkeys(args, [])
basekw = dict.fromkeys(args, []) # type: Dict[str, List[object]]
report = reporttype(newthing=1, **basekw)
assert report.newthing == 1

Expand Down Expand Up @@ -857,11 +867,11 @@ def test_makereport_getsource_dynamic_code(testdir, monkeypatch) -> None:

original_findsource = inspect.findsource

def findsource(obj, *args, **kwargs):
def findsource(obj):
# Can be triggered by dynamically created functions
if obj.__name__ == "foo":
raise IndexError()
return original_findsource(obj, *args, **kwargs)
return original_findsource(obj)

monkeypatch.setattr(inspect, "findsource", findsource)

Expand Down Expand Up @@ -901,6 +911,7 @@ def runtest(self):
pass
# Check that exception info is stored on sys
assert sys.last_type is IndexError
assert isinstance(sys.last_value, IndexError)
assert sys.last_value.args[0] == "TEST"
assert sys.last_traceback

Expand All @@ -913,7 +924,7 @@ def runtest(self):


def test_current_test_env_var(testdir, monkeypatch) -> None:
pytest_current_test_vars = []
pytest_current_test_vars = [] # type: List[Tuple[str, str]]
monkeypatch.setattr(
sys, "pytest_current_test_vars", pytest_current_test_vars, raising=False
)
Expand Down Expand Up @@ -1026,5 +1037,5 @@ def func():
"Perhaps you meant to use a mark?"
)
with pytest.raises(TypeError) as excinfo:
OutcomeException(func)
OutcomeException(func) # type: ignore
assert str(excinfo.value) == expected

0 comments on commit 6175738

Please sign in to comment.