diff --git a/CHANGELOG.md b/CHANGELOG.md index 1178b6ee..31e02afe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased](https://github.com/model-bakers/model_bakery/tree/master) ### Added -- Support to django 3.1 `JSONField`[PR #85](https://github.com/model-bakers/model_bakery/pull/85) and [PR #106](https://github.com/model-bakers/model_bakery/pull/106) +- Support to django 3.1 `JSONField` [PR #85](https://github.com/model-bakers/model_bakery/pull/85) and [PR #106](https://github.com/model-bakers/model_bakery/pull/106) - [dev] Changelog reminder (GitHub action) +- Added type annotations [PR #100](https://github.com/model-bakers/model_bakery/pull/100) ### Changed - [dev] CI switched to GitHub Actions diff --git a/model_bakery/baker.py b/model_bakery/baker.py index 515d3b70..d84bebf5 100644 --- a/model_bakery/baker.py +++ b/model_bakery/baker.py @@ -1,4 +1,5 @@ from os.path import dirname, join +from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union, cast from django.apps import apps from django.conf import settings @@ -10,6 +11,7 @@ FileField, ForeignKey, ManyToManyField, + Model, OneToOneField, ) from django.db.models.base import ModelBase @@ -17,6 +19,7 @@ from django.db.models.fields.related import ( ReverseManyToOneDescriptor as ForeignRelatedObjectsDescriptor, ) +from django.db.models.fields.reverse_related import ManyToOneRel, OneToOneRel from . import generators, random_gen from .exceptions import ( @@ -39,18 +42,18 @@ MAX_MANY_QUANTITY = 5 -def _valid_quantity(quantity): +def _valid_quantity(quantity: Optional[Union[str, int]]) -> bool: return quantity is not None and (not isinstance(quantity, int) or quantity < 1) def make( - _model, - _quantity=None, - make_m2m=False, - _save_kwargs=None, - _refresh_after_create=False, - _create_files=False, - **attrs + _model: str, + _quantity: Optional[int] = None, + make_m2m: bool = False, + _save_kwargs: Optional[Dict] = None, + _refresh_after_create: bool = False, + _create_files: bool = False, + **attrs: Any ): """Create a persisted instance from a given model its associated models. @@ -76,7 +79,7 @@ def make( ) -def prepare(_model, _quantity=None, _save_related=False, **attrs): +def prepare(_model: str, _quantity=None, _save_related=False, **attrs) -> Model: """Create but do not persist an instance from a given model. It fill the fields with random values or you can specify which @@ -95,7 +98,7 @@ def prepare(_model, _quantity=None, _save_related=False, **attrs): return baker.prepare(_save_related=_save_related, **attrs) -def _recipe(name): +def _recipe(name: str) -> Any: app, recipe_name = name.rsplit(".", 1) return import_from_str(".".join((app, "baker_recipes", recipe_name))) @@ -113,10 +116,10 @@ def prepare_recipe(baker_recipe_name, _quantity=None, _save_related=False, **new class ModelFinder(object): """Encapsulates all the logic for finding a model to Baker.""" - _unique_models = None - _ambiguous_models = None + _unique_models = None # type: Optional[Dict[str, Type[Model]]] + _ambiguous_models = None # type: Optional[List[str]] - def get_model(self, name): + def get_model(self, name: str) -> Type[Model]: """Get a model. Args: @@ -140,7 +143,7 @@ def get_model(self, name): return model - def get_model_by_name(self, name): + def get_model_by_name(self, name: str) -> Optional[Type[Model]]: """Get a model by name. If a model with that name exists in more than one app, raises @@ -148,18 +151,18 @@ def get_model_by_name(self, name): """ name = name.lower() - if self._unique_models is None: + if self._unique_models is None or self._ambiguous_models is None: self._populate() - if name in self._ambiguous_models: + if name in cast(List, self._ambiguous_models): raise AmbiguousModelName( "%s is a model in more than one app. " 'Use the form "app.model".' % name.title() ) - return self._unique_models.get(name) + return cast(Dict, self._unique_models).get(name) - def _populate(self): + def _populate(self) -> None: """Cache models for faster self._get_model.""" unique_models = {} ambiguous_models = [] @@ -180,14 +183,14 @@ def _populate(self): self._unique_models = unique_models -def is_iterator(value): +def is_iterator(value: Any) -> bool: if not hasattr(value, "__iter__"): return False return hasattr(value, "__next__") -def _custom_baker_class(): +def _custom_baker_class() -> Optional[Type]: """Return the specified custom baker class. Returns: @@ -217,36 +220,43 @@ def _custom_baker_class(): class Baker(object): - attr_mapping = {} - type_mapping = None + attr_mapping = {} # type: Dict[str, Any] + type_mapping = {} # type: Dict # Note: we're using one finder for all Baker instances to avoid # rebuilding the model cache for every make_* or prepare_* call. finder = ModelFinder() @classmethod - def create(cls, _model, make_m2m=False, create_files=False): + def create( + cls, _model: str, make_m2m: bool = False, create_files: bool = False + ) -> "Baker": """Create the baker class defined by the `BAKER_CUSTOM_CLASS` setting.""" baker_class = _custom_baker_class() or cls return baker_class(_model, make_m2m, create_files) - def __init__(self, _model, make_m2m=False, create_files=False): + def __init__( + self, + _model: Union[str, Type[ModelBase]], + make_m2m: bool = False, + create_files: bool = False, + ) -> None: self.make_m2m = make_m2m self.create_files = create_files - self.m2m_dict = {} - self.iterator_attrs = {} - self.model_attrs = {} - self.rel_attrs = {} - self.rel_fields = [] + self.m2m_dict = {} # type: Dict[str, List] + self.iterator_attrs = {} # type: Dict[str, Iterator] + self.model_attrs = {} # type: Dict[str, Any] + self.rel_attrs = {} # type: Dict[str, Any] + self.rel_fields = [] # type: List[str] - if isinstance(_model, ModelBase): - self.model = _model - else: + if isinstance(_model, str): self.model = self.finder.get_model(_model) + else: + self.model = _model self.init_type_mapping() - def init_type_mapping(self): + def init_type_mapping(self) -> None: self.type_mapping = generators.get_type_mapping() generators_from_settings = getattr(settings, "BAKER_CUSTOM_FIELDS_GEN", {}) for k, v in generators_from_settings.items(): @@ -256,10 +266,10 @@ def init_type_mapping(self): def make( self, - _save_kwargs=None, - _refresh_after_create=False, + _save_kwargs: Optional[Dict[str, Any]] = None, + _refresh_after_create: bool = False, _from_manager=None, - **attrs + **attrs: Any ): """Create and persist an instance of the model associated with Baker instance.""" params = { @@ -272,14 +282,16 @@ def make( params.update(attrs) return self._make(**params) - def prepare(self, _save_related=False, **attrs): + def prepare(self, _save_related=False, **attrs: Any) -> Model: """Create, but do not persist, an instance of the associated model.""" return self._make(commit=False, commit_related=_save_related, **attrs) - def get_fields(self): + def get_fields(self) -> Any: return self.model._meta.fields + self.model._meta.many_to_many - def get_related(self): + def get_related( + self, + ) -> List[Union[ManyToOneRel, OneToOneRel]]: return [r for r in self.model._meta.related_objects if not r.many_to_many] def _make( @@ -289,8 +301,8 @@ def _make( _save_kwargs=None, _refresh_after_create=False, _from_manager=None, - **attrs - ): + **attrs: Any + ) -> Model: _save_kwargs = _save_kwargs or {} self._clean_attrs(attrs) @@ -336,21 +348,23 @@ def _make( return instance - def m2m_value(self, field): + def m2m_value(self, field: ManyToManyField) -> List[Any]: if field.name in self.rel_fields: return self.generate_value(field) if not self.make_m2m or field.null and not field.fill_optional: return [] return self.generate_value(field) - def instance(self, attrs, _commit, _save_kwargs, _from_manager): + def instance( + self, attrs: Dict[str, Any], _commit, _save_kwargs, _from_manager + ) -> Model: one_to_many_keys = {} for k in tuple(attrs.keys()): field = getattr(self.model, k, None) if isinstance(field, ForeignRelatedObjectsDescriptor): one_to_many_keys[k] = attrs.pop(k) - instance = self.model(**attrs) + instance = self.model(**attrs) # type: Model # m2m only works for persisted instances if _commit: instance.save(**_save_kwargs) @@ -367,19 +381,20 @@ def instance(self, attrs, _commit, _save_kwargs, _from_manager): return instance - def create_by_related_name(self, instance, related): + def create_by_related_name( + self, instance: Model, related: Union[ManyToOneRel, OneToOneRel] + ) -> None: rel_name = related.get_accessor_name() if rel_name not in self.rel_fields: return kwargs = filter_rel_attrs(rel_name, **self.rel_attrs) kwargs[related.field.name] = instance - kwargs["_model"] = related.field.model - make(**kwargs) + make(related.field.model, **kwargs) - def _clean_attrs(self, attrs): - def is_rel_field(x): + def _clean_attrs(self, attrs: Dict[str, Any]) -> None: + def is_rel_field(x: str): return "__" in x self.fill_in_optional = attrs.pop("_fill_optional", False) @@ -401,7 +416,7 @@ def is_rel_field(x): x.split("__")[0] for x in self.rel_attrs.keys() if is_rel_field(x) ] - def _skip_field(self, field): + def _skip_field(self, field: Field) -> bool: from django.contrib.contenttypes.fields import GenericRelation # check for fill optional argument @@ -444,7 +459,7 @@ def _skip_field(self, field): return False - def _handle_one_to_many(self, instance, attrs): + def _handle_one_to_many(self, instance: Model, attrs: Dict[str, Any]): for k, v in attrs.items(): manager = getattr(instance, k) @@ -454,7 +469,7 @@ def _handle_one_to_many(self, instance, attrs): # for many-to-many relationships the bulk keyword argument doesn't exist manager.set(v, clear=True) - def _handle_m2m(self, instance): + def _handle_m2m(self, instance: Model): for key, values in self.m2m_dict.items(): for value in values: if not value.pk: @@ -473,10 +488,12 @@ def _handle_m2m(self, instance): } make(through_model, **base_kwargs) - def _remote_field(self, field): + def _remote_field( + self, field: Union[ForeignKey, OneToOneField] + ) -> Union[OneToOneRel, ManyToOneRel]: return field.remote_field - def generate_value(self, field, commit=True): + def generate_value(self, field: Field, commit: bool = True) -> Any: """Call the associated generator with a field passing all required args. Generator Resolution Precedence Order: @@ -517,10 +534,13 @@ def generate_value(self, field, commit=True): if not commit: generator = getattr(generator, "prepare", generator) + return generator(**generator_attrs) -def get_required_values(generator, field): +def get_required_values( + generator: Callable, field: Field +) -> Dict[str, Union[bool, int, str, List[Callable]]]: """Get required values for a generator from the field. If required value is a function, calls it with field as argument. If @@ -528,9 +548,9 @@ def get_required_values(generator, field): and return. """ # FIXME: avoid abbreviations - rt = {} + rt = {} # type: Dict[str, Any] if hasattr(generator, "required"): - for item in generator.required: + for item in generator.required: # type: ignore[attr-defined] if callable(item): # baker can deal with the nasty hacking too! key, value = item(field) @@ -549,7 +569,7 @@ def get_required_values(generator, field): return rt -def filter_rel_attrs(field_name, **rel_attrs): +def filter_rel_attrs(field_name: str, **rel_attrs) -> Dict[str, Any]: clean_dict = {} for k, v in rel_attrs.items(): diff --git a/model_bakery/generators.py b/model_bakery/generators.py index 2dc9b204..8ecfca66 100644 --- a/model_bakery/generators.py +++ b/model_bakery/generators.py @@ -1,4 +1,5 @@ from decimal import Decimal +from typing import Any, Callable, Dict, Optional, Type, Union from django.db.backends.base.operations import BaseDatabaseOperations from django.db.models import ( @@ -97,7 +98,7 @@ DateTimeRangeField = None -def _make_integer_gen_by_range(field_type): +def _make_integer_gen_by_range(field_type: Any) -> Callable: min_int, max_int = BaseDatabaseOperations.integer_field_ranges[field_type.__name__] def gen_integer(): @@ -134,7 +135,7 @@ def gen_integer(): FileField: random_gen.gen_file_field, ImageField: random_gen.gen_image_field, DurationField: random_gen.gen_interval, -} +} # type: Dict[Type, Callable] if ArrayField: default_mapping[ArrayField] = random_gen.gen_array @@ -177,7 +178,7 @@ def gen_integer(): # Add GIS fields -def get_type_mapping(): +def get_type_mapping() -> Dict[Type, Callable]: from django.contrib.contenttypes.models import ContentType from .gis import default_gis_mapping @@ -192,9 +193,9 @@ def get_type_mapping(): user_mapping = {} -def add(field, func): +def add(field: str, func: Optional[Union[Callable, str]]) -> None: user_mapping[import_from_str(field)] = import_from_str(func) -def get(field): +def get(field: Any) -> Optional[Callable]: return user_mapping.get(field) diff --git a/model_bakery/random_gen.py b/model_bakery/random_gen.py index b062042e..c7788091 100644 --- a/model_bakery/random_gen.py +++ b/model_bakery/random_gen.py @@ -11,11 +11,17 @@ import string import warnings +from datetime import date, datetime, time, timedelta from decimal import Decimal from os.path import abspath, dirname, join from random import choice, randint, random, uniform +from typing import Any, Callable, List, Optional, Tuple, Union +from uuid import UUID -from model_bakery.timezone import now +from django.core.files.base import ContentFile +from django.db.models import Field, Model + +from .timezone import now MAX_LENGTH = 300 # Using sys.maxint here breaks a bunch of tests when running against a @@ -23,27 +29,25 @@ MAX_INT = 100000000000 -def get_content_file(content, name): - from django.core.files.base import ContentFile - +def get_content_file(content: bytes, name: str) -> ContentFile: return ContentFile(content, name=name) -def gen_file_field(): +def gen_file_field() -> ContentFile: name = "mock_file.txt" file_path = abspath(join(dirname(__file__), name)) with open(file_path, "rb") as f: return get_content_file(f.read(), name=name) -def gen_image_field(): +def gen_image_field() -> ContentFile: name = "mock_img.jpeg" file_path = abspath(join(dirname(__file__), name)) with open(file_path, "rb") as f: return get_content_file(f.read(), name=name) -def gen_from_list(a_list): +def gen_from_list(a_list: Union[List[str], range]) -> Callable: """Make sure all values of the field are generated from a list. Examples: @@ -60,7 +64,7 @@ def gen_from_list(a_list): # -- DEFAULT GENERATORS -- -def gen_from_choices(choices): +def gen_from_choices(choices: List) -> Callable: choice_list = [] for value, label in choices: if isinstance(label, (list, tuple)): @@ -71,16 +75,16 @@ def gen_from_choices(choices): return gen_from_list(choice_list) -def gen_integer(min_int=-MAX_INT, max_int=MAX_INT): +def gen_integer(min_int: int = -MAX_INT, max_int: int = MAX_INT) -> int: return randint(min_int, max_int) -def gen_float(): +def gen_float() -> float: return random() * gen_integer() -def gen_decimal(max_digits, decimal_places): - def num_as_str(x): +def gen_decimal(max_digits: int, decimal_places: int) -> Decimal: + def num_as_str(x: int): return "".join([str(randint(0, 9)) for _ in range(x)]) if decimal_places: @@ -91,41 +95,41 @@ def num_as_str(x): return Decimal(num_as_str(max_digits)) -gen_decimal.required = ["max_digits", "decimal_places"] +gen_decimal.required = ["max_digits", "decimal_places"] # type: ignore[attr-defined] -def gen_date(): +def gen_date() -> date: return now().date() -def gen_datetime(): +def gen_datetime() -> datetime: return now() -def gen_time(): +def gen_time() -> time: return now().time() -def gen_string(max_length): +def gen_string(max_length: int) -> str: return str("".join(choice(string.ascii_letters) for _ in range(max_length))) -gen_string.required = ["max_length"] +gen_string.required = ["max_length"] # type: ignore[attr-defined] -def gen_slug(max_length): +def gen_slug(max_length: int) -> str: valid_chars = string.ascii_letters + string.digits + "_-" return str("".join(choice(valid_chars) for _ in range(max_length))) -gen_slug.required = ["max_length"] +gen_slug.required = ["max_length"] # type: ignore[attr-defined] -def gen_text(): +def gen_text() -> str: return gen_string(MAX_LENGTH) -def gen_boolean(): +def gen_boolean() -> bool: return choice((True, False)) @@ -133,28 +137,28 @@ def gen_null_boolean(): return choice((True, False, None)) -def gen_url(): +def gen_url() -> str: return str("http://www.%s.com/" % gen_string(30)) -def gen_email(): +def gen_email() -> str: return "%s@example.com" % gen_string(10) -def gen_ipv6(): +def gen_ipv6() -> str: return ":".join(format(randint(1, 65535), "x") for _ in range(8)) -def gen_ipv4(): +def gen_ipv4() -> str: return ".".join(str(randint(1, 255)) for _ in range(4)) -def gen_ipv46(): +def gen_ipv46() -> str: ip_gen = choice([gen_ipv4, gen_ipv6]) return ip_gen() -def gen_ip(protocol, default_validators): +def gen_ip(protocol: str, default_validators: List[Callable]) -> str: from django.core.exceptions import ValidationError protocol = (protocol or "").lower() @@ -183,17 +187,15 @@ def gen_ip(protocol, default_validators): return generator() -gen_ip.required = ["protocol", "default_validators"] +gen_ip.required = ["protocol", "default_validators"] # type: ignore[attr-defined] -def gen_byte_string(max_length=16): +def gen_byte_string(max_length: int = 16) -> bytes: generator = (randint(0, 255) for x in range(max_length)) return bytes(generator) -def gen_interval(interval_key="milliseconds", offset=0): - from datetime import timedelta - +def gen_interval(interval_key: str = "milliseconds", offset: int = 0) -> timedelta: interval = gen_integer() + offset kwargs = {interval_key: interval} return timedelta(**kwargs) @@ -210,7 +212,7 @@ def gen_content_type(): return ContentType() -def gen_uuid(): +def gen_uuid() -> UUID: import uuid return uuid.uuid4() @@ -228,14 +230,14 @@ def gen_hstore(): return {} -def _fk_model(field): +def _fk_model(field: Field) -> Tuple[str, Optional[Model]]: try: return ("model", field.related_model) except AttributeError: return ("model", field.related.parent_model) -def _prepare_related(model, **attrs): +def _prepare_related(model: str, **attrs: Any) -> Union[Model, List[Model]]: from .baker import prepare return prepare(model, **attrs) @@ -247,8 +249,8 @@ def gen_related(model, **attrs): return make(model, **attrs) -gen_related.required = [_fk_model] -gen_related.prepare = _prepare_related +gen_related.required = [_fk_model] # type: ignore[attr-defined] +gen_related.prepare = _prepare_related # type: ignore[attr-defined] def gen_m2m(model, **attrs): @@ -257,7 +259,7 @@ def gen_m2m(model, **attrs): return make(model, _quantity=MAX_MANY_QUANTITY, **attrs) -gen_m2m.required = [_fk_model] +gen_m2m.required = [_fk_model] # type: ignore[attr-defined] # GIS generators @@ -327,7 +329,7 @@ def gen_geometry_collection(): ) -def gen_pg_numbers_range(number_cast=int): +def gen_pg_numbers_range(number_cast: Callable[[int], Any]) -> Callable: def gen_range(): from psycopg2._range import NumericRange diff --git a/model_bakery/recipe.py b/model_bakery/recipe.py index c36ec97a..d0c429df 100644 --- a/model_bakery/recipe.py +++ b/model_bakery/recipe.py @@ -1,5 +1,8 @@ import inspect import itertools +from typing import Any, Dict, List, Union, cast + +from django.db.models import Model from . import baker from .exceptions import RecipeNotFound @@ -9,13 +12,13 @@ class Recipe(object): - def __init__(self, _model, **attrs): + def __init__(self, _model: str, **attrs) -> None: self.attr_mapping = attrs self._model = _model # _iterator_backups will hold values of the form (backup_iterator, usable_iterator). - self._iterator_backups = {} + self._iterator_backups = {} # type: Dict[str, Any] - def _mapping(self, new_attrs): + def _mapping(self, new_attrs: Dict[str, Any]) -> Dict[str, Any]: _save_related = new_attrs.get("_save_related", True) rel_fields_attrs = dict((k, v) for k, v in new_attrs.items() if "__" in k) new_attrs = dict((k, v) for k, v in new_attrs.items() if "__" not in k) @@ -50,22 +53,22 @@ def _mapping(self, new_attrs): mapping.update(rel_fields_attrs) return mapping - def make(self, **attrs): + def make(self, **attrs: Any) -> Union[Model, List[Model]]: return baker.make(self._model, **self._mapping(attrs)) - def prepare(self, **attrs): + def prepare(self, **attrs: Any) -> Union[Model, List[Model]]: defaults = {"_save_related": False} defaults.update(attrs) return baker.prepare(self._model, **self._mapping(defaults)) - def extend(self, **attrs): + def extend(self, **attrs) -> "Recipe": attr_mapping = self.attr_mapping.copy() attr_mapping.update(attrs) return type(self)(self._model, **attr_mapping) class RecipeForeignKey(object): - def __init__(self, recipe): + def __init__(self, recipe: Union[str, Recipe]) -> None: if isinstance(recipe, Recipe): self.recipe = recipe elif isinstance(recipe, str): @@ -73,14 +76,14 @@ def __init__(self, recipe): caller_module = inspect.getmodule(frame[0]) recipe = getattr(caller_module, recipe) if recipe: - self.recipe = recipe + self.recipe = cast(Recipe, recipe) else: raise RecipeNotFound else: raise TypeError("Not a recipe") -def foreign_key(recipe): +def foreign_key(recipe: Union[Recipe, str]) -> RecipeForeignKey: """Return a `RecipeForeignKey`. Return the callable, so that the associated `_model` will not be created @@ -90,8 +93,8 @@ def foreign_key(recipe): class related(object): # FIXME - def __init__(self, *args): - self.related = [] + def __init__(self, *args) -> None: + self.related = [] # type: List[Recipe] for recipe in args: if isinstance(recipe, Recipe): self.related.append(recipe) @@ -106,6 +109,6 @@ def __init__(self, *args): else: raise TypeError("Not a recipe") - def make(self): + def make(self) -> List[Union[Model, List[Model]]]: """Persist objects to m2m relation.""" return [m.make() for m in self.related] diff --git a/model_bakery/timezone.py b/model_bakery/timezone.py index 94667e1a..9641050b 100644 --- a/model_bakery/timezone.py +++ b/model_bakery/timezone.py @@ -15,12 +15,12 @@ def now(): return datetime.now() -def smart_datetime(*args): +def smart_datetime(*args) -> datetime: value = datetime(*args) return tz_aware(value) -def tz_aware(d): +def tz_aware(d: datetime) -> datetime: value = d if settings.USE_TZ: value = d.replace(tzinfo=utc) diff --git a/model_bakery/utils.py b/model_bakery/utils.py index ee89393b..2288f8e7 100644 --- a/model_bakery/utils.py +++ b/model_bakery/utils.py @@ -2,11 +2,12 @@ import importlib import itertools import warnings +from typing import Any, Callable, Optional, Union from .timezone import tz_aware -def import_from_str(import_string): +def import_from_str(import_string: Optional[Union[Callable, str]]) -> Any: """Import an object defined as import if it is an string. If `import_string` follows the format `path.to.module.object_name`, diff --git a/requirements_dev.txt b/requirements_dev.txt index a87b0be1..5d9d3c59 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -3,6 +3,7 @@ black==20.8b1 flake8==3.8.3 isort==5.5.0 +mypy==0.782 pillow==7.2.0 pip-tools==5.3.1 pre-commit==2.7.1 diff --git a/setup.cfg b/setup.cfg index 82f4cda8..acf9fc4a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,3 +20,7 @@ line_length=88 [pydocstyle] add_ignore = D1 match-dir = (?!test|docs|\.).* + +[mypy] +ignore_missing_imports=True +disallow_untyped_calls=True diff --git a/tests/generic/forms.py b/tests/generic/forms.py index e70f98da..d82cf324 100644 --- a/tests/generic/forms.py +++ b/tests/generic/forms.py @@ -1,4 +1,5 @@ from django.forms import ModelForm + from tests.generic.models import DummyGenericIPAddressFieldModel diff --git a/tests/test_extending_bakery.py b/tests/test_extending_bakery.py index ed5cd4a4..a266f201 100644 --- a/tests/test_extending_bakery.py +++ b/tests/test_extending_bakery.py @@ -1,4 +1,5 @@ import pytest + from model_bakery import baker from model_bakery.exceptions import CustomBakerNotFound, InvalidCustomBaker from model_bakery.random_gen import gen_from_list diff --git a/tests/test_utils.py b/tests/test_utils.py index b34295b8..3efeb763 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ import pytest + from model_bakery.utils import import_from_str from tests.generic.models import User diff --git a/tox.ini b/tox.ini index fc4b3135..24a58882 100644 --- a/tox.ini +++ b/tox.ini @@ -6,13 +6,14 @@ envlist = isort pydocstyle black + mypy [gh-actions] python = 3.5: py35 3.6: py36 3.7: py37 - 3.8: py38,flake8,isort,pydocstyle,black + 3.8: py38,flake8,isort,pydocstyle,black,mypy [testenv] setenv = @@ -48,3 +49,8 @@ commands=isort model_bakery --check-only {posargs} deps=black basepython=python3 commands=black . --check + +[testenv:mypy] +deps=mypy +basepython=python3 +commands=python -m mypy .