Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better type hints for model_bakery.recipe #292

Merged
merged 5 commits into from Mar 31, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions model_bakery/_types.py
@@ -0,0 +1,6 @@
from typing import TypeVar
from django.db.models import Model


M = TypeVar("M", bound=Model)
NewM = TypeVar("NewM", bound=Model)
5 changes: 1 addition & 4 deletions model_bakery/baker.py
Expand Up @@ -44,6 +44,7 @@
)
from .utils import seq # NoQA: enable seq to be imported from baker
from .utils import import_from_str
from ._types import M, NewM

recipes = None

Expand All @@ -58,10 +59,6 @@ def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool:
return quantity is not None and (not isinstance(quantity, int) or quantity < 1)


M = TypeVar("M", bound=Model)
NewM = TypeVar("NewM", bound=Model)


@overload
def make(
_model: Union[str, Type[M]],
Expand Down
133 changes: 112 additions & 21 deletions model_bakery/recipe.py
@@ -1,5 +1,16 @@
import itertools
from typing import Any, Dict, List, Type, Union, cast
from typing import (
Any,
Dict,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)

from django.db.models import Model

Expand All @@ -9,18 +20,22 @@
get_calling_module,
seq,
)
from ._types import M


finder = baker.ModelFinder()


class Recipe(object):
def __init__(self, _model: Union[str, Type[Model]], **attrs) -> None:
class Recipe(Generic[M]):
_T = TypeVar("_T", bound="Recipe[M]")

def __init__(self, _model: Union[str, Type[M]], **attrs: Any) -> None:
self.attr_mapping = attrs
self._model = _model
# _iterator_backups will hold values of the form (backup_iterator, usable_iterator).
self._iterator_backups = {} # type: Dict[str, Any]

def _mapping(self, _using, new_attrs: Dict[str, Any]) -> Dict[str, Any]:
def _mapping(self, _using: str, new_attrs: Dict[str, Any]) -> Dict[str, Any]:
_save_related = new_attrs.get("_save_related", True)
_quantity = new_attrs.get("_quantity")
if _quantity is None:
Expand Down Expand Up @@ -66,23 +81,99 @@ def _mapping(self, _using, new_attrs: Dict[str, Any]) -> Dict[str, Any]:
mapping.update(rel_fields_attrs)
return mapping

def make(self, _using="", **attrs: Any) -> Union[Model, List[Model]]:
return baker.make(self._model, _using=_using, **self._mapping(_using, attrs))

def prepare(self, _using="", **attrs: Any) -> Union[Model, List[Model]]:
defaults = {"_save_related": False}
@overload
def make(
self,
_quantity: None = None,
make_m2m: bool = False,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
_save_kwargs: Optional[Dict[str, Any]] = None,
**attrs: Any,
) -> M:
...

@overload
def make(
self,
_quantity: int,
make_m2m: bool = False,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
_save_kwargs: Optional[Dict[str, Any]] = None,
**attrs: Any,
) -> List[M]:
...

def make(
self,
_quantity: Optional[int] = None,
make_m2m: bool = False,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
_save_kwargs: Optional[Dict[str, Any]] = None,
**attrs: Any,
) -> Union[M, List[M]]:
defaults = {
"_quantity": _quantity,
"make_m2m": make_m2m,
"_refresh_after_create": _refresh_after_create,
"_create_files": _create_files,
"_bulk_create": _bulk_create,
"_save_kwargs": _save_kwargs,
}
defaults.update(attrs)
return baker.make(self._model, _using=_using, **self._mapping(_using, defaults))

@overload
def prepare(
self,
_quantity: None = None,
_save_related: bool = False,
_using: str = "",
**attrs: Any,
) -> M:
...

@overload
def prepare(
self,
_quantity: int,
_save_related: bool = False,
_using: str = "",
**attrs: Any,
) -> List[M]:
...

def prepare(
self,
_quantity: Optional[int] = None,
_save_related: bool = False,
_using: str = "",
**attrs: Any,
) -> Union[M, List[M]]:
defaults = {
"_quantity": _quantity,
"_save_related": _save_related,
}
defaults.update(attrs)
return baker.prepare(
self._model, _using=_using, **self._mapping(_using, defaults)
)

def extend(self, **attrs) -> "Recipe":
def extend(self: _T, **attrs: Any) -> _T:
attr_mapping = self.attr_mapping.copy()
attr_mapping.update(attrs)
return type(self)(self._model, **attr_mapping)


def _load_recipe_from_calling_module(recipe: str) -> Recipe:
def _load_recipe_from_calling_module(recipe: str) -> Recipe[Model]:
"""Load `Recipe` from the string attribute given from the calling module.

Args:
Expand All @@ -94,15 +185,15 @@ def _load_recipe_from_calling_module(recipe: str) -> Recipe:
"""
recipe = getattr(get_calling_module(2), recipe)
if recipe:
return cast(Recipe, recipe)
return cast(Recipe[Model], recipe)
else:
raise RecipeNotFound


class RecipeForeignKey(object):
class RecipeForeignKey(Generic[M]):
"""A `Recipe` to use for making ManyToOne and OneToOne related objects."""

def __init__(self, recipe: Recipe, one_to_one: bool) -> None:
def __init__(self, recipe: Recipe[M], one_to_one: bool) -> None:
if isinstance(recipe, Recipe):
self.recipe = recipe
self.one_to_one = one_to_one
Expand All @@ -111,8 +202,8 @@ def __init__(self, recipe: Recipe, one_to_one: bool) -> None:


def foreign_key(
recipe: Union[Recipe, str], one_to_one: bool = False
) -> RecipeForeignKey:
recipe: Union[Recipe[M], str], one_to_one: bool = False
) -> RecipeForeignKey[M]:
"""Return a `RecipeForeignKey`.

Return the callable, so that the associated `_model` will not be created
Expand All @@ -130,12 +221,12 @@ def foreign_key(
# Probably not in another module, so load it from calling module
recipe = _load_recipe_from_calling_module(cast(str, recipe))

return RecipeForeignKey(cast(Recipe, recipe), one_to_one)
return RecipeForeignKey(cast(Recipe[M], recipe), one_to_one)


class related(object): # FIXME
def __init__(self, *args) -> None:
self.related = [] # type: List[Recipe]
class related(Generic[M]): # FIXME
def __init__(self, *args: Union[str, Recipe[M]]) -> None:
self.related = [] # type: List[Recipe[M]]
for recipe in args:
if isinstance(recipe, Recipe):
self.related.append(recipe)
Expand All @@ -148,6 +239,6 @@ def __init__(self, *args) -> None:
else:
raise TypeError("Not a recipe")

def make(self) -> List[Union[Model, List[Model]]]:
def make(self) -> List[Union[M, List[M]]]:
"""Persist objects to m2m relation."""
return [m.make() for m in self.related]
2 changes: 1 addition & 1 deletion tests/test_baker.py
Expand Up @@ -125,7 +125,7 @@ def test_abstract_model_subclass_creation(self):
instance = baker.make(models.SubclassOfAbstract)
assert isinstance(instance, models.SubclassOfAbstract)
assert isinstance(instance, models.AbstractModel)
assert isinstance(instance.name, type(u""))
assert isinstance(instance.name, type(""))
assert len(instance.name) == 30
assert isinstance(instance.height, int)

Expand Down