Skip to content

Commit

Permalink
Better type hints for model_bakery.recipe (#292)
Browse files Browse the repository at this point in the history
* drop py2 unicode

Signed-off-by: Oleg Hoefling <oleg.hoefling@ionos.com>

* enhance model_bakery.recipe type hints

Signed-off-by: Oleg Hoefling <oleg.hoefling@ionos.com>

* resort imports

Signed-off-by: oleg.hoefling <oleg.hoefling@gmail.com>

* add changelog entry

Signed-off-by: oleg.hoefling <oleg.hoefling@gmail.com>

* fix pre-commit issues

Signed-off-by: oleg.hoefling <oleg.hoefling@gmail.com>
  • Loading branch information
hoefling committed Mar 31, 2022
1 parent 73e0dca commit ab53845
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 31 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Added

### Changed
- Extend type hints in `model_bakery.recipe` module, make `Recipe` class generic [PR #292](https://github.com/model-bakers/model_bakery/pull/292)

### Removed

Expand Down
6 changes: 6 additions & 0 deletions model_bakery/_types.py
@@ -0,0 +1,6 @@
from typing import TypeVar

from django.db.models import Model

M = TypeVar("M", bound=Model)
NewM = TypeVar("NewM", bound=Model)
6 changes: 1 addition & 5 deletions model_bakery/baker.py
Expand Up @@ -8,7 +8,6 @@
List,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
Expand All @@ -34,6 +33,7 @@
from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel

from . import generators, random_gen
from ._types import M, NewM
from .exceptions import (
AmbiguousModelName,
CustomBakerNotFound,
Expand All @@ -58,10 +58,6 @@ def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool:
return quantity is not None and (not isinstance(quantity, int) or quantity < 1)


M = TypeVar("M", bound=Model)
NewM = TypeVar("NewM", bound=Model)


@overload
def make(
_model: Union[str, Type[M]],
Expand Down
132 changes: 111 additions & 21 deletions model_bakery/recipe.py
@@ -1,9 +1,21 @@
import itertools
from typing import Any, Dict, List, Type, Union, cast
from typing import (
Any,
Dict,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)

from django.db.models import Model

from . import baker
from ._types import M
from .exceptions import RecipeNotFound
from .utils import ( # NoQA: Enable seq to be imported from recipes
get_calling_module,
Expand All @@ -13,14 +25,16 @@
finder = baker.ModelFinder()


class Recipe(object):
def __init__(self, _model: Union[str, Type[Model]], **attrs) -> None:
class Recipe(Generic[M]):
_T = TypeVar("_T", bound="Recipe[M]")

def __init__(self, _model: Union[str, Type[M]], **attrs: Any) -> None:
self.attr_mapping = attrs
self._model = _model
# _iterator_backups will hold values of the form (backup_iterator, usable_iterator).
self._iterator_backups = {} # type: Dict[str, Any]

def _mapping(self, _using, new_attrs: Dict[str, Any]) -> Dict[str, Any]:
def _mapping(self, _using: str, new_attrs: Dict[str, Any]) -> Dict[str, Any]:
_save_related = new_attrs.get("_save_related", True)
_quantity = new_attrs.get("_quantity")
if _quantity is None:
Expand Down Expand Up @@ -66,23 +80,99 @@ def _mapping(self, _using, new_attrs: Dict[str, Any]) -> Dict[str, Any]:
mapping.update(rel_fields_attrs)
return mapping

def make(self, _using="", **attrs: Any) -> Union[Model, List[Model]]:
return baker.make(self._model, _using=_using, **self._mapping(_using, attrs))

def prepare(self, _using="", **attrs: Any) -> Union[Model, List[Model]]:
defaults = {"_save_related": False}
@overload
def make(
self,
_quantity: None = None,
make_m2m: bool = False,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
_save_kwargs: Optional[Dict[str, Any]] = None,
**attrs: Any,
) -> M:
...

@overload
def make(
self,
_quantity: int,
make_m2m: bool = False,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
_save_kwargs: Optional[Dict[str, Any]] = None,
**attrs: Any,
) -> List[M]:
...

def make(
self,
_quantity: Optional[int] = None,
make_m2m: bool = False,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
_save_kwargs: Optional[Dict[str, Any]] = None,
**attrs: Any,
) -> Union[M, List[M]]:
defaults = {
"_quantity": _quantity,
"make_m2m": make_m2m,
"_refresh_after_create": _refresh_after_create,
"_create_files": _create_files,
"_bulk_create": _bulk_create,
"_save_kwargs": _save_kwargs,
}
defaults.update(attrs)
return baker.make(self._model, _using=_using, **self._mapping(_using, defaults))

@overload
def prepare(
self,
_quantity: None = None,
_save_related: bool = False,
_using: str = "",
**attrs: Any,
) -> M:
...

@overload
def prepare(
self,
_quantity: int,
_save_related: bool = False,
_using: str = "",
**attrs: Any,
) -> List[M]:
...

def prepare(
self,
_quantity: Optional[int] = None,
_save_related: bool = False,
_using: str = "",
**attrs: Any,
) -> Union[M, List[M]]:
defaults = {
"_quantity": _quantity,
"_save_related": _save_related,
}
defaults.update(attrs)
return baker.prepare(
self._model, _using=_using, **self._mapping(_using, defaults)
)

def extend(self, **attrs) -> "Recipe":
def extend(self: _T, **attrs: Any) -> _T:
attr_mapping = self.attr_mapping.copy()
attr_mapping.update(attrs)
return type(self)(self._model, **attr_mapping)


def _load_recipe_from_calling_module(recipe: str) -> Recipe:
def _load_recipe_from_calling_module(recipe: str) -> Recipe[Model]:
"""Load `Recipe` from the string attribute given from the calling module.
Args:
Expand All @@ -94,15 +184,15 @@ def _load_recipe_from_calling_module(recipe: str) -> Recipe:
"""
recipe = getattr(get_calling_module(2), recipe)
if recipe:
return cast(Recipe, recipe)
return cast(Recipe[Model], recipe)
else:
raise RecipeNotFound


class RecipeForeignKey(object):
class RecipeForeignKey(Generic[M]):
"""A `Recipe` to use for making ManyToOne and OneToOne related objects."""

def __init__(self, recipe: Recipe, one_to_one: bool) -> None:
def __init__(self, recipe: Recipe[M], one_to_one: bool) -> None:
if isinstance(recipe, Recipe):
self.recipe = recipe
self.one_to_one = one_to_one
Expand All @@ -111,8 +201,8 @@ def __init__(self, recipe: Recipe, one_to_one: bool) -> None:


def foreign_key(
recipe: Union[Recipe, str], one_to_one: bool = False
) -> RecipeForeignKey:
recipe: Union[Recipe[M], str], one_to_one: bool = False
) -> RecipeForeignKey[M]:
"""Return a `RecipeForeignKey`.
Return the callable, so that the associated `_model` will not be created
Expand All @@ -130,12 +220,12 @@ def foreign_key(
# Probably not in another module, so load it from calling module
recipe = _load_recipe_from_calling_module(cast(str, recipe))

return RecipeForeignKey(cast(Recipe, recipe), one_to_one)
return RecipeForeignKey(cast(Recipe[M], recipe), one_to_one)


class related(object): # FIXME
def __init__(self, *args) -> None:
self.related = [] # type: List[Recipe]
class related(Generic[M]): # FIXME
def __init__(self, *args: Union[str, Recipe[M]]) -> None:
self.related = [] # type: List[Recipe[M]]
for recipe in args:
if isinstance(recipe, Recipe):
self.related.append(recipe)
Expand All @@ -148,6 +238,6 @@ def __init__(self, *args) -> None:
else:
raise TypeError("Not a recipe")

def make(self) -> List[Union[Model, List[Model]]]:
def make(self) -> List[Union[M, List[M]]]:
"""Persist objects to m2m relation."""
return [m.make() for m in self.related]
2 changes: 1 addition & 1 deletion tests/test_baker.py
Expand Up @@ -125,7 +125,7 @@ def test_abstract_model_subclass_creation(self):
instance = baker.make(models.SubclassOfAbstract)
assert isinstance(instance, models.SubclassOfAbstract)
assert isinstance(instance, models.AbstractModel)
assert isinstance(instance.name, type(u""))
assert isinstance(instance.name, type(""))
assert len(instance.name) == 30
assert isinstance(instance.height, int)

Expand Down
6 changes: 4 additions & 2 deletions tests/test_filling_fields.py
Expand Up @@ -277,8 +277,10 @@ def test_filling_content_type_field(self):

def test_iteratively_filling_generic_foreign_key_field(self):
"""
Ensures private_fields are included in Baker.get_fields(), otherwise
calling next() when a GFK is in iterator_attrs would be bypassed.
Ensures private_fields are included in ``Baker.get_fields()``.
Otherwise, calling ``next()`` when a GFK is in ``iterator_attrs``
would be bypassed.
"""
objects = baker.make(models.Profile, _quantity=2)
dummies = baker.make(
Expand Down
3 changes: 1 addition & 2 deletions tests/test_recipes.py
Expand Up @@ -48,8 +48,7 @@ def test_import_seq_from_recipe():


def test_import_recipes():
"""Test imports works both for full import paths and for
`app_name.recipe_name` strings."""
"""Test imports works both for full import paths and for `app_name.recipe_name` strings."""
assert baker.prepare_recipe("generic.dog"), baker.prepare_recipe(
"tests.generic.dog"
)
Expand Down

0 comments on commit ab53845

Please sign in to comment.