Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type hinting #261

Merged
merged 5 commits into from Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Fix bulk_create not working with multi-database setup [PR #252](https://github.com/model-bakers/model_bakery/pull/252)
- Conditionally support NullBooleanField, it's under deprecation and will be removed in Django 4.0 [PR #25](https://github.com/model-bakers/model_bakery/pull/250)
- Fix Django max version pin in requirements file [PR #251](https://github.com/model-bakers/model_bakery/pull/251)
- Improve type hinting to return the correct type depending on `_quantity` usage [PR #261](https://github.com/model-bakers/model_bakery/pull/261)

### Removed

Expand Down
135 changes: 96 additions & 39 deletions model_bakery/baker.py
Expand Up @@ -3,13 +3,15 @@
Any,
Callable,
Dict,
Generic,
Iterator,
List,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)

from django.apps import apps
Expand All @@ -25,7 +27,6 @@
Model,
OneToOneField,
)
from django.db.models.base import ModelBase
from django.db.models.fields.proxy import OrderWrt
from django.db.models.fields.related import (
ReverseManyToOneDescriptor as ForeignRelatedObjectsDescriptor,
Expand Down Expand Up @@ -57,27 +58,58 @@ def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool:
return quantity is not None and (not isinstance(quantity, int) or quantity < 1)


SpecificModelType = TypeVar("SpecificModelType", bound=Model)
_M = TypeVar("_M", bound=Model)
_NewM = TypeVar("_NewM", bound=Model)
SmileyChris marked this conversation as resolved.
Show resolved Hide resolved


@overload
def make(
_model: Union[str, Type[SpecificModelType]],
_model: Union[str, Type[_M]],
_quantity: None = None,
make_m2m: bool = False,
_save_kwargs: Optional[Dict] = None,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
**attrs: Any,
) -> _M:
...


@overload
def make(
_model: Union[str, Type[_M]],
_quantity: int,
make_m2m: bool = False,
_save_kwargs: Optional[Dict] = None,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
**attrs: Any,
) -> List[_M]:
...


def make(
_model,
_quantity: Optional[int] = None,
make_m2m: bool = False,
_save_kwargs: Optional[Dict] = None,
_refresh_after_create: bool = False,
_create_files: bool = False,
_using: str = "",
_bulk_create: bool = False,
**attrs: Any
) -> Union[SpecificModelType, List[SpecificModelType]]:
**attrs: Any,
):
"""Create a persisted instance from a given model its associated models.

It fill the fields with random values or you can specify which
fields you want to define its values by yourself.
"""
_save_kwargs = _save_kwargs or {}
baker = Baker.create(
baker: Baker = Baker.create(
_model, make_m2m=make_m2m, create_files=_create_files, _using=_using
)
if _valid_quantity(_quantity):
Expand All @@ -100,13 +132,35 @@ def make(
)


@overload
def prepare(
_model: Union[str, Type[_M]],
_quantity: None = None,
_save_related: bool = False,
_using: str = "",
**attrs,
) -> _M:
...


@overload
def prepare(
_model: Union[str, Type[SpecificModelType]],
_quantity=None,
_save_related=False,
_using="",
**attrs
) -> Union[SpecificModelType, List[SpecificModelType]]:
_model: Union[str, Type[_M]],
_quantity: int,
_save_related: bool = False,
_using: str = "",
**attrs,
) -> List[_M]:
...


def prepare(
_model: Union[str, Type[_M]],
_quantity: Optional[int] = None,
_save_related: bool = False,
_using: str = "",
**attrs,
):
"""Create but do not persist an instance from a given model.

It fill the fields with random values or you can specify which
Expand All @@ -128,7 +182,8 @@ def prepare(
def _recipe(name: str) -> Any:
app_name, recipe_name = name.rsplit(".", 1)
try:
pkg = apps.get_app_config(app_name).module.__package__
module = apps.get_app_config(app_name).module
pkg = module.__package__ if module else app_name
except LookupError:
pkg = app_name
return import_from_str(".".join((pkg, "baker_recipes", recipe_name)))
Expand All @@ -148,11 +203,11 @@ def prepare_recipe(
)


class ModelFinder(object):
class ModelFinder:
"""Encapsulates all the logic for finding a model to Baker."""

_unique_models = None # type: Optional[Dict[str, Type[Model]]]
_ambiguous_models = None # type: Optional[List[str]]
_unique_models: Optional[Dict[str, Type[Model]]] = None
_ambiguous_models: Optional[List[str]] = None

def get_model(self, name: str) -> Type[Model]:
"""Get a model.
Expand Down Expand Up @@ -254,9 +309,9 @@ def _custom_baker_class() -> Optional[Type]:
)


