Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type support and mypy checking #100

Merged
merged 22 commits into from Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions dev_requirements.txt
Expand Up @@ -2,6 +2,7 @@
black
flake8
isort
mypy
pillow
pre-commit
psycopg2-binary
Expand Down
81 changes: 47 additions & 34 deletions model_bakery/baker.py
@@ -1,4 +1,5 @@
from os.path import dirname, join
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union, cast

from django.apps import apps
from django.conf import settings
Expand All @@ -10,13 +11,16 @@
FileField,
ForeignKey,
ManyToManyField,
Model,
ModelBase,
OneToOneField,
)
from django.db.models.base import ModelBase
from django.db.models.fields.proxy import OrderWrt
from django.db.models.fields.related import (
ReverseManyToOneDescriptor as ForeignRelatedObjectsDescriptor,
)
from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel
from tests.generic.fields import CustomForeignKey

from . import generators, random_gen
from .exceptions import (
Expand All @@ -27,6 +31,7 @@
ModelNotFound,
RecipeIteratorEmpty,
)
from .random_gen import ActionGenerator
from .utils import seq # NoQA: enable seq to be imported from baker
from .utils import import_from_str

Expand All @@ -39,7 +44,7 @@
MAX_MANY_QUANTITY = 5


def _valid_quantity(quantity):
def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool:
return quantity is not None and (not isinstance(quantity, int) or quantity < 1)


Expand Down Expand Up @@ -95,7 +100,7 @@ def prepare(_model, _quantity=None, _save_related=False, **attrs):
return baker.prepare(_save_related=_save_related, **attrs)


def _recipe(name):
def _recipe(name: str) -> Any:
app, recipe_name = name.rsplit(".", 1)
return import_from_str(".".join((app, "baker_recipes", recipe_name)))

