diff --git a/CHANGELOG.md b/CHANGELOG.md index d0eacf0d..6dd8c7c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Fix bulk_create not working with multi-database setup [PR #252](https://github.com/model-bakers/model_bakery/pull/252) - Conditionally support NullBooleanField, it's under deprecation and will be removed in Django 4.0 [PR #25](https://github.com/model-bakers/model_bakery/pull/250) - Fix Django max version pin in requirements file [PR #251](https://github.com/model-bakers/model_bakery/pull/251) +- Improve type hinting to return the correct type depending on `_quantity` usage [PR #261](https://github.com/model-bakers/model_bakery/pull/261) ### Removed diff --git a/model_bakery/baker.py b/model_bakery/baker.py index b0f7cef8..48b154a3 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -3,6 +3,7 @@ Any, Callable, Dict, + Generic, Iterator, List, Optional, @@ -10,6 +11,7 @@ TypeVar, Union, cast, + overload, ) from django.apps import apps @@ -25,7 +27,6 @@ Model, OneToOneField, ) -from django.db.models.base import ModelBase from django.db.models.fields.proxy import OrderWrt from django.db.models.fields.related import ( ReverseManyToOneDescriptor as ForeignRelatedObjectsDescriptor, @@ -57,11 +58,42 @@ def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool: return quantity is not None and (not isinstance(quantity, int) or quantity < 1) -SpecificModelType = TypeVar("SpecificModelType", bound=Model) +M = TypeVar("M", bound=Model) +NewM = TypeVar("NewM", bound=Model) +@overload def make( - _model: Union[str, Type[SpecificModelType]], + _model: Union[str, Type[M]], + _quantity: None = None, + make_m2m: bool = False, + _save_kwargs: Optional[Dict] = None, + _refresh_after_create: bool = False, + _create_files: bool = False, + _using: str = "", + _bulk_create: bool = False, + **attrs: Any, +) -> M: + ... + + +@overload +def make( + _model: Union[str, Type[M]], + _quantity: int, + make_m2m: bool = False, + _save_kwargs: Optional[Dict] = None, + _refresh_after_create: bool = False, + _create_files: bool = False, + _using: str = "", + _bulk_create: bool = False, + **attrs: Any, +) -> List[M]: + ... + + +def make( + _model, _quantity: Optional[int] = None, make_m2m: bool = False, _save_kwargs: Optional[Dict] = None, @@ -69,15 +101,15 @@ def make( _create_files: bool = False, _using: str = "", _bulk_create: bool = False, - **attrs: Any -) -> Union[SpecificModelType, List[SpecificModelType]]: + **attrs: Any, +): """Create a persisted instance from a given model its associated models. It fill the fields with random values or you can specify which fields you want to define its values by yourself. """ _save_kwargs = _save_kwargs or {} - baker = Baker.create( + baker: Baker = Baker.create( _model, make_m2m=make_m2m, create_files=_create_files, _using=_using ) if _valid_quantity(_quantity): @@ -100,13 +132,35 @@ def make( ) +@overload +def prepare( + _model: Union[str, Type[M]], + _quantity: None = None, + _save_related: bool = False, + _using: str = "", + **attrs, +) -> M: + ... + + +@overload def prepare( - _model: Union[str, Type[SpecificModelType]], - _quantity=None, - _save_related=False, - _using="", - **attrs -) -> Union[SpecificModelType, List[SpecificModelType]]: + _model: Union[str, Type[M]], + _quantity: int, + _save_related: bool = False, + _using: str = "", + **attrs, +) -> List[M]: + ... + + +def prepare( + _model: Union[str, Type[M]], + _quantity: Optional[int] = None, + _save_related: bool = False, + _using: str = "", + **attrs, +): """Create but do not persist an instance from a given model. It fill the fields with random values or you can specify which @@ -128,7 +182,8 @@ def prepare( def _recipe(name: str) -> Any: app_name, recipe_name = name.rsplit(".", 1) try: - pkg = apps.get_app_config(app_name).module.__package__ + module = apps.get_app_config(app_name).module + pkg = module.__package__ if module else app_name except LookupError: pkg = app_name return import_from_str(".".join((pkg, "baker_recipes", recipe_name))) @@ -148,11 +203,11 @@ def prepare_recipe( ) -class ModelFinder(object): +class ModelFinder: """Encapsulates all the logic for finding a model to Baker.""" - _unique_models = None # type: Optional[Dict[str, Type[Model]]] - _ambiguous_models = None # type: Optional[List[str]] + _unique_models: Optional[Dict[str, Type[Model]]] = None + _ambiguous_models: Optional[List[str]] = None def get_model(self, name: str) -> Type[Model]: """Get a model. @@ -254,9 +309,9 @@ def _custom_baker_class() -> Optional[Type]: ) -class Baker(object): - attr_mapping = {} # type: Dict[str, Any] - type_mapping = {} # type: Dict +class Baker(Generic[M]): + attr_mapping: Dict[str, Any] = {} + type_mapping: Dict = {} # Note: we're using one finder for all Baker instances to avoid # rebuilding the model cache for every make_* or prepare_* call. @@ -265,35 +320,37 @@ class Baker(object): @classmethod def create( cls, - _model: Union[str, Type[ModelBase]], + _model: Union[str, Type[NewM]], make_m2m: bool = False, create_files: bool = False, _using: str = "", - ) -> "Baker": + ) -> "Baker[NewM]": """Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting.""" baker_class = _custom_baker_class() or cls - return baker_class(_model, make_m2m, create_files, _using=_using) + return cast(Type[Baker[NewM]], baker_class)( + _model, make_m2m, create_files, _using=_using + ) def __init__( self, - _model: Union[str, Type[ModelBase]], + _model: Union[str, Type[M]], make_m2m: bool = False, create_files: bool = False, _using: str = "", ) -> None: self.make_m2m = make_m2m self.create_files = create_files - self.m2m_dict = {} # type: Dict[str, List] - self.iterator_attrs = {} # type: Dict[str, Iterator] - self.model_attrs = {} # type: Dict[str, Any] - self.rel_attrs = {} # type: Dict[str, Any] - self.rel_fields = [] # type: List[str] + self.m2m_dict: Dict[str, List] = {} + self.iterator_attrs: Dict[str, Iterator] = {} + self.model_attrs: Dict[str, Any] = {} + self.rel_attrs: Dict[str, Any] = {} + self.rel_fields: List[str] = [] self._using = _using if isinstance(_model, str): - self.model = self.finder.get_model(_model) + self.model = cast(Type[M], self.finder.get_model(_model)) else: - self.model = _model + self.model = cast(Type[M], _model) self.init_type_mapping() @@ -310,7 +367,7 @@ def make( _save_kwargs: Optional[Dict[str, Any]] = None, _refresh_after_create: bool = False, _from_manager=None, - **attrs: Any + **attrs: Any, ): """Create and persist an instance of the model associated with Baker instance.""" params = { @@ -323,7 +380,7 @@ def make( params.update(attrs) return self._make(**params) - def prepare(self, _save_related=False, **attrs: Any) -> Model: + def prepare(self, _save_related=False, **attrs: Any) -> M: """Create, but do not persist, an instance of the associated model.""" return self._make(commit=False, commit_related=_save_related, **attrs) @@ -342,8 +399,8 @@ def _make( _save_kwargs=None, _refresh_after_create=False, _from_manager=None, - **attrs: Any - ) -> Model: + **attrs: Any, + ) -> M: _save_kwargs = _save_kwargs or {} if self._using: _save_kwargs["using"] = self._using @@ -415,14 +472,14 @@ def m2m_value(self, field: ManyToManyField) -> List[Any]: def instance( self, attrs: Dict[str, Any], _commit, _save_kwargs, _from_manager - ) -> Model: + ) -> M: one_to_many_keys = {} for k in tuple(attrs.keys()): field = getattr(self.model, k, None) if isinstance(field, ForeignRelatedObjectsDescriptor): one_to_many_keys[k] = attrs.pop(k) - instance = self.model(**attrs) # type: Model + instance = self.model(**attrs) # m2m only works for persisted instances if _commit: instance.save(**_save_kwargs) @@ -435,7 +492,7 @@ def instance( # within its get_queryset() method (e.g. annotations) # is run. manager = getattr(self.model, _from_manager) - instance = manager.get(pk=instance.pk) + instance = cast(M, manager.get(pk=instance.pk)) return instance @@ -443,7 +500,7 @@ def create_by_related_name( self, instance: Model, related: Union[ManyToOneRel, OneToOneRel] ) -> None: rel_name = related.get_accessor_name() - if rel_name not in self.rel_fields: + if not rel_name or rel_name not in self.rel_fields: return kwargs = filter_rel_attrs(rel_name, **self.rel_attrs) @@ -657,7 +714,7 @@ def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, Any]: return clean_dict -def bulk_create(baker, quantity, **kwargs) -> List[Model]: +def bulk_create(baker: Baker[M], quantity: int, **kwargs) -> List[M]: """ Bulk create entries and all related FKs as well.