Skip to content

Commit

Permalink
refactor: adjust type hintings to satisfy mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
skarzi committed May 4, 2022
1 parent c63914b commit 773f7b5
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 35 deletions.
5 changes: 5 additions & 0 deletions pyproject.toml
Expand Up @@ -7,6 +7,7 @@ line-length = 120
target-version = ['py37', 'py38', 'py39', 'py310']

[tool.mypy]
exclude = ['docs/']
allow_redefinition = false
check_untyped_defs = true
disallow_untyped_decorators = true
Expand All @@ -26,3 +27,7 @@ warn_unreachable = true
warn_no_return = true
pretty = true
show_error_codes = true

[[tool.mypy.overrides]]
module = ["tests.*"]
disallow_untyped_decorators = false
5 changes: 4 additions & 1 deletion pytest_factoryboy/codegen.py
Expand Up @@ -11,6 +11,7 @@
from dataclasses import field, dataclass
from functools import lru_cache
from types import ModuleType
from typing import Any

import mako.template
from appdirs import AppDirs
Expand All @@ -26,7 +27,7 @@
class FixtureDef:
name: str
function_name: typing.Literal["model_fixture", "attr_fixture", "factory_fixture", "subfactory_fixture"]
function_kwargs: dict = field(default_factory=dict)
function_kwargs: dict[str, Any] = field(default_factory=dict)
deps: list[str] = field(default_factory=list)
related: list[str] = field(default_factory=list)

Expand Down Expand Up @@ -122,7 +123,9 @@ def make_module(code: str, module_name: str, package_name: str) -> ModuleType:
tmp_module_path.write_text(code)
name = f"{package_name}.{module_name}"
spec = importlib.util.spec_from_file_location(name, tmp_module_path)
assert spec # NOTE: satisfy `mypy`
mod = importlib.util.module_from_spec(spec)
assert spec.loader # NOTE: satisfy `mypy`
spec.loader.exec_module(mod)
return mod

Expand Down
2 changes: 2 additions & 0 deletions pytest_factoryboy/compat.py
Expand Up @@ -2,6 +2,8 @@
import sys
import pathlib

__all__ = ("PostGenerationContext", "path_with_stem")

try:
from factory.declarations import PostGenerationContext
except ImportError: # factory_boy < 3.2.0
Expand Down
45 changes: 23 additions & 22 deletions pytest_factoryboy/fixture.py
Expand Up @@ -14,17 +14,17 @@

from .codegen import make_fixture_model_module, FixtureDef
from .compat import PostGenerationContext
from typing import TYPE_CHECKING, overload
from typing_extensions import Protocol
from typing import TYPE_CHECKING, overload, cast
from typing_extensions import Protocol, TypeAlias

if TYPE_CHECKING:
from typing import Any, Callable, TypeVar
from _pytest.fixtures import FixtureRequest
from _pytest.fixtures import SubRequest, FixtureFunction
from factory.builder import BuildStep
from factory.declarations import PostGeneration
from factory.declarations import PostGenerationContext

FactoryType = type[factory.Factory]
FactoryType: TypeAlias = factory.Factory
T = TypeVar("T")
F = TypeVar("F", bound=FactoryType)

Expand All @@ -37,9 +37,9 @@ class DeferredFunction:
name: str
factory: FactoryType
is_related: bool
function: Callable[[FixtureRequest], Any]
function: Callable[[SubRequest], Any]

def __call__(self, request: FixtureRequest) -> Any:
def __call__(self, request: SubRequest) -> Any:
return self.function(request)


Expand All @@ -51,7 +51,7 @@ def __call__(self, factory_class: F, _name: str | None = None, **kwargs: Any) ->


