diff --git a/mypy/types.py b/mypy/types.py index 796bc8025437..fb84152008b4 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1,5 +1,6 @@ """Classes for representing mypy types.""" +import copy import sys from abc import abstractmethod from collections import OrderedDict @@ -17,7 +18,7 @@ FuncDef, ) from mypy.sharedparse import argument_elide_name -from mypy.util import IdMapper, replace_object_state +from mypy.util import IdMapper from mypy.bogus_type import Bogus @@ -2077,13 +2078,7 @@ def copy_type(t: TP) -> TP: """ Build a copy of the type; used to mutate the copy with truthiness information """ - # We'd like to just do a copy.copy(), but mypyc types aren't - # pickleable so we hack around it by manually creating a new type - # and copying everything in with replace_object_state. - typ = type(t) - nt = typ.__new__(typ) - replace_object_state(nt, t, copy_dict=True) - return nt + return copy.copy(t) def function_type(func: mypy.nodes.FuncBase, fallback: Instance) -> FunctionLike: diff --git a/mypyc/emitclass.py b/mypyc/emitclass.py index dac16ea64c8b..542e3ecd68d5 100644 --- a/mypyc/emitclass.py +++ b/mypyc/emitclass.py @@ -550,6 +550,14 @@ def generate_methods_table(cl: ClassIR, flags.append('METH_CLASS') emitter.emit_line(' {}, NULL}},'.format(' | '.join(flags))) + + # Provide a default __getstate__ and __setstate__ + if not cl.has_method('__setstate__') and not cl.has_method('__getstate__'): + emitter.emit_lines( + '{"__setstate__", (PyCFunction)CPyPickle_SetState, METH_O, NULL},', + '{"__getstate__", (PyCFunction)CPyPickle_GetState, METH_NOARGS, NULL},', + ) + emitter.emit_line('{NULL} /* Sentinel */') emitter.emit_line('};') diff --git a/mypyc/genops.py b/mypyc/genops.py index 09528b77ee69..26246cef0e4e 100644 --- a/mypyc/genops.py +++ b/mypyc/genops.py @@ -531,9 +531,10 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef, ir = mapper.type_to_ir[cdef.info] info = cdef.info - for name, node in info.names.items(): + # We sort the table for determinism here on Python 3.5 + for name, node in sorted(info.names.items()): if isinstance(node.node, Var): - assert node.node.type, "Class member missing type" + assert node.node.type, "Class member %s missing type" % name if not node.node.is_classvar and name != '__slots__': ir.attributes[name] = mapper.type_to_rtype(node.node.type) elif isinstance(node.node, (FuncDef, Decorator)): @@ -599,9 +600,8 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef, base_mro.append(base_ir) mro.append(base_ir) - # Generic and similar are python base classes - if cdef.removed_base_type_exprs: - ir.inherits_python = True + if cls.defn.removed_base_type_exprs or not base_ir.is_ext_class: + ir.inherits_python = True base_idx = 1 if not ir.is_trait else 0 if len(base_mro) > base_idx: @@ -1477,6 +1477,14 @@ def visit_class_def(self, cdef: ClassDef) -> None: # Set this attribute back to None until the next non-extension class is visited. self.non_ext_info = None + def create_mypyc_attrs_tuple(self, ir: ClassIR, line: int) -> Value: + attrs = [name for ancestor in ir.mro for name in ancestor.attributes] + if ir.inherits_python: + attrs.append('__dict__') + return self.primitive_op(new_tuple_op, + [self.load_static_unicode(attr) for attr in attrs], + line) + def allocate_class(self, cdef: ClassDef) -> None: # OK AND NOW THE FUN PART base_exprs = cdef.base_type_exprs + cdef.removed_base_type_exprs @@ -1496,6 +1504,12 @@ def allocate_class(self, cdef: ClassDef) -> None: FuncDecl(cdef.name + '_trait_vtable_setup', None, self.module_name, FuncSignature([], bool_rprimitive)), [], -1)) + # Populate a '__mypyc_attrs__' field containing the list of attrs + self.primitive_op(py_setattr_op, [ + tp, self.load_static_unicode('__mypyc_attrs__'), + self.create_mypyc_attrs_tuple(self.mapper.type_to_ir[cdef.info], cdef.line)], + cdef.line) + # Save the class self.add(InitStatic(tp, cdef.name, self.module_name, NAMESPACE_TYPE)) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index c01953511011..da4deeb5a727 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -1333,6 +1333,68 @@ static int CPy_YieldFromErrorHandle(PyObject *iter, PyObject **outp) return 2; } +// Support for pickling; reusable getstate and setstate functions +static PyObject * +CPyPickle_SetState(PyObject *obj, PyObject *state) +{ + Py_ssize_t pos = 0; + PyObject *key, *value; + while (PyDict_Next(state, &pos, &key, &value)) { + if (PyObject_SetAttr(obj, key, value) != 0) { + return NULL; + } + } + Py_RETURN_NONE; +} + +static PyObject * +CPyPickle_GetState(PyObject *obj) +{ + PyObject *attrs = NULL, *state = NULL; + + attrs = PyObject_GetAttrString((PyObject *)Py_TYPE(obj), "__mypyc_attrs__"); + if (!attrs) { + goto fail; + } + if (!PyTuple_Check(attrs)) { + PyErr_SetString(PyExc_TypeError, "__mypyc_attrs__ is not a tuple"); + goto fail; + } + state = PyDict_New(); + if (!state) { + goto fail; + } + + // Collect all the values of attributes in __mypyc_attrs__ + // Attributes that are missing we just ignore + int i; + for (i = 0; i < PyTuple_GET_SIZE(attrs); i++) { + PyObject *key = PyTuple_GET_ITEM(attrs, i); + PyObject *value = PyObject_GetAttr(obj, key); + if (!value) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + continue; + } + goto fail; + } + int result = PyDict_SetItem(state, key, value); + Py_DECREF(value); + if (result != 0) { + goto fail; + } + } + + Py_DECREF(attrs); + + return state; +fail: + Py_XDECREF(attrs); + Py_XDECREF(state); + return NULL; +} + + int CPyArg_ParseTupleAndKeywords(PyObject *, PyObject *, const char *, char **, ...); diff --git a/mypyc/test-data/genops-basic.test b/mypyc/test-data/genops-basic.test index 42278d190f14..45ef856fc1ac 100644 --- a/mypyc/test-data/genops-basic.test +++ b/mypyc/test-data/genops-basic.test @@ -1588,11 +1588,11 @@ def g(a): r4 :: str r5, r6 :: None L0: - r0 = unicode_3 :: static ('a') + r0 = unicode_4 :: static ('a') r1 = 0 r2 = a.f(r1, r0) r3 = 1 - r4 = unicode_4 :: static ('b') + r4 = unicode_5 :: static ('b') r5 = a.f(r3, r4) r6 = None return r6 @@ -1828,7 +1828,7 @@ L1: L2: if is_error(z) goto L3 else goto L4 L3: - r1 = unicode_3 :: static ('test') + r1 = unicode_4 :: static ('test') z = r1 L4: r2 = None diff --git a/mypyc/test-data/genops-classes.test b/mypyc/test-data/genops-classes.test index 416925a4a3cb..4c7a745c5e24 100644 --- a/mypyc/test-data/genops-classes.test +++ b/mypyc/test-data/genops-classes.test @@ -94,7 +94,7 @@ def g(a): r3 :: None L0: r0 = 1 - r1 = unicode_3 :: static ('hi') + r1 = unicode_4 :: static ('hi') r2 = a.f(r0, r1) r3 = None return r3 @@ -360,31 +360,40 @@ def __top_level__(): r40 :: str r41, r42 :: object r43 :: bool - r44 :: dict - r45 :: str + r44 :: str + r45 :: tuple r46 :: bool - r47 :: object + r47 :: dict r48 :: str - r49, r50 :: object - r51 :: bool - r52 :: dict - r53 :: str + r49 :: bool + r50 :: object + r51 :: str + r52, r53 :: object r54 :: bool - r55, r56 :: object - r57 :: dict - r58 :: str - r59 :: object - r60 :: dict - r61 :: str - r62, r63 :: object - r64 :: tuple - r65 :: str - r66, r67 :: object - r68 :: bool - r69 :: dict - r70 :: str - r71 :: bool - r72 :: None + r55 :: str + r56 :: tuple + r57 :: bool + r58 :: dict + r59 :: str + r60 :: bool + r61, r62 :: object + r63 :: dict + r64 :: str + r65 :: object + r66 :: dict + r67 :: str + r68, r69 :: object + r70 :: tuple + r71 :: str + r72, r73 :: object + r74 :: bool + r75, r76 :: str + r77 :: tuple + r78 :: bool + r79 :: dict + r80 :: str + r81 :: bool + r82 :: None L0: r0 = builtins.module :: static r1 = builtins.None :: object @@ -442,39 +451,49 @@ L6: r41 = __main__.C_template :: type r42 = pytype_from_template(r41, r39, r40) r43 = C_trait_vtable_setup() + r44 = unicode_8 :: static ('__mypyc_attrs__') + r45 = () :: tuple + r46 = setattr r42, r44, r45 __main__.C = r42 :: type - r44 = __main__.globals :: static - r45 = unicode_8 :: static ('C') - r46 = r44.__setitem__(r45, r42) :: dict - r47 = :: object - r48 = unicode_7 :: static ('__main__') - r49 = __main__.S_template :: type - r50 = pytype_from_template(r49, r47, r48) - r51 = S_trait_vtable_setup() - __main__.S = r50 :: type - r52 = __main__.globals :: static - r53 = unicode_9 :: static ('S') - r54 = r52.__setitem__(r53, r50) :: dict - r55 = __main__.C :: type - r56 = __main__.S :: type - r57 = __main__.globals :: static - r58 = unicode_3 :: static ('Generic') - r59 = r57[r58] :: dict - r60 = __main__.globals :: static - r61 = unicode_6 :: static ('T') - r62 = r60[r61] :: dict - r63 = r59[r62] :: object - r64 = (r55, r56, r63) :: tuple - r65 = unicode_7 :: static ('__main__') - r66 = __main__.D_template :: type - r67 = pytype_from_template(r66, r64, r65) - r68 = D_trait_vtable_setup() - __main__.D = r67 :: type - r69 = __main__.globals :: static - r70 = unicode_10 :: static ('D') - r71 = r69.__setitem__(r70, r67) :: dict - r72 = None - return r72 + r47 = __main__.globals :: static + r48 = unicode_9 :: static ('C') + r49 = r47.__setitem__(r48, r42) :: dict + r50 = :: object + r51 = unicode_7 :: static ('__main__') + r52 = __main__.S_template :: type + r53 = pytype_from_template(r52, r50, r51) + r54 = S_trait_vtable_setup() + r55 = unicode_8 :: static ('__mypyc_attrs__') + r56 = () :: tuple + r57 = setattr r53, r55, r56 + __main__.S = r53 :: type + r58 = __main__.globals :: static + r59 = unicode_10 :: static ('S') + r60 = r58.__setitem__(r59, r53) :: dict + r61 = __main__.C :: type + r62 = __main__.S :: type + r63 = __main__.globals :: static + r64 = unicode_3 :: static ('Generic') + r65 = r63[r64] :: dict + r66 = __main__.globals :: static + r67 = unicode_6 :: static ('T') + r68 = r66[r67] :: dict + r69 = r65[r68] :: object + r70 = (r61, r62, r69) :: tuple + r71 = unicode_7 :: static ('__main__') + r72 = __main__.D_template :: type + r73 = pytype_from_template(r72, r70, r71) + r74 = D_trait_vtable_setup() + r75 = unicode_8 :: static ('__mypyc_attrs__') + r76 = unicode_11 :: static ('__dict__') + r77 = (r76) :: tuple + r78 = setattr r73, r75, r77 + __main__.D = r73 :: type + r79 = __main__.globals :: static + r80 = unicode_12 :: static ('D') + r81 = r79.__setitem__(r80, r73) :: dict + r82 = None + return r82 [case testIsInstance] class A: pass @@ -785,12 +804,11 @@ def f(): r3 :: int L0: r0 = __main__.A :: type - r1 = unicode_5 :: static ('x') + r1 = unicode_6 :: static ('x') r2 = getattr r0, r1 r3 = unbox(int, r2) return r3 - [case testNoEqDefined] class A: pass @@ -1014,7 +1032,7 @@ L0: r0 = 10 __mypyc_self__.x = r0; r1 = is_error r2 = __main__.globals :: static - r3 = unicode_7 :: static ('LOL') + r3 = unicode_9 :: static ('LOL') r4 = r2[r3] :: dict r5 = cast(str, r4) __mypyc_self__.y = r5; r6 = is_error diff --git a/mypyc/test-data/genops-optional.test b/mypyc/test-data/genops-optional.test index c361c8bb3b71..df13146ecc82 100644 --- a/mypyc/test-data/genops-optional.test +++ b/mypyc/test-data/genops-optional.test @@ -384,7 +384,7 @@ def set(o, s): r1 :: bool r2 :: None L0: - r0 = unicode_6 :: static ('a') + r0 = unicode_5 :: static ('a') r1 = setattr o, r0, s r2 = None return r2 diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index 3985a1302d82..a48c63e1ae07 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -1031,3 +1031,81 @@ try: import native except TypeError as e: assert(str(e) == "mypyc classes can't have a metaclass") + +[case testPickling] +from mypy_extensions import trait +from typing import Any, TypeVar, Generic + +def dec(x: Any) -> Any: + return x + +class A: + x: int + y: str + +class B(A): + z: bool + + def __init__(self, x: int, y: str, z: bool) -> None: + self.x = x + self.y = y + self.z = z + +@trait +class T: + a: str + +class C(B, T): + w: object + + # property shouldn't go in + @property + def foo(self) -> int: + return 0 + +@dec +class D: + x: int + +class E(D): + y: int + + +U = TypeVar('U') + +class F(Generic[U]): + y: int + +class G(F[int]): + pass + +[file driver.py] +from native import A, B, T, C, D, E, F, G + +import copy +import pickle + +assert A.__mypyc_attrs__ == ('x', 'y') +assert B.__mypyc_attrs__ == ('z', 'x', 'y') +assert T.__mypyc_attrs__ == ('a',) +assert C.__mypyc_attrs__ == ('w', 'z', 'x', 'y', 'a') +assert not hasattr(D, '__mypyc_attrs__') +assert E.__mypyc_attrs__ == ('y', '__dict__') +assert F.__mypyc_attrs__ == ('y', '__dict__') +assert G.__mypyc_attrs__ == ('y', '__dict__') + +b = B(10, '20', False) +assert b.__getstate__() == {'z': False, 'x': 10, 'y': '20'} +b2 = copy.copy(b) +assert b is not b2 and b.y == b2.y + +b3 = pickle.loads(pickle.dumps(b)) +assert b is not b3 and b.y == b3.y + +e = E() +e.x = 10 +e.y = 20 + +assert e.__getstate__() == {'y': 20, '__dict__': {'x': 10}} +e2 = pickle.loads(pickle.dumps(e)) +assert e is not e2 and e.x == e2.x and e.y == e2.y