class Baker(object):
attr_mapping = {} # type: Dict[str, Any]
type_mapping = {} # type: Dict
class Baker(Generic[_M]):
attr_mapping: Dict[str, Any] = {}
type_mapping: Dict = {}

# Note: we're using one finder for all Baker instances to avoid
# rebuilding the model cache for every make_* or prepare_* call.
Expand All @@ -265,35 +320,37 @@ class Baker(object):
@classmethod
def create(
cls,
_model: Union[str, Type[ModelBase]],
_model: Union[str, Type[_NewM]],
make_m2m: bool = False,
create_files: bool = False,
_using: str = "",
) -> "Baker":
) -> "Baker[_NewM]":
"""Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting."""
baker_class = _custom_baker_class() or cls
return baker_class(_model, make_m2m, create_files, _using=_using)
return cast(Type[Baker[_NewM]], baker_class)(
_model, make_m2m, create_files, _using=_using
)

def __init__(
self,
_model: Union[str, Type[ModelBase]],
_model: Union[str, Type[_M]],
make_m2m: bool = False,
create_files: bool = False,
_using: str = "",
) -> None:
self.make_m2m = make_m2m
self.create_files = create_files
self.m2m_dict = {} # type: Dict[str, List]
self.iterator_attrs = {} # type: Dict[str, Iterator]
self.model_attrs = {} # type: Dict[str, Any]
self.rel_attrs = {} # type: Dict[str, Any]
self.rel_fields = [] # type: List[str]
self.m2m_dict: Dict[str, List] = {}
self.iterator_attrs: Dict[str, Iterator] = {}
self.model_attrs: Dict[str, Any] = {}
self.rel_attrs: Dict[str, Any] = {}
self.rel_fields: List[str] = []
self._using = _using

if isinstance(_model, str):
self.model = self.finder.get_model(_model)
self.model = cast(Type[_M], self.finder.get_model(_model))
else:
self.model = _model
self.model = cast(Type[_M], _model)

self.init_type_mapping()

Expand All @@ -310,7 +367,7 @@ def make(
_save_kwargs: Optional[Dict[str, Any]] = None,
_refresh_after_create: bool = False,
_from_manager=None,
**attrs: Any
**attrs: Any,
):
"""Create and persist an instance of the model associated with Baker instance."""
params = {
Expand All @@ -323,7 +380,7 @@ def make(
params.update(attrs)
return self._make(**params)

def prepare(self, _save_related=False, **attrs: Any) -> Model:
def prepare(self, _save_related=False, **attrs: Any) -> _M:
"""Create, but do not persist, an instance of the associated model."""
return self._make(commit=False, commit_related=_save_related, **attrs)

Expand All @@ -342,8 +399,8 @@ def _make(
_save_kwargs=None,
_refresh_after_create=False,
_from_manager=None,
**attrs: Any
) -> Model:
**attrs: Any,
) -> _M:
_save_kwargs = _save_kwargs or {}
if self._using:
_save_kwargs["using"] = self._using
Expand Down Expand Up @@ -415,14 +472,14 @@ def m2m_value(self, field: ManyToManyField) -> List[Any]:

def instance(
self, attrs: Dict[str, Any], _commit, _save_kwargs, _from_manager
) -> Model:
) -> _M:
one_to_many_keys = {}
for k in tuple(attrs.keys()):
field = getattr(self.model, k, None)
if isinstance(field, ForeignRelatedObjectsDescriptor):
one_to_many_keys[k] = attrs.pop(k)

instance = self.model(**attrs) # type: Model
instance = self.model(**attrs)
# m2m only works for persisted instances
if _commit:
instance.save(**_save_kwargs)
Expand All @@ -435,15 +492,15 @@ def instance(
# within its get_queryset() method (e.g. annotations)
# is run.
manager = getattr(self.model, _from_manager)
instance = manager.get(pk=instance.pk)
instance = cast(_M, manager.get(pk=instance.pk))

return instance

def create_by_related_name(
self, instance: Model, related: Union[ManyToOneRel, OneToOneRel]
) -> None:
rel_name = related.get_accessor_name()
if rel_name not in self.rel_fields:
if not rel_name or rel_name not in self.rel_fields:
return

kwargs = filter_rel_attrs(rel_name, **self.rel_attrs)
Expand Down Expand Up @@ -657,7 +714,7 @@ def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, Any]:
return clean_dict


def bulk_create(baker, quantity, **kwargs) -> List[Model]:
def bulk_create(baker: Baker[_M], quantity: int, **kwargs) -> List[_M]:
"""
Bulk create entries and all related FKs as well.

Expand Down