@overload
def register(
def register( # type: ignore[misc]
factory_class: None = None,
_name: str | None = None,
**kwargs: Any,
Expand Down Expand Up @@ -177,7 +177,7 @@ def register(
return factory_class


def inject_into_caller(name: str, function: Callable, locals_: dict[str, Any]) -> None:
def inject_into_caller(name: str, function: Callable[..., Any], locals_: dict[str, Any]) -> None:
"""Inject a function into the caller's locals, making sure that the function will work also within classes."""
# We need to check if the caller frame is a class, since in that case the first argument is the class itself.
# In that case, we can apply the staticmethod() decorator to the injected function, so that the first param
Expand All @@ -191,7 +191,7 @@ def inject_into_caller(name: str, function: Callable, locals_: dict[str, Any]) -
# Therefore, we can just check for __qualname__ to figure out if we are in a class, and apply the @staticmethod.
is_class_or_function = "__qualname__" in locals_
if is_class_or_function:
function = staticmethod(function)
function = staticmethod(function) # type: ignore[assignment]

locals_[name] = function

Expand Down Expand Up @@ -238,20 +238,21 @@ def is_dep(value: Any) -> bool:
]


def evaluate(request: FixtureRequest, value: LazyFixture | Any) -> Any:
def evaluate(request: SubRequest, value: LazyFixture | Any) -> Any:
"""Evaluate the declaration (lazy fixtures, etc)."""
return value.evaluate(request) if isinstance(value, LazyFixture) else value


def model_fixture(request: FixtureRequest, factory_name: str) -> Any:
def model_fixture(request: SubRequest, factory_name: str) -> Any:
"""Model fixture implementation."""
factoryboy_request = request.getfixturevalue("factoryboy_request")

# Try to evaluate as much post-generation dependencies as possible
factoryboy_request.evaluate(request)

fixture_name = str(request.fixturename)
factory_class: FactoryType = request.getfixturevalue(factory_name)
prefix = "".join((request.fixturename, SEPARATOR))
prefix = "".join((fixture_name, SEPARATOR))

# Create model fixture instance

Expand Down Expand Up @@ -279,7 +280,7 @@ class Factory(factory_class):

# Cache the instance value on pytest level so that the fixture can be resolved before the return
request._fixturedef.cached_result = (instance, 0, None)
request._fixture_defs[request.fixturename] = request._fixturedef
request._fixture_defs[fixture_name] = request._fixturedef

# Defer post-generation declarations
deferred: list[DeferredFunction] = []
Expand All @@ -289,7 +290,7 @@ class Factory(factory_class):
decl = factory_class._meta.post_declarations.declarations[attr]

if isinstance(decl, factory.RelatedFactory):
deferred.append(make_deferred_related(factory_class, request.fixturename, attr))
deferred.append(make_deferred_related(factory_class, fixture_name, attr))
else:
argname = "".join((prefix, attr))
extra = {}
Expand All @@ -309,7 +310,7 @@ class Factory(factory_class):
extra=extra,
)
deferred.append(
make_deferred_postgen(step, factory_class, request.fixturename, instance, attr, decl, postgen_context)
make_deferred_postgen(step, factory_class, fixture_name, instance, attr, decl, postgen_context)
)
factoryboy_request.defer(deferred)

Expand All @@ -329,7 +330,7 @@ def make_deferred_related(factory: FactoryType, fixture: str, attr: str) -> Defe
"""
name = SEPARATOR.join((fixture, attr))

def deferred_impl(request: FixtureRequest) -> Any:
def deferred_impl(request: SubRequest) -> Any:
return request.getfixturevalue(name)

return DeferredFunction(
Expand Down Expand Up @@ -362,7 +363,7 @@ def make_deferred_postgen(
"""
name = SEPARATOR.join((fixture, attr))

def deferred_impl(request: FixtureRequest) -> Any:
def deferred_impl(request: SubRequest) -> Any:
return declaration.call(instance, step, context)

return DeferredFunction(
Expand All @@ -373,17 +374,17 @@ def deferred_impl(request: FixtureRequest) -> Any:
)


def factory_fixture(request: FixtureRequest, factory_class: F) -> F:
def factory_fixture(request: SubRequest, factory_class: F) -> F:
"""Factory fixture implementation."""
return factory_class


def attr_fixture(request: FixtureRequest, value: T) -> T:
def attr_fixture(request: SubRequest, value: T) -> T:
"""Attribute fixture implementation."""
return value


def subfactory_fixture(request: FixtureRequest, factory_class: FactoryType) -> Any:
def subfactory_fixture(request: SubRequest, factory_class: FactoryType) -> Any:
"""SubFactory/RelatedFactory fixture implementation."""
fixture = inflection.underscore(factory_class._meta.model.__name__)
return request.getfixturevalue(fixture)
Expand All @@ -397,7 +398,7 @@ def get_caller_locals(depth: int = 2) -> dict[str, Any]:
class LazyFixture:
"""Lazy fixture."""

def __init__(self, fixture: Callable | str) -> None:
def __init__(self, fixture: FixtureFunction | str) -> None:
"""Lazy pytest fixture wrapper.
:param fixture: Fixture name or callable with dependencies.
Expand All @@ -409,7 +410,7 @@ def __init__(self, fixture: Callable | str) -> None:
else:
self.args = [self.fixture]

def evaluate(self, request: FixtureRequest) -> Any:
def evaluate(self, request: SubRequest) -> Any:
"""Evaluate the lazy fixture.
:param request: pytest request object.
Expand Down
18 changes: 9 additions & 9 deletions pytest_factoryboy/plugin.py
Expand Up @@ -9,7 +9,7 @@
if TYPE_CHECKING:
from typing import Any
from factory import Factory
from _pytest.fixtures import FixtureRequest
from _pytest.fixtures import FixtureRequest, SubRequest
from _pytest.config import PytestPluginManager
from _pytest.python import Metafunc
from _pytest.nodes import Item
Expand All @@ -29,7 +29,7 @@ def __init__(self) -> None:
self.deferred: list[list[DeferredFunction]] = []
self.results: dict[str, dict[str, Any]] = defaultdict(dict)
self.model_factories: dict[str, type[Factory]] = {}
self.in_progress: set = set()
self.in_progress: set[DeferredFunction] = set()

def defer(self, functions: list[DeferredFunction]) -> None:
"""Defer post-generation declaration execution until the end of the test setup.
Expand All @@ -39,7 +39,7 @@ def defer(self, functions: list[DeferredFunction]) -> None:
"""
self.deferred.append(functions)

def get_deps(self, request: FixtureRequest, fixture: str, deps: set[str] | None = None) -> set[str]:
def get_deps(self, request: SubRequest, fixture: str, deps: set[str] | None = None) -> set[str]:
request = request.getfixturevalue("request")

if deps is None:
Expand All @@ -54,15 +54,15 @@ def get_deps(self, request: FixtureRequest, fixture: str, deps: set[str] | None
deps.update(self.get_deps(request, argname, deps))
return deps

def get_current_deps(self, request: FixtureRequest) -> set[str]:
def get_current_deps(self, request: FixtureRequest | SubRequest) -> set[str]:
deps = set()
while hasattr(request, "_parent_request"):
if request.fixturename and request.fixturename not in getattr(request, "_fixturedefs", {}):
deps.add(request.fixturename)
request = request._parent_request
request = request._parent_request # type: ignore[union-attr]
return deps

def execute(self, request: FixtureRequest, function: DeferredFunction, deferred: list[DeferredFunction]) -> None:
def execute(self, request: SubRequest, function: DeferredFunction, deferred: list[DeferredFunction]) -> None:
"""Execute deferred function and store the result."""
if function in self.in_progress:
raise CycleDetected()
Expand All @@ -79,15 +79,15 @@ def execute(self, request: FixtureRequest, function: DeferredFunction, deferred:
deferred.remove(function)
self.in_progress.remove(function)

def after_postgeneration(self, request: FixtureRequest) -> None:
def after_postgeneration(self, request: SubRequest) -> None:
"""Call _after_postgeneration hooks."""
for model in list(self.results.keys()):
results = self.results.pop(model)
obj = request.getfixturevalue(model)
factory = self.model_factories[model]
factory._after_postgeneration(obj, create=True, results=results)

def evaluate(self, request: FixtureRequest) -> None:
def evaluate(self, request: SubRequest) -> None:
"""Finalize, run deferred post-generation actions, etc."""
while self.deferred:
try:
Expand All @@ -114,7 +114,7 @@ def pytest_runtest_call(item: Item) -> None:
"""Before the test item is called."""
# TODO: We should instead do an `if isinstance(item, Function)`.
try:
request = item._request
request = item._request # type: ignore[attr-defined]
except AttributeError:
# pytest-pep8 plugin passes Pep8Item here during tests.
return
Expand Down
13 changes: 10 additions & 3 deletions tests/test_postgen_dependencies.py
@@ -1,7 +1,7 @@
"""Test post-generation dependencies."""
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field

import factory
import pytest
Expand All @@ -19,6 +19,12 @@ class Foo:
value: int
expected: int

bar: Bar | None = None

# NOTE: following attributes are used internally only for assertions
_create: bool | None = None
_postgeneration_results: dict[str, Any] = field(default_factory=dict)


@dataclass
class Bar:
Expand Down Expand Up @@ -59,7 +65,7 @@ def set1(foo: Foo, create: bool, value: Any, **kwargs: Any) -> str:

@classmethod
def _after_postgeneration(cls, obj: Foo, create: bool, results: dict[str, Any] | None = None) -> None:
obj._postgeneration_results = results
obj._postgeneration_results = results or {}
obj._create = create


Expand Down Expand Up @@ -111,8 +117,9 @@ def test_after_postgeneration(foo: Foo):
assert len(foo._postgeneration_results) == 2


@dataclass
class Ordered:
value = None
value: str | None = None


@register
Expand Down

0 comments on commit 773f7b5

Please sign in to comment.