Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Test Case Class specified in Django settings #937

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
104 changes: 60 additions & 44 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import partial

import pytest
from django.utils.module_loading import import_string

from . import live_server_helper
from .django_compat import is_django_unittest
Expand All @@ -19,7 +20,6 @@
_DjangoDbDatabases = Optional[Union["Literal['__all__']", Iterable[str]]]
_DjangoDb = Tuple[bool, bool, _DjangoDbDatabases]


__all__ = [
"django_db_setup",
"db",
Expand All @@ -42,6 +42,18 @@
]


def import_from_string(val, setting_name):
"""
Attempt to import a class from a string representation.
"""
try:
return import_string(val)
except ImportError as e:
msg = "Could not import '%s' for API setting '%s'. %s: %s." \
% (val, setting_name, e.__class__.__name__, e)
raise ImportError(msg)


@pytest.fixture(scope="session")
def django_db_modify_db_settings_tox_suffix() -> None:
skip_if_no_django()
Expand All @@ -64,15 +76,15 @@ def django_db_modify_db_settings_xdist_suffix(request) -> None:

@pytest.fixture(scope="session")
def django_db_modify_db_settings_parallel_suffix(
django_db_modify_db_settings_tox_suffix: None,
django_db_modify_db_settings_xdist_suffix: None,
django_db_modify_db_settings_tox_suffix: None,
django_db_modify_db_settings_xdist_suffix: None,
) -> None:
skip_if_no_django()


@pytest.fixture(scope="session")
def django_db_modify_db_settings(
django_db_modify_db_settings_parallel_suffix: None,
django_db_modify_db_settings_parallel_suffix: None,
) -> None:
skip_if_no_django()

Expand All @@ -94,13 +106,13 @@ def django_db_createdb(request) -> bool:

@pytest.fixture(scope="session")
def django_db_setup(
request,
django_test_environment: None,
django_db_blocker,
django_db_use_migrations: bool,
django_db_keepdb: bool,
django_db_createdb: bool,
django_db_modify_db_settings: None,
request,
django_test_environment: None,
django_db_blocker,
django_db_use_migrations: bool,
django_db_keepdb: bool,
django_db_createdb: bool,
django_db_modify_db_settings: None,
) -> None:
"""Top level fixture to ensure test databases are available"""
from django.test.utils import setup_databases, teardown_databases
Expand Down Expand Up @@ -136,11 +148,12 @@ def teardown_database() -> None:


def _django_db_fixture_helper(
request,
django_db_blocker,
transactional: bool = False,
reset_sequences: bool = False,
request,
django_db_blocker,
transactional: bool = False,
reset_sequences: bool = False,
) -> None:

if is_django_unittest(request):
return

Expand All @@ -155,13 +168,16 @@ def _django_db_fixture_helper(
django_db_blocker.unblock()
request.addfinalizer(django_db_blocker.restore)

import django.test
import django.db

if transactional:
test_case_class = django.test.TransactionTestCase
test_case_classname = request.config.getvalue("transaction_testcase_class") or os.getenv(
"DJANGO_TRANSACTION_TEST_CASE_CLASS"
) or "django.test.TransactionTestCase"
else:
test_case_class = django.test.TestCase
test_case_classname = request.config.getvalue("testcase_class") or os.getenv(
"DJANGO_TEST_CASE_CLASS"
) or "django.test.TestCase"

test_case_class = import_string(test_case_classname)

_reset_sequences = reset_sequences

Expand Down Expand Up @@ -223,9 +239,9 @@ def _set_suffix_to_test_databases(suffix: str) -> None:

@pytest.fixture(scope="function")
def db(
request,
django_db_setup: None,
django_db_blocker,
request,
django_db_setup: None,
django_db_blocker,
) -> None:
"""Require a django test database.

Expand All @@ -243,8 +259,8 @@ def db(
if "django_db_reset_sequences" in request.fixturenames:
request.getfixturevalue("django_db_reset_sequences")
if (
"transactional_db" in request.fixturenames
or "live_server" in request.fixturenames
"transactional_db" in request.fixturenames
or "live_server" in request.fixturenames
):
request.getfixturevalue("transactional_db")
else:
Expand All @@ -253,9 +269,9 @@ def db(

@pytest.fixture(scope="function")
def transactional_db(
request,
django_db_setup: None,
django_db_blocker,
request,
django_db_setup: None,
django_db_blocker,
) -> None:
"""Require a django test database with transaction support.

Expand All @@ -276,9 +292,9 @@ def transactional_db(

@pytest.fixture(scope="function")
def django_db_reset_sequences(
request,
django_db_setup: None,
django_db_blocker,
request,
django_db_setup: None,
django_db_blocker,
) -> None:
"""Require a transactional test database with sequence reset support.

Expand Down Expand Up @@ -332,9 +348,9 @@ def django_username_field(django_user_model) -> str:

@pytest.fixture()
def admin_user(
db: None,
django_user_model,
django_username_field: str,
db: None,
django_user_model,
django_username_field: str,
):
"""A Django admin user.

Expand Down Expand Up @@ -363,8 +379,8 @@ def admin_user(

@pytest.fixture()
def admin_client(
db: None,
admin_user,
db: None,
admin_user,
) -> "django.test.client.Client":
"""A Django test client logged in as an admin user."""
from django.test.client import Client
Expand Down Expand Up @@ -496,11 +512,11 @@ def _live_server_helper(request) -> None:

@contextmanager
def _assert_num_queries(
config,
num: int,
exact: bool = True,
connection=None,
info=None,
config,
num: int,
exact: bool = True,
connection=None,
info=None,
) -> Generator["django.test.utils.CaptureQueriesContext", None, None]:
from django.test.utils import CaptureQueriesContext

Expand Down Expand Up @@ -547,9 +563,9 @@ def django_assert_max_num_queries(pytestconfig):

@contextmanager
def _capture_on_commit_callbacks(
*,
using: Optional[str] = None,
execute: bool = False
*,
using: Optional[str] = None,
execute: bool = False
):
from django.db import DEFAULT_DB_ALIAS, connections
from django.test import TestCase
Expand Down
11 changes: 11 additions & 0 deletions pytest_django/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ def pytest_addoption(parser) -> None:
default=None,
help="Address and port for the live_server fixture.",
)
group.addoption(
"--testcase-class",
default=None,
help="The base TestCase class to patch for use with django. Useful for hypothesis users",
)
group.addoption(
"--transaction-testcase-class",
default=None,
help="The base TransactionTestCase class to patch for use with django. "
"Useful for hypothesis users",
)
parser.addini(
SETTINGS_MODULE_ENV, "Django settings module to use by pytest-django."
)
Expand Down