Skip to content

Commit

Permalink
#11692 Cleanup some of trial runner (#11693)
Browse files Browse the repository at this point in the history
  • Loading branch information
exarkun committed Sep 27, 2022
2 parents 47f4763 + e5bb8f9 commit 0929f3f
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 236 deletions.
Empty file.
7 changes: 3 additions & 4 deletions src/twisted/python/modules.py
Expand Up @@ -64,7 +64,6 @@

# let's try to keep path imports to a minimum...
from os.path import dirname, split as splitpath
from typing import cast

from zope.interface import Interface, implementer

Expand Down Expand Up @@ -262,7 +261,7 @@ def __init__(self, name, onObject, loaded, pythonValue):
@param loaded: always True, for now
@param pythonValue: the value of the attribute we're pointing to.
"""
self.name = name
self.name: str = name
self.onObject = onObject
self._loaded = loaded
self.pythonValue = pythonValue
Expand Down Expand Up @@ -318,7 +317,7 @@ def __init__(self, name, filePath, pathEntry):
"""
_name = nativeString(name)
assert not _name.endswith(".__init__")
self.name = _name
self.name: str = _name
self.filePath = filePath
self.parentPath = filePath.parent()
self.pathEntry = pathEntry
Expand Down Expand Up @@ -397,7 +396,7 @@ def __eq__(self, other: object) -> bool:
PythonModules with the same name are equal.
"""
if isinstance(other, PythonModule):
return cast(bool, other.name == self.name)
return other.name == self.name
return NotImplemented

def walkModules(self, importPackages=False):
Expand Down
2 changes: 1 addition & 1 deletion src/twisted/python/reflect.py
Expand Up @@ -348,7 +348,7 @@ def filenameToModuleName(fn):
return modName


def qual(clazz):
def qual(clazz: Type[object]) -> str:
"""
Return full import path of a class.
"""
Expand Down
39 changes: 16 additions & 23 deletions src/twisted/scripts/trial.py
Expand Up @@ -10,15 +10,19 @@
import random
import sys
import time
import trace
import warnings
from typing import NoReturn, Optional, Type

from twisted import plugin
from twisted.application import app
from twisted.internet import defer
from twisted.python import failure, reflect, usage
from twisted.python.filepath import FilePath
from twisted.python.reflect import namedModule
from twisted.trial import itrial, reporter, runner
from twisted.trial import itrial, runner
from twisted.trial._dist.disttrial import DistTrialRunner
from twisted.trial.unittest import TestSuite

# Yea, this is stupid. Leave it for command-line compatibility for a
# while, though.
Expand Down Expand Up @@ -231,8 +235,7 @@ class _BasicOptions:
],
)

fallbackReporter = reporter.TreeReporter
tracer = None
tracer: Optional[trace.Trace] = None

def __init__(self):
self["tests"] = []
Expand Down Expand Up @@ -275,8 +278,6 @@ def opt_coverage(self):
Generate coverage information in the coverage file in the
directory specified by the temp-directory option.
"""
import trace

self.tracer = trace.Trace(count=1, trace=0)
sys.settrace(self.tracer.globaltrace)
self["coverage"] = True
Expand Down Expand Up @@ -474,10 +475,6 @@ class Options(_BasicOptions, usage.Options, app.ReactorSelectionMixin):
_workerFlags = ["disablegc", "force-gc", "coverage"]
_workerParameters = ["recursionlimit", "reactor", "without-module"]

fallbackReporter = reporter.TreeReporter
extra = None
tracer = None

