diff --git a/.gitignore b/.gitignore index ffb8e8a83..0e255dff4 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ out/ .mypy_cache/ django-sources build/ -dist/ \ No newline at end of file +dist/ +pip-wheel-metadata/ \ No newline at end of file diff --git a/django-stubs/conf/__init__.pyi b/django-stubs/conf/__init__.pyi index c2bdc32dc..8f7bfa0e8 100644 --- a/django-stubs/conf/__init__.pyi +++ b/django-stubs/conf/__init__.pyi @@ -2,10 +2,14 @@ from typing import Any from django.utils.functional import LazyObject +# explicit dependency on standard settings to make it loaded +from . import global_settings + ENVIRONMENT_VARIABLE: str = ... # required for plugin to be able to distinguish this specific instance of LazySettings from others -class _DjangoConfLazyObject(LazyObject): ... +class _DjangoConfLazyObject(LazyObject): + def __getattr__(self, item: Any) -> Any: ... class LazySettings(_DjangoConfLazyObject): configured: bool diff --git a/django-stubs/conf/global_settings.pyi b/django-stubs/conf/global_settings.pyi index 292ba2134..5869a79a1 100644 --- a/django-stubs/conf/global_settings.pyi +++ b/django-stubs/conf/global_settings.pyi @@ -5,14 +5,11 @@ by the DJANGO_SETTINGS_MODULE environment variable. # This is defined here as a do-nothing function because we can't import # django.utils.translation -- that module depends on the settings. -from typing import Any, Dict, List, Optional, Pattern, Tuple, Protocol, Union, Callable, TYPE_CHECKING, Sequence +from typing import Any, Dict, List, Optional, Pattern, Protocol, Sequence, Tuple, Union #################### # CORE # #################### -if TYPE_CHECKING: - from django.db.models.base import Model - DEBUG: bool = ... # Whether the framework should propagate raw exceptions rather than catching @@ -153,7 +150,7 @@ FORCE_SCRIPT_NAME = None # ] DISALLOWED_USER_AGENTS: List[Pattern] = ... -ABSOLUTE_URL_OVERRIDES: Dict[str, Callable[[Model], str]] = ... +ABSOLUTE_URL_OVERRIDES: Dict[str, Any] = ... # List of compiled regular expression objects representing URLs that need not # be reported by BrokenLinkEmailsMiddleware. Here are a few examples: diff --git a/mypy_django_plugin/helpers.py b/mypy_django_plugin/helpers.py index d4b06b833..abdc0861b 100644 --- a/mypy_django_plugin/helpers.py +++ b/mypy_django_plugin/helpers.py @@ -4,12 +4,13 @@ from mypy.mro import calculate_mro from mypy.nodes import ( - AssignmentStmt, ClassDef, Expression, ImportedName, Lvalue, MypyFile, NameExpr, SymbolNode, TypeInfo, - SymbolTable, SymbolTableNode, Block, GDEF, MDEF, Var) -from mypy.plugin import FunctionContext, MethodContext + GDEF, MDEF, AssignmentStmt, Block, CallExpr, ClassDef, Expression, ImportedName, Lvalue, MypyFile, NameExpr, + SymbolNode, SymbolTable, SymbolTableNode, TypeInfo, Var, +) +from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext from mypy.types import ( - AnyType, Instance, NoneTyp, Type, TypeOfAny, TypeVarType, UnionType, - TupleType, TypedDictType) + AnyType, Instance, NoneTyp, TupleType, Type, TypedDictType, TypeOfAny, TypeVarType, UnionType, +) if typing.TYPE_CHECKING: from mypy.checker import TypeChecker @@ -216,6 +217,7 @@ def extract_field_setter_type(tp: Instance) -> Optional[Type]: def extract_field_getter_type(tp: Type) -> Optional[Type]: + """ Extract return type of __get__ of subclass of Field""" if not isinstance(tp, Instance): return None if tp.type.has_base(FIELD_FULLNAME): @@ -226,13 +228,12 @@ def extract_field_getter_type(tp: Type) -> Optional[Type]: return None -def get_django_metadata(model: TypeInfo) -> Dict[str, typing.Any]: - return model.metadata.setdefault('django', {}) +def get_django_metadata(model_info: TypeInfo) -> Dict[str, typing.Any]: + return model_info.metadata.setdefault('django', {}) def get_related_field_primary_key_names(base_model: TypeInfo) -> typing.List[str]: - django_metadata = get_django_metadata(base_model) - return django_metadata.setdefault('related_field_primary_keys', []) + return get_django_metadata(base_model).setdefault('related_field_primary_keys', []) def get_fields_metadata(model: TypeInfo) -> Dict[str, typing.Any]: @@ -243,6 +244,10 @@ def get_lookups_metadata(model: TypeInfo) -> Dict[str, typing.Any]: return get_django_metadata(model).setdefault('lookups', {}) +def get_related_managers_metadata(model: TypeInfo) -> Dict[str, typing.Any]: + return get_django_metadata(model).setdefault('related_managers', {}) + + def extract_explicit_set_type_of_model_primary_key(model: TypeInfo) -> Optional[Type]: """ If field with primary_key=True is set on the model, extract its __set__ type. @@ -310,7 +315,7 @@ def is_field_nullable(model: TypeInfo, field_name: str) -> bool: return get_fields_metadata(model).get(field_name, {}).get('null', False) -def is_foreign_key(t: Type) -> bool: +def is_foreign_key_like(t: Type) -> bool: if not isinstance(t, Instance): return False return has_any_of_bases(t.type, (FOREIGN_KEY_FULLNAME, ONETOONE_FIELD_FULLNAME)) @@ -366,13 +371,14 @@ def make_named_tuple(api: 'TypeChecker', fields: 'OrderedDict[str, Type]', name: return TupleType(list(fields.values()), fallback=fallback) -def make_typeddict(api: 'TypeChecker', fields: 'OrderedDict[str, Type]', required_keys: typing.Set[str]) -> Type: +def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, Type]', + required_keys: typing.Set[str]) -> TypedDictType: object_type = api.named_generic_type('mypy_extensions._TypedDict', []) typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type) return typed_dict_type -def make_tuple(api: 'TypeChecker', fields: typing.List[Type]) -> Type: +def make_tuple(api: 'TypeChecker', fields: typing.List[Type]) -> TupleType: implicit_any = AnyType(TypeOfAny.special_form) fallback = api.named_generic_type('builtins.tuple', [implicit_any]) return TupleType(fields, fallback=fallback) @@ -386,3 +392,52 @@ def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is descriptor_type = make_optional(descriptor_type) return descriptor_type return AnyType(TypeOfAny.unannotated) + + +def iter_over_classdefs(module_file: MypyFile) -> typing.Iterator[ClassDef]: + for defn in module_file.defs: + if isinstance(defn, ClassDef): + yield defn + + +def iter_call_assignments(klass: ClassDef) -> typing.Iterator[typing.Tuple[Lvalue, CallExpr]]: + for lvalue, rvalue in iter_over_assignments(klass): + if isinstance(rvalue, CallExpr): + yield lvalue, rvalue + + +def get_related_manager_type_from_metadata(model_info: TypeInfo, related_manager_name: str, + api: CheckerPluginInterface) -> Optional[Instance]: + related_manager_metadata = get_related_managers_metadata(model_info) + if not related_manager_metadata: + return None + + if related_manager_name not in related_manager_metadata: + return None + + manager_class_name = related_manager_metadata[related_manager_name]['manager'] + of = related_manager_metadata[related_manager_name]['of'] + of_types = [] + for of_type_name in of: + if of_type_name == 'any': + of_types.append(AnyType(TypeOfAny.implementation_artifact)) + else: + try: + of_type = api.named_generic_type(of_type_name, []) + except AssertionError: + # Internal error: attempted lookup of unknown name + of_type = AnyType(TypeOfAny.implementation_artifact) + + of_types.append(of_type) + + return api.named_generic_type(manager_class_name, of_types) + + +def get_primary_key_field_name(model_info: TypeInfo) -> Optional[str]: + for base in model_info.mro: + fields = get_fields_metadata(base) + for field_name, field_props in fields.items(): + is_primary_key = field_props.get('primary_key', False) + if is_primary_key: + return field_name + return None diff --git a/mypy_django_plugin/lookups.py b/mypy_django_plugin/lookups.py index 48b48d131..a8eb54a10 100644 --- a/mypy_django_plugin/lookups.py +++ b/mypy_django_plugin/lookups.py @@ -1,9 +1,9 @@ -import dataclasses -from typing import Union, List +from typing import List, Union +import dataclasses from mypy.nodes import TypeInfo from mypy.plugin import CheckerPluginInterface -from mypy.types import Type, Instance +from mypy.types import Instance, Type from mypy_django_plugin import helpers @@ -57,20 +57,24 @@ def resolve_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, looku return nodes +def resolve_model_pk_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo) -> LookupNode: + # Primary keys are special-cased + primary_key_type = helpers.extract_primary_key_type_for_get(model_type_info) + if primary_key_type: + return FieldNode(primary_key_type) + else: + # No PK, use the get type for AutoField as PK type. + autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField') + pk_type = helpers.get_private_descriptor_type(autofield_info, '_pyi_private_get_type', + is_nullable=False) + return FieldNode(pk_type) + + def resolve_model_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, lookup: str) -> LookupNode: """Resolve a lookup on the given model.""" if lookup == 'pk': - # Primary keys are special-cased - primary_key_type = helpers.extract_primary_key_type_for_get(model_type_info) - if primary_key_type: - return FieldNode(primary_key_type) - else: - # No PK, use the get type for AutoField as PK type. - autofield_info = api.lookup_typeinfo('django.db.models.fields.AutoField') - pk_type = helpers.get_private_descriptor_type(autofield_info, '_pyi_private_get_type', - is_nullable=False) - return FieldNode(pk_type) + return resolve_model_pk_lookup(api, model_type_info) field_name = get_actual_field_name_for_lookup_field(lookup, model_type_info) @@ -82,7 +86,7 @@ def resolve_model_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, if field_name.endswith('_id'): field_name_without_id = field_name.rstrip('_id') foreign_key_field = model_type_info.get(field_name_without_id) - if foreign_key_field is not None and helpers.is_foreign_key(foreign_key_field.type): + if foreign_key_field is not None and helpers.is_foreign_key_like(foreign_key_field.type): # Hack: If field ends with '_id' and there is a model field without the '_id' suffix, then use that field. field_node = foreign_key_field field_name = field_name_without_id @@ -92,10 +96,23 @@ def resolve_model_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, raise LookupException( f'When resolving lookup "{lookup}", could not determine type for {model_type_info.name()}.{field_name}') - if helpers.is_foreign_key(field_node_type): + if field_node_type.type.fullname() == 'builtins.object': + # could be related manager + related_manager_type = helpers.get_related_manager_type_from_metadata(model_type_info, field_name, api) + if related_manager_type: + model_arg = related_manager_type.args[0] + if not isinstance(model_arg, Instance): + raise LookupException( + f'When resolving lookup "{lookup}", could not determine type ' + f'for {model_type_info.name()}.{field_name}') + + return RelatedModelNode(typ=model_arg, is_nullable=False) + + if helpers.is_foreign_key_like(field_node_type): field_type = helpers.extract_field_getter_type(field_node_type) is_nullable = helpers.is_optional(field_type) if is_nullable: + # type is always non-optional field_type = helpers.make_required(field_type) if isinstance(field_type, Instance): @@ -104,24 +121,16 @@ def resolve_model_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo, raise LookupException(f"Not an instance for field {field_type} lookup {lookup}") field_type = helpers.extract_field_getter_type(field_node_type) - if field_type: return FieldNode(typ=field_type) - else: - # Not a Field - if field_name == 'id': - # If no 'id' field was fouond, use an int - return FieldNode(api.named_generic_type('builtins.int', [])) - - related_manager_arg = None - if field_node_type.type.has_base(helpers.RELATED_MANAGER_CLASS_FULLNAME): - related_manager_arg = field_node_type.args[0] - - if related_manager_arg is not None: - # Reverse relation - return RelatedModelNode(typ=related_manager_arg, is_nullable=True) - raise LookupException( - f'When resolving lookup "{lookup}", could not determine type for {model_type_info.name()}.{field_name}') + + # Not a Field + if field_name == 'id': + # If no 'id' field was found, use an int + return FieldNode(api.named_generic_type('builtins.int', [])) + + raise LookupException( + f'When resolving lookup {lookup!r}, could not determine type for {model_type_info.name()}.{field_name}') def get_actual_field_name_for_lookup_field(lookup: str, model_type_info: TypeInfo) -> str: diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 00e5d8ef8..5e3c73937 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,29 +1,33 @@ import os from functools import partial -from typing import Callable, Dict, Optional, Union, cast +from typing import Callable, Dict, List, Optional, Tuple, cast -from mypy.nodes import MemberExpr, NameExpr, TypeInfo +from mypy.nodes import MypyFile, NameExpr, TypeInfo from mypy.options import Options from mypy.plugin import ( - AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin, - AnalyzeTypeContext) -from mypy.types import ( - AnyType, CallableType, Instance, NoneTyp, Type, TypeOfAny, TypeType, UnionType, + AnalyzeTypeContext, AttributeContext, ClassDefContext, FunctionContext, MethodContext, Plugin, ) +from mypy.types import AnyType, Instance, Type, TypeOfAny -from mypy_django_plugin import helpers, monkeypatch +from mypy_django_plugin import helpers from mypy_django_plugin.config import Config from mypy_django_plugin.transformers import fields, init_create from mypy_django_plugin.transformers.forms import ( - make_meta_nested_class_inherit_from_any, + extract_proper_type_for_get_form, extract_proper_type_for_get_form_class, make_meta_nested_class_inherit_from_any, ) from mypy_django_plugin.transformers.migrations import ( - determine_model_cls_from_string_for_migrations, get_string_value_from_expr, + determine_model_cls_from_string_for_migrations, ) from mypy_django_plugin.transformers.models import process_model_class -from mypy_django_plugin.transformers.queryset import extract_proper_type_for_values_and_values_list +from mypy_django_plugin.transformers.queryset import ( + extract_proper_type_for_queryset_values, extract_proper_type_queryset_values_list, + set_first_generic_param_as_default_for_second, +) +from mypy_django_plugin.transformers.related import ( + determine_type_of_related_manager, extract_and_return_primary_key_of_bound_related_field_parameter, +) from mypy_django_plugin.transformers.settings import ( - AddSettingValuesToDjangoConfObject, get_settings_metadata, + get_type_of_setting, return_user_model_hook, ) @@ -35,20 +39,21 @@ def transform_model_class(ctx: ClassDefContext) -> None: pass else: if sym is not None and isinstance(sym.node, TypeInfo): - sym.node.metadata['django']['model_bases'][ctx.cls.fullname] = 1 + helpers.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1 + process_model_class(ctx) def transform_manager_class(ctx: ClassDefContext) -> None: sym = ctx.api.lookup_fully_qualified_or_none(helpers.MANAGER_CLASS_FULLNAME) if sym is not None and isinstance(sym.node, TypeInfo): - sym.node.metadata['django']['manager_bases'][ctx.cls.fullname] = 1 + helpers.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1 def transform_form_class(ctx: ClassDefContext) -> None: sym = ctx.api.lookup_fully_qualified_or_none(helpers.BASEFORM_CLASS_FULLNAME) if sym is not None and isinstance(sym.node, TypeInfo): - sym.node.metadata['django']['baseform_bases'][ctx.cls.fullname] = 1 + helpers.get_django_metadata(sym.node)['baseform_bases'][ctx.cls.fullname] = 1 make_meta_nested_class_inherit_from_any(ctx) @@ -83,123 +88,31 @@ def determine_proper_manager_type(ctx: FunctionContext) -> Type: return ret -def set_first_generic_param_as_default_for_second(fullname: str, ctx: AnalyzeTypeContext) -> Type: - if not ctx.type.args: - try: - return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit), - AnyType(TypeOfAny.explicit)]) - except KeyError: - # really should never happen - return AnyType(TypeOfAny.explicit) - - args = ctx.type.args - if len(args) == 1: - args = [args[0], args[0]] - - analyzed_args = [ctx.api.analyze_type(arg) for arg in args] - try: - return ctx.api.named_type(fullname, analyzed_args) - except KeyError: - # really should never happen - return AnyType(TypeOfAny.explicit) - - -def return_user_model_hook(ctx: FunctionContext) -> Type: - from mypy.checker import TypeChecker - - api = cast(TypeChecker, ctx.api) - setting_expr = helpers.get_setting_expr(api, 'AUTH_USER_MODEL') - if setting_expr is None: - return ctx.default_return_type - - model_path = get_string_value_from_expr(setting_expr) - if model_path is None: - return ctx.default_return_type - - app_label, _, model_class_name = model_path.rpartition('.') - if app_label is None: - return ctx.default_return_type - - model_fullname = helpers.get_model_fullname(app_label, model_class_name, - all_modules=api.modules) - if model_fullname is None: - api.fail(f'"{app_label}.{model_class_name}" model class is not imported so far. Try to import it ' - f'(under if TYPE_CHECKING) at the beginning of the current file', - context=ctx.context) - return ctx.default_return_type - - model_info = helpers.lookup_fully_qualified_generic(model_fullname, - all_modules=api.modules) - if model_info is None or not isinstance(model_info, TypeInfo): - return ctx.default_return_type - return TypeType(Instance(model_info, [])) - - -def _extract_referred_to_type_info(typ: Union[UnionType, Instance]) -> Optional[TypeInfo]: - if isinstance(typ, Instance): - return typ.type - else: - # should be Union[TYPE, None] - typ = helpers.make_required(typ) - if isinstance(typ, Instance): - return typ.type - return None - - -def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type: - if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'): - return ctx.default_attr_type - - if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME): - return ctx.default_attr_type - - field_name = ctx.context.name.split('_')[0] - sym = ctx.type.type.get(field_name) - if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0: - referred_to = sym.type.args[1] - if isinstance(referred_to, AnyType): - return AnyType(TypeOfAny.implementation_artifact) +def return_type_for_id_field(ctx: AttributeContext) -> Type: + if not isinstance(ctx.type, Instance): + return AnyType(TypeOfAny.from_error) - model_type = _extract_referred_to_type_info(referred_to) - if model_type is None: - return AnyType(TypeOfAny.implementation_artifact) - - primary_key_type = helpers.extract_primary_key_type_for_get(model_type) - if primary_key_type: - return primary_key_type - - is_nullable = helpers.is_field_nullable(ctx.type.type, field_name) - if is_nullable: - return helpers.make_optional(ctx.default_attr_type) - - return ctx.default_attr_type - - -def return_integer_type_for_id_for_non_defined_primary_key_in_models(ctx: AttributeContext) -> Type: - if isinstance(ctx.type, Instance) and ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME): + model_info = ctx.type.type # type: TypeInfo + primary_key_field_name = helpers.get_primary_key_field_name(model_info) + if not primary_key_field_name: + # no field with primary_key=True, just return id as int return ctx.api.named_generic_type('builtins.int', []) - return ctx.default_attr_type + if primary_key_field_name != 'id': + # there's field with primary_key=True, but it's name is not 'id', fail + ctx.api.fail("Default primary key 'id' is not defined", ctx.context) + return AnyType(TypeOfAny.from_error) -class ExtractSettingType: - def __init__(self, module_fullname: str): - self.module_fullname = module_fullname - - def __call__(self, ctx: AttributeContext) -> Type: - from mypy.checker import TypeChecker - - api = cast(TypeChecker, ctx.api) - original_module = api.modules.get(self.module_fullname) - if original_module is None: - return ctx.default_attr_type + primary_key_sym = model_info.get(primary_key_field_name) + if primary_key_sym and isinstance(primary_key_sym.type, Instance): + pass - definition = ctx.context - if isinstance(definition, MemberExpr): - sym = original_module.names.get(definition.name) - if sym and sym.type: - return sym.type + # try to parse field type out of primary key field + field_type = helpers.extract_field_getter_type(primary_key_sym.type) + if field_type: + return field_type - return ctx.default_attr_type + return primary_key_sym.type def transform_form_view(ctx: ClassDefContext) -> None: @@ -208,80 +121,10 @@ def transform_form_view(ctx: ClassDefContext) -> None: helpers.get_django_metadata(ctx.cls.info)['form_class'] = form_class_value.fullname -def extract_proper_type_for_get_form_class(ctx: MethodContext) -> Type: - object_type = ctx.type - if not isinstance(object_type, Instance): - return ctx.default_return_type - - form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None) - if not form_class_fullname: - return ctx.default_return_type - - return TypeType(ctx.api.named_generic_type(form_class_fullname, [])) - - -def extract_proper_type_for_get_form(ctx: MethodContext) -> Type: - object_type = ctx.type - if not isinstance(object_type, Instance): - return ctx.default_return_type - - form_class_type = helpers.get_argument_type_by_name(ctx, 'form_class') - if form_class_type is None or isinstance(form_class_type, NoneTyp): - # extract from specified form_class in metadata - form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None) - if not form_class_fullname: - return ctx.default_return_type - - return ctx.api.named_generic_type(form_class_fullname, []) - - if isinstance(form_class_type, TypeType) and isinstance(form_class_type.item, Instance): - return form_class_type.item - - if isinstance(form_class_type, CallableType) and isinstance(form_class_type.ret_type, Instance): - return form_class_type.ret_type - - return ctx.default_return_type - - -def extract_proper_type_for_values_list(ctx: MethodContext) -> Type: - object_type = ctx.type - if not isinstance(object_type, Instance): - return ctx.default_return_type - - flat = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'flat')) - named = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'named')) - - ret = ctx.default_return_type - - any_type = AnyType(TypeOfAny.implementation_artifact) - if named and flat: - ctx.api.fail("'flat' and 'named' can't be used together.", ctx.context) - return ret - elif named: - # TODO: Fill in namedtuple fields/types - row_arg = ctx.api.named_generic_type('typing.NamedTuple', []) - elif flat: - # TODO: Figure out row_arg type dependent on the argument passed in - if len(ctx.args[0]) > 1: - ctx.api.fail("'flat' is not valid when values_list is called with more than one field.", ctx.context) - return ret - row_arg = any_type - else: - # TODO: Figure out tuple argument types dependent on the arguments passed in - row_arg = ctx.api.named_generic_type('builtins.tuple', [any_type]) - - first_arg = ret.args[0] if len(ret.args) > 0 else any_type - new_type_args = [first_arg, row_arg] - return helpers.reparametrize_instance(ret, new_type_args) - - class DjangoPlugin(Plugin): def __init__(self, options: Options) -> None: super().__init__(options) - monkeypatch.restore_original_load_graph() - monkeypatch.restore_original_dependencies_handling() - config_fpath = os.environ.get('MYPY_DJANGO_CONFIG', 'mypy_django.ini') if config_fpath and os.path.exists(config_fpath): self.config = Config.from_config_file(config_fpath) @@ -293,16 +136,6 @@ def __init__(self, options: Options) -> None: if 'DJANGO_SETTINGS_MODULE' in os.environ: self.django_settings_module = os.environ['DJANGO_SETTINGS_MODULE'] - settings_modules = ['django.conf.global_settings'] - if self.django_settings_module: - settings_modules.append(self.django_settings_module) - - auto_imports = ['mypy_extensions'] - auto_imports.extend(settings_modules) - - monkeypatch.add_modules_as_a_source_seed_files(auto_imports) - monkeypatch.inject_modules_as_dependencies_for_django_conf_settings(settings_modules) - def _get_current_model_bases(self) -> Dict[str, int]: model_sym = self.lookup_fully_qualified(helpers.MODEL_CLASS_FULLNAME) if model_sym is not None and isinstance(model_sym.node, TypeInfo): @@ -337,40 +170,70 @@ def _get_current_queryset_bases(self) -> Dict[str, int]: else: return {} + def _get_settings_modules_in_order_of_priority(self) -> List[str]: + settings_modules = [] + if self.django_settings_module: + settings_modules.append(self.django_settings_module) + + settings_modules.append('django.conf.global_settings') + return settings_modules + + def _get_typeinfo_or_none(self, class_name: str) -> Optional[TypeInfo]: + sym = self.lookup_fully_qualified(class_name) + if sym is not None and isinstance(sym.node, TypeInfo): + return sym.node + return None + + def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]: + if file.fullname() == 'django.conf' and self.django_settings_module: + return [(10, self.django_settings_module, -1)] + + if file.fullname() == 'django.db.models.query': + return [(10, 'mypy_extensions', -1)] + + return [] + def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: if fullname == 'django.contrib.auth.get_user_model': - return return_user_model_hook + return partial(return_user_model_hook, + settings_modules=self._get_settings_modules_in_order_of_priority()) manager_bases = self._get_current_manager_bases() if fullname in manager_bases: return determine_proper_manager_type - sym = self.lookup_fully_qualified(fullname) - if sym is not None and isinstance(sym.node, TypeInfo): - if sym.node.has_base(helpers.FIELD_FULLNAME): + info = self._get_typeinfo_or_none(fullname) + if info: + if info.has_base(helpers.FIELD_FULLNAME): return fields.adjust_return_type_of_field_instantiation - if sym.node.metadata.get('django', {}).get('generated_init'): + if helpers.get_django_metadata(info).get('generated_init'): return init_create.redefine_and_typecheck_model_init def get_method_hook(self, fullname: str ) -> Optional[Callable[[MethodContext], Type]]: class_name, _, method_name = fullname.rpartition('.') + if method_name == 'get_form_class': - sym = self.lookup_fully_qualified(class_name) - if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): + info = self._get_typeinfo_or_none(class_name) + if info and info.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): return extract_proper_type_for_get_form_class if method_name == 'get_form': - sym = self.lookup_fully_qualified(class_name) - if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): + info = self._get_typeinfo_or_none(class_name) + if info and info.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): return extract_proper_type_for_get_form - if method_name in ('values', 'values_list'): - sym = self.lookup_fully_qualified(class_name) - if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.QUERYSET_CLASS_FULLNAME): - return partial(extract_proper_type_for_values_and_values_list, method_name) + if method_name == 'values': + model_info = self._get_typeinfo_or_none(class_name) + if model_info and model_info.has_base(helpers.QUERYSET_CLASS_FULLNAME): + return extract_proper_type_for_queryset_values + + if method_name == 'values_list': + model_info = self._get_typeinfo_or_none(class_name) + if model_info and model_info.has_base(helpers.QUERYSET_CLASS_FULLNAME): + return extract_proper_type_queryset_values_list if fullname in {'django.apps.registry.Apps.get_model', 'django.db.migrations.state.StateApps.get_model'}: @@ -384,13 +247,6 @@ def get_method_hook(self, fullname: str def get_base_class_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: - if fullname == helpers.DUMMY_SETTINGS_BASE_CLASS: - settings_modules = ['django.conf.global_settings'] - if self.django_settings_module: - settings_modules.append(self.django_settings_module) - return AddSettingValuesToDjangoConfObject(settings_modules, - self.config.ignore_missing_settings) - if fullname in self._get_current_model_bases(): return transform_model_class @@ -400,25 +256,34 @@ def get_base_class_hook(self, fullname: str if fullname in self._get_current_form_bases(): return transform_form_class - sym = self.lookup_fully_qualified(fullname) - if sym and isinstance(sym.node, TypeInfo) and sym.node.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): + info = self._get_typeinfo_or_none(fullname) + if info and info.has_base(helpers.FORM_MIXIN_CLASS_FULLNAME): return transform_form_view return None def get_attribute_hook(self, fullname: str ) -> Optional[Callable[[AttributeContext], Type]]: - if fullname == 'builtins.object.id': - return return_integer_type_for_id_for_non_defined_primary_key_in_models - - module, _, name = fullname.rpartition('.') - sym = self.lookup_fully_qualified('django.conf.LazySettings') - if sym and isinstance(sym.node, TypeInfo): - metadata = get_settings_metadata(sym.node) - if module == 'builtins.object' and name in metadata: - return ExtractSettingType(module_fullname=metadata[name]) - - return extract_and_return_primary_key_of_bound_related_field_parameter + class_name, _, attr_name = fullname.rpartition('.') + if class_name == helpers.DUMMY_SETTINGS_BASE_CLASS: + return partial(get_type_of_setting, + setting_name=attr_name, + settings_modules=self._get_settings_modules_in_order_of_priority(), + ignore_missing_settings=self.config.ignore_missing_settings) + + if class_name in self._get_current_model_bases(): + if attr_name == 'id': + return return_type_for_id_field + + model_info = self._get_typeinfo_or_none(class_name) + if model_info: + related_managers = helpers.get_related_managers_metadata(model_info) + if attr_name in related_managers: + return partial(determine_type_of_related_manager, + related_manager_name=attr_name) + + if attr_name.endswith('_id'): + return extract_and_return_primary_key_of_bound_related_field_parameter def get_type_analyze_hook(self, fullname: str ) -> Optional[Callable[[AnalyzeTypeContext], Type]]: diff --git a/mypy_django_plugin/monkeypatch.py b/mypy_django_plugin/monkeypatch.py deleted file mode 100644 index 45252c39a..000000000 --- a/mypy_django_plugin/monkeypatch.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import List, Optional - -from mypy import build -from mypy.build import BuildManager, Graph, State -from mypy.modulefinder import BuildSource - -old_load_graph = build.load_graph -OldState = build.State - - -def is_module_present_in_sources(module_name: str, sources: List[BuildSource]): - return any([source.module == module_name for source in sources]) - - -def add_modules_as_a_source_seed_files(modules: List[str]) -> None: - def patched_load_graph(sources: List[BuildSource], manager: BuildManager, - old_graph: Optional[Graph] = None, - new_modules: Optional[List[State]] = None): - # add global settings - for module_name in modules: - if not is_module_present_in_sources(module_name, sources): - sources.append(BuildSource(None, module_name, None)) - - return old_load_graph(sources=sources, manager=manager, - old_graph=old_graph, - new_modules=new_modules) - - build.load_graph = patched_load_graph - - -def restore_original_load_graph(): - from mypy import build - - build.load_graph = old_load_graph - - -def inject_modules_as_dependencies_for_django_conf_settings(modules: List[str]) -> None: - from mypy import build - - class PatchedState(build.State): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self.id == 'django.conf': - self.dependencies.extend(modules) - - build.State = PatchedState - - -def restore_original_dependencies_handling(): - from mypy import build - - build.State = OldState diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index bb96d0db3..85d2df7da 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -149,7 +149,7 @@ def record_field_properties_into_outer_model_class(ctx: FunctionContext) -> None if field_name is None: return - fields_metadata = outer_model.metadata.setdefault('django', {}).setdefault('fields', {}) + fields_metadata = helpers.get_fields_metadata(outer_model) # primary key is_primary_key = False diff --git a/mypy_django_plugin/transformers/forms.py b/mypy_django_plugin/transformers/forms.py index d12ba429b..2afc519da 100644 --- a/mypy_django_plugin/transformers/forms.py +++ b/mypy_django_plugin/transformers/forms.py @@ -1,4 +1,5 @@ -from mypy.plugin import ClassDefContext +from mypy.plugin import ClassDefContext, MethodContext +from mypy.types import CallableType, Instance, NoneTyp, Type, TypeType from mypy_django_plugin import helpers @@ -8,3 +9,38 @@ def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None: if meta_node is None: return None meta_node.fallback_to_any = True + + +def extract_proper_type_for_get_form(ctx: MethodContext) -> Type: + object_type = ctx.type + if not isinstance(object_type, Instance): + return ctx.default_return_type + + form_class_type = helpers.get_argument_type_by_name(ctx, 'form_class') + if form_class_type is None or isinstance(form_class_type, NoneTyp): + # extract from specified form_class in metadata + form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None) + if not form_class_fullname: + return ctx.default_return_type + + return ctx.api.named_generic_type(form_class_fullname, []) + + if isinstance(form_class_type, TypeType) and isinstance(form_class_type.item, Instance): + return form_class_type.item + + if isinstance(form_class_type, CallableType) and isinstance(form_class_type.ret_type, Instance): + return form_class_type.ret_type + + return ctx.default_return_type + + +def extract_proper_type_for_get_form_class(ctx: MethodContext) -> Type: + object_type = ctx.type + if not isinstance(object_type, Instance): + return ctx.default_return_type + + form_class_fullname = helpers.get_django_metadata(object_type.type).get('form_class', None) + if not form_class_fullname: + return ctx.default_return_type + + return TypeType(ctx.api.named_generic_type(form_class_fullname, [])) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index a357fa3b0..9e4f7de50 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -1,11 +1,11 @@ from abc import ABCMeta, abstractmethod -from typing import Dict, Iterator, List, Optional, Tuple, cast +from typing import Any, Dict, Iterator, List, Optional, Tuple, cast import dataclasses from mypy.nodes import ( - ARG_STAR, ARG_STAR2, MDEF, Argument, CallExpr, ClassDef, Expression, IndexExpr, Lvalue, MemberExpr, MypyFile, + ARG_POS, ARG_STAR, ARG_STAR2, MDEF, Argument, CallExpr, ClassDef, Expression, IndexExpr, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTableNode, TypeInfo, Var, - ARG_POS) +) from mypy.plugin import ClassDefContext from mypy.plugins.common import add_method from mypy.semanal import SemanticAnalyzerPass2 @@ -37,8 +37,10 @@ def is_abstract_model(self) -> bool: return self.api.parse_bool(is_abstract_expr) def add_new_node_to_model_class(self, name: str, typ: Instance) -> None: + # type=: type of the variable itself var = Var(name=name, type=typ) - var.info = typ.type + # var.info: type of the object variable is bound to + var.info = self.model_classdef.info var._fullname = self.model_classdef.info.fullname() + '.' + name var.is_inferred = True var.is_initialized_in_class = True @@ -49,14 +51,8 @@ def run(self) -> None: raise NotImplementedError() -def iter_call_assignments(klass: ClassDef) -> Iterator[Tuple[Lvalue, CallExpr]]: - for lvalue, rvalue in helpers.iter_over_assignments(klass): - if isinstance(rvalue, CallExpr): - yield lvalue, rvalue - - def iter_over_one_to_n_related_fields(klass: ClassDef) -> Iterator[Tuple[NameExpr, CallExpr]]: - for lvalue, rvalue in iter_call_assignments(klass): + for lvalue, rvalue in helpers.iter_call_assignments(klass): if (isinstance(lvalue, NameExpr) and isinstance(rvalue.callee, MemberExpr)): if rvalue.callee.fullname in {helpers.FOREIGN_KEY_FULLNAME, @@ -80,15 +76,6 @@ def run(self) -> None: meta_node.fallback_to_any = True -def get_model_argument(manager_info: TypeInfo) -> Optional[Instance]: - for base in manager_info.bases: - if base.args: - model_arg = base.args[0] - if isinstance(model_arg, Instance) and model_arg.type.has_base(helpers.MODEL_CLASS_FULLNAME): - return model_arg - return None - - class AddDefaultObjectsManager(ModelClassInitializer): def add_new_manager(self, name: str, manager_type: Optional[Instance]) -> None: if manager_type is None: @@ -103,7 +90,7 @@ def add_private_default_manager(self, manager_type: Optional[Instance]) -> None: def get_existing_managers(self) -> List[Tuple[str, TypeInfo]]: managers = [] for base in self.model_classdef.info.mro: - for name_expr, member_expr in iter_call_assignments(base.defn): + for name_expr, member_expr in helpers.iter_call_assignments(base.defn): manager_name = name_expr.name callee_expr = member_expr.callee if isinstance(callee_expr, IndexExpr): @@ -147,7 +134,7 @@ def run(self) -> None: # no need for .id attr return None - for _, rvalue in iter_call_assignments(self.model_classdef): + for _, rvalue in helpers.iter_call_assignments(self.model_classdef): if ('primary_key' in rvalue.arg_names and self.api.parse_bool(rvalue.args[rvalue.arg_names.index('primary_key')])): break @@ -156,23 +143,31 @@ def run(self) -> None: class AddRelatedManagers(ModelClassInitializer): + def add_related_manager_variable(self, manager_name: str, related_field_type_data: Dict[str, Any]) -> None: + # add dummy related manager for use later + self.add_new_node_to_model_class(manager_name, self.api.builtin_type('builtins.object')) + + # save name in metadata for use in get_attribute_hook later + related_managers_metadata = helpers.get_related_managers_metadata(self.model_classdef.info) + related_managers_metadata[manager_name] = related_field_type_data + def run(self) -> None: for module_name, module_file in self.api.modules.items(): - for defn in iter_over_classdefs(module_file): - for lvalue, rvalue in iter_call_assignments(defn): + for model_defn in helpers.iter_over_classdefs(module_file): + for lvalue, rvalue in helpers.iter_call_assignments(model_defn): if is_related_field(rvalue, module_file): try: - ref_to_fullname = extract_ref_to_fullname(rvalue, - module_file=module_file, - all_modules=self.api.modules) + referenced_model_fullname = extract_ref_to_fullname(rvalue, + module_file=module_file, + all_modules=self.api.modules) except helpers.SelfReference: - ref_to_fullname = defn.fullname + referenced_model_fullname = model_defn.fullname except helpers.SameFileModel as exc: - ref_to_fullname = module_name + '.' + exc.model_cls_name + referenced_model_fullname = module_name + '.' + exc.model_cls_name - if self.model_classdef.fullname == ref_to_fullname: - related_name = defn.name.lower() + '_set' + if self.model_classdef.fullname == referenced_model_fullname: + related_name = model_defn.name.lower() + '_set' if 'related_name' in rvalue.arg_names: related_name_expr = rvalue.args[rvalue.arg_names.index('related_name')] if not isinstance(related_name_expr, StrExpr): @@ -192,10 +187,28 @@ def run(self) -> None: else: # No related_query_name specified, default to related_name related_query_name = related_name - typ = get_related_field_type(rvalue, self.api, defn.info) - if typ is None: - continue - self.add_new_node_to_model_class(related_name, typ) + + # field_type_data = get_related_field_type(rvalue, self.api, defn.info) + # if typ is None: + # continue + + # TODO: recursively serialize types, or just https://github.com/python/mypy/issues/6506 + # as long as Model is not a Generic, one level depth is fine + if rvalue.callee.name in {'ForeignKey', 'ManyToManyField'}: + field_type_data = { + 'manager': helpers.RELATED_MANAGER_CLASS_FULLNAME, + 'of': [model_defn.info.fullname()] + } + # return api.named_type_or_none(helpers.RELATED_MANAGER_CLASS_FULLNAME, + # args=[Instance(related_model_typ, [])]) + else: + field_type_data = { + 'manager': model_defn.info.fullname(), + 'of': [] + } + + self.add_related_manager_variable(related_name, related_field_type_data=field_type_data) + if related_query_name is not None: # Only create related_query_name if it is a string literal helpers.get_lookups_metadata(self.model_classdef.info)[related_query_name] = { @@ -203,19 +216,20 @@ def run(self) -> None: } -def iter_over_classdefs(module_file: MypyFile) -> Iterator[ClassDef]: - for defn in module_file.defs: - if isinstance(defn, ClassDef): - yield defn - - -def get_related_field_type(rvalue: CallExpr, api: SemanticAnalyzerPass2, - related_model_typ: TypeInfo) -> Optional[Instance]: +def get_related_field_type(rvalue: CallExpr, related_model_typ: TypeInfo) -> Dict[str, Any]: if rvalue.callee.name in {'ForeignKey', 'ManyToManyField'}: - return api.named_type_or_none(helpers.RELATED_MANAGER_CLASS_FULLNAME, - args=[Instance(related_model_typ, [])]) + return { + 'manager': helpers.RELATED_MANAGER_CLASS_FULLNAME, + 'of': [related_model_typ.fullname()] + } + # return api.named_type_or_none(helpers.RELATED_MANAGER_CLASS_FULLNAME, + # args=[Instance(related_model_typ, [])]) else: - return Instance(related_model_typ, []) + return { + 'manager': related_model_typ.fullname(), + 'of': [] + } + # return Instance(related_model_typ, []) def is_related_field(expr: CallExpr, module_file: MypyFile) -> bool: diff --git a/mypy_django_plugin/transformers/queryset.py b/mypy_django_plugin/transformers/queryset.py index 33a6af0b5..d3b0e3378 100644 --- a/mypy_django_plugin/transformers/queryset.py +++ b/mypy_django_plugin/transformers/queryset.py @@ -1,46 +1,99 @@ from collections import OrderedDict -from typing import Union, List, cast, Optional +from typing import List, Optional, cast from mypy.checker import TypeChecker from mypy.nodes import StrExpr, TypeInfo -from mypy.plugin import MethodContext, CheckerPluginInterface -from mypy.types import Type, Instance, AnyType, TypeOfAny +from mypy.plugin import ( + AnalyzeTypeContext, CheckerPluginInterface, MethodContext, +) +from mypy.types import AnyType, Instance, Type, TypeOfAny from mypy_django_plugin import helpers -from mypy_django_plugin.lookups import resolve_lookup, RelatedModelNode, LookupException +from mypy_django_plugin.lookups import ( + LookupException, RelatedModelNode, resolve_lookup, +) -def extract_proper_type_for_values_and_values_list(method_name: str, ctx: MethodContext) -> Type: - api = cast(TypeChecker, ctx.api) +def get_queryset_model_arg(ret_type: Instance) -> Type: + if ret_type.args: + return ret_type.args[0] + else: + return AnyType(TypeOfAny.implementation_artifact) + +def extract_proper_type_for_queryset_values(ctx: MethodContext) -> Type: object_type = ctx.type if not isinstance(object_type, Instance): return ctx.default_return_type - ret = ctx.default_return_type - - any_type = AnyType(TypeOfAny.implementation_artifact) fields_arg_expr = ctx.args[ctx.callee_arg_names.index('fields')] + if len(fields_arg_expr) == 0: + # values_list/values with no args is not yet supported, so default to Any types for field types + # It should in the future include all model fields, "extra" fields and "annotated" fields + return ctx.default_return_type - model_arg: Union[AnyType, Type] = ret.args[0] if len(ret.args) > 0 else any_type + model_arg = get_queryset_model_arg(ctx.default_return_type) + if isinstance(model_arg, Instance): + model_type_info = model_arg.type + else: + model_type_info = None - column_names: List[Optional[str]] = [] column_types: OrderedDict[str, Type] = OrderedDict() - fill_column_types = True + # parse *fields + for field_expr in fields_arg_expr: + if isinstance(field_expr, StrExpr): + field_name = field_expr.value + # Default to any type + column_types[field_name] = AnyType(TypeOfAny.implementation_artifact) + + if model_type_info: + resolved_lookup_type = resolve_values_lookup(ctx.api, model_type_info, field_name) + if resolved_lookup_type is not None: + column_types[field_name] = resolved_lookup_type + else: + return ctx.default_return_type + + # parse **expressions + expression_arg_names = ctx.arg_names[ctx.callee_arg_names.index('expressions')] + for expression_name in expression_arg_names: + # Arbitrary additional annotation expressions are supported, but they all have type Any for now + column_types[expression_name] = AnyType(TypeOfAny.implementation_artifact) + + row_arg = helpers.make_typeddict(ctx.api, fields=column_types, + required_keys=set()) + return helpers.reparametrize_instance(ctx.default_return_type, [model_arg, row_arg]) + +def extract_proper_type_queryset_values_list(ctx: MethodContext) -> Type: + object_type = ctx.type + if not isinstance(object_type, Instance): + return ctx.default_return_type + + ret = ctx.default_return_type + + model_arg = get_queryset_model_arg(ctx.default_return_type) + # model_arg: Union[AnyType, Type] = ret.args[0] if len(ret.args) > 0 else any_type + + column_names: List[Optional[str]] = [] + column_types: OrderedDict[str, Type] = OrderedDict() + + fields_arg_expr = ctx.args[ctx.callee_arg_names.index('fields')] + fields_param_is_specified = True if len(fields_arg_expr) == 0: # values_list/values with no args is not yet supported, so default to Any types for field types # It should in the future include all model fields, "extra" fields and "annotated" fields - fill_column_types = False + fields_param_is_specified = False if isinstance(model_arg, Instance): model_type_info = model_arg.type else: model_type_info = None + any_type = AnyType(TypeOfAny.implementation_artifact) + # Figure out each field name passed to fields - has_dynamic_column_names = False + only_strings_as_fields_expressions = True for field_expr in fields_arg_expr: if isinstance(field_expr, StrExpr): field_name = field_expr.value @@ -55,52 +108,48 @@ def extract_proper_type_for_values_and_values_list(method_name: str, ctx: Method else: # Dynamic field names are partially supported for values_list, but not values column_names.append(None) - has_dynamic_column_names = True - - if method_name == 'values_list': - flat = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'flat')) - named = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'named')) - - if named and flat: - api.fail("'flat' and 'named' can't be used together.", ctx.context) - return ret - elif named: - if fill_column_types and not has_dynamic_column_names: - row_arg = helpers.make_named_tuple(api, fields=column_types, name="Row") - else: - row_arg = helpers.make_named_tuple(api, fields=OrderedDict(), name="Row") - elif flat: - if len(ctx.args[0]) > 1: - api.fail("'flat' is not valid when values_list is called with more than one field.", ctx.context) - return ret - if fill_column_types and not has_dynamic_column_names: - # Grab first element - row_arg = column_types[column_names[0]] - else: - row_arg = any_type - else: - if fill_column_types: - args = [ - # Fallback to Any if the column name is unknown (e.g. dynamic) - column_types.get(column_name, any_type) if column_name is not None else any_type - for column_name in column_names - ] - else: - args = [any_type] - row_arg = helpers.make_tuple(api, fields=args) - elif method_name == 'values': - expression_arg_names = ctx.arg_names[ctx.callee_arg_names.index('expressions')] - for expression_name in expression_arg_names: - # Arbitrary additional annotation expressions are supported, but they all have type Any for now - column_names.append(expression_name) - column_types[expression_name] = any_type - - if fill_column_types and not has_dynamic_column_names: - row_arg = helpers.make_typeddict(api, fields=column_types, required_keys=set()) + only_strings_as_fields_expressions = False + + flat = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'flat')) + named = helpers.parse_bool(helpers.get_argument_by_name(ctx, 'named')) + + api = cast(TypeChecker, ctx.api) + if named and flat: + api.fail("'flat' and 'named' can't be used together.", ctx.context) + return ret + + elif named: + # named=True, flat=False -> List[NamedTuple] + if fields_param_is_specified and only_strings_as_fields_expressions: + row_arg = helpers.make_named_tuple(api, fields=column_types, name="Row") else: + # fallback to catch-all NamedTuple + row_arg = helpers.make_named_tuple(api, fields=OrderedDict(), name="Row") + + elif flat: + # named=False, flat=True -> List of elements + if len(ctx.args[0]) > 1: + api.fail("'flat' is not valid when values_list is called with more than one field.", + ctx.context) return ctx.default_return_type + + if fields_param_is_specified and only_strings_as_fields_expressions: + # Grab first element + row_arg = column_types[column_names[0]] + else: + row_arg = any_type + else: - raise Exception(f"extract_proper_type_for_values_list doesn't support method {method_name}") + # named=False, flat=False -> List[Tuple] + if fields_param_is_specified: + args = [ + # Fallback to Any if the column name is unknown (e.g. dynamic) + column_types.get(column_name, any_type) if column_name is not None else any_type + for column_name in column_names + ] + else: + args = [any_type] + row_arg = helpers.make_tuple(api, fields=args) new_type_args = [model_arg, row_arg] return helpers.reparametrize_instance(ret, new_type_args) @@ -137,3 +186,24 @@ def resolve_values_lookup(api: CheckerPluginInterface, model_type_info: TypeInfo return helpers.make_optional(node_type) else: return node_type + + +def set_first_generic_param_as_default_for_second(fullname: str, ctx: AnalyzeTypeContext) -> Type: + if not ctx.type.args: + try: + return ctx.api.named_type(fullname, [AnyType(TypeOfAny.explicit), + AnyType(TypeOfAny.explicit)]) + except KeyError: + # really should never happen + return AnyType(TypeOfAny.explicit) + + args = ctx.type.args + if len(args) == 1: + args = [args[0], args[0]] + + analyzed_args = [ctx.api.analyze_type(arg) for arg in args] + try: + return ctx.api.named_type(fullname, analyzed_args) + except KeyError: + # really should never happen + return AnyType(TypeOfAny.explicit) diff --git a/mypy_django_plugin/transformers/related.py b/mypy_django_plugin/transformers/related.py new file mode 100644 index 000000000..46d2e02e6 --- /dev/null +++ b/mypy_django_plugin/transformers/related.py @@ -0,0 +1,59 @@ +from typing import Optional, Union + +from mypy.checkmember import AttributeContext +from mypy.nodes import TypeInfo +from mypy.types import AnyType, Instance, Type, TypeOfAny, UnionType + +from mypy_django_plugin import helpers + + +def _extract_referred_to_type_info(typ: Union[UnionType, Instance]) -> Optional[TypeInfo]: + if isinstance(typ, Instance): + return typ.type + else: + # should be Union[TYPE, None] + typ = helpers.make_required(typ) + if isinstance(typ, Instance): + return typ.type + return None + + +def extract_and_return_primary_key_of_bound_related_field_parameter(ctx: AttributeContext) -> Type: + if not isinstance(ctx.default_attr_type, Instance) or not (ctx.default_attr_type.type.fullname() == 'builtins.int'): + return ctx.default_attr_type + + if not isinstance(ctx.type, Instance) or not ctx.type.type.has_base(helpers.MODEL_CLASS_FULLNAME): + return ctx.default_attr_type + + field_name = ctx.context.name.split('_')[0] + sym = ctx.type.type.get(field_name) + if sym and isinstance(sym.type, Instance) and len(sym.type.args) > 0: + referred_to = sym.type.args[1] + if isinstance(referred_to, AnyType): + return AnyType(TypeOfAny.implementation_artifact) + + model_type = _extract_referred_to_type_info(referred_to) + if model_type is None: + return AnyType(TypeOfAny.implementation_artifact) + + primary_key_type = helpers.extract_primary_key_type_for_get(model_type) + if primary_key_type: + return primary_key_type + + is_nullable = helpers.is_field_nullable(ctx.type.type, field_name) + if is_nullable: + return helpers.make_optional(ctx.default_attr_type) + + return ctx.default_attr_type + + +def determine_type_of_related_manager(ctx: AttributeContext, related_manager_name: str) -> Type: + if not isinstance(ctx.type, Instance): + return ctx.default_attr_type + + related_manager_type = helpers.get_related_manager_type_from_metadata(ctx.type.type, + related_manager_name, ctx.api) + if not related_manager_type: + return ctx.default_attr_type + + return related_manager_type diff --git a/mypy_django_plugin/transformers/settings.py b/mypy_django_plugin/transformers/settings.py index 739a41e53..68963e53f 100644 --- a/mypy_django_plugin/transformers/settings.py +++ b/mypy_django_plugin/transformers/settings.py @@ -1,97 +1,78 @@ -from typing import Iterable, List, Optional, cast - -from mypy.nodes import ( - ClassDef, Context, ImportAll, MypyFile, SymbolNode, SymbolTableNode, TypeInfo, Var, -) -from mypy.plugin import ClassDefContext -from mypy.semanal import SemanticAnalyzerPass2 -from mypy.types import AnyType, Instance, NoneTyp, Type, TypeOfAny, UnionType -from mypy.util import correct_relative_import - - -def get_error_context(node: SymbolNode) -> Context: - context = Context() - context.set_line(node) - return context - - -def filter_out_nones(typ: UnionType) -> List[Type]: - return [item for item in typ.items if not isinstance(item, NoneTyp)] - - -def make_sym_copy_of_setting(sym: SymbolTableNode) -> Optional[SymbolTableNode]: - if isinstance(sym.type, Instance): - copied = sym.copy() - copied.node.info = sym.type.type - return copied - elif isinstance(sym.type, UnionType): - instances = filter_out_nones(sym.type) - if len(instances) > 1: - # plain unions not supported yet - return None - typ = instances[0] - if isinstance(typ, Instance): - copied = sym.copy() - copied.node.info = typ.type - return copied - return None - else: - return None - - -def get_settings_metadata(lazy_settings_info: TypeInfo): - return lazy_settings_info.metadata.setdefault('django', {}).setdefault('settings', {}) - - -def load_settings_from_names(settings_classdef: ClassDef, - modules: Iterable[MypyFile], - api: SemanticAnalyzerPass2) -> None: - settings_metadata = get_settings_metadata(settings_classdef.info) - - for module in modules: - for name, sym in module.names.items(): - if name.isupper() and isinstance(sym.node, Var): - if sym.type is not None: - copied = make_sym_copy_of_setting(sym) - if copied is None: - continue - settings_classdef.info.names[name] = copied - else: - var = Var(name, AnyType(TypeOfAny.unannotated)) - var.info = api.named_type('__builtins__.object').type # outer class type - settings_classdef.info.names[name] = SymbolTableNode(sym.kind, var, plugin_generated=True) - - settings_metadata[name] = module.fullname() - - -def get_import_star_modules(api: SemanticAnalyzerPass2, module: MypyFile) -> List[str]: - import_star_modules = [] - for module_import in module.imports: - # relative import * are not resolved by mypy - if isinstance(module_import, ImportAll) and module_import.relative: - absolute_import_path, correct = correct_relative_import(module.fullname(), module_import.relative, - module_import.id, is_cur_package_init_file=False) - if not correct: - return [] - for path in [absolute_import_path] + get_import_star_modules(api, - module=api.modules.get(absolute_import_path)): - if path not in import_star_modules: - import_star_modules.append(path) - return import_star_modules - - -class AddSettingValuesToDjangoConfObject: - def __init__(self, settings_modules: List[str], ignore_missing_settings: bool): - self.settings_modules = settings_modules - self.ignore_missing_settings = ignore_missing_settings - - def __call__(self, ctx: ClassDefContext) -> None: - api = cast(SemanticAnalyzerPass2, ctx.api) - for module_name in self.settings_modules: - module = api.modules[module_name] - star_deps = [api.modules[star_dep] - for star_dep in reversed(get_import_star_modules(api, module))] - load_settings_from_names(ctx.cls, modules=star_deps + [module], api=api) - - if self.ignore_missing_settings: - ctx.cls.info.fallback_to_any = True +from typing import TYPE_CHECKING, List, Optional, cast + +from mypy.checkexpr import FunctionContext +from mypy.checkmember import AttributeContext +from mypy.nodes import NameExpr, StrExpr, SymbolTableNode, TypeInfo +from mypy.types import AnyType, Instance, Type, TypeOfAny, TypeType + +from mypy_django_plugin import helpers + +if TYPE_CHECKING: + from mypy.checker import TypeChecker + + +def get_setting_sym(name: str, api: 'TypeChecker', settings_modules: List[str]) -> Optional[SymbolTableNode]: + for settings_mod_name in settings_modules: + file = api.modules[settings_mod_name] + sym = file.names.get(name) + if sym is not None: + return sym + + return None + + +def get_type_of_setting(ctx: AttributeContext, setting_name: str, + settings_modules: List[str], ignore_missing_settings: bool) -> Type: + setting_sym = get_setting_sym(setting_name, ctx.api, settings_modules) + if setting_sym: + if setting_sym.type is None: + # TODO: defer till setting_sym.type is not None + return AnyType(TypeOfAny.implementation_artifact) + + return setting_sym.type + + if not ignore_missing_settings: + ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context) + + return ctx.default_attr_type + + +def return_user_model_hook(ctx: FunctionContext, settings_modules: List[str]) -> Type: + from mypy.checker import TypeChecker + + api = cast(TypeChecker, ctx.api) + + setting_sym = get_setting_sym('AUTH_USER_MODEL', api, settings_modules) + if setting_sym is None: + return ctx.default_return_type + + setting_module_name, _, _ = setting_sym.fullname.rpartition('.') + setting_module = api.modules[setting_module_name] + + model_path = None + for name_expr, rvalue_expr in helpers.iter_over_assignments(setting_module): + if isinstance(name_expr, NameExpr) and isinstance(rvalue_expr, StrExpr): + if name_expr.name == 'AUTH_USER_MODEL': + model_path = rvalue_expr.value + break + + if not model_path: + return ctx.default_return_type + + app_label, _, model_class_name = model_path.rpartition('.') + if app_label is None: + return ctx.default_return_type + + model_fullname = helpers.get_model_fullname(app_label, model_class_name, + all_modules=api.modules) + if model_fullname is None: + api.fail(f'"{app_label}.{model_class_name}" model class is not imported so far. Try to import it ' + f'(under if TYPE_CHECKING) at the beginning of the current file', + context=ctx.context) + return ctx.default_return_type + + model_info = helpers.lookup_fully_qualified_generic(model_fullname, + all_modules=api.modules) + if model_info is None or not isinstance(model_info, TypeInfo): + return ctx.default_return_type + return TypeType(Instance(model_info, [])) diff --git a/pyproject.toml b/pyproject.toml index 729b90b6c..78fe223df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [tool.black] line-length = 120 -include = 'django-stubs/.*.pyi$' +include = 'django-stubs/.*.pyi$' \ No newline at end of file diff --git a/pytest.ini b/pytest.ini index af5af9b6c..a0afe5454 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,4 +4,4 @@ addopts = --tb=native --mypy-ini-file=./test-data/plugins.ini -s - -v \ No newline at end of file + -v diff --git a/scripts/build_import_all_test.py b/scripts/build_import_all_test.py index e13ffa25b..77e1efd40 100644 --- a/scripts/build_import_all_test.py +++ b/scripts/build_import_all_test.py @@ -2,7 +2,6 @@ from pathlib import Path from typing import List - STUBS_ROOT = Path(__file__).parent.parent / 'django-stubs' diff --git a/scripts/mypy.ini b/scripts/mypy.ini index e23a468ed..0b59a4d75 100644 --- a/scripts/mypy.ini +++ b/scripts/mypy.ini @@ -5,7 +5,7 @@ check_untyped_defs = True warn_no_return = False show_traceback = True allow_redefinition = True -incremental = False +incremental = True plugins = mypy_django_plugin.main diff --git a/setup.cfg b/setup.cfg index e85d89278..a8ecb8887 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,13 +14,5 @@ exclude = test-data max_line_length = 120 -[tool:pytest] -testpaths = ./test-data -addopts = - --tb=native - --mypy-ini-file=./test-data/plugins.ini - -s - -v - [metadata] license_file = LICENSE.txt diff --git a/setup.py b/setup.py index 278326637..21680b30a 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ def find_stub_files(name): readme = f.read() dependencies = [ - 'mypy>=0.670,<0.700', + 'mypy>=0.700,<0.710', 'typing-extensions' ] if sys.version_info[:2] < (3, 7): @@ -30,7 +30,7 @@ def find_stub_files(name): setup( name="django-stubs", - version="0.11.1", + version="0.12.0", description='Django mypy stubs', long_description=readme, long_description_content_type='text/markdown', diff --git a/test-data/typecheck/fields.test b/test-data/typecheck/fields.test index 7ebd6ec4e..ac6f4d633 100644 --- a/test-data/typecheck/fields.test +++ b/test-data/typecheck/fields.test @@ -33,7 +33,7 @@ class User(models.Model): text = models.TextField() user = User() -reveal_type(user.id) # E: Revealed type is 'builtins.int*' +reveal_type(user.id) # E: Revealed type is 'builtins.int' reveal_type(user.small_int) # E: Revealed type is 'builtins.int*' reveal_type(user.name) # E: Revealed type is 'builtins.str*' reveal_type(user.slug) # E: Revealed type is 'builtins.str*' @@ -51,7 +51,7 @@ class Booking(models.Model): some_decimal = models.DecimalField(max_digits=10, decimal_places=5) booking = Booking() -reveal_type(booking.id) # E: Revealed type is 'builtins.int*' +reveal_type(booking.id) # E: Revealed type is 'builtins.int' reveal_type(booking.time_range) # E: Revealed type is 'Any' reveal_type(booking.some_decimal) # E: Revealed type is 'decimal.Decimal*' [/CASE] @@ -72,7 +72,10 @@ class User(models.Model): my_pk = models.IntegerField(primary_key=True) reveal_type(User().my_pk) # E: Revealed type is 'builtins.int*' -reveal_type(User().id) # E: Revealed type is 'Any' +reveal_type(User().id) +[out] +main:7: error: Revealed type is 'Any' +main:7: error: Default primary key 'id' is not defined [/CASE] [CASE test_meta_nested_class_allows_subclassing_in_multiple_inheritance] @@ -100,6 +103,16 @@ class User(Abstract): id = models.AutoField(primary_key=True) [/CASE] +[CASE test_primary_key_on_optional_queryset_method] +from django.db import models +class User(models.Model): + pass +reveal_type(User.objects.first().id) +[out] +main:4: error: Revealed type is 'Any' +main:4: error: Item "None" of "Optional[User]" has no attribute "id" +[/CASE] + [CASE standard_it_from_parent_model_could_be_overridden_with_non_integer_field_in_child_model] from django.db import models import uuid @@ -107,7 +120,7 @@ class ParentModel(models.Model): pass class MyModel(ParentModel): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) -reveal_type(MyModel().id) # E: Revealed type is 'uuid.UUID*' +reveal_type(MyModel().id) # E: Revealed type is 'uuid.UUID' [/CASE] [CASE blank_and_null_char_field_allows_none] diff --git a/test-data/typecheck/queryset.test b/test-data/typecheck/queryset.test index 555682e3b..ab0623746 100644 --- a/test-data/typecheck/queryset.test +++ b/test-data/typecheck/queryset.test @@ -1,37 +1,22 @@ -[CASE test_queryset] +[CASE test_queryset_second_argument_filled_automatically] from django.db import models -class Blog(models.Model): - name = models.CharField(max_length=100) - created_at = models.DateTimeField() +class Blog(models.Model): pass # QuerySet where second type argument is not specified shouldn't raise any errors class BlogQuerySet(models.QuerySet[Blog]): pass -class Entry(models.Model): - blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name="entries") - title = models.CharField(max_length=100) - - -# Test that second type argument gets filled automatically blog_qs: models.QuerySet[Blog] reveal_type(blog_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog, main.Blog]' +[/CASE] -reveal_type(Blog.objects.in_bulk([1])) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' -reveal_type(Blog.objects.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' -reveal_type(Blog.objects.in_bulk(['beatles_blog'], field_name='name')) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' -# When ANDing QuerySets, the left-side's _Row parameter is used -reveal_type(Blog.objects.all() & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, main.Blog*]' -reveal_type(Blog.objects.values() & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.dict*[builtins.str, Any]]' -reveal_type(Blog.objects.values_list('id', 'name') & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, builtins.str]]' -reveal_type(Blog.objects.values_list('id', 'name', named=True) & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, builtins.str, fallback=main.Row]]' -reveal_type(Blog.objects.values_list('id', flat=True) & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.int*]' +[CASE test_queryset_methods] +from django.db import models -# .dates / .datetimes -reveal_type(Blog.objects.dates("created_at", "day")) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, datetime.date]' -reveal_type(Blog.objects.datetimes("created_at", "day")) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, datetime.datetime]' +class Blog(models.Model): + created_at = models.DateTimeField() qs = Blog.objects.all() reveal_type(qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, main.Blog*]' @@ -44,6 +29,32 @@ reveal_type(qs[0]) # E: Revealed type is 'main.Blog*' reveal_type(qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, main.Blog*]' reveal_type(qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' +# .dates / .datetimes +reveal_type(Blog.objects.dates("created_at", "day")) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, datetime.date]' +reveal_type(Blog.objects.datetimes("created_at", "day")) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, datetime.datetime]' +[/CASE] + + +[CASE test_combine_querysets_with_and] +from django.db import models + +class Blog(models.Model): + name = models.CharField(max_length=100) + created_at = models.DateTimeField() + +# When ANDing QuerySets, the left-side's _Row parameter is used +reveal_type(Blog.objects.all() & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, main.Blog*]' +reveal_type(Blog.objects.values() & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.dict*[builtins.str, Any]]' +reveal_type(Blog.objects.values_list('id', 'name') & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, builtins.str]]' +reveal_type(Blog.objects.values_list('id', 'name', named=True) & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, builtins.str, fallback=main.Row]]' +reveal_type(Blog.objects.values_list('id', flat=True) & Blog.objects.values()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.int*]' +[/CASE] + + +[CASE test_queryset_values_method] +from django.db import models + +class Blog(models.Model): pass values_qs = Blog.objects.values() reveal_type(values_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.dict[builtins.str, Any]]' @@ -56,7 +67,14 @@ reveal_type(values_qs.earliest()) # E: Revealed type is 'builtins.dict*[builtins reveal_type(values_qs[0]) # E: Revealed type is 'builtins.dict*[builtins.str, Any]' reveal_type(values_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.dict*[builtins.str, Any]]' reveal_type(values_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' +[/CASE] + +[CASE test_queryset_values_list_named_false_flat_false] +from django.db import models + +class Blog(models.Model): + name = models.CharField(max_length=100) values_list_qs = Blog.objects.values_list('id', 'name') reveal_type(values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, builtins.str]]' @@ -69,8 +87,15 @@ reveal_type(values_list_qs.earliest()) # E: Revealed type is 'Tuple[builtins.int reveal_type(values_list_qs[0]) # E: Revealed type is 'Tuple[builtins.int, builtins.str]' reveal_type(values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, builtins.str]]' reveal_type(values_list_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' +[/CASE] +[CASE test_queryset_values_list_named_false_flat_true] +from django.db import models + +class Blog(models.Model): + name = models.CharField(max_length=100) + flat_values_list_qs = Blog.objects.values_list('id', flat=True) reveal_type(flat_values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.int]' reveal_type(flat_values_list_qs.all()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.int*]' @@ -82,36 +107,51 @@ reveal_type(flat_values_list_qs.earliest()) # E: Revealed type is 'builtins.int* reveal_type(flat_values_list_qs[0]) # E: Revealed type is 'builtins.int*' reveal_type(flat_values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.int*]' reveal_type(flat_values_list_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' +[/CASE] + +[CASE test_queryset_values_list_named_true_flat_false] +from django.db import models + +class Blog(models.Model): + name = models.CharField(max_length=100) named_values_list_qs = Blog.objects.values_list('id', named=True) -reveal_type(named_values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, fallback=main.Row1]]' -reveal_type(named_values_list_qs.all()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, fallback=main.Row1]]' -reveal_type(named_values_list_qs.get(id=1)) # E: Revealed type is 'Tuple[builtins.int, fallback=main.Row1]' -reveal_type(iter(named_values_list_qs)) # E: Revealed type is 'typing.Iterator[Tuple[builtins.int, fallback=main.Row1]]' -reveal_type(named_values_list_qs.iterator()) # E: Revealed type is 'typing.Iterator[Tuple[builtins.int, fallback=main.Row1]]' -reveal_type(named_values_list_qs.first()) # E: Revealed type is 'Union[Tuple[builtins.int, fallback=main.Row1], None]' -reveal_type(named_values_list_qs.earliest()) # E: Revealed type is 'Tuple[builtins.int, fallback=main.Row1]' -reveal_type(named_values_list_qs[0]) # E: Revealed type is 'Tuple[builtins.int, fallback=main.Row1]' -reveal_type(named_values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, fallback=main.Row1]]' +reveal_type(named_values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, fallback=main.Row]]' +reveal_type(named_values_list_qs.all()) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, fallback=main.Row]]' +reveal_type(named_values_list_qs.get(id=1)) # E: Revealed type is 'Tuple[builtins.int, fallback=main.Row]' +reveal_type(iter(named_values_list_qs)) # E: Revealed type is 'typing.Iterator[Tuple[builtins.int, fallback=main.Row]]' +reveal_type(named_values_list_qs.iterator()) # E: Revealed type is 'typing.Iterator[Tuple[builtins.int, fallback=main.Row]]' +reveal_type(named_values_list_qs.first()) # E: Revealed type is 'Union[Tuple[builtins.int, fallback=main.Row], None]' +reveal_type(named_values_list_qs.earliest()) # E: Revealed type is 'Tuple[builtins.int, fallback=main.Row]' +reveal_type(named_values_list_qs[0]) # E: Revealed type is 'Tuple[builtins.int, fallback=main.Row]' +reveal_type(named_values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Tuple[builtins.int, fallback=main.Row]]' reveal_type(named_values_list_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]' +[/CASE] -[out] -[CASE test_queryset_values_list_custom_primary_key] +[CASE test_queryset_values_list_flat_true_custom_primary_key_get_element] from django.db import models class Blog(models.Model): primary_uuid = models.UUIDField(primary_key=True) -class Entry(models.Model): - blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name="entries") - # Blog has a primary key field specified, so no automatic 'id' field is expected to exist reveal_type(Blog.objects.values_list('id', flat=True).get()) # E: Revealed type is 'Any' # Access Blog's pk (which is UUID field) reveal_type(Blog.objects.values_list('pk', flat=True).get()) # E: Revealed type is 'uuid.UUID*' +[/CASE] + + +[CASE test_queryset_values_list_flat_true_custom_primary_key_related_field] +from django.db import models + +class Blog(models.Model): + primary_uuid = models.UUIDField(primary_key=True) + +class Entry(models.Model): + blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name="entries") # Accessing PK of model pointed to by foreign key reveal_type(Entry.objects.values_list('blog', flat=True).get()) # E: Revealed type is 'uuid.UUID*' @@ -122,26 +162,27 @@ reveal_type(Entry.objects.values_list('blog__pk', flat=True).get()) # E: Reveale # Blog has a primary key field specified, so no automatic 'id' field is expected to exist reveal_type(Entry.objects.values_list('blog__id', flat=True).get()) # E: Revealed type is 'Any' +[/CASE] -[CASE test_queryset_values_list] + +[CASE test_queryset_values_list_error_conditions] from django.db import models class Blog(models.Model): name = models.CharField(max_length=100) - created_at = models.DateTimeField() - -class Entry(models.Model): - blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name="entries") - nullable_blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name="+", null=True) - blog_with_related_query_name = models.ForeignKey(Blog, on_delete=models.CASCADE, related_query_name="my_related_query_name") - title = models.CharField(max_length=100) - -class BlogChild(Blog): - child_field = models.CharField(max_length=100) # Emulate at type-check time the errors that Django reports Blog.objects.values_list('id', flat=True, named=True) # E: 'flat' and 'named' can't be used together. -Blog.objects.values_list('id', 'created_at', flat=True) # E: 'flat' is not valid when values_list is called with more than one field. +Blog.objects.values_list('id', 'name', flat=True) # E: 'flat' is not valid when values_list is called with more than one field. +[/CASE] + + +[CASE test_queryset_values_list_returns_tuple_of_fields] +from django.db import models + +class Blog(models.Model): + name = models.CharField(max_length=100) + created_at = models.DateTimeField() # values_list where parameter types are all known reveal_type(Blog.objects.values_list('id', 'created_at').get()) # E: Revealed type is 'Tuple[builtins.int, datetime.datetime]' @@ -152,36 +193,123 @@ tup[2] # E: Tuple index out of range # values_list returning namedtuple reveal_type(Blog.objects.values_list('id', 'created_at', named=True).get()) # E: Revealed type is 'Tuple[builtins.int, datetime.datetime, fallback=main.Row]' +[/CASE] + + +[CASE test_queryset_values_list_invalid_lookups_produce_any] +from django.db import models + +class Blog(models.Model): pass +class Entry(models.Model): + blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name="entries") # Invalid lookups produce Any type rather than giving errors. reveal_type(Blog.objects.values_list('id', 'invalid_lookup').get()) # E: Revealed type is 'Tuple[builtins.int, Any]' reveal_type(Blog.objects.values_list('entries_id', flat=True).get()) # E: Revealed type is 'Any' reveal_type(Blog.objects.values_list('entries__foo', flat=True).get()) # E: Revealed type is 'Any' reveal_type(Blog.objects.values_list('+', flat=True).get()) # E: Revealed type is 'Any' +[/CASE] + + +[CASE test_queryset_values_list_basic_inheritance] +from django.db import models + +class Blog(models.Model): + name = models.CharField(max_length=100) + created_at = models.DateTimeField() + +class BlogChild(Blog): + child_field = models.CharField(max_length=100) + +# Basic inheritance +reveal_type(BlogChild.objects.values_list('id', 'created_at', 'child_field').get()) # E: Revealed type is 'Tuple[builtins.int, datetime.datetime, builtins.str]' +[/CASE] + + +[CASE test_query_values_list_flat_true_plain_foreign_key] +from django.db import models + +class Blog(models.Model): pass +class Entry(models.Model): + blog = models.ForeignKey(Blog, on_delete=models.CASCADE) # Foreign key reveal_type(Entry.objects.values_list('blog', flat=True).get()) # E: Revealed type is 'builtins.int*' reveal_type(Entry.objects.values_list('blog__id', flat=True).get()) # E: Revealed type is 'builtins.int*' reveal_type(Entry.objects.values_list('blog__pk', flat=True).get()) # E: Revealed type is 'builtins.int*' reveal_type(Entry.objects.values_list('blog_id', flat=True).get()) # E: Revealed type is 'builtins.int*' +[/CASE] + + +[CASE test_query_values_list_flat_true_custom_primary_key] +from django.db import models + +class Blog(models.Model): + id = models.UUIDField(primary_key=True) + +class Entry(models.Model): + blog = models.ForeignKey(Blog, on_delete=models.CASCADE) + +# Foreign key +reveal_type(Entry.objects.values_list('blog', flat=True).get()) # E: Revealed type is 'uuid.UUID*' +reveal_type(Entry.objects.values_list('blog__id', flat=True).get()) # E: Revealed type is 'uuid.UUID*' +reveal_type(Entry.objects.values_list('blog__pk', flat=True).get()) # E: Revealed type is 'uuid.UUID*' +reveal_type(Entry.objects.values_list('blog_id', flat=True).get()) # E: Revealed type is 'uuid.UUID*' +[/CASE] + + +[CASE test_query_values_list_flat_true_nullable_foreign_key] +from django.db import models + +class Blog(models.Model): pass +class Entry(models.Model): + nullable_blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name="+", null=True) # Foreign key (nullable=True) reveal_type(Entry.objects.values_list('nullable_blog', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' reveal_type(Entry.objects.values_list('nullable_blog_id', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' reveal_type(Entry.objects.values_list('nullable_blog__id', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' reveal_type(Entry.objects.values_list('nullable_blog__pk', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' +[/CASE] + + +[CASE test_query_values_list_flat_true_foreign_key_reverse_relation] +from django.db import models + +class Blog(models.Model): pass +class Entry(models.Model): + blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name='entries') + blog_with_related_query_name = models.ForeignKey(Blog, on_delete=models.CASCADE, related_query_name="my_related_query_name") + title = models.CharField(max_length=100) # Reverse relation of ForeignKey -reveal_type(Blog.objects.values_list('entries', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' -reveal_type(Blog.objects.values_list('entries__id', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' -reveal_type(Blog.objects.values_list('entries__title', flat=True).get()) # E: Revealed type is 'Union[builtins.str, None]' +reveal_type(Blog.objects.values_list('entries', flat=True).get()) # E: Revealed type is 'builtins.int*' +reveal_type(Blog.objects.values_list('entries__id', flat=True).get()) # E: Revealed type is 'builtins.int*' +reveal_type(Blog.objects.values_list('entries__title', flat=True).get()) # E: Revealed type is 'builtins.str*' # Reverse relation of ForeignKey (with related_query_name set) -reveal_type(Blog.objects.values_list('my_related_query_name__id', flat=True).get()) # E: Revealed type is 'Union[builtins.int, None]' +reveal_type(Blog.objects.values_list('my_related_query_name__id', flat=True).get()) # E: Revealed type is 'builtins.int*' +[/CASE] -# Basic inheritance -reveal_type(BlogChild.objects.values_list('id', 'created_at', 'child_field').get()) # E: Revealed type is 'Tuple[builtins.int, datetime.datetime, builtins.str]' +[CASE test_query_values_list_flat_true_foreign_key_custom_primary_key_reverse_relation] +from django.db import models + +class Blog(models.Model): pass + +class Entry(models.Model): + id = models.UUIDField(primary_key=True) + blog = models.ForeignKey(Blog, on_delete=models.CASCADE, related_name='entries') + blog_with_related_query_name = models.ForeignKey(Blog, on_delete=models.CASCADE, related_query_name="my_related_query_name") + title = models.CharField(max_length=100) + +# Reverse relation of ForeignKey +reveal_type(Blog.objects.values_list('entries', flat=True).get()) # E: Revealed type is 'uuid.UUID*' +reveal_type(Blog.objects.values_list('entries__id', flat=True).get()) # E: Revealed type is 'uuid.UUID*' + +# Reverse relation of ForeignKey (with related_query_name set) +reveal_type(Blog.objects.values_list('my_related_query_name__id', flat=True).get()) # E: Revealed type is 'uuid.UUID*' +[/CASE] [CASE test_queryset_values_list_and_values_behavior_with_no_fields_specified_and_accessing_unknown_attributes] diff --git a/test-data/typecheck/related_fields.test b/test-data/typecheck/related_fields.test index 0e80c74ca..6255758f1 100644 --- a/test-data/typecheck/related_fields.test +++ b/test-data/typecheck/related_fields.test @@ -136,7 +136,8 @@ reveal_type(App().views) # E: Revealed type is 'django.db.models.manager.Relate reveal_type(App().views2) # E: Revealed type is 'django.db.models.manager.RelatedManager[main.View2]' [out] -[CASE models_imported_inside_init_file] +[CASE models_imported_inside_init_file_foreign_key] +[disable_cache] from django.db import models from myapp.models import App class View(models.Model): @@ -150,8 +151,10 @@ from .app import App from django.db import models class App(models.Model): pass +[/CASE] [CASE models_imported_inside_init_file_one_to_one_field] +[disable_cache] from django.db import models from myapp.models import User class Profile(models.Model): diff --git a/test-data/typecheck/settings.test b/test-data/typecheck/settings.test index 4ff232a48..7448f0eb5 100644 --- a/test-data/typecheck/settings.test +++ b/test-data/typecheck/settings.test @@ -3,10 +3,13 @@ [disable_cache] from django.conf import settings +# standard settings +reveal_type(settings.AUTH_USER_MODEL) # E: Revealed type is 'builtins.str' + reveal_type(settings.ROOT_DIR) # E: Revealed type is 'builtins.str' reveal_type(settings.APPS_DIR) # E: Revealed type is 'pathlib.Path' reveal_type(settings.OBJ) # E: Revealed type is 'django.utils.functional.LazyObject' -reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[builtins.str]' +reveal_type(settings.NUMBERS) # E: Revealed type is 'builtins.list[builtins.str*]' reveal_type(settings.DICT) # E: Revealed type is 'builtins.dict[Any, Any]' [file base.py] from pathlib import Path @@ -27,18 +30,32 @@ OBJ = LazyObject() from django.conf import settings reveal_type(settings.ROOT_DIR) # E: Revealed type is 'pathlib.Path' -reveal_type(settings.SETUP) # E: Revealed type is 'builtins.int' -reveal_type(settings.DATABASES) # E: Revealed type is 'builtins.dict[builtins.str, builtins.str]' +reveal_type(settings.SETUP) # E: Revealed type is 'Union[builtins.int, None]' +reveal_type(settings.DATABASES) # E: Revealed type is 'builtins.dict[builtins.str*, builtins.str*]' + +reveal_type(settings.LOCAL_SETTING) # E: Revealed type is 'builtins.int' +reveal_type(settings.BASE_SETTING) # E: Revealed type is 'builtins.int' + [file mysettings.py] from local import * -DATABASES = {'default': 'mydb'} +from typing import Optional +SETUP: Optional[int] = 3 + [file local.py] from base import * -SETUP = 3 +SETUP: int = 3 +DATABASES = {'default': 'mydb'} + +LOCAL_SETTING = 1 + [file base.py] from pathlib import Path - +from typing import Any +SETUP: Any = None ROOT_DIR = Path(__file__) + +BASE_SETTING = 1 + [/CASE] [CASE global_settings_are_always_loaded] @@ -73,10 +90,10 @@ LIST: List[str] = ['1', '2'] from django.conf import settings reveal_type(settings.NOT_EXISTING) -[env DJANGO_SETTINGS_MODULE=mysettings] +[env DJANGO_SETTINGS_MODULE=mysettings2] [disable_cache] -[file mysettings.py] +[file mysettings2.py] [out] main:2: error: Revealed type is 'Any' -main:2: error: "LazySettings" has no attribute "NOT_EXISTING" +main:2: error: 'Settings' object has no attribute 'NOT_EXISTING' [/CASE] \ No newline at end of file diff --git a/test-data/typecheck/shortcuts.test b/test-data/typecheck/shortcuts.test index abb24d430..72f510953 100644 --- a/test-data/typecheck/shortcuts.test +++ b/test-data/typecheck/shortcuts.test @@ -25,7 +25,10 @@ UserModel = get_user_model() reveal_type(UserModel.objects) # E: Revealed type is 'django.db.models.manager.Manager[myapp.models.MyUser]' [file mysettings.py] +from basic import * INSTALLED_APPS = ('myapp',) + +[file basic.py] AUTH_USER_MODEL = 'myapp.MyUser' [file myapp/__init__.py]