diff --git a/CHANGELOG.md b/CHANGELOG.md index f30ca6f9..e98be269 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added ### Changed +- Extend type hints in `model_bakery.recipe` module, make `Recipe` class generic [PR #292](https://github.com/model-bakers/model_bakery/pull/292) ### Removed diff --git a/model_bakery/_types.py b/model_bakery/_types.py new file mode 100644 index 00000000..fac40eb1 --- /dev/null +++ b/model_bakery/_types.py @@ -0,0 +1,6 @@ +from typing import TypeVar + +from django.db.models import Model + +M = TypeVar("M", bound=Model) +NewM = TypeVar("NewM", bound=Model) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 48b154a3..6046fa9a 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -8,7 +8,6 @@ List, Optional, Type, - TypeVar, Union, cast, overload, @@ -34,6 +33,7 @@ from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel from . import generators, random_gen +from ._types import M, NewM from .exceptions import ( AmbiguousModelName, CustomBakerNotFound, @@ -58,10 +58,6 @@ def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool: return quantity is not None and (not isinstance(quantity, int) or quantity < 1) -M = TypeVar("M", bound=Model) -NewM = TypeVar("NewM", bound=Model) - - @overload def make( _model: Union[str, Type[M]], diff --git a/model_bakery/recipe.py b/model_bakery/recipe.py index 0c979c41..93f7ee28 100644 --- a/model_bakery/recipe.py +++ b/model_bakery/recipe.py @@ -1,9 +1,21 @@ import itertools -from typing import Any, Dict, List, Type, Union, cast +from typing import ( + Any, + Dict, + Generic, + List, + Optional, + Type, + TypeVar, + Union, + cast, + overload, +) from django.db.models import Model from . import baker +from ._types import M from .exceptions import RecipeNotFound from .utils import ( # NoQA: Enable seq to be imported from recipes get_calling_module, @@ -13,14 +25,16 @@ finder = baker.ModelFinder() -class Recipe(object): - def __init__(self, _model: Union[str, Type[Model]], **attrs) -> None: +class Recipe(Generic[M]): + _T = TypeVar("_T", bound="Recipe[M]") + + def __init__(self, _model: Union[str, Type[M]], **attrs: Any) -> None: self.attr_mapping = attrs self._model = _model # _iterator_backups will hold values of the form (backup_iterator, usable_iterator). self._iterator_backups = {} # type: Dict[str, Any] - def _mapping(self, _using, new_attrs: Dict[str, Any]) -> Dict[str, Any]: + def _mapping(self, _using: str, new_attrs: Dict[str, Any]) -> Dict[str, Any]: _save_related = new_attrs.get("_save_related", True) _quantity = new_attrs.get("_quantity") if _quantity is None: @@ -66,23 +80,99 @@ def _mapping(self, _using, new_attrs: Dict[str, Any]) -> Dict[str, Any]: mapping.update(rel_fields_attrs) return mapping - def make(self, _using="", **attrs: Any) -> Union[Model, List[Model]]: - return baker.make(self._model, _using=_using, **self._mapping(_using, attrs)) - - def prepare(self, _using="", **attrs: Any) -> Union[Model, List[Model]]: - defaults = {"_save_related": False} + @overload + def make( + self, + _quantity: None = None, + make_m2m: bool = False, + _refresh_after_create: bool = False, + _create_files: bool = False, + _using: str = "", + _bulk_create: bool = False, + _save_kwargs: Optional[Dict[str, Any]] = None, + **attrs: Any, + ) -> M: + ... + + @overload + def make( + self, + _quantity: int, + make_m2m: bool = False, + _refresh_after_create: bool = False, + _create_files: bool = False, + _using: str = "", + _bulk_create: bool = False, + _save_kwargs: Optional[Dict[str, Any]] = None, + **attrs: Any, + ) -> List[M]: + ... + + def make( + self, + _quantity: Optional[int] = None, + make_m2m: bool = False, + _refresh_after_create: bool = False, + _create_files: bool = False, + _using: str = "", + _bulk_create: bool = False, + _save_kwargs: Optional[Dict[str, Any]] = None, + **attrs: Any, + ) -> Union[M, List[M]]: + defaults = { + "_quantity": _quantity, + "make_m2m": make_m2m, + "_refresh_after_create": _refresh_after_create, + "_create_files": _create_files, + "_bulk_create": _bulk_create, + "_save_kwargs": _save_kwargs, + } + defaults.update(attrs) + return baker.make(self._model, _using=_using, **self._mapping(_using, defaults)) + + @overload + def prepare( + self, + _quantity: None = None, + _save_related: bool = False, + _using: str = "", + **attrs: Any, + ) -> M: + ... + + @overload + def prepare( + self, + _quantity: int, + _save_related: bool = False, + _using: str = "", + **attrs: Any, + ) -> List[M]: + ... + + def prepare( + self, + _quantity: Optional[int] = None, + _save_related: bool = False, + _using: str = "", + **attrs: Any, + ) -> Union[M, List[M]]: + defaults = { + "_quantity": _quantity, + "_save_related": _save_related, + } defaults.update(attrs) return baker.prepare( self._model, _using=_using, **self._mapping(_using, defaults) ) - def extend(self, **attrs) -> "Recipe": + def extend(self: _T, **attrs: Any) -> _T: attr_mapping = self.attr_mapping.copy() attr_mapping.update(attrs) return type(self)(self._model, **attr_mapping) -def _load_recipe_from_calling_module(recipe: str) -> Recipe: +def _load_recipe_from_calling_module(recipe: str) -> Recipe[Model]: """Load `Recipe` from the string attribute given from the calling module. Args: @@ -94,15 +184,15 @@ def _load_recipe_from_calling_module(recipe: str) -> Recipe: """ recipe = getattr(get_calling_module(2), recipe) if recipe: - return cast(Recipe, recipe) + return cast(Recipe[Model], recipe) else: raise RecipeNotFound -class RecipeForeignKey(object): +class RecipeForeignKey(Generic[M]): """A `Recipe` to use for making ManyToOne and OneToOne related objects.""" - def __init__(self, recipe: Recipe, one_to_one: bool) -> None: + def __init__(self, recipe: Recipe[M], one_to_one: bool) -> None: if isinstance(recipe, Recipe): self.recipe = recipe self.one_to_one = one_to_one @@ -111,8 +201,8 @@ def __init__(self, recipe: Recipe, one_to_one: bool) -> None: def foreign_key( - recipe: Union[Recipe, str], one_to_one: bool = False -) -> RecipeForeignKey: + recipe: Union[Recipe[M], str], one_to_one: bool = False +) -> RecipeForeignKey[M]: """Return a `RecipeForeignKey`. Return the callable, so that the associated `_model` will not be created @@ -130,12 +220,12 @@ def foreign_key( # Probably not in another module, so load it from calling module recipe = _load_recipe_from_calling_module(cast(str, recipe)) - return RecipeForeignKey(cast(Recipe, recipe), one_to_one) + return RecipeForeignKey(cast(Recipe[M], recipe), one_to_one) -class related(object): # FIXME - def __init__(self, *args) -> None: - self.related = [] # type: List[Recipe] +class related(Generic[M]): # FIXME + def __init__(self, *args: Union[str, Recipe[M]]) -> None: + self.related = [] # type: List[Recipe[M]] for recipe in args: if isinstance(recipe, Recipe): self.related.append(recipe) @@ -148,6 +238,6 @@ def __init__(self, *args) -> None: else: raise TypeError("Not a recipe") - def make(self) -> List[Union[Model, List[Model]]]: + def make(self) -> List[Union[M, List[M]]]: """Persist objects to m2m relation.""" return [m.make() for m in self.related] diff --git a/tests/test_baker.py b/tests/test_baker.py index b4ac5176..231dfea8 100644 --- a/tests/test_baker.py +++ b/tests/test_baker.py @@ -125,7 +125,7 @@ def test_abstract_model_subclass_creation(self): instance = baker.make(models.SubclassOfAbstract) assert isinstance(instance, models.SubclassOfAbstract) assert isinstance(instance, models.AbstractModel) - assert isinstance(instance.name, type(u"")) + assert isinstance(instance.name, type("")) assert len(instance.name) == 30 assert isinstance(instance.height, int) diff --git a/tests/test_filling_fields.py b/tests/test_filling_fields.py index 11589556..83a760e2 100644 --- a/tests/test_filling_fields.py +++ b/tests/test_filling_fields.py @@ -277,8 +277,10 @@ def test_filling_content_type_field(self): def test_iteratively_filling_generic_foreign_key_field(self): """ - Ensures private_fields are included in Baker.get_fields(), otherwise - calling next() when a GFK is in iterator_attrs would be bypassed. + Ensures private_fields are included in ``Baker.get_fields()``. + + Otherwise, calling ``next()`` when a GFK is in ``iterator_attrs`` + would be bypassed. """ objects = baker.make(models.Profile, _quantity=2) dummies = baker.make( diff --git a/tests/test_recipes.py b/tests/test_recipes.py index 74c8037a..45d8780a 100644 --- a/tests/test_recipes.py +++ b/tests/test_recipes.py @@ -48,8 +48,7 @@ def test_import_seq_from_recipe(): def test_import_recipes(): - """Test imports works both for full import paths and for - `app_name.recipe_name` strings.""" + """Test imports works both for full import paths and for `app_name.recipe_name` strings.""" assert baker.prepare_recipe("generic.dog"), baker.prepare_recipe( "tests.generic.dog" )