Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better type hints for model_bakery.recipe #292

Merged
merged 5 commits into from Mar 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Added

### Changed
- Extend type hints in `model_bakery.recipe` module, make `Recipe` class generic [PR #292](https://github.com/model-bakers/model_bakery/pull/292)

### Removed

Expand Down
6 changes: 6 additions & 0 deletions model_bakery/_types.py
@@ -0,0 +1,6 @@
from typing import TypeVar

from django.db.models import Model

M = TypeVar("M", bound=Model)
NewM = TypeVar("NewM", bound=Model)
6 changes: 1 addition & 5 deletions model_bakery/baker.py
Expand Up @@ -8,7 +8,6 @@
List,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
Expand All @@ -34,6 +33,7 @@
from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel

from . import generators, random_gen
from ._types import M, NewM
from .exceptions import (
AmbiguousModelName,
CustomBakerNotFound,
Expand All @@ -58,10 +58,6 @@ def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool:
return quantity is not None and (not isinstance(quantity, int) or quantity < 1)


M = TypeVar("M", bound=Model)
NewM = TypeVar("NewM", bound=Model)


@overload
def make(
_model: Union[str, Type[M]],
Expand Down
132 changes: 111 additions & 21 deletions model_bakery/recipe.py
@@ -1,9 +1,21 @@
import itertools
from typing import Any, Dict, List, Type, Union, cast
from typing import (
Any,
Dict,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)

from django.db.models import Model

from . import baker
from ._types import M
from .exceptions import RecipeNotFound
from .utils import ( # NoQA: Enable seq to be imported from recipes
get_calling_module,
Expand All @@ -13,14 +25,16 @@
finder = baker.ModelFinder()


class Recipe(object):
def __init__(self, _model: Union[str, Type[Model]], **attrs) -> None:
class Recipe(Generic[M]):
_T = TypeVar("_T", bound="Recipe[M]")

def __init__(self, _model: Union[str, Type[M]], **attrs: Any) -> None:
self.attr_mapping = attrs
self._model = _model
# _iterator_backups will hold values of the form (backup_iterator, usable_iterator).
self._iterator_backups = {} # type: Dict[str, Any]

def _mapping(self, _using, new_attrs: Dict[str, Any]) -> Dict[str, Any]:
def _mapping(self, _using: str, new_attrs: Dict[str, Any]) -> Dict[str, Any]:
_save_related = new_attrs.get("_save_related", True)
_quantity = new_attrs.get("_quantity")
if _quantity is None:
Expand Down Expand Up @@ -66,23 +80,99 @@ def _mapping(self, _using, new_attrs: Dict[str, Any]) -> Dict[str, Any]:
mapping.update(rel_fields_attrs)
return mapping

def make(self, _using="", **attrs: Any) -> Union[Model, List[Model]]:
return baker.make(self._model, _using=_using, **self._mapping(_using, attrs))

def prepare(self, _using="", **attrs: Any) -> Union[Model, List[Model]]:
defaults = {"_save_related": False}
@overload
def make(
self,
_quantity: None = None,
make_m2m: bool = False,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
_save_kwargs: Optional[Dict[str, Any]] = None,
**attrs: Any,
) -> M:
...

@overload
def make(
self,
_quantity: int,
make_m2m: bool = False,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
_save_kwargs: Optional[Dict[str, Any]] = None,
**attrs: Any,
) -> List[M]:
...

def make(
self,
_quantity: Optional[int] = None,
make_m2m: bool = False,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
_save_kwargs: Optional[Dict[str, Any]] = None,
**attrs: Any,
) -> Union[M, List[M]]:
defaults = {
"_quantity": _quantity,
"make_m2m": make_m2m,
"_refresh_after_create": _refresh_after_create,
"_create_files": _create_files,
"_bulk_create": _bulk_create,
"_save_kwargs": _save_kwargs,
}
defaults.update(attrs)
return baker.make(self._model, _using=_using, **self._mapping(_using, defaults))

@overload
def prepare(
self,
_quantity: None = None,
_save_related: bool = False,
_using: str = "",
**attrs: Any,
) -> M:
...

@overload
def prepare(
self,
_quantity: int,
_save_related: bool = False,
_using: str = "",
**attrs: Any,
) -> List[M]:
...

def prepare(
self,
_quantity: Optional[int] = None,
_save_related: bool = False,
_using: str = "",
**attrs: Any,
) -> Union[M, List[M]]:
defaults = {
"_quantity": _quantity,
"_save_related": _save_related,
}
defaults.update(attrs)
return baker.prepare(
self._model, _using=_using, **self._mapping(_using, defaults)
)

def extend(self, **attrs) -> "Recipe":
def extend(self: _T, **attrs: Any) -> _T:
attr_mapping = self.attr_mapping.copy()
attr_mapping.update(attrs)
return type(self)(self._model, **attr_mapping)


def _load_recipe_from_calling_module(recipe: str) -> Recipe:
def _load_recipe_from_calling_module(recipe: str) -> Recipe[Model]:
"""Load `Recipe` from the string attribute given from the calling module.
Args:
Expand All @@ -94,15 +184,15 @@ def _load_recipe_from_calling_module(recipe: str) -> Recipe:
"""
recipe = getattr(get_calling_module(2), recipe)
if recipe:
return cast(Recipe, recipe)
return cast(Recipe[Model], recipe)
else:
raise RecipeNotFound


class RecipeForeignKey(object):
class RecipeForeignKey(Generic[M]):
"""A `Recipe` to use for making ManyToOne and OneToOne related objects."""

def __init__(self, recipe: Recipe, one_to_one: bool) -> None:
def __init__(self, recipe: Recipe[M], one_to_one: bool) -> None:
if isinstance(recipe, Recipe):
self.recipe = recipe
self.one_to_one = one_to_one
Expand All @@ -111,8 +201,8 @@ def __init__(self, recipe: Recipe, one_to_one: bool) -> None:


def foreign_key(
recipe: Union[Recipe, str], one_to_one: bool = False
) -> RecipeForeignKey:
recipe: Union[Recipe[M], str], one_to_one: bool = False
) -> RecipeForeignKey[M]:
"""Return a `RecipeForeignKey`.
Return the callable, so that the associated `_model` will not be created
Expand All @@ -130,12 +220,12 @@ def foreign_key(
# Probably not in another module, so load it from calling module
recipe = _load_recipe_from_calling_module(cast(str, recipe))

return RecipeForeignKey(cast(Recipe, recipe), one_to_one)
return RecipeForeignKey(cast(Recipe[M], recipe), one_to_one)


class related(object): # FIXME
def __init__(self, *args) -> None:
self.related = [] # type: List[Recipe]
class related(Generic[M]): # FIXME
def __init__(self, *args: Union[str, Recipe[M]]) -> None:
self.related = [] # type: List[Recipe[M]]
for recipe in args:
if isinstance(recipe, Recipe):
self.related.append(recipe)
Expand All @@ -148,6 +238,6 @@ def __init__(self, *args) -> None:
else:
raise TypeError("Not a recipe")

def make(self) -> List[Union[Model, List[Model]]]:
def make(self) -> List[Union[M, List[M]]]:
"""Persist objects to m2m relation."""
return [m.make() for m in self.related]
2 changes: 1 addition & 1 deletion tests/test_baker.py
Expand Up @@ -125,7 +125,7 @@ def test_abstract_model_subclass_creation(self):
instance = baker.make(models.SubclassOfAbstract)
assert isinstance(instance, models.SubclassOfAbstract)
assert isinstance(instance, models.AbstractModel)
assert isinstance(instance.name, type(u""))
assert isinstance(instance.name, type(""))
assert len(instance.name) == 30
assert isinstance(instance.height, int)

Expand Down
6 changes: 4 additions & 2 deletions tests/test_filling_fields.py
Expand Up @@ -277,8 +277,10 @@ def test_filling_content_type_field(self):

def test_iteratively_filling_generic_foreign_key_field(self):
"""
Ensures private_fields are included in Baker.get_fields(), otherwise
calling next() when a GFK is in iterator_attrs would be bypassed.
Ensures private_fields are included in ``Baker.get_fields()``.
Otherwise, calling ``next()`` when a GFK is in ``iterator_attrs``
would be bypassed.
"""
objects = baker.make(models.Profile, _quantity=2)
dummies = baker.make(
Expand Down
3 changes: 1 addition & 2 deletions tests/test_recipes.py
Expand Up @@ -48,8 +48,7 @@ def test_import_seq_from_recipe():


def test_import_recipes():
"""Test imports works both for full import paths and for
`app_name.recipe_name` strings."""
"""Test imports works both for full import paths and for `app_name.recipe_name` strings."""
assert baker.prepare_recipe("generic.dog"), baker.prepare_recipe(
"tests.generic.dog"
)
Expand Down