def opt_jobs(self, number):
"""
Number of local workers to run, a strictly positive integer.
Expand Down Expand Up @@ -523,21 +520,21 @@ def postOptions(self):
failure.DO_POST_MORTEM = False


def _initialDebugSetup(config):
def _initialDebugSetup(config: Options) -> None:
# do this part of debug setup first for easy debugging of import failures
if config["debug"]:
failure.startDebugMode()
if config["debug"] or config["debug-stacktraces"]:
defer.setDebugging(True)


def _getSuite(config):
def _getSuite(config: Options) -> TestSuite:
loader = _getLoader(config)
recurse = not config["no-recurse"]
return loader.loadByNames(config["tests"], recurse=recurse)


def _getLoader(config):
def _getLoader(config: Options) -> runner.TestLoader:
loader = runner.TestLoader()
if config["random"]:
randomer = random.Random()
Expand Down Expand Up @@ -584,16 +581,14 @@ class _DebuggerNotFound(Exception):
"""


def _makeRunner(config):
def _makeRunner(config: Options) -> runner._Runner:
"""
Return a trial runner class set up with the parameters extracted from
C{config}.
@return: A trial runner instance.
@rtype: L{runner.TrialRunner} or C{DistTrialRunner} depending on the
configuration.
"""
cls = runner.TrialRunner
cls: Type[runner._Runner] = runner.TrialRunner
args = {
"reporterFactory": config["reporter"],
"tracebackFormat": config["tbformat"],
Expand All @@ -606,8 +601,6 @@ def _makeRunner(config):
if config["dry-run"]:
args["mode"] = runner.TrialRunner.DRY_RUN
elif config["jobs"]:
from twisted.trial._dist.disttrial import DistTrialRunner

cls = DistTrialRunner
args["maxWorkers"] = config["jobs"]
args["workerArguments"] = config._getWorkerArguments()
Expand All @@ -632,7 +625,7 @@ def _makeRunner(config):
return cls(**args)


def run():
def run() -> NoReturn:
if len(sys.argv) == 1:
sys.argv.append("--help")
config = Options()
Expand All @@ -649,13 +642,13 @@ def run():

suite = _getSuite(config)
if config["until-failure"]:
test_result = trialRunner.runUntilFailure(suite)
testResult = trialRunner.runUntilFailure(suite)
else:
test_result = trialRunner.run(suite)
testResult = trialRunner.run(suite)
if config.tracer:
sys.settrace(None)
results = config.tracer.results()
results.write_results(
show_missing=1, summary=False, coverdir=config.coverdir().path
show_missing=True, summary=False, coverdir=config.coverdir().path
)
sys.exit(not test_result.wasSuccessful())
sys.exit(not testResult.wasSuccessful())
9 changes: 6 additions & 3 deletions src/twisted/trial/_asyncrunner.py
Expand Up @@ -10,6 +10,7 @@
import doctest
import gc
import unittest as pyunit
from typing import Iterator, Union

from zope.interface import implementer

Expand Down Expand Up @@ -160,14 +161,16 @@ def run(self, result):
components.registerAdapter(_BrokenIDTestCaseAdapter, _docTestCase, itrial.ITestCase)


def _iterateTests(testSuiteOrCase):
def _iterateTests(
testSuiteOrCase: Union[pyunit.TestCase, pyunit.TestSuite]
) -> Iterator[itrial.ITestCase]:
"""
Iterate through all of the test cases in C{testSuiteOrCase}.
"""
try:
suite = iter(testSuiteOrCase)
suite = iter(testSuiteOrCase) # type: ignore[arg-type]
except TypeError:
yield testSuiteOrCase
yield testSuiteOrCase # type: ignore[misc]
else:
for test in suite:
yield from _iterateTests(test)
44 changes: 24 additions & 20 deletions src/twisted/trial/_dist/disttrial.py
Expand Up @@ -14,7 +14,7 @@
from functools import partial
from os.path import isabs
from typing import Awaitable, Callable, Iterable, List, Sequence, TextIO, Union, cast
from unittest import TestResult, TestSuite
from unittest import TestCase, TestSuite

from attrs import define, field, frozen
from attrs.converters import default_if_none
Expand Down Expand Up @@ -287,7 +287,7 @@ class DistTrialRunner:
``False`` to run through the whole suite and report all of the results
at the end.
@ivar _stream: stream which the reporter will use.
@ivar stream: stream which the reporter will use.
@ivar _reporterFactory: the reporter class to be used.
"""
Expand All @@ -307,7 +307,8 @@ class DistTrialRunner:
converter=default_if_none(factory=_defaultReactor), # type: ignore [misc]
)
# mypy doesn't understand the converter
_stream: TextIO = field(default=None, converter=default_if_none(sys.stdout)) # type: ignore [misc]
stream: TextIO = field(default=None, converter=default_if_none(sys.stdout)) # type: ignore [misc]

_tracebackFormat: str = "default"
_realTimeErrors: bool = False
_uncleanWarnings: bool = False
Expand All @@ -320,7 +321,7 @@ def _makeResult(self) -> DistReporter:
Make reporter factory, and wrap it with a L{DistReporter}.
"""
reporter = self._reporterFactory(
self._stream, self._tracebackFormat, realtime=self._realTimeErrors
self.stream, self._tracebackFormat, realtime=self._realTimeErrors
)
if self._uncleanWarnings:
reporter = UncleanWarningsReporterWrapper(reporter)
Expand Down Expand Up @@ -365,13 +366,13 @@ async def task(case):

async def runAsync(
self,
suite: TestSuite,
suite: Union[TestCase, TestSuite],
untilFailure: bool = False,
) -> DistReporter:
"""
Spawn local worker processes and load tests. After that, run them.
@param suite: A tests suite to be run.
@param suite: A test or suite to be run.
@param untilFailure: If C{True}, continue to run the tests until they
fail.
Expand All @@ -396,7 +397,7 @@ async def runAsync(
# Announce that we're beginning. countTestCases result is preferred
# (over len(testCases)) because testCases may contain synthetic cases
# for error reporting purposes.
self._stream.write(f"Running {suite.countTestCases()} tests.\n")
self.stream.write(f"Running {suite.countTestCases()} tests.\n")

# Start the worker pool.
startedPool = await poolStarter.start(self._reactor)
Expand All @@ -410,7 +411,7 @@ async def runAndReport(n: int) -> DistReporter:
if untilFailure:
# If and only if we're running the suite more than once,
# provide a report about which run this is.
self._stream.write(f"Test Pass {n + 1}\n")
self.stream.write(f"Test Pass {n + 1}\n")

result = self._makeResult()

Expand Down Expand Up @@ -439,19 +440,14 @@ async def runAndReport(n: int) -> DistReporter:
# Shut down the worker pool.
await startedPool.join()

def run(self, suite: TestSuite, untilFailure: bool = False) -> TestResult:
"""
Run a reactor and a test suite.
@param suite: The test suite to run.
"""
def _run(self, test: Union[TestCase, TestSuite], untilFailure: bool) -> IReporter:
result: Union[Failure, DistReporter]

def capture(r):
nonlocal result
result = r

d = Deferred.fromCoroutine(self.runAsync(suite, untilFailure))
d = Deferred.fromCoroutine(self.runAsync(test, untilFailure))
d.addBoth(capture)
d.addBoth(lambda ignored: self._reactor.stop())
self._reactor.run()
Expand All @@ -464,14 +460,22 @@ def capture(r):
# certainly a DistReporter at this point.
assert isinstance(result, DistReporter)

# Unwrap the DistReporter to give the caller some regular TestResult
# Unwrap the DistReporter to give the caller some regular IReporter
# object. DistReporter isn't type annotated correctly so fix it here.
return cast(TestResult, result.original)
return cast(IReporter, result.original)

def run(self, test: Union[TestCase, TestSuite]) -> IReporter:
"""
Run a reactor and a test suite.
@param test: The test or suite to run.
"""
return self._run(test, untilFailure=False)

def runUntilFailure(self, suite):
def runUntilFailure(self, test: Union[TestCase, TestSuite]) -> IReporter:
"""
Run the tests with local worker processes until they fail.
@param suite: A tests suite to be run.
@param test: The test or suite to run.
"""
return self.run(suite, untilFailure=True)
return self._run(test, untilFailure=True)

0 comments on commit 0929f3f

Please sign in to comment.