From 7bb1f3718e9e527ea01feb11530af167721dc0b4 Mon Sep 17 00:00:00 2001 From: pranavrajpal <78008260+pranavrajpal@users.noreply.github.com> Date: Tue, 8 Jun 2021 20:40:28 -0700 Subject: [PATCH] [mypyc] Avoid crash when importing unknown module with from import (#10550) Fixes mypyc/mypyc#851 This fixes a bug where code compiled with mypyc would crash on from imports (from x import y) if: * y is a module * mypy doesn't know that y is a module (due to an ignore_missing_imports configuration option or something else) The bug was caused by using getattr to import modules (i.e. y = getattr(x, 'y')) and changing this to import x.y as y when it can determine that y is a module. This doesn't work when we don't know that y is a module. I changed the from import handling to use something similar to the method shown in the __import__ docs. I also removed the special casing of from imports for modules (from x import y where y is a module) mentioned earlier, because these changes make that special casing unnecessary. --- mypyc/irbuild/builder.py | 38 +- mypyc/irbuild/statement.py | 9 +- mypyc/primitives/misc_ops.py | 12 +- mypyc/test-data/irbuild-basic.test | 498 ++++++++++++++------------- mypyc/test-data/irbuild-classes.test | 340 +++++++++--------- mypyc/test-data/run-imports.test | 43 +++ 6 files changed, 522 insertions(+), 418 deletions(-) diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index fca709d85fa4..1411222010a1 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -35,7 +35,7 @@ SetAttr, LoadStatic, InitStatic, NAMESPACE_MODULE, RaiseStandardError ) from mypyc.ir.rtypes import ( - RType, RTuple, RInstance, int_rprimitive, dict_rprimitive, + RType, RTuple, RInstance, c_int_rprimitive, int_rprimitive, dict_rprimitive, none_rprimitive, is_none_rprimitive, object_rprimitive, is_object_rprimitive, str_rprimitive, is_tagged, is_list_rprimitive, is_tuple_rprimitive, c_pyssize_t_rprimitive ) @@ -45,7 +45,9 @@ from mypyc.primitives.list_ops import to_list, list_pop_last, list_get_item_unsafe_op from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op from mypyc.primitives.generic_ops import py_setattr_op, iter_op, next_op -from mypyc.primitives.misc_ops import import_op, check_unpack_count_op, get_module_dict_op +from mypyc.primitives.misc_ops import ( + import_op, check_unpack_count_op, get_module_dict_op, import_extra_args_op +) from mypyc.crash import catch_errors from mypyc.options import CompilerOptions from mypyc.errors import Errors @@ -286,19 +288,45 @@ def add_to_non_ext_dict(self, non_ext: NonExtClassInfo, key_unicode = self.load_str(key) self.call_c(dict_set_item_op, [non_ext.dict, key_unicode, val], line) + def gen_import_from(self, id: str, line: int, imported: List[str]) -> None: + self.imports[id] = None + + globals_dict = self.load_globals_dict() + null = Integer(0, dict_rprimitive, line) + names_to_import = self.new_list_op([self.load_str(name) for name in imported], line) + + level = Integer(0, c_int_rprimitive, line) + value = self.call_c( + import_extra_args_op, + [self.load_str(id), globals_dict, null, names_to_import, level], + line, + ) + self.add(InitStatic(value, id, namespace=NAMESPACE_MODULE)) + def gen_import(self, id: str, line: int) -> None: self.imports[id] = None needs_import, out = BasicBlock(), BasicBlock() - first_load = self.load_module(id) - comparison = self.translate_is_op(first_load, self.none_object(), 'is not', line) - self.add_bool_branch(comparison, out, needs_import) + self.check_if_module_loaded(id, line, needs_import, out) self.activate_block(needs_import) value = self.call_c(import_op, [self.load_str(id)], line) self.add(InitStatic(value, id, namespace=NAMESPACE_MODULE)) self.goto_and_activate(out) + def check_if_module_loaded(self, id: str, line: int, + needs_import: BasicBlock, out: BasicBlock) -> None: + """Generate code that checks if the module `id` has been loaded yet. + + Arguments: + id: name of module to check if imported + line: line number that the import occurs on + needs_import: the BasicBlock that is run if the module has not been loaded yet + out: the BasicBlock that is run if the module has already been loaded""" + first_load = self.load_module(id) + comparison = self.translate_is_op(first_load, self.none_object(), 'is not', line) + self.add_bool_branch(comparison, out, needs_import) + def get_module(self, module: str, line: int) -> Value: # Python 3.7 has a nice 'PyImport_GetModule' function that we can't use :( mod_dict = self.call_c(get_module_dict_op, [], line) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index f33509eb93d9..f37e2455f3e2 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -172,7 +172,8 @@ def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None: id = importlib.util.resolve_name('.' * node.relative + node.id, module_package) - builder.gen_import(id, node.line) + imported = [name for name, _ in node.names] + builder.gen_import_from(id, node.line, imported) module = builder.load_module(id) # Copy everything into our module's dict. @@ -181,12 +182,6 @@ def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None: # This probably doesn't matter much and the code runs basically right. globals = builder.load_globals_dict() for name, maybe_as_name in node.names: - # If one of the things we are importing is a module, - # import it as a module also. - fullname = id + '.' + name - if fullname in builder.graph or fullname in module_state.suppressed: - builder.gen_import(fullname, node.line) - as_name = maybe_as_name or name obj = builder.py_get_attr(module, name, node.line) builder.gen_method_call( diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 3c8f8e4fa9dd..8d18978622ec 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -3,7 +3,8 @@ from mypyc.ir.ops import ERR_NEVER, ERR_MAGIC, ERR_FALSE from mypyc.ir.rtypes import ( bool_rprimitive, object_rprimitive, str_rprimitive, object_pointer_rprimitive, - int_rprimitive, dict_rprimitive, c_int_rprimitive, bit_rprimitive, c_pyssize_t_rprimitive + int_rprimitive, dict_rprimitive, c_int_rprimitive, bit_rprimitive, c_pyssize_t_rprimitive, + list_rprimitive, ) from mypyc.primitives.registry import ( function_op, custom_op, load_address_op, ERR_NEG_INT @@ -113,6 +114,15 @@ c_function_name='PyImport_Import', error_kind=ERR_MAGIC) +# Import with extra arguments (used in from import handling) +import_extra_args_op = custom_op( + arg_types=[str_rprimitive, dict_rprimitive, dict_rprimitive, + list_rprimitive, c_int_rprimitive], + return_type=object_rprimitive, + c_function_name='PyImport_ImportModuleLevelObject', + error_kind=ERR_MAGIC +) + # Get the sys.modules dictionary get_module_dict_op = custom_op( arg_types=[], diff --git a/mypyc/test-data/irbuild-basic.test b/mypyc/test-data/irbuild-basic.test index 98a86e3f7ee9..bd1df8276f59 100644 --- a/mypyc/test-data/irbuild-basic.test +++ b/mypyc/test-data/irbuild-basic.test @@ -2491,80 +2491,83 @@ def __top_level__(): r0, r1 :: object r2 :: bit r3 :: str - r4, r5, r6 :: object - r7 :: bit - r8 :: str - r9, r10 :: object - r11 :: dict - r12 :: str - r13 :: object + r4 :: object + r5 :: dict + r6, r7, r8 :: str + r9 :: list + r10, r11, r12, r13 :: ptr r14 :: str - r15 :: int32 - r16 :: bit - r17 :: str - r18 :: object - r19 :: str - r20 :: int32 - r21 :: bit - r22 :: str - r23 :: object - r24 :: str - r25 :: int32 - r26 :: bit - r27, r28 :: str + r15, r16 :: object + r17 :: dict + r18 :: str + r19 :: object + r20 :: str + r21 :: int32 + r22 :: bit + r23 :: str + r24 :: object + r25 :: str + r26 :: int32 + r27 :: bit + r28 :: str r29 :: object - r30 :: tuple[str, object] - r31 :: object - r32 :: str - r33 :: object - r34 :: tuple[str, object] + r30 :: str + r31 :: int32 + r32 :: bit + r33, r34 :: str r35 :: object - r36 :: tuple[object, object] + r36 :: tuple[str, object] r37 :: object - r38 :: dict - r39 :: str - r40, r41 :: object - r42 :: dict - r43 :: str - r44 :: int32 - r45 :: bit - r46 :: str - r47 :: dict - r48 :: str - r49, r50, r51 :: object - r52 :: tuple + r38 :: str + r39 :: object + r40 :: tuple[str, object] + r41 :: object + r42 :: tuple[object, object] + r43 :: object + r44 :: dict + r45 :: str + r46, r47 :: object + r48 :: dict + r49 :: str + r50 :: int32 + r51 :: bit + r52 :: str r53 :: dict r54 :: str - r55 :: int32 - r56 :: bit - r57 :: dict - r58 :: str - r59, r60, r61 :: object - r62 :: dict - r63 :: str - r64 :: int32 - r65 :: bit - r66 :: str - r67 :: dict - r68 :: str - r69 :: object - r70 :: dict - r71 :: str - r72, r73 :: object - r74 :: dict - r75 :: str - r76 :: int32 - r77 :: bit - r78 :: list - r79, r80, r81 :: object - r82, r83, r84, r85 :: ptr - r86 :: dict - r87 :: str - r88, r89 :: object - r90 :: dict - r91 :: str - r92 :: int32 - r93 :: bit + r55, r56, r57 :: object + r58 :: tuple + r59 :: dict + r60 :: str + r61 :: int32 + r62 :: bit + r63 :: dict + r64 :: str + r65, r66, r67 :: object + r68 :: dict + r69 :: str + r70 :: int32 + r71 :: bit + r72 :: str + r73 :: dict + r74 :: str + r75 :: object + r76 :: dict + r77 :: str + r78, r79 :: object + r80 :: dict + r81 :: str + r82 :: int32 + r83 :: bit + r84 :: list + r85, r86, r87 :: object + r88, r89, r90, r91 :: ptr + r92 :: dict + r93 :: str + r94, r95 :: object + r96 :: dict + r97 :: str + r98 :: int32 + r99 :: bit L0: r0 = builtins :: module r1 = load_address _Py_NoneStruct @@ -2575,103 +2578,110 @@ L1: r4 = PyImport_Import(r3) builtins = r4 :: module L2: - r5 = typing :: module - r6 = load_address _Py_NoneStruct - r7 = r5 != r6 - if r7 goto L4 else goto L3 :: bool -L3: - r8 = 'typing' - r9 = PyImport_Import(r8) - typing = r9 :: module -L4: - r10 = typing :: module - r11 = __main__.globals :: static - r12 = 'List' - r13 = CPyObject_GetAttr(r10, r12) - r14 = 'List' - r15 = CPyDict_SetItem(r11, r14, r13) - r16 = r15 >= 0 :: signed - r17 = 'NewType' - r18 = CPyObject_GetAttr(r10, r17) - r19 = 'NewType' - r20 = CPyDict_SetItem(r11, r19, r18) - r21 = r20 >= 0 :: signed - r22 = 'NamedTuple' - r23 = CPyObject_GetAttr(r10, r22) - r24 = 'NamedTuple' - r25 = CPyDict_SetItem(r11, r24, r23) - r26 = r25 >= 0 :: signed - r27 = 'Lol' - r28 = 'a' - r29 = load_address PyLong_Type - r30 = (r28, r29) - r31 = box(tuple[str, object], r30) - r32 = 'b' - r33 = load_address PyUnicode_Type - r34 = (r32, r33) - r35 = box(tuple[str, object], r34) - r36 = (r31, r35) - r37 = box(tuple[object, object], r36) - r38 = __main__.globals :: static - r39 = 'NamedTuple' - r40 = CPyDict_GetItem(r38, r39) - r41 = PyObject_CallFunctionObjArgs(r40, r27, r37, 0) - r42 = __main__.globals :: static - r43 = 'Lol' - r44 = CPyDict_SetItem(r42, r43, r41) - r45 = r44 >= 0 :: signed - r46 = '' - r47 = __main__.globals :: static - r48 = 'Lol' - r49 = CPyDict_GetItem(r47, r48) - r50 = box(short_int, 2) - r51 = PyObject_CallFunctionObjArgs(r49, r50, r46, 0) - r52 = cast(tuple, r51) + r5 = __main__.globals :: static + r6 = 'List' + r7 = 'NewType' + r8 = 'NamedTuple' + r9 = PyList_New(3) + r10 = get_element_ptr r9 ob_item :: PyListObject + r11 = load_mem r10 :: ptr* + set_mem r11, r6 :: builtins.object* + r12 = r11 + WORD_SIZE*1 + set_mem r12, r7 :: builtins.object* + r13 = r11 + WORD_SIZE*2 + set_mem r13, r8 :: builtins.object* + keep_alive r9 + r14 = 'typing' + r15 = PyImport_ImportModuleLevelObject(r14, r5, 0, r9, 0) + typing = r15 :: module + r16 = typing :: module + r17 = __main__.globals :: static + r18 = 'List' + r19 = CPyObject_GetAttr(r16, r18) + r20 = 'List' + r21 = CPyDict_SetItem(r17, r20, r19) + r22 = r21 >= 0 :: signed + r23 = 'NewType' + r24 = CPyObject_GetAttr(r16, r23) + r25 = 'NewType' + r26 = CPyDict_SetItem(r17, r25, r24) + r27 = r26 >= 0 :: signed + r28 = 'NamedTuple' + r29 = CPyObject_GetAttr(r16, r28) + r30 = 'NamedTuple' + r31 = CPyDict_SetItem(r17, r30, r29) + r32 = r31 >= 0 :: signed + r33 = 'Lol' + r34 = 'a' + r35 = load_address PyLong_Type + r36 = (r34, r35) + r37 = box(tuple[str, object], r36) + r38 = 'b' + r39 = load_address PyUnicode_Type + r40 = (r38, r39) + r41 = box(tuple[str, object], r40) + r42 = (r37, r41) + r43 = box(tuple[object, object], r42) + r44 = __main__.globals :: static + r45 = 'NamedTuple' + r46 = CPyDict_GetItem(r44, r45) + r47 = PyObject_CallFunctionObjArgs(r46, r33, r43, 0) + r48 = __main__.globals :: static + r49 = 'Lol' + r50 = CPyDict_SetItem(r48, r49, r47) + r51 = r50 >= 0 :: signed + r52 = '' r53 = __main__.globals :: static - r54 = 'x' - r55 = CPyDict_SetItem(r53, r54, r52) - r56 = r55 >= 0 :: signed - r57 = __main__.globals :: static - r58 = 'List' - r59 = CPyDict_GetItem(r57, r58) - r60 = load_address PyLong_Type - r61 = PyObject_GetItem(r59, r60) - r62 = __main__.globals :: static - r63 = 'Foo' - r64 = CPyDict_SetItem(r62, r63, r61) - r65 = r64 >= 0 :: signed - r66 = 'Bar' - r67 = __main__.globals :: static - r68 = 'Foo' - r69 = CPyDict_GetItem(r67, r68) - r70 = __main__.globals :: static - r71 = 'NewType' - r72 = CPyDict_GetItem(r70, r71) - r73 = PyObject_CallFunctionObjArgs(r72, r66, r69, 0) - r74 = __main__.globals :: static - r75 = 'Bar' - r76 = CPyDict_SetItem(r74, r75, r73) - r77 = r76 >= 0 :: signed - r78 = PyList_New(3) - r79 = box(short_int, 2) - r80 = box(short_int, 4) - r81 = box(short_int, 6) - r82 = get_element_ptr r78 ob_item :: PyListObject - r83 = load_mem r82 :: ptr* - set_mem r83, r79 :: builtins.object* - r84 = r83 + WORD_SIZE*1 - set_mem r84, r80 :: builtins.object* - r85 = r83 + WORD_SIZE*2 - set_mem r85, r81 :: builtins.object* - keep_alive r78 - r86 = __main__.globals :: static - r87 = 'Bar' - r88 = CPyDict_GetItem(r86, r87) - r89 = PyObject_CallFunctionObjArgs(r88, r78, 0) - r90 = __main__.globals :: static - r91 = 'y' - r92 = CPyDict_SetItem(r90, r91, r89) - r93 = r92 >= 0 :: signed + r54 = 'Lol' + r55 = CPyDict_GetItem(r53, r54) + r56 = box(short_int, 2) + r57 = PyObject_CallFunctionObjArgs(r55, r56, r52, 0) + r58 = cast(tuple, r57) + r59 = __main__.globals :: static + r60 = 'x' + r61 = CPyDict_SetItem(r59, r60, r58) + r62 = r61 >= 0 :: signed + r63 = __main__.globals :: static + r64 = 'List' + r65 = CPyDict_GetItem(r63, r64) + r66 = load_address PyLong_Type + r67 = PyObject_GetItem(r65, r66) + r68 = __main__.globals :: static + r69 = 'Foo' + r70 = CPyDict_SetItem(r68, r69, r67) + r71 = r70 >= 0 :: signed + r72 = 'Bar' + r73 = __main__.globals :: static + r74 = 'Foo' + r75 = CPyDict_GetItem(r73, r74) + r76 = __main__.globals :: static + r77 = 'NewType' + r78 = CPyDict_GetItem(r76, r77) + r79 = PyObject_CallFunctionObjArgs(r78, r72, r75, 0) + r80 = __main__.globals :: static + r81 = 'Bar' + r82 = CPyDict_SetItem(r80, r81, r79) + r83 = r82 >= 0 :: signed + r84 = PyList_New(3) + r85 = box(short_int, 2) + r86 = box(short_int, 4) + r87 = box(short_int, 6) + r88 = get_element_ptr r84 ob_item :: PyListObject + r89 = load_mem r88 :: ptr* + set_mem r89, r85 :: builtins.object* + r90 = r89 + WORD_SIZE*1 + set_mem r90, r86 :: builtins.object* + r91 = r89 + WORD_SIZE*2 + set_mem r91, r87 :: builtins.object* + keep_alive r84 + r92 = __main__.globals :: static + r93 = 'Bar' + r94 = CPyDict_GetItem(r92, r93) + r95 = PyObject_CallFunctionObjArgs(r94, r84, 0) + r96 = __main__.globals :: static + r97 = 'y' + r98 = CPyDict_SetItem(r96, r97, r95) + r99 = r98 >= 0 :: signed return 1 [case testChainedConditional] @@ -2987,29 +2997,32 @@ def __top_level__(): r0, r1 :: object r2 :: bit r3 :: str - r4, r5, r6 :: object - r7 :: bit - r8 :: str - r9, r10 :: object - r11 :: dict - r12 :: str - r13 :: object + r4 :: object + r5 :: dict + r6 :: str + r7 :: list + r8, r9 :: ptr + r10 :: str + r11, r12 :: object + r13 :: dict r14 :: str - r15 :: int32 - r16 :: bit - r17 :: dict - r18 :: str - r19 :: object - r20 :: dict - r21 :: str - r22, r23 :: object - r24 :: dict - r25 :: str - r26, r27 :: object - r28 :: dict - r29 :: str - r30 :: int32 - r31 :: bit + r15 :: object + r16 :: str + r17 :: int32 + r18 :: bit + r19 :: dict + r20 :: str + r21 :: object + r22 :: dict + r23 :: str + r24, r25 :: object + r26 :: dict + r27 :: str + r28, r29 :: object + r30 :: dict + r31 :: str + r32 :: int32 + r33 :: bit L0: r0 = builtins :: module r1 = load_address _Py_NoneStruct @@ -3020,37 +3033,38 @@ L1: r4 = PyImport_Import(r3) builtins = r4 :: module L2: - r5 = typing :: module - r6 = load_address _Py_NoneStruct - r7 = r5 != r6 - if r7 goto L4 else goto L3 :: bool -L3: - r8 = 'typing' - r9 = PyImport_Import(r8) - typing = r9 :: module -L4: - r10 = typing :: module - r11 = __main__.globals :: static - r12 = 'Callable' - r13 = CPyObject_GetAttr(r10, r12) + r5 = __main__.globals :: static + r6 = 'Callable' + r7 = PyList_New(1) + r8 = get_element_ptr r7 ob_item :: PyListObject + r9 = load_mem r8 :: ptr* + set_mem r9, r6 :: builtins.object* + keep_alive r7 + r10 = 'typing' + r11 = PyImport_ImportModuleLevelObject(r10, r5, 0, r7, 0) + typing = r11 :: module + r12 = typing :: module + r13 = __main__.globals :: static r14 = 'Callable' - r15 = CPyDict_SetItem(r11, r14, r13) - r16 = r15 >= 0 :: signed - r17 = __main__.globals :: static - r18 = '__mypyc_c_decorator_helper__' - r19 = CPyDict_GetItem(r17, r18) - r20 = __main__.globals :: static - r21 = 'b' - r22 = CPyDict_GetItem(r20, r21) - r23 = PyObject_CallFunctionObjArgs(r22, r19, 0) - r24 = __main__.globals :: static - r25 = 'a' - r26 = CPyDict_GetItem(r24, r25) - r27 = PyObject_CallFunctionObjArgs(r26, r23, 0) - r28 = __main__.globals :: static - r29 = 'c' - r30 = CPyDict_SetItem(r28, r29, r27) - r31 = r30 >= 0 :: signed + r15 = CPyObject_GetAttr(r12, r14) + r16 = 'Callable' + r17 = CPyDict_SetItem(r13, r16, r15) + r18 = r17 >= 0 :: signed + r19 = __main__.globals :: static + r20 = '__mypyc_c_decorator_helper__' + r21 = CPyDict_GetItem(r19, r20) + r22 = __main__.globals :: static + r23 = 'b' + r24 = CPyDict_GetItem(r22, r23) + r25 = PyObject_CallFunctionObjArgs(r24, r21, 0) + r26 = __main__.globals :: static + r27 = 'a' + r28 = CPyDict_GetItem(r26, r27) + r29 = PyObject_CallFunctionObjArgs(r28, r25, 0) + r30 = __main__.globals :: static + r31 = 'c' + r32 = CPyDict_SetItem(r30, r31, r29) + r33 = r32 >= 0 :: signed return 1 [case testDecoratorsSimple_toplevel] @@ -3125,16 +3139,19 @@ def __top_level__(): r0, r1 :: object r2 :: bit r3 :: str - r4, r5, r6 :: object - r7 :: bit - r8 :: str - r9, r10 :: object - r11 :: dict - r12 :: str - r13 :: object + r4 :: object + r5 :: dict + r6 :: str + r7 :: list + r8, r9 :: ptr + r10 :: str + r11, r12 :: object + r13 :: dict r14 :: str - r15 :: int32 - r16 :: bit + r15 :: object + r16 :: str + r17 :: int32 + r18 :: bit L0: r0 = builtins :: module r1 = load_address _Py_NoneStruct @@ -3145,22 +3162,23 @@ L1: r4 = PyImport_Import(r3) builtins = r4 :: module L2: - r5 = typing :: module - r6 = load_address _Py_NoneStruct - r7 = r5 != r6 - if r7 goto L4 else goto L3 :: bool -L3: - r8 = 'typing' - r9 = PyImport_Import(r8) - typing = r9 :: module -L4: - r10 = typing :: module - r11 = __main__.globals :: static - r12 = 'Callable' - r13 = CPyObject_GetAttr(r10, r12) + r5 = __main__.globals :: static + r6 = 'Callable' + r7 = PyList_New(1) + r8 = get_element_ptr r7 ob_item :: PyListObject + r9 = load_mem r8 :: ptr* + set_mem r9, r6 :: builtins.object* + keep_alive r7 + r10 = 'typing' + r11 = PyImport_ImportModuleLevelObject(r10, r5, 0, r7, 0) + typing = r11 :: module + r12 = typing :: module + r13 = __main__.globals :: static r14 = 'Callable' - r15 = CPyDict_SetItem(r11, r14, r13) - r16 = r15 >= 0 :: signed + r15 = CPyObject_GetAttr(r12, r14) + r16 = 'Callable' + r17 = CPyDict_SetItem(r13, r16, r15) + r18 = r17 >= 0 :: signed return 1 [case testAnyAllG] diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index 1db7cba6f249..5962c2498768 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -290,81 +290,86 @@ def __top_level__(): r0, r1 :: object r2 :: bit r3 :: str - r4, r5, r6 :: object - r7 :: bit - r8 :: str - r9, r10 :: object - r11 :: dict + r4 :: object + r5 :: dict + r6, r7 :: str + r8 :: list + r9, r10, r11 :: ptr r12 :: str - r13 :: object - r14 :: str - r15 :: int32 - r16 :: bit - r17 :: str - r18 :: object - r19 :: str - r20 :: int32 - r21 :: bit - r22, r23 :: object - r24 :: bit - r25 :: str - r26, r27 :: object - r28 :: dict - r29 :: str - r30 :: object + r13, r14 :: object + r15 :: dict + r16 :: str + r17 :: object + r18 :: str + r19 :: int32 + r20 :: bit + r21 :: str + r22 :: object + r23 :: str + r24 :: int32 + r25 :: bit + r26 :: dict + r27 :: str + r28 :: list + r29, r30 :: ptr r31 :: str - r32 :: int32 - r33 :: bit - r34 :: str - r35 :: dict - r36 :: str - r37, r38 :: object - r39 :: dict + r32, r33 :: object + r34 :: dict + r35 :: str + r36 :: object + r37 :: str + r38 :: int32 + r39 :: bit r40 :: str - r41 :: int32 - r42 :: bit - r43 :: object - r44 :: str - r45, r46 :: object - r47 :: bool - r48 :: str - r49 :: tuple - r50 :: int32 - r51 :: bit - r52 :: dict - r53 :: str - r54 :: int32 - r55 :: bit - r56 :: object - r57 :: str - r58, r59 :: object - r60 :: str - r61 :: tuple - r62 :: int32 - r63 :: bit - r64 :: dict - r65 :: str - r66 :: int32 - r67 :: bit - r68, r69 :: object + r41 :: dict + r42 :: str + r43, r44 :: object + r45 :: dict + r46 :: str + r47 :: int32 + r48 :: bit + r49 :: object + r50 :: str + r51, r52 :: object + r53 :: bool + r54 :: str + r55 :: tuple + r56 :: int32 + r57 :: bit + r58 :: dict + r59 :: str + r60 :: int32 + r61 :: bit + r62 :: object + r63 :: str + r64, r65 :: object + r66 :: str + r67 :: tuple + r68 :: int32 + r69 :: bit r70 :: dict r71 :: str - r72 :: object - r73 :: dict - r74 :: str - r75, r76 :: object - r77 :: tuple - r78 :: str - r79, r80 :: object - r81 :: bool - r82, r83 :: str - r84 :: tuple - r85 :: int32 - r86 :: bit - r87 :: dict - r88 :: str - r89 :: int32 - r90 :: bit + r72 :: int32 + r73 :: bit + r74, r75 :: object + r76 :: dict + r77 :: str + r78 :: object + r79 :: dict + r80 :: str + r81, r82 :: object + r83 :: tuple + r84 :: str + r85, r86 :: object + r87 :: bool + r88, r89 :: str + r90 :: tuple + r91 :: int32 + r92 :: bit + r93 :: dict + r94 :: str + r95 :: int32 + r96 :: bit L0: r0 = builtins :: module r1 = load_address _Py_NoneStruct @@ -375,103 +380,108 @@ L1: r4 = PyImport_Import(r3) builtins = r4 :: module L2: - r5 = typing :: module - r6 = load_address _Py_NoneStruct - r7 = r5 != r6 - if r7 goto L4 else goto L3 :: bool -L3: - r8 = 'typing' - r9 = PyImport_Import(r8) - typing = r9 :: module -L4: - r10 = typing :: module - r11 = __main__.globals :: static - r12 = 'TypeVar' - r13 = CPyObject_GetAttr(r10, r12) - r14 = 'TypeVar' - r15 = CPyDict_SetItem(r11, r14, r13) - r16 = r15 >= 0 :: signed - r17 = 'Generic' - r18 = CPyObject_GetAttr(r10, r17) - r19 = 'Generic' - r20 = CPyDict_SetItem(r11, r19, r18) - r21 = r20 >= 0 :: signed - r22 = mypy_extensions :: module - r23 = load_address _Py_NoneStruct - r24 = r22 != r23 - if r24 goto L6 else goto L5 :: bool -L5: - r25 = 'mypy_extensions' - r26 = PyImport_Import(r25) - mypy_extensions = r26 :: module -L6: - r27 = mypy_extensions :: module - r28 = __main__.globals :: static - r29 = 'trait' - r30 = CPyObject_GetAttr(r27, r29) - r31 = 'trait' - r32 = CPyDict_SetItem(r28, r31, r30) - r33 = r32 >= 0 :: signed - r34 = 'T' - r35 = __main__.globals :: static - r36 = 'TypeVar' - r37 = CPyDict_GetItem(r35, r36) - r38 = PyObject_CallFunctionObjArgs(r37, r34, 0) - r39 = __main__.globals :: static + r5 = __main__.globals :: static + r6 = 'TypeVar' + r7 = 'Generic' + r8 = PyList_New(2) + r9 = get_element_ptr r8 ob_item :: PyListObject + r10 = load_mem r9 :: ptr* + set_mem r10, r6 :: builtins.object* + r11 = r10 + WORD_SIZE*1 + set_mem r11, r7 :: builtins.object* + keep_alive r8 + r12 = 'typing' + r13 = PyImport_ImportModuleLevelObject(r12, r5, 0, r8, 0) + typing = r13 :: module + r14 = typing :: module + r15 = __main__.globals :: static + r16 = 'TypeVar' + r17 = CPyObject_GetAttr(r14, r16) + r18 = 'TypeVar' + r19 = CPyDict_SetItem(r15, r18, r17) + r20 = r19 >= 0 :: signed + r21 = 'Generic' + r22 = CPyObject_GetAttr(r14, r21) + r23 = 'Generic' + r24 = CPyDict_SetItem(r15, r23, r22) + r25 = r24 >= 0 :: signed + r26 = __main__.globals :: static + r27 = 'trait' + r28 = PyList_New(1) + r29 = get_element_ptr r28 ob_item :: PyListObject + r30 = load_mem r29 :: ptr* + set_mem r30, r27 :: builtins.object* + keep_alive r28 + r31 = 'mypy_extensions' + r32 = PyImport_ImportModuleLevelObject(r31, r26, 0, r28, 0) + mypy_extensions = r32 :: module + r33 = mypy_extensions :: module + r34 = __main__.globals :: static + r35 = 'trait' + r36 = CPyObject_GetAttr(r33, r35) + r37 = 'trait' + r38 = CPyDict_SetItem(r34, r37, r36) + r39 = r38 >= 0 :: signed r40 = 'T' - r41 = CPyDict_SetItem(r39, r40, r38) - r42 = r41 >= 0 :: signed - r43 = :: object - r44 = '__main__' - r45 = __main__.C_template :: type - r46 = CPyType_FromTemplate(r45, r43, r44) - r47 = C_trait_vtable_setup() - r48 = '__mypyc_attrs__' - r49 = PyTuple_Pack(0) - r50 = PyObject_SetAttr(r46, r48, r49) - r51 = r50 >= 0 :: signed - __main__.C = r46 :: type - r52 = __main__.globals :: static - r53 = 'C' - r54 = CPyDict_SetItem(r52, r53, r46) - r55 = r54 >= 0 :: signed - r56 = :: object - r57 = '__main__' - r58 = __main__.S_template :: type - r59 = CPyType_FromTemplate(r58, r56, r57) - r60 = '__mypyc_attrs__' - r61 = PyTuple_Pack(0) - r62 = PyObject_SetAttr(r59, r60, r61) - r63 = r62 >= 0 :: signed - __main__.S = r59 :: type - r64 = __main__.globals :: static - r65 = 'S' - r66 = CPyDict_SetItem(r64, r65, r59) - r67 = r66 >= 0 :: signed - r68 = __main__.C :: type - r69 = __main__.S :: type + r41 = __main__.globals :: static + r42 = 'TypeVar' + r43 = CPyDict_GetItem(r41, r42) + r44 = PyObject_CallFunctionObjArgs(r43, r40, 0) + r45 = __main__.globals :: static + r46 = 'T' + r47 = CPyDict_SetItem(r45, r46, r44) + r48 = r47 >= 0 :: signed + r49 = :: object + r50 = '__main__' + r51 = __main__.C_template :: type + r52 = CPyType_FromTemplate(r51, r49, r50) + r53 = C_trait_vtable_setup() + r54 = '__mypyc_attrs__' + r55 = PyTuple_Pack(0) + r56 = PyObject_SetAttr(r52, r54, r55) + r57 = r56 >= 0 :: signed + __main__.C = r52 :: type + r58 = __main__.globals :: static + r59 = 'C' + r60 = CPyDict_SetItem(r58, r59, r52) + r61 = r60 >= 0 :: signed + r62 = :: object + r63 = '__main__' + r64 = __main__.S_template :: type + r65 = CPyType_FromTemplate(r64, r62, r63) + r66 = '__mypyc_attrs__' + r67 = PyTuple_Pack(0) + r68 = PyObject_SetAttr(r65, r66, r67) + r69 = r68 >= 0 :: signed + __main__.S = r65 :: type r70 = __main__.globals :: static - r71 = 'Generic' - r72 = CPyDict_GetItem(r70, r71) - r73 = __main__.globals :: static - r74 = 'T' - r75 = CPyDict_GetItem(r73, r74) - r76 = PyObject_GetItem(r72, r75) - r77 = PyTuple_Pack(3, r68, r69, r76) - r78 = '__main__' - r79 = __main__.D_template :: type - r80 = CPyType_FromTemplate(r79, r77, r78) - r81 = D_trait_vtable_setup() - r82 = '__mypyc_attrs__' - r83 = '__dict__' - r84 = PyTuple_Pack(1, r83) - r85 = PyObject_SetAttr(r80, r82, r84) - r86 = r85 >= 0 :: signed - __main__.D = r80 :: type - r87 = __main__.globals :: static - r88 = 'D' - r89 = CPyDict_SetItem(r87, r88, r80) - r90 = r89 >= 0 :: signed + r71 = 'S' + r72 = CPyDict_SetItem(r70, r71, r65) + r73 = r72 >= 0 :: signed + r74 = __main__.C :: type + r75 = __main__.S :: type + r76 = __main__.globals :: static + r77 = 'Generic' + r78 = CPyDict_GetItem(r76, r77) + r79 = __main__.globals :: static + r80 = 'T' + r81 = CPyDict_GetItem(r79, r80) + r82 = PyObject_GetItem(r78, r81) + r83 = PyTuple_Pack(3, r74, r75, r82) + r84 = '__main__' + r85 = __main__.D_template :: type + r86 = CPyType_FromTemplate(r85, r83, r84) + r87 = D_trait_vtable_setup() + r88 = '__mypyc_attrs__' + r89 = '__dict__' + r90 = PyTuple_Pack(1, r89) + r91 = PyObject_SetAttr(r86, r88, r90) + r92 = r91 >= 0 :: signed + __main__.D = r86 :: type + r93 = __main__.globals :: static + r94 = 'D' + r95 = CPyDict_SetItem(r93, r94, r86) + r96 = r95 >= 0 :: signed return 1 [case testIsInstance] diff --git a/mypyc/test-data/run-imports.test b/mypyc/test-data/run-imports.test index 78b167861ae8..cff07f158190 100644 --- a/mypyc/test-data/run-imports.test +++ b/mypyc/test-data/run-imports.test @@ -86,6 +86,49 @@ def g(x: int) -> int: from native import f assert f(1) == 2 +[case testFromImportWithUntypedModule] + +# avoid including an __init__.py and use type: ignore to test what happens +# if mypy can't tell if mod isn't a module +from pkg import mod # type: ignore + +def test_import() -> None: + assert mod.h(8) == 24 + +[file pkg/mod.py] +def h(x): + return x * 3 + +[case testFromImportWithKnownModule] +from pkg import mod + +def test_import() -> None: + assert mod.h(8) == 24 + +[file pkg/__init__.py] +[file pkg/mod.py] +def h(x: int) -> int: + return x * 3 + +[case testMultipleFromImportsWithSamePackageButDifferentModules] +from pkg import a +from pkg import b + +def test_import() -> None: + assert a.g() == 4 + assert b.h() == 39 + +[file pkg/__init__.py] +[file pkg/a.py] + +def g() -> int: + return 4 + +[file pkg/b.py] + +def h() -> int: + return 39 + [case testReexport] # Test that we properly handle accessing values that have been reexported import a