diff --git a/docs/api_reference.md b/docs/api_reference.md index 5d77d3dc..5579eb6d 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -47,7 +47,7 @@ def case_hi(): - `id`: the custom pytest id that should be used when this case is active. Replaces the deprecated `@case_name` decorator from v1. If no id is provided, the id is generated from case functions by removing their prefix, see [`@parametrize_with_cases(prefix='case_')`](#parametrize_with_cases). - `tags`: custom tags to be used for filtering in [`@parametrize_with_cases(has_tags)`](#parametrize_with_cases). Replaces the deprecated `@case_tags` and `@target` decorators. - + - `marks`: optional pytest marks to add on the case. Note that decorating the function directly with the mark also works, and if marks are provided in both places they are merged. @@ -280,7 +280,7 @@ Note that `@parametrize_with_cases` collection and parameter creation steps are ```python # Collect all cases -cases_funs = get_all_cases(f, cases=cases, prefix=prefix, +cases_funs = get_all_cases(f, cases=cases, prefix=prefix, glob=glob, has_tag=has_tag, filter=filter) # Transform the various functions found @@ -335,16 +335,32 @@ Note that you can get the same contents directly by using the [`current_cases`]( ### `get_all_cases` ```python -def get_all_cases(parametrization_target: Callable, +def get_all_cases(parametrization_target: Callable = None, cases: Union[Callable, Type, ModuleRef] = None, prefix: str = 'case_', glob: str = None, has_tag: Union[str, Iterable[str]] = None, filter: Callable[[Callable], bool] = None ) -> List[Callable]: +``` +Collect all cases as used with [`@parametrize_with_cases`](#parametrize_with_cases). See [`@parametrize_with_cases`](#parametrize_with_cases) for more details on the parameters. +This can be used to lists all desired cases for a given `parametrization_target` (a test function or a fixture) which may be convenient for debugging purposes. + +```python +# Get the cases for f that are defined in the current file +cases = get_all_cases(f, cases=".") + +# Get the cases from cases_xyz.py or test_xyz_cases.py +import test.test_xyz +xyz_cases = get_all_cases(test.test_xyz) + +# Can be used to filter explict cases, in which case no parametrization_target is needed +filtered_cases = get_all_cases(cases=[case_1, case_2, case_3], has_tag=["banana"]) ``` -Lists all desired cases for a given `parametrization_target` (a test function or a fixture). This function may be convenient for debugging purposes. See [`@parametrize_with_cases`](#parametrize_with_cases) for details on the parameters. + - If using a `cases` argument that requires module information, such as `"."` `AUTO` or a relative module like `".xyz"`, the value of `parametrization_target` will be used to to determine the context. + If `None` or simply left empty, it will use the module from which `get_all_cases` was called. + You can pass an explicit module object or a function, in which case the module in which it's defined will be used. ### `get_parametrize_args` diff --git a/src/pytest_cases/case_parametrizer_new.py b/src/pytest_cases/case_parametrizer_new.py index e0908205..1ac9ec43 100644 --- a/src/pytest_cases/case_parametrizer_new.py +++ b/src/pytest_cases/case_parametrizer_new.py @@ -37,7 +37,7 @@ from .fixture_core1_unions import USED, NOT_USED from .fixture_core2 import CombinedFixtureParamValue, fixture -from .fixture__creation import check_name_available, CHANGE +from .fixture__creation import check_name_available, get_caller_module, CHANGE from .fixture_parametrize_plus import fixture_ref, _parametrize_plus, FixtureParamAlternative, ParamAlternative, \ SingleParamAlternative, MultiParamAlternative, FixtureRefItem @@ -205,19 +205,20 @@ def _glob_name_filter(case_fun): return _glob_name_filter -def get_all_cases(parametrization_target, # type: Callable - cases=None, # type: Union[Callable, Type, ModuleRef] - prefix=CASE_PREFIX_FUN, # type: str - glob=None, # type: str - has_tag=None, # type: Union[str, Iterable[str]] - filter=None # type: Callable[[Callable], bool] # noqa +def get_all_cases(parametrization_target=None, # type: Callable + cases=None, # type: Union[Callable, Type, ModuleRef] + prefix=CASE_PREFIX_FUN, # type: str + glob=None, # type: str + has_tag=None, # type: Union[str, Iterable[str]] + filter=None # type: Callable[[Callable], bool] # noqa ): # type: (...) -> List[Callable] """ Lists all desired cases for a given `parametrization_target` (a test function or a fixture). This function may be convenient for debugging purposes. See `@parametrize_with_cases` for details on the parameters. - :param parametrization_target: a test function + :param parametrization_target: either an explicit module object or a function or None. If it's a function, it will + use the module it is defined in. If None is given, it will just get the module it was called from. :param cases: a case function, a class containing cases, a module or a module name string (relative module names accepted). Or a list of such items. You may use `THIS_MODULE` or `'.'` to include current module. `AUTO` (default) means that the module named `test__cases.py` will be loaded, where `test_.py` is @@ -266,7 +267,16 @@ def get_all_cases(parametrization_target, # type: Callable filters += (filter,) # parent package - caller_module_name = getattr(parametrization_target, '__module__', None) + if parametrization_target is None: + parametrization_target = get_caller_module() + + if ismodule(parametrization_target): + caller_module_name = parametrization_target.__name__ + elif callable(parametrization_target): + caller_module_name = getattr(parametrization_target, '__module__', None) + else: + raise ValueError("Can't handle parametrization_target=%s" % parametrization_target) + parent_pkg_name = '.'.join(caller_module_name.split('.')[:-1]) if caller_module_name is not None else None # start collecting all cases @@ -293,6 +303,7 @@ def get_all_cases(parametrization_target, # type: Callable elif c is THIS_MODULE or c == '.': c = caller_module_name + new_cases = extract_cases_from_module(c, package_name=parent_pkg_name, case_fun_prefix=prefix) cases_funs += new_cases @@ -635,31 +646,41 @@ def _get_fixture_cases(module_or_class # type: Union[ModuleType, Type] return cache, imported_fixtures_list -def import_default_cases_module(f): +def import_default_cases_module(context): """ - Implements the `module=AUTO` behaviour of `@parameterize_cases`: based on the decorated test function `f`, - it finds its containing module name "test_.py" and then tries to import the python module - "test__cases.py". + Implements the `module=AUTO` behaviour of `@parameterize_cases`: based on the context + passed in. This can either a object or a decorated test function in which + case it finds its containing module name "test_.py" and then tries to import + the python module "test__cases.py". - If the module is not found it looks for the alternate file `cases_.py`. + If "test__cases.py" module is not found it looks for the alternate + file `cases_.py`. - :param f: the decorated test function + :param f: the decorated test function or a module :return: """ + if ismodule(context): + module_name = context.__name__ + elif hasattr(context, "__module__"): + module_name = context.__module__ + else: + raise ValueError("Can't get module from context %s" % context) + # First try `test__cases.py` - cases_module_name1 = "%s_cases" % f.__module__ + cases_module_name1 = "%s_cases" % module_name + try: cases_module = import_module(cases_module_name1) except ModuleNotFoundError: # Then try `case_.py` - parts = f.__module__.split('.') + parts = module_name.split('.') assert parts[-1][0:5] == 'test_' cases_module_name2 = "%s.cases_%s" % ('.'.join(parts[:-1]), parts[-1][5:]) try: cases_module = import_module(cases_module_name2) except ModuleNotFoundError: # Nothing worked - raise ValueError("Error importing test cases module to parametrize function %r: unable to import AUTO " + raise ValueError("Error importing test cases module to parametrize %r: unable to import AUTO " "cases module %r nor %r. Maybe you wish to import cases from somewhere else ? In that case" "please specify `cases=...`." % (f, cases_module_name1, cases_module_name2)) diff --git a/tests/cases/issues/issue_258/__init__.py b/tests/cases/issues/issue_258/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/cases/issues/issue_258/cases.py b/tests/cases/issues/issue_258/cases.py new file mode 100644 index 00000000..932d04a8 --- /dev/null +++ b/tests/cases/issues/issue_258/cases.py @@ -0,0 +1,12 @@ +# Imported explicitly or with ".cases" +from pytest_cases import case + + +@case +def case_1(): + return "hello ." + + +@case +def case_2(): + return "hi ." diff --git a/tests/cases/issues/issue_258/cases_issue_258.py b/tests/cases/issues/issue_258/cases_issue_258.py new file mode 100644 index 00000000..94aae46c --- /dev/null +++ b/tests/cases/issues/issue_258/cases_issue_258.py @@ -0,0 +1,12 @@ +# Import with AUTO +from pytest_cases import case + + +@case +def case_1(): + return "hello AUTO" + + +@case +def case_2(): + return "hi AUTO" diff --git a/tests/cases/issues/issue_258/cases_other.py b/tests/cases/issues/issue_258/cases_other.py new file mode 100644 index 00000000..9459361f --- /dev/null +++ b/tests/cases/issues/issue_258/cases_other.py @@ -0,0 +1,13 @@ +# Used by passing the corresponding module `test_other` to `get_all_cases` +# `get_all_cases(test_other)` +from pytest_cases import case + + +@case +def case_1(): + return "hello cases_other" + + +@case +def case_2(): + return "hi cases_other" diff --git a/tests/cases/issues/issue_258/test_issue_258.py b/tests/cases/issues/issue_258/test_issue_258.py new file mode 100644 index 00000000..bb1fa37b --- /dev/null +++ b/tests/cases/issues/issue_258/test_issue_258.py @@ -0,0 +1,80 @@ +from pytest_cases import (AUTO, case, get_all_cases, parametrize, + parametrize_with_cases) + + +@case(tags=["a", "banana"]) +def case_1(): + return "a_banana" + + +@case(tags=["a"]) +def case_2(): + return "a" + + +@case(tags=["b", "banana"]) +def case_3(): + return "b_banana" + + +@case(tags=["b"]) +def case_4(): + return "b" + + +all_cases = get_all_cases(cases=[case_1, case_2, case_3, case_4]) +a_cases = get_all_cases(cases=all_cases, has_tag="a") +b_cases = get_all_cases(cases=all_cases, has_tag="b") +banana_cases = get_all_cases(cases=a_cases + b_cases, has_tag=["banana"]) + + +# Test behaviour with explicit cases and no parametrization target +################################################################## +@parametrize_with_cases("word", cases=all_cases) +def test_all(word): + assert word in ["a", "a_banana", "b", "b_banana"] + + +@parametrize_with_cases("word", cases=a_cases) +def test_a(word): + assert "a" in word + + +@parametrize_with_cases("word", cases=b_cases) +def test_b(word): + assert "b" in word + + +@parametrize_with_cases("word", cases=banana_cases) +def test_banana(word): + assert "banana" in word + + +# Test behaviour with string module ref and AUTO and no parametrization target +############################################################################## +def test_this_module_cases(): + this_module_cases = get_all_cases(cases=".") + assert set(this_module_cases) == {case_1, case_2, case_3, case_4} + + +def test_relative_module_cases(): + relative_import_cases = get_all_cases(cases=".cases") + assert {"hello .", "hi ."} == {f() for f in relative_import_cases} + + +def test_auto_cases(): + auto_import_cases = get_all_cases(cases=AUTO) + assert {"hello AUTO", "hi AUTO"} == {f() for f in auto_import_cases} + + +# Test behaviour with an explicit module parametrization target +############################################################### +from tests.cases.issues.issue_258 import test_other +def test_module_parametrization_auto(): + cases_other_cases = get_all_cases(test_other, cases=AUTO) + assert {"hello cases_other", "hi cases_other"} == {f() for f in cases_other_cases} + + +def test_module_parametrization_this_module(): + test_other_cases = get_all_cases(test_other, cases='.') + assert {"hello test_other", "hi test_other"} == {f() for f in test_other_cases} diff --git a/tests/cases/issues/issue_258/test_other.py b/tests/cases/issues/issue_258/test_other.py new file mode 100644 index 00000000..b7c1c206 --- /dev/null +++ b/tests/cases/issues/issue_258/test_other.py @@ -0,0 +1,11 @@ +from pytest_cases import case + + +@case +def case_1(): + return "hello test_other" + + +@case +def case_2(): + return "hi test_other"