Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type support and mypy checking #100

Merged
merged 22 commits into from Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
136 changes: 78 additions & 58 deletions model_bakery/baker.py
@@ -1,4 +1,5 @@
from os.path import dirname, join
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union, cast

from django.apps import apps
from django.conf import settings
Expand All @@ -10,13 +11,15 @@
FileField,
ForeignKey,
ManyToManyField,
Model,
OneToOneField,
)
from django.db.models.base import ModelBase
from django.db.models.fields.proxy import OrderWrt
from django.db.models.fields.related import (
ReverseManyToOneDescriptor as ForeignRelatedObjectsDescriptor,
)
from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel

from . import generators, random_gen
from .exceptions import (
Expand All @@ -39,18 +42,18 @@
MAX_MANY_QUANTITY = 5


def _valid_quantity(quantity):
def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool:
return quantity is not None and (not isinstance(quantity, int) or quantity < 1)


def make(
_model,
_quantity=None,
make_m2m=False,
_save_kwargs=None,
_refresh_after_create=False,
_create_files=False,
**attrs
_model: str,
_quantity: Optional[int] = None,
make_m2m: bool = False,
_save_kwargs: Optional[Dict] = None,
_refresh_after_create: bool = False,
_create_files: bool = False,
**attrs: Any
):
"""Create a persisted instance from a given model its associated models.

Expand All @@ -76,7 +79,7 @@ def make(
)


def prepare(_model, _quantity=None, _save_related=False, **attrs):
def prepare(_model: str, _quantity=None, _save_related=False, **attrs) -> Model:
"""Create but do not persist an instance from a given model.

It fill the fields with random values or you can specify which
Expand All @@ -95,7 +98,7 @@ def prepare(_model, _quantity=None, _save_related=False, **attrs):
return baker.prepare(_save_related=_save_related, **attrs)


def _recipe(name):
def _recipe(name: str) -> Any:
app, recipe_name = name.rsplit(".", 1)
return import_from_str(".".join((app, "baker_recipes", recipe_name)))

Expand All @@ -113,10 +116,10 @@ def prepare_recipe(baker_recipe_name, _quantity=None, _save_related=False, **new
class ModelFinder(object):
"""Encapsulates all the logic for finding a model to Baker."""

_unique_models = None
_ambiguous_models = None
_unique_models = None # type: Optional[Dict[str, Type[Model]]]
_ambiguous_models = None # type: Optional[List[str]]

def get_model(self, name):
def get_model(self, name: str) -> Type[Model]:
"""Get a model.

Args:
Expand All @@ -140,26 +143,26 @@ def get_model(self, name):

return model

def get_model_by_name(self, name):
def get_model_by_name(self, name: str) -> Optional[Type[Model]]:
"""Get a model by name.

If a model with that name exists in more than one app, raises
AmbiguousModelName.
"""
name = name.lower()

if self._unique_models is None:
if self._unique_models is None or self._ambiguous_models is None:
self._populate()

if name in self._ambiguous_models:
if name in cast(List, self._ambiguous_models):
raise AmbiguousModelName(
"%s is a model in more than one app. "
'Use the form "app.model".' % name.title()
)

return self._unique_models.get(name)
return cast(Dict, self._unique_models).get(name)

def _populate(self):
def _populate(self) -> None:
"""Cache models for faster self._get_model."""
unique_models = {}
ambiguous_models = []
Expand All @@ -180,14 +183,14 @@ def _populate(self):
self._unique_models = unique_models


def is_iterator(value):
def is_iterator(value: Any) -> bool:
if not hasattr(value, "__iter__"):
return False

return hasattr(value, "__next__")


def _custom_baker_class():
def _custom_baker_class() -> Optional[Type]:
"""Return the specified custom baker class.

Returns:
Expand Down Expand Up @@ -217,36 +220,43 @@ def _custom_baker_class():


class Baker(object):
attr_mapping = {}
type_mapping = None
attr_mapping = {} # type: Dict[str, Any]
type_mapping = {} # type: Dict

# Note: we're using one finder for all Baker instances to avoid
# rebuilding the model cache for every make_* or prepare_* call.
finder = ModelFinder()

@classmethod
def create(cls, _model, make_m2m=False, create_files=False):
def create(
cls, _model: str, make_m2m: bool = False, create_files: bool = False
) -> "Baker":
"""Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting."""
baker_class = _custom_baker_class() or cls
return baker_class(_model, make_m2m, create_files)

def __init__(self, _model, make_m2m=False, create_files=False):
def __init__(
self,
_model: Union[str, Type[ModelBase]],
make_m2m: bool = False,
create_files: bool = False,
) -> None:
self.make_m2m = make_m2m
self.create_files = create_files
self.m2m_dict = {}
self.iterator_attrs = {}
self.model_attrs = {}
self.rel_attrs = {}
self.rel_fields = []
self.m2m_dict = {} # type: Dict[str, List]
self.iterator_attrs = {} # type: Dict[str, Iterator]
self.model_attrs = {} # type: Dict[str, Any]
self.rel_attrs = {} # type: Dict[str, Any]
self.rel_fields = [] # type: List[str]

if isinstance(_model, ModelBase):
self.model = _model
else:
if isinstance(_model, str):
self.model = self.finder.get_model(_model)
else:
self.model = _model

self.init_type_mapping()

def init_type_mapping(self):
def init_type_mapping(self) -> None:
self.type_mapping = generators.get_type_mapping()
generators_from_settings = getattr(settings, "BAKER_CUSTOM_FIELDS_GEN", {})
for k, v in generators_from_settings.items():
Expand All @@ -256,10 +266,10 @@ def init_type_mapping(self):

def make(
self,
_save_kwargs=None,
_refresh_after_create=False,
_save_kwargs: Optional[Dict[str, Any]] = None,
_refresh_after_create: bool = False,
_from_manager=None,
**attrs
**attrs: Any
):
"""Create and persist an instance of the model associated with Baker instance."""
params = {
Expand All @@ -272,14 +282,16 @@ def make(
params.update(attrs)
return self._make(**params)

def prepare(self, _save_related=False, **attrs):
def prepare(self, _save_related=False, **attrs: Any) -> Model:
"""Create, but do not persist, an instance of the associated model."""
return self._make(commit=False, commit_related=_save_related, **attrs)

def get_fields(self):
def get_fields(self) -> Any:
return self.model._meta.fields + self.model._meta.many_to_many

def get_related(self):
def get_related(
self,
) -> List[Union[ManyToOneRel, OneToOneRel]]:
return [r for r in self.model._meta.related_objects if not r.many_to_many]

def _make(
Expand All @@ -289,8 +301,8 @@ def _make(
_save_kwargs=None,
_refresh_after_create=False,
_from_manager=None,
**attrs
):
**attrs: Any
) -> Model:
_save_kwargs = _save_kwargs or {}

self._clean_attrs(attrs)
Expand Down Expand Up @@ -336,21 +348,23 @@ def _make(

return instance

def m2m_value(self, field):
def m2m_value(self, field: ManyToManyField) -> List[Any]:
if field.name in self.rel_fields:
return self.generate_value(field)
if not self.make_m2m or field.null and not field.fill_optional:
return []
return self.generate_value(field)

def instance(self, attrs, _commit, _save_kwargs, _from_manager):
def instance(
self, attrs: Dict[str, Any], _commit, _save_kwargs, _from_manager
) -> Model:
one_to_many_keys = {}
for k in tuple(attrs.keys()):
field = getattr(self.model, k, None)
if isinstance(field, ForeignRelatedObjectsDescriptor):
one_to_many_keys[k] = attrs.pop(k)

instance = self.model(**attrs)
instance = self.model(**attrs) # type: Model
# m2m only works for persisted instances
if _commit:
instance.save(**_save_kwargs)
Expand All @@ -367,19 +381,20 @@ def instance(self, attrs, _commit, _save_kwargs, _from_manager):

return instance

def create_by_related_name(self, instance, related):
def create_by_related_name(
self, instance: Model, related: Union[ManyToOneRel, OneToOneRel]
) -> None:
rel_name = related.get_accessor_name()
if rel_name not in self.rel_fields:
return

kwargs = filter_rel_attrs(rel_name, **self.rel_attrs)
kwargs[related.field.name] = instance
kwargs["_model"] = related.field.model

make(**kwargs)
make(related.field.model, **kwargs)

def _clean_attrs(self, attrs):
def is_rel_field(x):
def _clean_attrs(self, attrs: Dict[str, Any]) -> None:
def is_rel_field(x: str):
return "__" in x

self.fill_in_optional = attrs.pop("_fill_optional", False)
Expand All @@ -401,7 +416,7 @@ def is_rel_field(x):
x.split("__")[0] for x in self.rel_attrs.keys() if is_rel_field(x)
]

def _skip_field(self, field):
def _skip_field(self, field: Field) -> bool:
from django.contrib.contenttypes.fields import GenericRelation

# check for fill optional argument
Expand Down Expand Up @@ -444,7 +459,7 @@ def _skip_field(self, field):

return False

def _handle_one_to_many(self, instance, attrs):
def _handle_one_to_many(self, instance: Model, attrs: Dict[str, Any]):
for k, v in attrs.items():
manager = getattr(instance, k)

Expand All @@ -454,7 +469,7 @@ def _handle_one_to_many(self, instance, attrs):
# for many-to-many relationships the bulk keyword argument doesn't exist
manager.set(v, clear=True)

def _handle_m2m(self, instance):
def _handle_m2m(self, instance: Model):
for key, values in self.m2m_dict.items():
for value in values:
if not value.pk:
Expand All @@ -473,10 +488,12 @@ def _handle_m2m(self, instance):
}
make(through_model, **base_kwargs)

def _remote_field(self, field):
def _remote_field(
self, field: Union[ForeignKey, OneToOneField]
) -> Union[OneToOneRel, ManyToOneRel]:
return field.remote_field

def generate_value(self, field, commit=True):
def generate_value(self, field: Field, commit: bool = True) -> Any:
"""Call the associated generator with a field passing all required args.

Generator Resolution Precedence Order:
Expand Down Expand Up @@ -517,20 +534,23 @@ def generate_value(self, field, commit=True):

if not commit:
generator = getattr(generator, "prepare", generator)

return generator(**generator_attrs)


def get_required_values(generator, field):
def get_required_values(
generator: Callable, field: Field
) -> Dict[str, Union[bool, int, str, List[Callable]]]:
"""Get required values for a generator from the field.

If required value is a function, calls it with field as argument. If
required value is a string, simply fetch the value from the field
and return.
"""
# FIXME: avoid abbreviations
rt = {}
rt = {} # type: Dict[str, Any]
if hasattr(generator, "required"):
for item in generator.required:
for item in generator.required: # type: ignore[attr-defined]

if callable(item): # baker can deal with the nasty hacking too!
key, value = item(field)
Expand All @@ -549,7 +569,7 @@ def get_required_values(generator, field):
return rt


def filter_rel_attrs(field_name, **rel_attrs):
def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, Any]:
clean_dict = {}

for k, v in rel_attrs.items():
Expand Down