Skip to content

Commit

Permalink
Improve type hinting (#261)
Browse files Browse the repository at this point in the history
* Improve type hinting
`make`/`prepare`  now return the correct type depending on `_quantity`

* Add changelog

* Don't use postponed eval to allow for Py3.6 compatibility

* Typing improvements fixing all outstanding mypy issues

* Rename typevars
  • Loading branch information
SmileyChris committed Dec 13, 2021
1 parent b643563 commit fea8c21
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Fix bulk_create not working with multi-database setup [PR #252](https://github.com/model-bakers/model_bakery/pull/252)
- Conditionally support NullBooleanField, it's under deprecation and will be removed in Django 4.0 [PR #25](https://github.com/model-bakers/model_bakery/pull/250)
- Fix Django max version pin in requirements file [PR #251](https://github.com/model-bakers/model_bakery/pull/251)
- Improve type hinting to return the correct type depending on `_quantity` usage [PR #261](https://github.com/model-bakers/model_bakery/pull/261)

### Removed

Expand Down
135 changes: 96 additions & 39 deletions model_bakery/baker.py
Expand Up @@ -3,13 +3,15 @@
Any,
Callable,
Dict,
Generic,
Iterator,
List,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)

from django.apps import apps
Expand All @@ -25,7 +27,6 @@
Model,
OneToOneField,
)
from django.db.models.base import ModelBase
from django.db.models.fields.proxy import OrderWrt
from django.db.models.fields.related import (
ReverseManyToOneDescriptor as ForeignRelatedObjectsDescriptor,
Expand Down Expand Up @@ -57,27 +58,58 @@ def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool:
return quantity is not None and (not isinstance(quantity, int) or quantity < 1)


SpecificModelType = TypeVar("SpecificModelType", bound=Model)
M = TypeVar("M", bound=Model)
NewM = TypeVar("NewM", bound=Model)


@overload
def make(
_model: Union[str, Type[SpecificModelType]],
_model: Union[str, Type[M]],
_quantity: None = None,
make_m2m: bool = False,
_save_kwargs: Optional[Dict] = None,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
**attrs: Any,
) -> M:
...


@overload
def make(
_model: Union[str, Type[M]],
_quantity: int,
make_m2m: bool = False,
_save_kwargs: Optional[Dict] = None,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
**attrs: Any,
) -> List[M]:
...


def make(
_model,
_quantity: Optional[int] = None,
make_m2m: bool = False,
_save_kwargs: Optional[Dict] = None,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
**attrs: Any
) -> Union[SpecificModelType, List[SpecificModelType]]:
**attrs: Any,
):
"""Create a persisted instance from a given model its associated models.
It fill the fields with random values or you can specify which
fields you want to define its values by yourself.
"""
_save_kwargs = _save_kwargs or {}
baker = Baker.create(
baker: Baker = Baker.create(
_model, make_m2m=make_m2m, create_files=_create_files, _using=_using
)
if _valid_quantity(_quantity):
Expand All @@ -100,13 +132,35 @@ def make(
)


@overload
def prepare(
_model: Union[str, Type[M]],
_quantity: None = None,
_save_related: bool = False,
_using: str = "",
**attrs,
) -> M:
...


@overload
def prepare(
_model: Union[str, Type[SpecificModelType]],
_quantity=None,
_save_related=False,
_using="",
**attrs
) -> Union[SpecificModelType, List[SpecificModelType]]:
_model: Union[str, Type[M]],
_quantity: int,
_save_related: bool = False,
_using: str = "",
**attrs,
) -> List[M]:
...


def prepare(
_model: Union[str, Type[M]],
_quantity: Optional[int] = None,
_save_related: bool = False,
_using: str = "",
**attrs,
):
"""Create but do not persist an instance from a given model.
It fill the fields with random values or you can specify which
Expand All @@ -128,7 +182,8 @@ def prepare(
def _recipe(name: str) -> Any:
app_name, recipe_name = name.rsplit(".", 1)
try:
pkg = apps.get_app_config(app_name).module.__package__
module = apps.get_app_config(app_name).module
pkg = module.__package__ if module else app_name
except LookupError:
pkg = app_name
return import_from_str(".".join((pkg, "baker_recipes", recipe_name)))
Expand All @@ -148,11 +203,11 @@ def prepare_recipe(
)


class ModelFinder(object):
class ModelFinder:
"""Encapsulates all the logic for finding a model to Baker."""

_unique_models = None # type: Optional[Dict[str, Type[Model]]]
_ambiguous_models = None # type: Optional[List[str]]
_unique_models: Optional[Dict[str, Type[Model]]] = None
_ambiguous_models: Optional[List[str]] = None

def get_model(self, name: str) -> Type[Model]:
"""Get a model.
Expand Down Expand Up @@ -254,9 +309,9 @@ def _custom_baker_class() -> Optional[Type]:
)


class Baker(object):
attr_mapping = {} # type: Dict[str, Any]
type_mapping = {} # type: Dict
class Baker(Generic[M]):
attr_mapping: Dict[str, Any] = {}
type_mapping: Dict = {}

# Note: we're using one finder for all Baker instances to avoid
# rebuilding the model cache for every make_* or prepare_* call.
Expand All @@ -265,35 +320,37 @@ class Baker(object):
@classmethod
def create(
cls,
_model: Union[str, Type[ModelBase]],
_model: Union[str, Type[NewM]],
make_m2m: bool = False,
create_files: bool = False,
_using: str = "",
) -> "Baker":
) -> "Baker[NewM]":
"""Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting."""
baker_class = _custom_baker_class() or cls
return baker_class(_model, make_m2m, create_files, _using=_using)
return cast(Type[Baker[NewM]], baker_class)(
_model, make_m2m, create_files, _using=_using
)

