diff --git a/src/xdist/dsession.py b/src/xdist/dsession.py index 65907074..9671b2fd 100644 --- a/src/xdist/dsession.py +++ b/src/xdist/dsession.py @@ -15,6 +15,7 @@ from xdist.scheduler import LoadGroupScheduling from xdist.scheduler import LoadScheduling from xdist.scheduler import LoadScopeScheduling +from xdist.scheduler import Scheduling from xdist.scheduler import WorkStealingScheduling from xdist.workermanage import NodeManager @@ -97,17 +98,21 @@ def pytest_collection(self): return True @pytest.hookimpl(trylast=True) - def pytest_xdist_make_scheduler(self, config, log): + def pytest_xdist_make_scheduler(self, config, log) -> Scheduling | None: dist = config.getvalue("dist") - schedulers = { - "each": EachScheduling, - "load": LoadScheduling, - "loadscope": LoadScopeScheduling, - "loadfile": LoadFileScheduling, - "loadgroup": LoadGroupScheduling, - "worksteal": WorkStealingScheduling, - } - return schedulers[dist](config, log) + if dist == "each": + return EachScheduling(config, log) + if dist == "load": + return LoadScheduling(config, log) + if dist == "loadscope": + return LoadScopeScheduling(config, log) + if dist == "loadfile": + return LoadFileScheduling(config, log) + if dist == "loadgroup": + return LoadGroupScheduling(config, log) + if dist == "worksteal": + return WorkStealingScheduling(config, log) + return None @pytest.hookimpl def pytest_runtestloop(self): diff --git a/src/xdist/scheduler/__init__.py b/src/xdist/scheduler/__init__.py index 54be9ad6..b4894732 100644 --- a/src/xdist/scheduler/__init__.py +++ b/src/xdist/scheduler/__init__.py @@ -3,4 +3,5 @@ from xdist.scheduler.loadfile import LoadFileScheduling as LoadFileScheduling from xdist.scheduler.loadgroup import LoadGroupScheduling as LoadGroupScheduling from xdist.scheduler.loadscope import LoadScopeScheduling as LoadScopeScheduling +from xdist.scheduler.protocol import Scheduling as Scheduling from xdist.scheduler.worksteal import WorkStealingScheduling as WorkStealingScheduling diff --git a/src/xdist/scheduler/each.py b/src/xdist/scheduler/each.py index dab0ff8a..47f7add3 100644 --- a/src/xdist/scheduler/each.py +++ b/src/xdist/scheduler/each.py @@ -103,6 +103,9 @@ def mark_test_complete(self, node, item_index, duration=0): def mark_test_pending(self, item): raise NotImplementedError() + def remove_pending_tests_from_node(self, node, indices): + raise NotImplementedError() + def remove_node(self, node): # KeyError if we didn't get an add_node() yet pending = self.node2pending.pop(node) diff --git a/src/xdist/scheduler/load.py b/src/xdist/scheduler/load.py index bf1316b0..fb9bdfdc 100644 --- a/src/xdist/scheduler/load.py +++ b/src/xdist/scheduler/load.py @@ -160,6 +160,9 @@ def mark_test_pending(self, item): for node in self.node2pending: self.check_schedule(node) + def remove_pending_tests_from_node(self, node, indices): + raise NotImplementedError() + def check_schedule(self, node, duration=0): """Maybe schedule new items on the node. diff --git a/src/xdist/scheduler/loadscope.py b/src/xdist/scheduler/loadscope.py index 076840cc..7c66ed51 100644 --- a/src/xdist/scheduler/loadscope.py +++ b/src/xdist/scheduler/loadscope.py @@ -244,6 +244,9 @@ def mark_test_complete(self, node, item_index, duration=0): def mark_test_pending(self, item): raise NotImplementedError() + def remove_pending_tests_from_node(self, node, indices): + raise NotImplementedError() + def _assign_work_unit(self, node): """Assign a work unit to a node.""" assert self.workqueue diff --git a/src/xdist/scheduler/protocol.py b/src/xdist/scheduler/protocol.py new file mode 100644 index 00000000..0435d15b --- /dev/null +++ b/src/xdist/scheduler/protocol.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from typing import Protocol +from typing import Sequence + +from xdist.workermanage import WorkerController + + +class Scheduling(Protocol): + @property + def nodes(self) -> list[WorkerController]: ... + + @property + def collection_is_completed(self) -> bool: ... + + @property + def tests_finished(self) -> bool: ... + + @property + def has_pending(self) -> bool: ... + + def add_node(self, node: WorkerController) -> None: ... + + def add_node_collection( + self, + node: WorkerController, + collection: Sequence[str], + ) -> None: ... + + def mark_test_complete( + self, + node: WorkerController, + item_index: int, + duration: float = 0, + ) -> None: ... + + def mark_test_pending(self, item: str) -> None: ... + + def remove_pending_tests_from_node( + self, + node: WorkerController, + indices: Sequence[int], + ) -> None: ... + + def remove_node(self, node: WorkerController) -> str | None: ... + + def schedule(self) -> None: ...