Skip to content

Commit

Permalink
Replace addfinalizer uses with yield fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
bluetech committed Oct 30, 2023
1 parent 967618e commit d071ff7
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 28 deletions.
3 changes: 2 additions & 1 deletion docs/database.rst
Original file line number Diff line number Diff line change
Expand Up @@ -528,4 +528,5 @@ Put this in ``conftest.py``::
@pytest.fixture
def db_access_without_rollback_and_truncate(request, django_db_setup, django_db_blocker):
django_db_blocker.unblock()
request.addfinalizer(django_db_blocker.restore)
yield
django_db_blocker.restore()
38 changes: 23 additions & 15 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def django_db_setup(
django_db_keepdb: bool,
django_db_createdb: bool,
django_db_modify_db_settings: None,
) -> None:
) -> Generator[None, None, None]:
"""Top level fixture to ensure test databases are available"""
from django.test.utils import setup_databases, teardown_databases

Expand All @@ -127,7 +127,9 @@ def django_db_setup(
**setup_databases_args
)

def teardown_database() -> None:
yield

if not django_db_keepdb:
with django_db_blocker.unblock():
try:
teardown_databases(db_cfg, verbosity=request.config.option.verbose)
Expand All @@ -138,19 +140,17 @@ def teardown_database() -> None:
)
)

if not django_db_keepdb:
request.addfinalizer(teardown_database)


@pytest.fixture()
def _django_db_helper(
request: pytest.FixtureRequest,
django_db_setup: None,
django_db_blocker,
) -> None:
) -> Generator[None, None, None]:
from django import VERSION

if is_django_unittest(request):
yield
return

marker = request.node.get_closest_marker("django_db")
Expand Down Expand Up @@ -183,7 +183,6 @@ def _django_db_helper(
)

django_db_blocker.unblock()
request.addfinalizer(django_db_blocker.restore)

import django.db
import django.test
Expand Down Expand Up @@ -233,13 +232,20 @@ def tearDownClass(cls) -> None:
super(django.test.TestCase, cls).tearDownClass()

PytestDjangoTestCase.setUpClass()
if VERSION >= (4, 0):
request.addfinalizer(PytestDjangoTestCase.doClassCleanups)
request.addfinalizer(PytestDjangoTestCase.tearDownClass)

test_case = PytestDjangoTestCase(methodName="__init__")
test_case._pre_setup()
request.addfinalizer(test_case._post_teardown)

yield

test_case._post_teardown()

PytestDjangoTestCase.tearDownClass()

if VERSION >= (4, 0):
PytestDjangoTestCase.doClassCleanups()

django_db_blocker.restore()


def validate_django_db(marker) -> _DjangoDb:
Expand Down Expand Up @@ -547,12 +553,12 @@ def live_server(request: pytest.FixtureRequest):
) or "localhost"

server = live_server_helper.LiveServer(addr)
request.addfinalizer(server.stop)
return server
yield server
server.stop()


@pytest.fixture(autouse=True, scope="function")
def _live_server_helper(request: pytest.FixtureRequest) -> None:
def _live_server_helper(request: pytest.FixtureRequest) -> Generator[None, None, None]:
"""Helper to make live_server work, internal to pytest-django.
This helper will dynamically request the transactional_db fixture
Expand All @@ -568,13 +574,15 @@ def _live_server_helper(request: pytest.FixtureRequest) -> None:
It will also override settings only for the duration of the test.
"""
if "live_server" not in request.fixturenames:
yield
return

request.getfixturevalue("transactional_db")

live_server = request.getfixturevalue("live_server")
live_server._live_server_modified_settings.enable()
request.addfinalizer(live_server._live_server_modified_settings.disable)
yield
live_server._live_server_modified_settings.disable()


@contextmanager
Expand Down
24 changes: 14 additions & 10 deletions pytest_django/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def get_order_number(test: pytest.Item) -> int:


@pytest.fixture(autouse=True, scope="session")
def django_test_environment(request: pytest.FixtureRequest) -> None:
def django_test_environment(request: pytest.FixtureRequest) -> Generator[None, None, None]:
"""
Ensure that Django is loaded and has its testing environment setup.
Expand All @@ -487,7 +487,11 @@ def django_test_environment(request: pytest.FixtureRequest) -> None:
debug = _get_boolean_value(debug_ini, "django_debug_mode", False)

setup_test_environment(debug=debug)
request.addfinalizer(teardown_test_environment)
yield
teardown_test_environment()

else:
yield


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -587,7 +591,7 @@ def django_mail_dnsname() -> str:


@pytest.fixture(autouse=True, scope="function")
def _django_set_urlconf(request: pytest.FixtureRequest) -> None:
def _django_set_urlconf(request: pytest.FixtureRequest) -> Generator[None, None, None]:
"""Apply the @pytest.mark.urls marker, internal to pytest-django."""
marker = request.node.get_closest_marker("urls")
if marker:
Expand All @@ -601,14 +605,14 @@ def _django_set_urlconf(request: pytest.FixtureRequest) -> None:
clear_url_caches()
set_urlconf(None)

def restore() -> None:
django.conf.settings.ROOT_URLCONF = original_urlconf
# Copy the pattern from
# https://github.com/django/django/blob/main/django/test/signals.py#L152
clear_url_caches()
set_urlconf(None)
yield

request.addfinalizer(restore)
if marker:
django.conf.settings.ROOT_URLCONF = original_urlconf
# Copy the pattern from
# https://github.com/django/django/blob/main/django/test/signals.py#L152
clear_url_caches()
set_urlconf(None)


@pytest.fixture(autouse=True, scope="session")
Expand Down
7 changes: 5 additions & 2 deletions tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Generator

import pytest
from django.db import connection, transaction

Expand Down Expand Up @@ -188,9 +190,10 @@ def test_fixture_clean(self, all_dbs: None) -> None:
assert Item.objects.count() == 0

@pytest.fixture
def fin(self, request: pytest.FixtureRequest, all_dbs: None) -> None:
def fin(self, request: pytest.FixtureRequest, all_dbs: None) -> Generator[None, None, None]:
# This finalizer must be able to access the database
request.addfinalizer(lambda: Item.objects.create(name="spam"))
yield
Item.objects.create(name="spam")

def test_fin(self, fin: None) -> None:
# Check finalizer has db access (teardown will fail if not)
Expand Down

0 comments on commit d071ff7

Please sign in to comment.