def __init__(
self,
_model: Union[str, Type[ModelBase]],
_model: Union[str, Type[M]],
make_m2m: bool = False,
create_files: bool = False,
_using: str = "",
) -> None:
self.make_m2m = make_m2m
self.create_files = create_files
self.m2m_dict = {} # type: Dict[str, List]
self.iterator_attrs = {} # type: Dict[str, Iterator]
self.model_attrs = {} # type: Dict[str, Any]
self.rel_attrs = {} # type: Dict[str, Any]
self.rel_fields = [] # type: List[str]
self.m2m_dict: Dict[str, List] = {}
self.iterator_attrs: Dict[str, Iterator] = {}
self.model_attrs: Dict[str, Any] = {}
self.rel_attrs: Dict[str, Any] = {}
self.rel_fields: List[str] = []
self._using = _using

if isinstance(_model, str):
self.model = self.finder.get_model(_model)
self.model = cast(Type[M], self.finder.get_model(_model))
else:
self.model = _model
self.model = cast(Type[M], _model)

self.init_type_mapping()

Expand All @@ -310,7 +367,7 @@ def make(
_save_kwargs: Optional[Dict[str, Any]] = None,
_refresh_after_create: bool = False,
_from_manager=None,
**attrs: Any
**attrs: Any,
):
"""Create and persist an instance of the model associated with Baker instance."""
params = {
Expand All @@ -323,7 +380,7 @@ def make(
params.update(attrs)
return self._make(**params)

def prepare(self, _save_related=False, **attrs: Any) -> Model:
def prepare(self, _save_related=False, **attrs: Any) -> M:
"""Create, but do not persist, an instance of the associated model."""
return self._make(commit=False, commit_related=_save_related, **attrs)

Expand All @@ -342,8 +399,8 @@ def _make(
_save_kwargs=None,
_refresh_after_create=False,
_from_manager=None,
**attrs: Any
) -> Model:
**attrs: Any,
) -> M:
_save_kwargs = _save_kwargs or {}
if self._using:
_save_kwargs["using"] = self._using
Expand Down Expand Up @@ -415,14 +472,14 @@ def m2m_value(self, field: ManyToManyField) -> List[Any]:

def instance(
self, attrs: Dict[str, Any], _commit, _save_kwargs, _from_manager
) -> Model:
) -> M:
one_to_many_keys = {}
for k in tuple(attrs.keys()):
field = getattr(self.model, k, None)
if isinstance(field, ForeignRelatedObjectsDescriptor):
one_to_many_keys[k] = attrs.pop(k)

instance = self.model(**attrs) # type: Model
instance = self.model(**attrs)
# m2m only works for persisted instances
if _commit:
instance.save(**_save_kwargs)
Expand All @@ -435,15 +492,15 @@ def instance(
# within its get_queryset() method (e.g. annotations)
# is run.
manager = getattr(self.model, _from_manager)
instance = manager.get(pk=instance.pk)
instance = cast(M, manager.get(pk=instance.pk))

return instance

def create_by_related_name(
self, instance: Model, related: Union[ManyToOneRel, OneToOneRel]
) -> None:
rel_name = related.get_accessor_name()
if rel_name not in self.rel_fields:
if not rel_name or rel_name not in self.rel_fields:
return

kwargs = filter_rel_attrs(rel_name, **self.rel_attrs)
Expand Down Expand Up @@ -657,7 +714,7 @@ def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, Any]:
return clean_dict


def bulk_create(baker, quantity, **kwargs) -> List[Model]:
def bulk_create(baker: Baker[M], quantity: int, **kwargs) -> List[M]:
"""
Bulk create entries and all related FKs as well.
Expand Down

0 comments on commit fea8c21

Please sign in to comment.