From 9d948101b7a9591b4ffaad52c47838638dc347ab Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 2 Nov 2021 15:31:59 +1300 Subject: [PATCH 1/5] Improve type hinting `make`/`prepare` now return the correct type depending on `_quantity` --- model_bakery/baker.py | 141 ++++++++++++++++++++++++++++++------------ 1 file changed, 100 insertions(+), 41 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index b0f7cef8..bdf37189 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -1,8 +1,11 @@ +from __future__ import annotations + from os.path import dirname, join from typing import ( Any, Callable, Dict, + Generic, Iterator, List, Optional, @@ -10,6 +13,7 @@ TypeVar, Union, cast, + overload, ) from django.apps import apps @@ -25,7 +29,6 @@ Model, OneToOneField, ) -from django.db.models.base import ModelBase from django.db.models.fields.proxy import OrderWrt from django.db.models.fields.related import ( ReverseManyToOneDescriptor as ForeignRelatedObjectsDescriptor, @@ -57,11 +60,42 @@ def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool: return quantity is not None and (not isinstance(quantity, int) or quantity < 1) -SpecificModelType = TypeVar("SpecificModelType", bound=Model) +_M = TypeVar("_M", bound=Model) +_NewM = TypeVar("_NewM", bound=Model) + + +@overload +def make( + _model: Union[str, Type[_M]], + _quantity: None = None, + make_m2m: bool = False, + _save_kwargs: Optional[Dict] = None, + _refresh_after_create: bool = False, + _create_files: bool = False, + _using: str = "", + _bulk_create: bool = False, + **attrs: Any, +) -> _M: + ... + + +@overload +def make( + _model: Union[str, Type[_M]], + _quantity: int, + make_m2m: bool = False, + _save_kwargs: Optional[Dict] = None, + _refresh_after_create: bool = False, + _create_files: bool = False, + _using: str = "", + _bulk_create: bool = False, + **attrs: Any, +) -> List[_M]: + ... def make( - _model: Union[str, Type[SpecificModelType]], + _model, _quantity: Optional[int] = None, make_m2m: bool = False, _save_kwargs: Optional[Dict] = None, @@ -69,8 +103,8 @@ def make( _create_files: bool = False, _using: str = "", _bulk_create: bool = False, - **attrs: Any -) -> Union[SpecificModelType, List[SpecificModelType]]: + **attrs: Any, +): """Create a persisted instance from a given model its associated models. It fill the fields with random values or you can specify which @@ -100,13 +134,35 @@ def make( ) +@overload +def prepare( + _model: Union[str, Type[_M]], + _quantity: None = None, + _save_related: bool = False, + _using: str = "", + **attrs, +) -> _M: + ... + + +@overload def prepare( - _model: Union[str, Type[SpecificModelType]], - _quantity=None, - _save_related=False, - _using="", - **attrs -) -> Union[SpecificModelType, List[SpecificModelType]]: + _model: Union[str, Type[_M]], + _quantity: int, + _save_related: bool = False, + _using: str = "", + **attrs, +) -> List[_M]: + ... + + +def prepare( + _model: Union[str, Type[_M]], + _quantity: Optional[int] = None, + _save_related: bool = False, + _using: str = "", + **attrs, +): """Create but do not persist an instance from a given model. It fill the fields with random values or you can specify which @@ -128,7 +184,8 @@ def prepare( def _recipe(name: str) -> Any: app_name, recipe_name = name.rsplit(".", 1) try: - pkg = apps.get_app_config(app_name).module.__package__ + module = apps.get_app_config(app_name).module + pkg = module.__package__ if module else app_name except LookupError: pkg = app_name return import_from_str(".".join((pkg, "baker_recipes", recipe_name))) @@ -148,13 +205,13 @@ def prepare_recipe( ) -class ModelFinder(object): +class ModelFinder(Generic[_M]): """Encapsulates all the logic for finding a model to Baker.""" - _unique_models = None # type: Optional[Dict[str, Type[Model]]] - _ambiguous_models = None # type: Optional[List[str]] + _unique_models: Optional[Dict[str, Type[Model]]] = None + _ambiguous_models: Optional[List[str]] = None - def get_model(self, name: str) -> Type[Model]: + def get_model(self, name: str) -> Type[_M]: """Get a model. Args: @@ -167,7 +224,7 @@ def get_model(self, name: str) -> Type[Model]: try: if "." in name: app_label, model_name = name.split(".") - model = apps.get_model(app_label, model_name) + model = cast(Type[_M], apps.get_model(app_label, model_name)) else: model = self.get_model_by_name(name) except LookupError: @@ -178,7 +235,7 @@ def get_model(self, name: str) -> Type[Model]: return model - def get_model_by_name(self, name: str) -> Optional[Type[Model]]: + def get_model_by_name(self, name: str) -> Optional[Type[_M]]: """Get a model by name. If a model with that name exists in more than one app, raises @@ -254,46 +311,48 @@ def _custom_baker_class() -> Optional[Type]: ) -class Baker(object): - attr_mapping = {} # type: Dict[str, Any] - type_mapping = {} # type: Dict +class Baker(Generic[_M]): + attr_mapping: Dict[str, Any] = {} + type_mapping: Dict = {} # Note: we're using one finder for all Baker instances to avoid # rebuilding the model cache for every make_* or prepare_* call. - finder = ModelFinder() + finder: ModelFinder[_M] = ModelFinder() @classmethod def create( cls, - _model: Union[str, Type[ModelBase]], + _model: Union[str, Type[_NewM]], make_m2m: bool = False, create_files: bool = False, _using: str = "", - ) -> "Baker": + ) -> Baker[_NewM]: """Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting.""" baker_class = _custom_baker_class() or cls - return baker_class(_model, make_m2m, create_files, _using=_using) + return cast(Type[Baker[_NewM]], baker_class)( + _model, make_m2m, create_files, _using=_using + ) def __init__( self, - _model: Union[str, Type[ModelBase]], + _model: Union[str, Type[_M]], make_m2m: bool = False, create_files: bool = False, _using: str = "", ) -> None: self.make_m2m = make_m2m self.create_files = create_files - self.m2m_dict = {} # type: Dict[str, List] - self.iterator_attrs = {} # type: Dict[str, Iterator] - self.model_attrs = {} # type: Dict[str, Any] - self.rel_attrs = {} # type: Dict[str, Any] - self.rel_fields = [] # type: List[str] + self.m2m_dict: Dict[str, List] = {} + self.iterator_attrs: Dict[str, Iterator] = {} + self.model_attrs: Dict[str, Any] = {} + self.rel_attrs: Dict[str, Any] = {} + self.rel_fields: List[str] = [] self._using = _using if isinstance(_model, str): self.model = self.finder.get_model(_model) else: - self.model = _model + self.model = cast(Type[_M], _model) self.init_type_mapping() @@ -310,7 +369,7 @@ def make( _save_kwargs: Optional[Dict[str, Any]] = None, _refresh_after_create: bool = False, _from_manager=None, - **attrs: Any + **attrs: Any, ): """Create and persist an instance of the model associated with Baker instance.""" params = { @@ -323,7 +382,7 @@ def make( params.update(attrs) return self._make(**params) - def prepare(self, _save_related=False, **attrs: Any) -> Model: + def prepare(self, _save_related=False, **attrs: Any) -> _M: """Create, but do not persist, an instance of the associated model.""" return self._make(commit=False, commit_related=_save_related, **attrs) @@ -342,8 +401,8 @@ def _make( _save_kwargs=None, _refresh_after_create=False, _from_manager=None, - **attrs: Any - ) -> Model: + **attrs: Any, + ) -> _M: _save_kwargs = _save_kwargs or {} if self._using: _save_kwargs["using"] = self._using @@ -415,14 +474,14 @@ def m2m_value(self, field: ManyToManyField) -> List[Any]: def instance( self, attrs: Dict[str, Any], _commit, _save_kwargs, _from_manager - ) -> Model: + ) -> _M: one_to_many_keys = {} for k in tuple(attrs.keys()): field = getattr(self.model, k, None) if isinstance(field, ForeignRelatedObjectsDescriptor): one_to_many_keys[k] = attrs.pop(k) - instance = self.model(**attrs) # type: Model + instance = self.model(**attrs) # m2m only works for persisted instances if _commit: instance.save(**_save_kwargs) @@ -435,7 +494,7 @@ def instance( # within its get_queryset() method (e.g. annotations) # is run. manager = getattr(self.model, _from_manager) - instance = manager.get(pk=instance.pk) + instance: _M = manager.get(pk=instance.pk) return instance @@ -443,7 +502,7 @@ def create_by_related_name( self, instance: Model, related: Union[ManyToOneRel, OneToOneRel] ) -> None: rel_name = related.get_accessor_name() - if rel_name not in self.rel_fields: + if not rel_name or rel_name not in self.rel_fields: return kwargs = filter_rel_attrs(rel_name, **self.rel_attrs) @@ -657,7 +716,7 @@ def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, Any]: return clean_dict -def bulk_create(baker, quantity, **kwargs) -> List[Model]: +def bulk_create(baker: Baker[_M], quantity: int, **kwargs) -> List[_M]: """ Bulk create entries and all related FKs as well. From b4ddc90781050e931ed7a240a51dc2a28cf206d2 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Tue, 2 Nov 2021 15:39:02 +1300 Subject: [PATCH 2/5] Add changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0eacf0d..6dd8c7c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Fix bulk_create not working with multi-database setup [PR #252](https://github.com/model-bakers/model_bakery/pull/252) - Conditionally support NullBooleanField, it's under deprecation and will be removed in Django 4.0 [PR #25](https://github.com/model-bakers/model_bakery/pull/250) - Fix Django max version pin in requirements file [PR #251](https://github.com/model-bakers/model_bakery/pull/251) +- Improve type hinting to return the correct type depending on `_quantity` usage [PR #261](https://github.com/model-bakers/model_bakery/pull/261) ### Removed From 7f55b55597cfabbd5c5b57d87c039fe6dbb62361 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 4 Nov 2021 11:35:02 +1300 Subject: [PATCH 3/5] Don't use postponed eval to allow for Py3.6 compatibility --- model_bakery/baker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index bdf37189..89b6fd4b 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from os.path import dirname, join from typing import ( Any, @@ -326,7 +324,7 @@ def create( make_m2m: bool = False, create_files: bool = False, _using: str = "", - ) -> Baker[_NewM]: + ) -> "Baker[_NewM]": """Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting.""" baker_class = _custom_baker_class() or cls return cast(Type[Baker[_NewM]], baker_class)( From f587e8fdafda1e23329eaba3f23832abab90c11f Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 4 Nov 2021 11:41:06 +1300 Subject: [PATCH 4/5] Typing improvements fixing all outstanding mypy issues --- model_bakery/baker.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 89b6fd4b..d17dc1c1 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -109,7 +109,7 @@ def make( fields you want to define its values by yourself. """ _save_kwargs = _save_kwargs or {} - baker = Baker.create( + baker: Baker = Baker.create( _model, make_m2m=make_m2m, create_files=_create_files, _using=_using ) if _valid_quantity(_quantity): @@ -203,13 +203,13 @@ def prepare_recipe( ) -class ModelFinder(Generic[_M]): +class ModelFinder: """Encapsulates all the logic for finding a model to Baker.""" _unique_models: Optional[Dict[str, Type[Model]]] = None _ambiguous_models: Optional[List[str]] = None - def get_model(self, name: str) -> Type[_M]: + def get_model(self, name: str) -> Type[Model]: """Get a model. Args: @@ -222,7 +222,7 @@ def get_model(self, name: str) -> Type[_M]: try: if "." in name: app_label, model_name = name.split(".") - model = cast(Type[_M], apps.get_model(app_label, model_name)) + model = apps.get_model(app_label, model_name) else: model = self.get_model_by_name(name) except LookupError: @@ -233,7 +233,7 @@ def get_model(self, name: str) -> Type[_M]: return model - def get_model_by_name(self, name: str) -> Optional[Type[_M]]: + def get_model_by_name(self, name: str) -> Optional[Type[Model]]: """Get a model by name. If a model with that name exists in more than one app, raises @@ -315,7 +315,7 @@ class Baker(Generic[_M]): # Note: we're using one finder for all Baker instances to avoid # rebuilding the model cache for every make_* or prepare_* call. - finder: ModelFinder[_M] = ModelFinder() + finder = ModelFinder() @classmethod def create( @@ -348,7 +348,7 @@ def __init__( self._using = _using if isinstance(_model, str): - self.model = self.finder.get_model(_model) + self.model = cast(Type[_M], self.finder.get_model(_model)) else: self.model = cast(Type[_M], _model) @@ -492,7 +492,7 @@ def instance( # within its get_queryset() method (e.g. annotations) # is run. manager = getattr(self.model, _from_manager) - instance: _M = manager.get(pk=instance.pk) + instance = cast(_M, manager.get(pk=instance.pk)) return instance From 9fcdc48350118318aa6c98760d941f146fd7e237 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 5 Nov 2021 17:08:41 +1300 Subject: [PATCH 5/5] Rename typevars --- model_bakery/baker.py | 46 +++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index d17dc1c1..48b154a3 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -58,13 +58,13 @@ def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool: return quantity is not None and (not isinstance(quantity, int) or quantity < 1) -_M = TypeVar("_M", bound=Model) -_NewM = TypeVar("_NewM", bound=Model) +M = TypeVar("M", bound=Model) +NewM = TypeVar("NewM", bound=Model) @overload def make( - _model: Union[str, Type[_M]], + _model: Union[str, Type[M]], _quantity: None = None, make_m2m: bool = False, _save_kwargs: Optional[Dict] = None, @@ -73,13 +73,13 @@ def make( _using: str = "", _bulk_create: bool = False, **attrs: Any, -) -> _M: +) -> M: ... @overload def make( - _model: Union[str, Type[_M]], + _model: Union[str, Type[M]], _quantity: int, make_m2m: bool = False, _save_kwargs: Optional[Dict] = None, @@ -88,7 +88,7 @@ def make( _using: str = "", _bulk_create: bool = False, **attrs: Any, -) -> List[_M]: +) -> List[M]: ... @@ -134,28 +134,28 @@ def make( @overload def prepare( - _model: Union[str, Type[_M]], + _model: Union[str, Type[M]], _quantity: None = None, _save_related: bool = False, _using: str = "", **attrs, -) -> _M: +) -> M: ... @overload def prepare( - _model: Union[str, Type[_M]], + _model: Union[str, Type[M]], _quantity: int, _save_related: bool = False, _using: str = "", **attrs, -) -> List[_M]: +) -> List[M]: ... def prepare( - _model: Union[str, Type[_M]], + _model: Union[str, Type[M]], _quantity: Optional[int] = None, _save_related: bool = False, _using: str = "", @@ -309,7 +309,7 @@ def _custom_baker_class() -> Optional[Type]: ) -class Baker(Generic[_M]): +class Baker(Generic[M]): attr_mapping: Dict[str, Any] = {} type_mapping: Dict = {} @@ -320,20 +320,20 @@ class Baker(Generic[_M]): @classmethod def create( cls, - _model: Union[str, Type[_NewM]], + _model: Union[str, Type[NewM]], make_m2m: bool = False, create_files: bool = False, _using: str = "", - ) -> "Baker[_NewM]": + ) -> "Baker[NewM]": """Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting.""" baker_class = _custom_baker_class() or cls - return cast(Type[Baker[_NewM]], baker_class)( + return cast(Type[Baker[NewM]], baker_class)( _model, make_m2m, create_files, _using=_using ) def __init__( self, - _model: Union[str, Type[_M]], + _model: Union[str, Type[M]], make_m2m: bool = False, create_files: bool = False, _using: str = "", @@ -348,9 +348,9 @@ def __init__( self._using = _using if isinstance(_model, str): - self.model = cast(Type[_M], self.finder.get_model(_model)) + self.model = cast(Type[M], self.finder.get_model(_model)) else: - self.model = cast(Type[_M], _model) + self.model = cast(Type[M], _model) self.init_type_mapping() @@ -380,7 +380,7 @@ def make( params.update(attrs) return self._make(**params) - def prepare(self, _save_related=False, **attrs: Any) -> _M: + def prepare(self, _save_related=False, **attrs: Any) -> M: """Create, but do not persist, an instance of the associated model.""" return self._make(commit=False, commit_related=_save_related, **attrs) @@ -400,7 +400,7 @@ def _make( _refresh_after_create=False, _from_manager=None, **attrs: Any, - ) -> _M: + ) -> M: _save_kwargs = _save_kwargs or {} if self._using: _save_kwargs["using"] = self._using @@ -472,7 +472,7 @@ def m2m_value(self, field: ManyToManyField) -> List[Any]: def instance( self, attrs: Dict[str, Any], _commit, _save_kwargs, _from_manager - ) -> _M: + ) -> M: one_to_many_keys = {} for k in tuple(attrs.keys()): field = getattr(self.model, k, None) @@ -492,7 +492,7 @@ def instance( # within its get_queryset() method (e.g. annotations) # is run. manager = getattr(self.model, _from_manager) - instance = cast(_M, manager.get(pk=instance.pk)) + instance = cast(M, manager.get(pk=instance.pk)) return instance @@ -714,7 +714,7 @@ def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, Any]: return clean_dict -def bulk_create(baker: Baker[_M], quantity: int, **kwargs) -> List[_M]: +def bulk_create(baker: Baker[M], quantity: int, **kwargs) -> List[M]: """ Bulk create entries and all related FKs as well.