From d18099925fa8692cb5858b11f91310b63ab1eeb5 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Wed, 2 Sep 2020 23:49:22 +0100 Subject: [PATCH 01/20] Add type support and mypy checking --- dev_requirements.txt | 1 + model_bakery/baker.py | 82 +++++++++++++++----------- model_bakery/generators.py | 7 ++- model_bakery/random_gen.py | 117 +++++++++++++++++++++++++++---------- model_bakery/recipe.py | 19 +++--- model_bakery/timezone.py | 4 +- model_bakery/utils.py | 3 +- tox.ini | 5 ++ 8 files changed, 158 insertions(+), 80 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 4716a6ae..9feac4b7 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -2,6 +2,7 @@ black flake8 isort +mypy pillow pre-commit psycopg2-binary diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 515d3b70..aea9e6e7 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -1,8 +1,10 @@ from os.path import dirname, join +from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union, cast from django.apps import apps from django.conf import settings from django.contrib import contenttypes +from django.db import models from django.db.models import ( AutoField, BooleanField, @@ -10,13 +12,16 @@ FileField, ForeignKey, ManyToManyField, + Model, + ModelBase, OneToOneField, ) -from django.db.models.base import ModelBase from django.db.models.fields.proxy import OrderWrt from django.db.models.fields.related import ( ReverseManyToOneDescriptor as ForeignRelatedObjectsDescriptor, ) +from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel +from tests.generic.fields import CustomForeignKey from . import generators, random_gen from .exceptions import ( @@ -27,6 +32,7 @@ ModelNotFound, RecipeIteratorEmpty, ) +from .random_gen import ActionGenerator from .utils import seq # NoQA: enable seq to be imported from baker from .utils import import_from_str @@ -39,7 +45,7 @@ MAX_MANY_QUANTITY = 5 -def _valid_quantity(quantity): +def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool: return quantity is not None and (not isinstance(quantity, int) or quantity < 1) @@ -95,7 +101,7 @@ def prepare(_model, _quantity=None, _save_related=False, **attrs): return baker.prepare(_save_related=_save_related, **attrs) -def _recipe(name): +def _recipe(name: str) -> Any: app, recipe_name = name.rsplit(".", 1) return import_from_str(".".join((app, "baker_recipes", recipe_name))) @@ -113,8 +119,8 @@ def prepare_recipe(baker_recipe_name, _quantity=None, _save_related=False, **new class ModelFinder(object): """Encapsulates all the logic for finding a model to Baker.""" - _unique_models = None - _ambiguous_models = None + _unique_models: Optional[Dict[str, Type[models.Model]]] = None + _ambiguous_models: Optional[List[str]] = None def get_model(self, name): """Get a model. @@ -140,7 +146,7 @@ def get_model(self, name): return model - def get_model_by_name(self, name): + def get_model_by_name(self, name: str) -> Optional[Type[models.Model]]: """Get a model by name. If a model with that name exists in more than one app, raises @@ -148,18 +154,18 @@ def get_model_by_name(self, name): """ name = name.lower() - if self._unique_models is None: + if self._unique_models is None or self._ambiguous_models is None: self._populate() - if name in self._ambiguous_models: + if name in cast(List, self._ambiguous_models): raise AmbiguousModelName( "%s is a model in more than one app. " 'Use the form "app.model".' % name.title() ) - return self._unique_models.get(name) + return cast(Dict, self._unique_models).get(name) - def _populate(self): + def _populate(self) -> None: """Cache models for faster self._get_model.""" unique_models = {} ambiguous_models = [] @@ -180,14 +186,14 @@ def _populate(self): self._unique_models = unique_models -def is_iterator(value): +def is_iterator(value: Any) -> bool: if not hasattr(value, "__iter__"): return False return hasattr(value, "__next__") -def _custom_baker_class(): +def _custom_baker_class() -> Optional[Type]: """Return the specified custom baker class. Returns: @@ -217,27 +223,31 @@ def _custom_baker_class(): class Baker(object): - attr_mapping = {} - type_mapping = None + attr_mapping: Dict[str, Any] = {} + type_mapping: Dict = {} # Note: we're using one finder for all Baker instances to avoid # rebuilding the model cache for every make_* or prepare_* call. finder = ModelFinder() @classmethod - def create(cls, _model, make_m2m=False, create_files=False): + def create( + cls, _model: str, make_m2m: bool = False, create_files: bool = False + ) -> "Baker": """Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting.""" baker_class = _custom_baker_class() or cls return baker_class(_model, make_m2m, create_files) - def __init__(self, _model, make_m2m=False, create_files=False): + def __init__( + self, _model: str, make_m2m: bool = False, create_files: bool = False + ) -> None: self.make_m2m = make_m2m self.create_files = create_files - self.m2m_dict = {} - self.iterator_attrs = {} - self.model_attrs = {} - self.rel_attrs = {} - self.rel_fields = [] + self.m2m_dict: Dict[str, List] = {} + self.iterator_attrs: Dict[str, Iterator] = {} + self.model_attrs: Dict[str, Any] = {} + self.rel_attrs: Dict[str, Any] = {} + self.rel_fields: List[str] = [] if isinstance(_model, ModelBase): self.model = _model @@ -246,7 +256,7 @@ def __init__(self, _model, make_m2m=False, create_files=False): self.init_type_mapping() - def init_type_mapping(self): + def init_type_mapping(self) -> None: self.type_mapping = generators.get_type_mapping() generators_from_settings = getattr(settings, "BAKER_CUSTOM_FIELDS_GEN", {}) for k, v in generators_from_settings.items(): @@ -276,10 +286,10 @@ def prepare(self, _save_related=False, **attrs): """Create, but do not persist, an instance of the associated model.""" return self._make(commit=False, commit_related=_save_related, **attrs) - def get_fields(self): + def get_fields(self) -> Any: return self.model._meta.fields + self.model._meta.many_to_many - def get_related(self): + def get_related(self,) -> List[Union[ManyToOneRel, OneToOneRel]]: return [r for r in self.model._meta.related_objects if not r.many_to_many] def _make( @@ -336,7 +346,7 @@ def _make( return instance - def m2m_value(self, field): + def m2m_value(self, field: ManyToManyField) -> List[Any]: if field.name in self.rel_fields: return self.generate_value(field) if not self.make_m2m or field.null and not field.fill_optional: @@ -363,7 +373,7 @@ def instance(self, attrs, _commit, _save_kwargs, _from_manager): # within its get_queryset() method (e.g. annotations) # is run. manager = getattr(self.model, _from_manager) - instance = manager.get(pk=instance.pk) + instance: Model = manager.get(pk=instance.pk) return instance @@ -378,7 +388,7 @@ def create_by_related_name(self, instance, related): make(**kwargs) - def _clean_attrs(self, attrs): + def _clean_attrs(self, attrs: Dict[str, Any]) -> None: def is_rel_field(x): return "__" in x @@ -401,7 +411,7 @@ def is_rel_field(x): x.split("__")[0] for x in self.rel_attrs.keys() if is_rel_field(x) ] - def _skip_field(self, field): + def _skip_field(self, field: Any) -> bool: from django.contrib.contenttypes.fields import GenericRelation # check for fill optional argument @@ -473,10 +483,12 @@ def _handle_m2m(self, instance): } make(through_model, **base_kwargs) - def _remote_field(self, field): + def _remote_field( + self, field: Union[CustomForeignKey, ForeignKey, OneToOneField] + ) -> Union[OneToOneRel, ManyToOneRel]: return field.remote_field - def generate_value(self, field, commit=True): + def generate_value(self, field: Field, commit: bool = True) -> Any: """Call the associated generator with a field passing all required args. Generator Resolution Precedence Order: @@ -520,7 +532,9 @@ def generate_value(self, field, commit=True): return generator(**generator_attrs) -def get_required_values(generator, field): +def get_required_values( + generator: Callable, field: Field +) -> Dict[str, Union[bool, int, str, List[Callable]]]: """Get required values for a generator from the field. If required value is a function, calls it with field as argument. If @@ -528,8 +542,8 @@ def get_required_values(generator, field): and return. """ # FIXME: avoid abbreviations - rt = {} - if hasattr(generator, "required"): + rt: Dict[str, Any] = {} + if isinstance(generator, ActionGenerator): for item in generator.required: if callable(item): # baker can deal with the nasty hacking too! @@ -549,7 +563,7 @@ def get_required_values(generator, field): return rt -def filter_rel_attrs(field_name, **rel_attrs): +def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, str]: clean_dict = {} for k, v in rel_attrs.items(): diff --git a/model_bakery/generators.py b/model_bakery/generators.py index 2dc9b204..96d725be 100644 --- a/model_bakery/generators.py +++ b/model_bakery/generators.py @@ -1,4 +1,5 @@ from decimal import Decimal +from typing import Any, Callable, Optional, Union from django.db.backends.base.operations import BaseDatabaseOperations from django.db.models import ( @@ -97,7 +98,7 @@ DateTimeRangeField = None -def _make_integer_gen_by_range(field_type): +def _make_integer_gen_by_range(field_type: Any) -> Callable: min_int, max_int = BaseDatabaseOperations.integer_field_ranges[field_type.__name__] def gen_integer(): @@ -192,9 +193,9 @@ def get_type_mapping(): user_mapping = {} -def add(field, func): +def add(field: str, func: Optional[Union[Callable, str]]) -> None: user_mapping[import_from_str(field)] = import_from_str(func) -def get(field): +def get(field: Any) -> Optional[Callable]: return user_mapping.get(field) diff --git a/model_bakery/random_gen.py b/model_bakery/random_gen.py index 87569999..b1575739 100644 --- a/model_bakery/random_gen.py +++ b/model_bakery/random_gen.py @@ -11,11 +11,29 @@ import string import warnings +from datetime import date, datetime, time, timedelta from decimal import Decimal from os.path import abspath, dirname, join from random import choice, randint, random, uniform - -from model_bakery.timezone import now +from typing import ( + Any, + Callable, + List, + Literal, + Optional, + Protocol, + Tuple, + TypeVar, + Union, + cast, + runtime_checkable, +) +from uuid import UUID + +from django.core.files.base import ContentFile +from django.db.models import Field, Model + +from .timezone import now MAX_LENGTH = 300 # Using sys.maxint here breaks a bunch of tests when running against a @@ -23,27 +41,50 @@ MAX_INT = 100000000000 -def get_content_file(content, name): - from django.core.files.base import ContentFile +# Hack to make mypy happy with attributes on functions +# See https://github.com/python/mypy/issues/2087#issuecomment-587741762 +F = TypeVar("F", bound=Callable[..., object]) + + +@runtime_checkable +class ActionGenerator(Protocol[F]): + required: Optional[ + List[Union[str, Callable[[Field], Tuple[Literal["model"], Optional[Model]]]]] + ] + prepare: Optional[Callable[..., Union[Model, List[Model]]]] + __call__: F + + +def action_generator(action: F) -> ActionGenerator[F]: + action_generator = cast(ActionGenerator[F], action) + # Make sure the cast isn't a lie. + action_generator.required = None + action_generator.prepare = None + return action_generator + +# End mypy hack + + +def get_content_file(content: bytes, name: str) -> ContentFile: return ContentFile(content, name=name) -def gen_file_field(): +def gen_file_field() -> ContentFile: name = "mock_file.txt" file_path = abspath(join(dirname(__file__), name)) with open(file_path, "rb") as f: return get_content_file(f.read(), name=name) -def gen_image_field(): +def gen_image_field() -> ContentFile: name = "mock_img.jpeg" file_path = abspath(join(dirname(__file__), name)) with open(file_path, "rb") as f: return get_content_file(f.read(), name=name) -def gen_from_list(a_list): +def gen_from_list(a_list: Union[List[str], range]) -> Callable: """Make sure all values of the field are generated from a list. Examples: @@ -60,7 +101,15 @@ def gen_from_list(a_list): # -- DEFAULT GENERATORS -- -def gen_from_choices(choices): +def gen_from_choices( + choices: Union[ + List[Tuple[str, str]], + Tuple[ + Tuple[str, Tuple[Tuple[str, str], Tuple[str, str]]], + Tuple[str, Tuple[Tuple[str, str], Tuple[str, str]]], + ], + ] +) -> Callable: choice_list = [] for value, label in choices: if isinstance(label, (list, tuple)): @@ -71,15 +120,16 @@ def gen_from_choices(choices): return gen_from_list(choice_list) -def gen_integer(min_int=-MAX_INT, max_int=MAX_INT): +def gen_integer(min_int: int = -MAX_INT, max_int: int = MAX_INT) -> int: return randint(min_int, max_int) -def gen_float(): +def gen_float() -> float: return random() * gen_integer() -def gen_decimal(max_digits, decimal_places): +@action_generator +def gen_decimal(max_digits: int, decimal_places: int) -> Decimal: def num_as_str(x): return "".join([str(randint(0, 9)) for _ in range(x)]) @@ -94,26 +144,28 @@ def num_as_str(x): gen_decimal.required = ["max_digits", "decimal_places"] -def gen_date(): +def gen_date() -> date: return now().date() -def gen_datetime(): +def gen_datetime() -> datetime: return now() -def gen_time(): +def gen_time() -> time: return now().time() -def gen_string(max_length): +@action_generator +def gen_string(max_length: int) -> str: return str("".join(choice(string.ascii_letters) for _ in range(max_length))) gen_string.required = ["max_length"] -def gen_slug(max_length): +@action_generator +def gen_slug(max_length: int) -> str: valid_chars = string.ascii_letters + string.digits + "_-" return str("".join(choice(valid_chars) for _ in range(max_length))) @@ -121,11 +173,11 @@ def gen_slug(max_length): gen_slug.required = ["max_length"] -def gen_text(): +def gen_text() -> str: return gen_string(MAX_LENGTH) -def gen_boolean(): +def gen_boolean() -> bool: return choice((True, False)) @@ -133,28 +185,29 @@ def gen_null_boolean(): return choice((True, False, None)) -def gen_url(): +def gen_url() -> str: return str("http://www.%s.com/" % gen_string(30)) -def gen_email(): +def gen_email() -> str: return "%s@example.com" % gen_string(10) -def gen_ipv6(): +def gen_ipv6() -> str: return ":".join(format(randint(1, 65535), "x") for _ in range(8)) -def gen_ipv4(): +def gen_ipv4() -> str: return ".".join(str(randint(1, 255)) for _ in range(4)) -def gen_ipv46(): +def gen_ipv46() -> str: ip_gen = choice([gen_ipv4, gen_ipv6]) return ip_gen() -def gen_ip(protocol, default_validators): +@action_generator +def gen_ip(protocol: str, default_validators: List[Callable]) -> str: from django.core.exceptions import ValidationError protocol = (protocol or "").lower() @@ -186,14 +239,12 @@ def gen_ip(protocol, default_validators): gen_ip.required = ["protocol", "default_validators"] -def gen_byte_string(max_length=16): +def gen_byte_string(max_length: int = 16) -> bytes: generator = (randint(0, 255) for x in range(max_length)) return bytes(generator) -def gen_interval(interval_key="milliseconds", offset=0): - from datetime import timedelta - +def gen_interval(interval_key: str = "milliseconds", offset: int = 0) -> timedelta: interval = gen_integer() + offset kwargs = {interval_key: interval} return timedelta(**kwargs) @@ -210,7 +261,7 @@ def gen_content_type(): return ContentType() -def gen_uuid(): +def gen_uuid() -> UUID: import uuid return uuid.uuid4() @@ -228,19 +279,20 @@ def gen_hstore(): return {} -def _fk_model(field): +def _fk_model(field: Field) -> Tuple[Literal["model"], Optional[Model]]: try: return ("model", field.related_model) except AttributeError: return ("model", field.related.parent_model) -def _prepare_related(model, **attrs): +def _prepare_related(model: str, **attrs: Any) -> Union[Model, List[Model]]: from .baker import prepare return prepare(model, **attrs) +@action_generator def gen_related(model, **attrs): from .baker import make @@ -251,6 +303,7 @@ def gen_related(model, **attrs): gen_related.prepare = _prepare_related +@action_generator def gen_m2m(model, **attrs): from .baker import MAX_MANY_QUANTITY, make @@ -312,7 +365,7 @@ def gen_geometry_collection(): return "GEOMETRYCOLLECTION ({})".format(gen_point(),) -def gen_pg_numbers_range(number_cast=int): +def gen_pg_numbers_range(number_cast: Callable[[int], Any]) -> Callable: def gen_range(): from psycopg2._range import NumericRange diff --git a/model_bakery/recipe.py b/model_bakery/recipe.py index c36ec97a..ea34fa37 100644 --- a/model_bakery/recipe.py +++ b/model_bakery/recipe.py @@ -1,5 +1,6 @@ import inspect import itertools +from typing import Any, Dict, Optional, Union, cast from . import baker from .exceptions import RecipeNotFound @@ -9,13 +10,15 @@ class Recipe(object): - def __init__(self, _model, **attrs): + def __init__(self, _model: str, **attrs) -> None: self.attr_mapping = attrs self._model = _model # _iterator_backups will hold values of the form (backup_iterator, usable_iterator). - self._iterator_backups = {} + self._iterator_backups: Dict[str, Any] = {} - def _mapping(self, new_attrs): + def _mapping( + self, new_attrs: Dict[str, Optional[Union[bool, str, int]]] + ) -> Dict[str, Any]: _save_related = new_attrs.get("_save_related", True) rel_fields_attrs = dict((k, v) for k, v in new_attrs.items() if "__" in k) new_attrs = dict((k, v) for k, v in new_attrs.items() if "__" not in k) @@ -58,14 +61,14 @@ def prepare(self, **attrs): defaults.update(attrs) return baker.prepare(self._model, **self._mapping(defaults)) - def extend(self, **attrs): + def extend(self, **attrs) -> "Recipe": attr_mapping = self.attr_mapping.copy() attr_mapping.update(attrs) return type(self)(self._model, **attr_mapping) class RecipeForeignKey(object): - def __init__(self, recipe): + def __init__(self, recipe: Union[str, int, Recipe]) -> None: if isinstance(recipe, Recipe): self.recipe = recipe elif isinstance(recipe, str): @@ -73,14 +76,14 @@ def __init__(self, recipe): caller_module = inspect.getmodule(frame[0]) recipe = getattr(caller_module, recipe) if recipe: - self.recipe = recipe + self.recipe = cast(Recipe, recipe) else: raise RecipeNotFound else: raise TypeError("Not a recipe") -def foreign_key(recipe): +def foreign_key(recipe: Union[Recipe, int, str]) -> RecipeForeignKey: """Return a `RecipeForeignKey`. Return the callable, so that the associated `_model` will not be created @@ -90,7 +93,7 @@ def foreign_key(recipe): class related(object): # FIXME - def __init__(self, *args): + def __init__(self, *args) -> None: self.related = [] for recipe in args: if isinstance(recipe, Recipe): diff --git a/model_bakery/timezone.py b/model_bakery/timezone.py index 94667e1a..9641050b 100644 --- a/model_bakery/timezone.py +++ b/model_bakery/timezone.py @@ -15,12 +15,12 @@ def now(): return datetime.now() -def smart_datetime(*args): +def smart_datetime(*args) -> datetime: value = datetime(*args) return tz_aware(value) -def tz_aware(d): +def tz_aware(d: datetime) -> datetime: value = d if settings.USE_TZ: value = d.replace(tzinfo=utc) diff --git a/model_bakery/utils.py b/model_bakery/utils.py index ee89393b..2288f8e7 100644 --- a/model_bakery/utils.py +++ b/model_bakery/utils.py @@ -2,11 +2,12 @@ import importlib import itertools import warnings +from typing import Any, Callable, Optional, Union from .timezone import tz_aware -def import_from_str(import_string): +def import_from_str(import_string: Optional[Union[Callable, str]]) -> Any: """Import an object defined as import if it is an string. If `import_string` follows the format `path.to.module.object_name`, diff --git a/tox.ini b/tox.ini index 3b367c6f..fbd137c5 100644 --- a/tox.ini +++ b/tox.ini @@ -46,3 +46,8 @@ commands=isort model_bakery --check-only {posargs} deps=black basepython=python3 commands=black . --check + +[testenv:type] +deps=mypy +basepython=python3 +commands=python -m mypy --ignore-missing-imports . From 704af6b73e940d03d1cd8702fdafda12d8f0000b Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Wed, 2 Sep 2020 23:52:21 +0100 Subject: [PATCH 02/20] Don't need models direct import --- model_bakery/baker.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index aea9e6e7..faea2d2f 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -4,7 +4,6 @@ from django.apps import apps from django.conf import settings from django.contrib import contenttypes -from django.db import models from django.db.models import ( AutoField, BooleanField, @@ -119,7 +118,7 @@ def prepare_recipe(baker_recipe_name, _quantity=None, _save_related=False, **new class ModelFinder(object): """Encapsulates all the logic for finding a model to Baker.""" - _unique_models: Optional[Dict[str, Type[models.Model]]] = None + _unique_models: Optional[Dict[str, Type[Model]]] = None _ambiguous_models: Optional[List[str]] = None def get_model(self, name): @@ -146,7 +145,7 @@ def get_model(self, name): return model - def get_model_by_name(self, name: str) -> Optional[Type[models.Model]]: + def get_model_by_name(self, name: str) -> Optional[Type[Model]]: """Get a model by name. If a model with that name exists in more than one app, raises From 91b7946becf1285dfdda108ce39b8e7974f60932 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Wed, 2 Sep 2020 23:57:19 +0100 Subject: [PATCH 03/20] Simplify some auto-generated types --- model_bakery/random_gen.py | 10 +--------- model_bakery/recipe.py | 4 ++-- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/model_bakery/random_gen.py b/model_bakery/random_gen.py index b1575739..f4023340 100644 --- a/model_bakery/random_gen.py +++ b/model_bakery/random_gen.py @@ -101,15 +101,7 @@ def gen_from_list(a_list: Union[List[str], range]) -> Callable: # -- DEFAULT GENERATORS -- -def gen_from_choices( - choices: Union[ - List[Tuple[str, str]], - Tuple[ - Tuple[str, Tuple[Tuple[str, str], Tuple[str, str]]], - Tuple[str, Tuple[Tuple[str, str], Tuple[str, str]]], - ], - ] -) -> Callable: +def gen_from_choices(choices: List) -> Callable: choice_list = [] for value, label in choices: if isinstance(label, (list, tuple)): diff --git a/model_bakery/recipe.py b/model_bakery/recipe.py index ea34fa37..ab363241 100644 --- a/model_bakery/recipe.py +++ b/model_bakery/recipe.py @@ -68,7 +68,7 @@ def extend(self, **attrs) -> "Recipe": class RecipeForeignKey(object): - def __init__(self, recipe: Union[str, int, Recipe]) -> None: + def __init__(self, recipe: Union[str, Recipe]) -> None: if isinstance(recipe, Recipe): self.recipe = recipe elif isinstance(recipe, str): @@ -83,7 +83,7 @@ def __init__(self, recipe: Union[str, int, Recipe]) -> None: raise TypeError("Not a recipe") -def foreign_key(recipe: Union[Recipe, int, str]) -> RecipeForeignKey: +def foreign_key(recipe: Union[Recipe, str]) -> RecipeForeignKey: """Return a `RecipeForeignKey`. Return the callable, so that the associated `_model` will not be created From e34c6daffbe98f5c57a7989c8b2e9805269a6d0b Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 12:42:37 +0100 Subject: [PATCH 04/20] Pull mypy settings into config --- setup.cfg | 3 +++ tox.ini | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 82f4cda8..5cbde907 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,3 +20,6 @@ line_length=88 [pydocstyle] add_ignore = D1 match-dir = (?!test|docs|\.).* + +[mypy] +ignore_missing_imports=True diff --git a/tox.ini b/tox.ini index 180cab16..1fa5905d 100644 --- a/tox.ini +++ b/tox.ini @@ -52,4 +52,4 @@ commands=black . --check [testenv:type] deps=mypy basepython=python3 -commands=python -m mypy --ignore-missing-imports . +commands=python -m mypy . From 4706fadc6867c3f14b5c4c22bc871714f83fe1bd Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 13:03:49 +0100 Subject: [PATCH 05/20] Disallow untyped calls --- model_bakery/baker.py | 56 +++++++++++++++++++--------------- model_bakery/generators.py | 6 ++-- model_bakery/random_gen.py | 2 +- model_bakery/recipe.py | 16 +++++----- setup.cfg | 1 + tests/generic/forms.py | 1 + tests/test_extending_bakery.py | 1 + tests/test_utils.py | 1 + 8 files changed, 47 insertions(+), 37 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index faea2d2f..18e5968b 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -20,6 +20,7 @@ ReverseManyToOneDescriptor as ForeignRelatedObjectsDescriptor, ) from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel + from tests.generic.fields import CustomForeignKey from . import generators, random_gen @@ -49,13 +50,13 @@ def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool: def make( - _model, - _quantity=None, - make_m2m=False, - _save_kwargs=None, - _refresh_after_create=False, + _model: str, + _quantity: Optional[int] = None, + make_m2m: bool = False, + _save_kwargs: Optional[Dict] = None, + _refresh_after_create: bool = False, _create_files=False, - **attrs + **attrs: Any ): """Create a persisted instance from a given model its associated models. @@ -81,7 +82,7 @@ def make( ) -def prepare(_model, _quantity=None, _save_related=False, **attrs): +def prepare(_model: str, _quantity=None, _save_related=False, **attrs) -> Model: """Create but do not persist an instance from a given model. It fill the fields with random values or you can specify which @@ -121,7 +122,7 @@ class ModelFinder(object): _unique_models: Optional[Dict[str, Type[Model]]] = None _ambiguous_models: Optional[List[str]] = None - def get_model(self, name): + def get_model(self, name: str) -> Type[Model]: """Get a model. Args: @@ -265,10 +266,10 @@ def init_type_mapping(self) -> None: def make( self, - _save_kwargs=None, - _refresh_after_create=False, + _save_kwargs: Optional[Dict[str, Any]] = None, + _refresh_after_create: bool = False, _from_manager=None, - **attrs + **attrs: Any ): """Create and persist an instance of the model associated with Baker instance.""" params = { @@ -281,14 +282,16 @@ def make( params.update(attrs) return self._make(**params) - def prepare(self, _save_related=False, **attrs): + def prepare(self, _save_related=False, **attrs: Any) -> Model: """Create, but do not persist, an instance of the associated model.""" return self._make(commit=False, commit_related=_save_related, **attrs) def get_fields(self) -> Any: return self.model._meta.fields + self.model._meta.many_to_many - def get_related(self,) -> List[Union[ManyToOneRel, OneToOneRel]]: + def get_related( + self, + ) -> List[Union[ManyToOneRel, OneToOneRel]]: return [r for r in self.model._meta.related_objects if not r.many_to_many] def _make( @@ -298,8 +301,8 @@ def _make( _save_kwargs=None, _refresh_after_create=False, _from_manager=None, - **attrs - ): + **attrs: Any + ) -> Model: _save_kwargs = _save_kwargs or {} self._clean_attrs(attrs) @@ -352,14 +355,16 @@ def m2m_value(self, field: ManyToManyField) -> List[Any]: return [] return self.generate_value(field) - def instance(self, attrs, _commit, _save_kwargs, _from_manager): + def instance( + self, attrs: Dict[str, Any], _commit, _save_kwargs, _from_manager + ) -> Model: one_to_many_keys = {} for k in tuple(attrs.keys()): field = getattr(self.model, k, None) if isinstance(field, ForeignRelatedObjectsDescriptor): one_to_many_keys[k] = attrs.pop(k) - instance = self.model(**attrs) + instance: Model = self.model(**attrs) # m2m only works for persisted instances if _commit: instance.save(**_save_kwargs) @@ -372,23 +377,24 @@ def instance(self, attrs, _commit, _save_kwargs, _from_manager): # within its get_queryset() method (e.g. annotations) # is run. manager = getattr(self.model, _from_manager) - instance: Model = manager.get(pk=instance.pk) + instance = manager.get(pk=instance.pk) return instance - def create_by_related_name(self, instance, related): + def create_by_related_name( + self, instance: Model, related: Union[ManyToOneRel, OneToOneRel] + ) -> None: rel_name = related.get_accessor_name() if rel_name not in self.rel_fields: return kwargs = filter_rel_attrs(rel_name, **self.rel_attrs) kwargs[related.field.name] = instance - kwargs["_model"] = related.field.model - make(**kwargs) + make(related.field.model, **kwargs) def _clean_attrs(self, attrs: Dict[str, Any]) -> None: - def is_rel_field(x): + def is_rel_field(x: str): return "__" in x self.fill_in_optional = attrs.pop("_fill_optional", False) @@ -453,7 +459,7 @@ def _skip_field(self, field: Any) -> bool: return False - def _handle_one_to_many(self, instance, attrs): + def _handle_one_to_many(self, instance: Model, attrs: Dict[str, Any]): for k, v in attrs.items(): manager = getattr(instance, k) @@ -463,7 +469,7 @@ def _handle_one_to_many(self, instance, attrs): # for many-to-many relationships the bulk keyword argument doesn't exist manager.set(v, clear=True) - def _handle_m2m(self, instance): + def _handle_m2m(self, instance: Model): for key, values in self.m2m_dict.items(): for value in values: if not value.pk: @@ -562,7 +568,7 @@ def get_required_values( return rt -def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, str]: +def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, Any]: clean_dict = {} for k, v in rel_attrs.items(): diff --git a/model_bakery/generators.py b/model_bakery/generators.py index 96d725be..99932ea6 100644 --- a/model_bakery/generators.py +++ b/model_bakery/generators.py @@ -1,5 +1,5 @@ from decimal import Decimal -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, Optional, Type, Union from django.db.backends.base.operations import BaseDatabaseOperations from django.db.models import ( @@ -107,7 +107,7 @@ def gen_integer(): return gen_integer -default_mapping = { +default_mapping: Dict[Type, Callable] = { ForeignKey: random_gen.gen_related, OneToOneField: random_gen.gen_related, ManyToManyField: random_gen.gen_m2m, @@ -178,7 +178,7 @@ def gen_integer(): # Add GIS fields -def get_type_mapping(): +def get_type_mapping() -> Dict[Type, Callable]: from django.contrib.contenttypes.models import ContentType from .gis import default_gis_mapping diff --git a/model_bakery/random_gen.py b/model_bakery/random_gen.py index 800a0c95..462800b7 100644 --- a/model_bakery/random_gen.py +++ b/model_bakery/random_gen.py @@ -122,7 +122,7 @@ def gen_float() -> float: @action_generator def gen_decimal(max_digits: int, decimal_places: int) -> Decimal: - def num_as_str(x): + def num_as_str(x: int): return "".join([str(randint(0, 9)) for _ in range(x)]) if decimal_places: diff --git a/model_bakery/recipe.py b/model_bakery/recipe.py index ab363241..85606f25 100644 --- a/model_bakery/recipe.py +++ b/model_bakery/recipe.py @@ -1,6 +1,8 @@ import inspect import itertools -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Dict, List, Union, cast + +from django.db.models import Model from . import baker from .exceptions import RecipeNotFound @@ -16,9 +18,7 @@ def __init__(self, _model: str, **attrs) -> None: # _iterator_backups will hold values of the form (backup_iterator, usable_iterator). self._iterator_backups: Dict[str, Any] = {} - def _mapping( - self, new_attrs: Dict[str, Optional[Union[bool, str, int]]] - ) -> Dict[str, Any]: + def _mapping(self, new_attrs: Dict[str, Any]) -> Dict[str, Any]: _save_related = new_attrs.get("_save_related", True) rel_fields_attrs = dict((k, v) for k, v in new_attrs.items() if "__" in k) new_attrs = dict((k, v) for k, v in new_attrs.items() if "__" not in k) @@ -53,10 +53,10 @@ def _mapping( mapping.update(rel_fields_attrs) return mapping - def make(self, **attrs): + def make(self, **attrs: Any) -> Union[Model, List[Model]]: return baker.make(self._model, **self._mapping(attrs)) - def prepare(self, **attrs): + def prepare(self, **attrs: Any) -> Union[Model, List[Model]]: defaults = {"_save_related": False} defaults.update(attrs) return baker.prepare(self._model, **self._mapping(defaults)) @@ -94,7 +94,7 @@ def foreign_key(recipe: Union[Recipe, str]) -> RecipeForeignKey: class related(object): # FIXME def __init__(self, *args) -> None: - self.related = [] + self.related: List[Recipe] = [] for recipe in args: if isinstance(recipe, Recipe): self.related.append(recipe) @@ -109,6 +109,6 @@ def __init__(self, *args) -> None: else: raise TypeError("Not a recipe") - def make(self): + def make(self) -> List[Union[Model, List[Model]]]: """Persist objects to m2m relation.""" return [m.make() for m in self.related] diff --git a/setup.cfg b/setup.cfg index 5cbde907..acf9fc4a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,3 +23,4 @@ match-dir = (?!test|docs|\.).* [mypy] ignore_missing_imports=True +disallow_untyped_calls=True diff --git a/tests/generic/forms.py b/tests/generic/forms.py index e70f98da..d82cf324 100644 --- a/tests/generic/forms.py +++ b/tests/generic/forms.py @@ -1,4 +1,5 @@ from django.forms import ModelForm + from tests.generic.models import DummyGenericIPAddressFieldModel diff --git a/tests/test_extending_bakery.py b/tests/test_extending_bakery.py index ed5cd4a4..a266f201 100644 --- a/tests/test_extending_bakery.py +++ b/tests/test_extending_bakery.py @@ -1,4 +1,5 @@ import pytest + from model_bakery import baker from model_bakery.exceptions import CustomBakerNotFound, InvalidCustomBaker from model_bakery.random_gen import gen_from_list diff --git a/tests/test_utils.py b/tests/test_utils.py index b34295b8..3efeb763 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ import pytest + from model_bakery.utils import import_from_str from tests.generic.models import User From ea4f8ce8b0a942035a7afbe9c209fe0a003fd1c2 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 13:48:33 +0100 Subject: [PATCH 06/20] Make type annotations 3.5 compatible --- model_bakery/baker.py | 33 ++++++++++++++++++--------------- model_bakery/generators.py | 4 ++-- model_bakery/random_gen.py | 13 ++++++------- model_bakery/recipe.py | 4 ++-- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 18e5968b..fca55ec6 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -119,8 +119,9 @@ def prepare_recipe(baker_recipe_name, _quantity=None, _save_related=False, **new class ModelFinder(object): """Encapsulates all the logic for finding a model to Baker.""" - _unique_models: Optional[Dict[str, Type[Model]]] = None - _ambiguous_models: Optional[List[str]] = None + def __init__(self) -> None: + self._unique_models = None # type: Optional[Dict[str, Type[Model]]] + self._ambiguous_models = None # type: Optional[List[str]] def get_model(self, name: str) -> Type[Model]: """Get a model. @@ -223,8 +224,8 @@ def _custom_baker_class() -> Optional[Type]: class Baker(object): - attr_mapping: Dict[str, Any] = {} - type_mapping: Dict = {} + attr_mapping = {} # type: Dict[str, Any] + type_mapping = {} # type: Dict # Note: we're using one finder for all Baker instances to avoid # rebuilding the model cache for every make_* or prepare_* call. @@ -232,22 +233,22 @@ class Baker(object): @classmethod def create( - cls, _model: str, make_m2m: bool = False, create_files: bool = False + cls, _model: str, make_m2m: bool = False, create_files=False # type: bool ) -> "Baker": """Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting.""" baker_class = _custom_baker_class() or cls return baker_class(_model, make_m2m, create_files) def __init__( - self, _model: str, make_m2m: bool = False, create_files: bool = False + self, _model: str, make_m2m: bool = False, create_files=False # type: bool ) -> None: self.make_m2m = make_m2m self.create_files = create_files - self.m2m_dict: Dict[str, List] = {} - self.iterator_attrs: Dict[str, Iterator] = {} - self.model_attrs: Dict[str, Any] = {} - self.rel_attrs: Dict[str, Any] = {} - self.rel_fields: List[str] = [] + self.m2m_dict = {} # type: Dict[str, List] + self.iterator_attrs = {} # type: Dict[str, Iterator] + self.model_attrs = {} # type: Dict[str, Any] + self.rel_attrs = {} # type: Dict[str, Any] + self.rel_fields = [] # type: List[str] if isinstance(_model, ModelBase): self.model = _model @@ -266,7 +267,7 @@ def init_type_mapping(self) -> None: def make( self, - _save_kwargs: Optional[Dict[str, Any]] = None, + _save_kwargs=None, # type: Optional[Dict[str, Any]] _refresh_after_create: bool = False, _from_manager=None, **attrs: Any @@ -364,7 +365,7 @@ def instance( if isinstance(field, ForeignRelatedObjectsDescriptor): one_to_many_keys[k] = attrs.pop(k) - instance: Model = self.model(**attrs) + instance = self.model(**attrs) # type: Model # m2m only works for persisted instances if _commit: instance.save(**_save_kwargs) @@ -493,7 +494,9 @@ def _remote_field( ) -> Union[OneToOneRel, ManyToOneRel]: return field.remote_field - def generate_value(self, field: Field, commit: bool = True) -> Any: + def generate_value( + self, field: Field, commit=True # type: bool + ) -> Any: """Call the associated generator with a field passing all required args. Generator Resolution Precedence Order: @@ -547,7 +550,7 @@ def get_required_values( and return. """ # FIXME: avoid abbreviations - rt: Dict[str, Any] = {} + rt = {} # type: Dict[str, Any] if isinstance(generator, ActionGenerator): for item in generator.required: diff --git a/model_bakery/generators.py b/model_bakery/generators.py index 99932ea6..8ecfca66 100644 --- a/model_bakery/generators.py +++ b/model_bakery/generators.py @@ -107,7 +107,7 @@ def gen_integer(): return gen_integer -default_mapping: Dict[Type, Callable] = { +default_mapping = { ForeignKey: random_gen.gen_related, OneToOneField: random_gen.gen_related, ManyToManyField: random_gen.gen_m2m, @@ -135,7 +135,7 @@ def gen_integer(): FileField: random_gen.gen_file_field, ImageField: random_gen.gen_image_field, DurationField: random_gen.gen_interval, -} +} # type: Dict[Type, Callable] if ArrayField: default_mapping[ArrayField] = random_gen.gen_array diff --git a/model_bakery/random_gen.py b/model_bakery/random_gen.py index 462800b7..4b39b9e8 100644 --- a/model_bakery/random_gen.py +++ b/model_bakery/random_gen.py @@ -19,7 +19,6 @@ Any, Callable, List, - Literal, Optional, Protocol, Tuple, @@ -48,11 +47,11 @@ @runtime_checkable class ActionGenerator(Protocol[F]): - required: Optional[ - List[Union[str, Callable[[Field], Tuple[Literal["model"], Optional[Model]]]]] - ] - prepare: Optional[Callable[..., Union[Model, List[Model]]]] - __call__: F + required = ( + None + ) # type: Optional[List[Union[str, Callable[[Field], Tuple[str, Optional[Model]]]]]] + prepare = None # type: Optional[Callable[..., Union[Model, List[Model]]]] + __call__ = None # type: F def action_generator(action: F) -> ActionGenerator[F]: @@ -271,7 +270,7 @@ def gen_hstore(): return {} -def _fk_model(field: Field) -> Tuple[Literal["model"], Optional[Model]]: +def _fk_model(field: Field) -> Tuple[str, Optional[Model]]: try: return ("model", field.related_model) except AttributeError: diff --git a/model_bakery/recipe.py b/model_bakery/recipe.py index 85606f25..d0c429df 100644 --- a/model_bakery/recipe.py +++ b/model_bakery/recipe.py @@ -16,7 +16,7 @@ def __init__(self, _model: str, **attrs) -> None: self.attr_mapping = attrs self._model = _model # _iterator_backups will hold values of the form (backup_iterator, usable_iterator). - self._iterator_backups: Dict[str, Any] = {} + self._iterator_backups = {} # type: Dict[str, Any] def _mapping(self, new_attrs: Dict[str, Any]) -> Dict[str, Any]: _save_related = new_attrs.get("_save_related", True) @@ -94,7 +94,7 @@ def foreign_key(recipe: Union[Recipe, str]) -> RecipeForeignKey: class related(object): # FIXME def __init__(self, *args) -> None: - self.related: List[Recipe] = [] + self.related = [] # type: List[Recipe] for recipe in args: if isinstance(recipe, Recipe): self.related.append(recipe) From 370a0e45fc5dfaaabc3b2519351cb6bf30a1e6ff Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 13:51:05 +0100 Subject: [PATCH 07/20] Remove ModelBase usage --- model_bakery/baker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index fca55ec6..161f294a 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -12,7 +12,6 @@ ForeignKey, ManyToManyField, Model, - ModelBase, OneToOneField, ) from django.db.models.fields.proxy import OrderWrt @@ -250,7 +249,7 @@ def __init__( self.rel_attrs = {} # type: Dict[str, Any] self.rel_fields = [] # type: List[str] - if isinstance(_model, ModelBase): + if isinstance(_model, Model): self.model = _model else: self.model = self.finder.get_model(_model) From f4d6c26887b65dc5b503760e1c1770cf319e4155 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 14:06:28 +0100 Subject: [PATCH 08/20] Explicitly run mypy in tox --- tox.ini | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index 1fa5905d..24a58882 100644 --- a/tox.ini +++ b/tox.ini @@ -6,13 +6,14 @@ envlist = isort pydocstyle black + mypy [gh-actions] python = 3.5: py35 3.6: py36 3.7: py37 - 3.8: py38,flake8,isort,pydocstyle,black + 3.8: py38,flake8,isort,pydocstyle,black,mypy [testenv] setenv = @@ -49,7 +50,7 @@ deps=black basepython=python3 commands=black . --check -[testenv:type] +[testenv:mypy] deps=mypy basepython=python3 commands=python -m mypy . From d54890a1cf00da0712f49de74f3e63a135e40016 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 14:07:06 +0100 Subject: [PATCH 09/20] Python 3.5 needs typing_extensions --- model_bakery/random_gen.py | 14 ++------------ requirements.txt | 1 + tox.ini | 1 + 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/model_bakery/random_gen.py b/model_bakery/random_gen.py index 4b39b9e8..e60fbe46 100644 --- a/model_bakery/random_gen.py +++ b/model_bakery/random_gen.py @@ -15,22 +15,12 @@ from decimal import Decimal from os.path import abspath, dirname, join from random import choice, randint, random, uniform -from typing import ( - Any, - Callable, - List, - Optional, - Protocol, - Tuple, - TypeVar, - Union, - cast, - runtime_checkable, -) +from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union, cast from uuid import UUID from django.core.files.base import ContentFile from django.db.models import Field, Model +from typing_extensions import Protocol, runtime_checkable from .timezone import now diff --git a/requirements.txt b/requirements.txt index f9105b95..d57228a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ django>=1.11.0<3.2 +typing-extensions==3.7.4.3 diff --git a/tox.ini b/tox.ini index 24a58882..bca3ec96 100644 --- a/tox.ini +++ b/tox.ini @@ -25,6 +25,7 @@ deps = pillow pytest pytest-django + typing-extensions django111: Django>=1.11,<1.12 django20: Django==2.0 django21: Django==2.1 From 65194445ed1ef3028b03c45e52c70547db9444c1 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 14:17:26 +0100 Subject: [PATCH 10/20] Correct ModelBase usage --- model_bakery/baker.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 161f294a..f60da8b9 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -14,6 +14,7 @@ Model, OneToOneField, ) +from django.db.models.base import ModelBase from django.db.models.fields.proxy import OrderWrt from django.db.models.fields.related import ( ReverseManyToOneDescriptor as ForeignRelatedObjectsDescriptor, @@ -239,7 +240,10 @@ def create( return baker_class(_model, make_m2m, create_files) def __init__( - self, _model: str, make_m2m: bool = False, create_files=False # type: bool + self, + _model: Union[str, Type[ModelBase]], + make_m2m: bool = False, + create_files: bool = False, ) -> None: self.make_m2m = make_m2m self.create_files = create_files @@ -249,7 +253,7 @@ def __init__( self.rel_attrs = {} # type: Dict[str, Any] self.rel_fields = [] # type: List[str] - if isinstance(_model, Model): + if isinstance(_model, ModelBase): self.model = _model else: self.model = self.finder.get_model(_model) From a729326e81fad4c1845dc64a27b5a3c8b3e22897 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 14:22:03 +0100 Subject: [PATCH 11/20] Cope with inference from union types --- model_bakery/baker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index f60da8b9..30c05807 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -253,10 +253,10 @@ def __init__( self.rel_attrs = {} # type: Dict[str, Any] self.rel_fields = [] # type: List[str] - if isinstance(_model, ModelBase): - self.model = _model - else: + if isinstance(_model, str): self.model = self.finder.get_model(_model) + else: + self.model = _model self.init_type_mapping() From 71fe58731e8c12bb4c2e5ace455e6acfd4bc81ad Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 14:38:23 +0100 Subject: [PATCH 12/20] Deal with new action_generator work in tests --- model_bakery/baker.py | 12 +++++++----- tests/test_extending_bakery.py | 4 +++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 30c05807..b0d97be6 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -233,7 +233,7 @@ class Baker(object): @classmethod def create( - cls, _model: str, make_m2m: bool = False, create_files=False # type: bool + cls, _model: str, make_m2m: bool = False, create_files: bool = False ) -> "Baker": """Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting.""" baker_class = _custom_baker_class() or cls @@ -270,7 +270,7 @@ def init_type_mapping(self) -> None: def make( self, - _save_kwargs=None, # type: Optional[Dict[str, Any]] + _save_kwargs: Optional[Dict[str, Any]] = None, _refresh_after_create: bool = False, _from_manager=None, **attrs: Any @@ -497,9 +497,7 @@ def _remote_field( ) -> Union[OneToOneRel, ManyToOneRel]: return field.remote_field - def generate_value( - self, field: Field, commit=True # type: bool - ) -> Any: + def generate_value(self, field: Field, commit: bool = True) -> Any: """Call the associated generator with a field passing all required args. Generator Resolution Precedence Order: @@ -539,7 +537,11 @@ def generate_value( generator_attrs.update(filter_rel_attrs(field.name, **self.rel_attrs)) if not commit: + old_generator = generator generator = getattr(generator, "prepare", generator) + if generator is None: + generator = old_generator + return generator(**generator_attrs) diff --git a/tests/test_extending_bakery.py b/tests/test_extending_bakery.py index a266f201..e9999fb6 100644 --- a/tests/test_extending_bakery.py +++ b/tests/test_extending_bakery.py @@ -2,14 +2,16 @@ from model_bakery import baker from model_bakery.exceptions import CustomBakerNotFound, InvalidCustomBaker -from model_bakery.random_gen import gen_from_list +from model_bakery.random_gen import action_generator, gen_from_list from tests.generic.models import Person +@action_generator def gen_opposite(default): return not default +@action_generator def gen_age(): # forever young return 18 From f48ade21b6cc9333d20f751e5dd4af3e48556809 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 14:41:36 +0100 Subject: [PATCH 13/20] Don't use test class in main code! --- model_bakery/baker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index b0d97be6..4956c57e 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -21,8 +21,6 @@ ) from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel -from tests.generic.fields import CustomForeignKey - from . import generators, random_gen from .exceptions import ( AmbiguousModelName, @@ -493,7 +491,7 @@ def _handle_m2m(self, instance: Model): make(through_model, **base_kwargs) def _remote_field( - self, field: Union[CustomForeignKey, ForeignKey, OneToOneField] + self, field: Union[ForeignKey, OneToOneField] ) -> Union[OneToOneRel, ManyToOneRel]: return field.remote_field From 1b7a9d86066a64d82d7bf13c2427109ee2449207 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 14:45:55 +0100 Subject: [PATCH 14/20] _skip_field can use Field --- model_bakery/baker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 4956c57e..a812e083 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -418,7 +418,7 @@ def is_rel_field(x: str): x.split("__")[0] for x in self.rel_attrs.keys() if is_rel_field(x) ] - def _skip_field(self, field: Any) -> bool: + def _skip_field(self, field: Field) -> bool: from django.contrib.contenttypes.fields import GenericRelation # check for fill optional argument From 519eeb0486f1f76cadb05d40cc1d6ccb27f8c093 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sat, 5 Sep 2020 14:48:04 +0100 Subject: [PATCH 15/20] Loosen typing-extensions requirement --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d57228a2..77bad2c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ django>=1.11.0<3.2 -typing-extensions==3.7.4.3 +typing-extensions From 4bc7c107ab385dba4cb44ea9159efd7084616e04 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sun, 13 Sep 2020 15:01:54 +0100 Subject: [PATCH 16/20] Fix create_files attribute Co-authored-by: Bernardo Fontes --- model_bakery/baker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index a812e083..0edaa689 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -53,7 +53,7 @@ def make( make_m2m: bool = False, _save_kwargs: Optional[Dict] = None, _refresh_after_create: bool = False, - _create_files=False, + _create_files: bool = False, **attrs: Any ): """Create a persisted instance from a given model its associated models. From 529df43e5ae0b761cb42bf76c4de284d380cf73f Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sun, 13 Sep 2020 15:02:02 +0100 Subject: [PATCH 17/20] Fix unique/ambiguous_models types --- model_bakery/baker.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 0edaa689..64367309 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -117,9 +117,8 @@ def prepare_recipe(baker_recipe_name, _quantity=None, _save_related=False, **new class ModelFinder(object): """Encapsulates all the logic for finding a model to Baker.""" - def __init__(self) -> None: - self._unique_models = None # type: Optional[Dict[str, Type[Model]]] - self._ambiguous_models = None # type: Optional[List[str]] + _unique_models = None # type: Optional[Dict[str, Type[Model]]] + _ambiguous_models = None # type: Optional[List[str]] def get_model(self, name: str) -> Type[Model]: """Get a model. From c8e137ea97d0a88137fbb1d7c89decc81d026934 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sun, 13 Sep 2020 15:09:51 +0100 Subject: [PATCH 18/20] Rip out action_generator --- model_bakery/baker.py | 8 ++---- model_bakery/random_gen.py | 48 ++++++---------------------------- tests/test_extending_bakery.py | 4 +-- 3 files changed, 11 insertions(+), 49 deletions(-) diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 64367309..d84bebf5 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -30,7 +30,6 @@ ModelNotFound, RecipeIteratorEmpty, ) -from .random_gen import ActionGenerator from .utils import seq # NoQA: enable seq to be imported from baker from .utils import import_from_str @@ -534,10 +533,7 @@ def generate_value(self, field: Field, commit: bool = True) -> Any: generator_attrs.update(filter_rel_attrs(field.name, **self.rel_attrs)) if not commit: - old_generator = generator generator = getattr(generator, "prepare", generator) - if generator is None: - generator = old_generator return generator(**generator_attrs) @@ -553,8 +549,8 @@ def get_required_values( """ # FIXME: avoid abbreviations rt = {} # type: Dict[str, Any] - if isinstance(generator, ActionGenerator): - for item in generator.required: + if hasattr(generator, "required"): + for item in generator.required: # type: ignore[attr-defined] if callable(item): # baker can deal with the nasty hacking too! key, value = item(field) diff --git a/model_bakery/random_gen.py b/model_bakery/random_gen.py index e60fbe46..c7788091 100644 --- a/model_bakery/random_gen.py +++ b/model_bakery/random_gen.py @@ -15,12 +15,11 @@ from decimal import Decimal from os.path import abspath, dirname, join from random import choice, randint, random, uniform -from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union, cast +from typing import Any, Callable, List, Optional, Tuple, Union from uuid import UUID from django.core.files.base import ContentFile from django.db.models import Field, Model -from typing_extensions import Protocol, runtime_checkable from .timezone import now @@ -30,31 +29,6 @@ MAX_INT = 100000000000 -# Hack to make mypy happy with attributes on functions -# See https://github.com/python/mypy/issues/2087#issuecomment-587741762 -F = TypeVar("F", bound=Callable[..., object]) - - -@runtime_checkable -class ActionGenerator(Protocol[F]): - required = ( - None - ) # type: Optional[List[Union[str, Callable[[Field], Tuple[str, Optional[Model]]]]]] - prepare = None # type: Optional[Callable[..., Union[Model, List[Model]]]] - __call__ = None # type: F - - -def action_generator(action: F) -> ActionGenerator[F]: - action_generator = cast(ActionGenerator[F], action) - # Make sure the cast isn't a lie. - action_generator.required = None - action_generator.prepare = None - return action_generator - - -# End mypy hack - - def get_content_file(content: bytes, name: str) -> ContentFile: return ContentFile(content, name=name) @@ -109,7 +83,6 @@ def gen_float() -> float: return random() * gen_integer() -@action_generator def gen_decimal(max_digits: int, decimal_places: int) -> Decimal: def num_as_str(x: int): return "".join([str(randint(0, 9)) for _ in range(x)]) @@ -122,7 +95,7 @@ def num_as_str(x: int): return Decimal(num_as_str(max_digits)) -gen_decimal.required = ["max_digits", "decimal_places"] +gen_decimal.required = ["max_digits", "decimal_places"] # type: ignore[attr-defined] def gen_date() -> date: @@ -137,21 +110,19 @@ def gen_time() -> time: return now().time() -@action_generator def gen_string(max_length: int) -> str: return str("".join(choice(string.ascii_letters) for _ in range(max_length))) -gen_string.required = ["max_length"] +gen_string.required = ["max_length"] # type: ignore[attr-defined] -@action_generator def gen_slug(max_length: int) -> str: valid_chars = string.ascii_letters + string.digits + "_-" return str("".join(choice(valid_chars) for _ in range(max_length))) -gen_slug.required = ["max_length"] +gen_slug.required = ["max_length"] # type: ignore[attr-defined] def gen_text() -> str: @@ -187,7 +158,6 @@ def gen_ipv46() -> str: return ip_gen() -@action_generator def gen_ip(protocol: str, default_validators: List[Callable]) -> str: from django.core.exceptions import ValidationError @@ -217,7 +187,7 @@ def gen_ip(protocol: str, default_validators: List[Callable]) -> str: return generator() -gen_ip.required = ["protocol", "default_validators"] +gen_ip.required = ["protocol", "default_validators"] # type: ignore[attr-defined] def gen_byte_string(max_length: int = 16) -> bytes: @@ -273,25 +243,23 @@ def _prepare_related(model: str, **attrs: Any) -> Union[Model, List[Model]]: return prepare(model, **attrs) -@action_generator def gen_related(model, **attrs): from .baker import make return make(model, **attrs) -gen_related.required = [_fk_model] -gen_related.prepare = _prepare_related +gen_related.required = [_fk_model] # type: ignore[attr-defined] +gen_related.prepare = _prepare_related # type: ignore[attr-defined] -@action_generator def gen_m2m(model, **attrs): from .baker import MAX_MANY_QUANTITY, make return make(model, _quantity=MAX_MANY_QUANTITY, **attrs) -gen_m2m.required = [_fk_model] +gen_m2m.required = [_fk_model] # type: ignore[attr-defined] # GIS generators diff --git a/tests/test_extending_bakery.py b/tests/test_extending_bakery.py index e9999fb6..a266f201 100644 --- a/tests/test_extending_bakery.py +++ b/tests/test_extending_bakery.py @@ -2,16 +2,14 @@ from model_bakery import baker from model_bakery.exceptions import CustomBakerNotFound, InvalidCustomBaker -from model_bakery.random_gen import action_generator, gen_from_list +from model_bakery.random_gen import gen_from_list from tests.generic.models import Person -@action_generator def gen_opposite(default): return not default -@action_generator def gen_age(): # forever young return 18 From e6ce7c6aa83cdfea5f64dc8e14fc220a030eb8e3 Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Sun, 13 Sep 2020 15:12:57 +0100 Subject: [PATCH 19/20] Don't need typing-extensions any more --- requirements.txt | 1 - tox.ini | 1 - 2 files changed, 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 77bad2c2..f9105b95 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ django>=1.11.0<3.2 -typing-extensions diff --git a/tox.ini b/tox.ini index bca3ec96..24a58882 100644 --- a/tox.ini +++ b/tox.ini @@ -25,7 +25,6 @@ deps = pillow pytest pytest-django - typing-extensions django111: Django>=1.11,<1.12 django20: Django==2.0 django21: Django==2.1 From 84c61dc4b62536b755c5be0a2054fe2ab7ac486a Mon Sep 17 00:00:00 2001 From: Tom Parker-Shemilt Date: Mon, 28 Sep 2020 20:16:57 +0100 Subject: [PATCH 20/20] Added changelog entry for type annotations --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1178b6ee..31e02afe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased](https://github.com/model-bakers/model_bakery/tree/master) ### Added -- Support to django 3.1 `JSONField`[PR #85](https://github.com/model-bakers/model_bakery/pull/85) and [PR #106](https://github.com/model-bakers/model_bakery/pull/106) +- Support to django 3.1 `JSONField` [PR #85](https://github.com/model-bakers/model_bakery/pull/85) and [PR #106](https://github.com/model-bakers/model_bakery/pull/106) - [dev] Changelog reminder (GitHub action) +- Added type annotations [PR #100](https://github.com/model-bakers/model_bakery/pull/100) ### Changed - [dev] CI switched to GitHub Actions