Expand All @@ -113,8 +118,8 @@ def prepare_recipe(baker_recipe_name, _quantity=None, _save_related=False, **new
class ModelFinder(object):
"""Encapsulates all the logic for finding a model to Baker."""

_unique_models = None
_ambiguous_models = None
_unique_models: Optional[Dict[str, Type[Model]]] = None
_ambiguous_models: Optional[List[str]] = None

def get_model(self, name):
"""Get a model.
Expand All @@ -140,26 +145,26 @@ def get_model(self, name):

return model

def get_model_by_name(self, name):
def get_model_by_name(self, name: str) -> Optional[Type[Model]]:
"""Get a model by name.

If a model with that name exists in more than one app, raises
AmbiguousModelName.
"""
name = name.lower()

if self._unique_models is None:
if self._unique_models is None or self._ambiguous_models is None:
self._populate()

if name in self._ambiguous_models:
if name in cast(List, self._ambiguous_models):
raise AmbiguousModelName(
"%s is a model in more than one app. "
'Use the form "app.model".' % name.title()
)

return self._unique_models.get(name)
return cast(Dict, self._unique_models).get(name)

def _populate(self):
def _populate(self) -> None:
"""Cache models for faster self._get_model."""
unique_models = {}
ambiguous_models = []
Expand All @@ -180,14 +185,14 @@ def _populate(self):
self._unique_models = unique_models


def is_iterator(value):
def is_iterator(value: Any) -> bool:
if not hasattr(value, "__iter__"):
return False

return hasattr(value, "__next__")


def _custom_baker_class():
def _custom_baker_class() -> Optional[Type]:
"""Return the specified custom baker class.

Returns:
Expand Down Expand Up @@ -217,27 +222,31 @@ def _custom_baker_class():


class Baker(object):
attr_mapping = {}
type_mapping = None
attr_mapping: Dict[str, Any] = {}
type_mapping: Dict = {}

# Note: we're using one finder for all Baker instances to avoid
# rebuilding the model cache for every make_* or prepare_* call.
finder = ModelFinder()

@classmethod
def create(cls, _model, make_m2m=False, create_files=False):
def create(
cls, _model: str, make_m2m: bool = False, create_files: bool = False
) -> "Baker":
"""Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting."""
baker_class = _custom_baker_class() or cls
return baker_class(_model, make_m2m, create_files)

def __init__(self, _model, make_m2m=False, create_files=False):
def __init__(
self, _model: str, make_m2m: bool = False, create_files: bool = False
) -> None:
self.make_m2m = make_m2m
self.create_files = create_files
self.m2m_dict = {}
self.iterator_attrs = {}
self.model_attrs = {}
self.rel_attrs = {}
self.rel_fields = []
self.m2m_dict: Dict[str, List] = {}
self.iterator_attrs: Dict[str, Iterator] = {}
self.model_attrs: Dict[str, Any] = {}
self.rel_attrs: Dict[str, Any] = {}
self.rel_fields: List[str] = []

if isinstance(_model, ModelBase):
self.model = _model
Expand All @@ -246,7 +255,7 @@ def __init__(self, _model, make_m2m=False, create_files=False):

self.init_type_mapping()

def init_type_mapping(self):
def init_type_mapping(self) -> None:
self.type_mapping = generators.get_type_mapping()
generators_from_settings = getattr(settings, "BAKER_CUSTOM_FIELDS_GEN", {})
for k, v in generators_from_settings.items():
Expand Down Expand Up @@ -276,10 +285,10 @@ def prepare(self, _save_related=False, **attrs):
"""Create, but do not persist, an instance of the associated model."""
return self._make(commit=False, commit_related=_save_related, **attrs)

def get_fields(self):
def get_fields(self) -> Any:
return self.model._meta.fields + self.model._meta.many_to_many

def get_related(self):
def get_related(self,) -> List[Union[ManyToOneRel, OneToOneRel]]:
return [r for r in self.model._meta.related_objects if not r.many_to_many]

def _make(
Expand Down Expand Up @@ -336,7 +345,7 @@ def _make(

return instance

def m2m_value(self, field):
def m2m_value(self, field: ManyToManyField) -> List[Any]:
if field.name in self.rel_fields:
return self.generate_value(field)
if not self.make_m2m or field.null and not field.fill_optional:
Expand All @@ -363,7 +372,7 @@ def instance(self, attrs, _commit, _save_kwargs, _from_manager):
# within its get_queryset() method (e.g. annotations)
# is run.
manager = getattr(self.model, _from_manager)
instance = manager.get(pk=instance.pk)
instance: Model = manager.get(pk=instance.pk)

return instance

Expand All @@ -378,7 +387,7 @@ def create_by_related_name(self, instance, related):

make(**kwargs)

def _clean_attrs(self, attrs):
def _clean_attrs(self, attrs: Dict[str, Any]) -> None:
def is_rel_field(x):
return "__" in x

Expand All @@ -401,7 +410,7 @@ def is_rel_field(x):
x.split("__")[0] for x in self.rel_attrs.keys() if is_rel_field(x)
]

def _skip_field(self, field):
def _skip_field(self, field: Any) -> bool:
from django.contrib.contenttypes.fields import GenericRelation

# check for fill optional argument
Expand Down Expand Up @@ -473,10 +482,12 @@ def _handle_m2m(self, instance):
}
make(through_model, **base_kwargs)

def _remote_field(self, field):
def _remote_field(
self, field: Union[CustomForeignKey, ForeignKey, OneToOneField]
) -> Union[OneToOneRel, ManyToOneRel]:
return field.remote_field

def generate_value(self, field, commit=True):
def generate_value(self, field: Field, commit: bool = True) -> Any:
"""Call the associated generator with a field passing all required args.

Generator Resolution Precedence Order:
Expand Down Expand Up @@ -520,16 +531,18 @@ def generate_value(self, field, commit=True):
return generator(**generator_attrs)


def get_required_values(generator, field):
def get_required_values(
generator: Callable, field: Field
) -> Dict[str, Union[bool, int, str, List[Callable]]]:
"""Get required values for a generator from the field.

If required value is a function, calls it with field as argument. If
required value is a string, simply fetch the value from the field
and return.
"""
# FIXME: avoid abbreviations
rt = {}
if hasattr(generator, "required"):
rt: Dict[str, Any] = {}
if isinstance(generator, ActionGenerator):
for item in generator.required:

if callable(item): # baker can deal with the nasty hacking too!
Expand All @@ -549,7 +562,7 @@ def get_required_values(generator, field):
return rt


def filter_rel_attrs(field_name, **rel_attrs):
def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, str]:
clean_dict = {}

for k, v in rel_attrs.items():
Expand Down
7 changes: 4 additions & 3 deletions model_bakery/generators.py
@@ -1,4 +1,5 @@
from decimal import Decimal
from typing import Any, Callable, Optional, Union

from django.db.backends.base.operations import BaseDatabaseOperations
from django.db.models import (
Expand Down Expand Up @@ -97,7 +98,7 @@
DateTimeRangeField = None


def _make_integer_gen_by_range(field_type):
def _make_integer_gen_by_range(field_type: Any) -> Callable:
min_int, max_int = BaseDatabaseOperations.integer_field_ranges[field_type.__name__]

def gen_integer():
Expand Down Expand Up @@ -192,9 +193,9 @@ def get_type_mapping():
user_mapping = {}


def add(field, func):
def add(field: str, func: Optional[Union[Callable, str]]) -> None:
user_mapping[import_from_str(field)] = import_from_str(func)


def get(field):
def get(field: Any) -> Optional[Callable]:
return user_mapping.get(field)