Skip to content

Commit

Permalink
Support pickling of extension classes (#7481)
Browse files Browse the repository at this point in the history
This operates by providing default implementations of `__getstate__`
and `__setstate__` for extension classes. Our implementations work by
storing a `__mypyc_attrs__` tuple in each class that we generate and
collecting all of the attributes in it into a dict.

Fixes #697.
  • Loading branch information
msullivan committed Sep 9, 2019
1 parent 9f1b8e9 commit 88e2b67
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 75 deletions.
11 changes: 3 additions & 8 deletions mypy/types.py
@@ -1,5 +1,6 @@
"""Classes for representing mypy types."""

import copy
import sys
from abc import abstractmethod
from collections import OrderedDict
Expand All @@ -17,7 +18,7 @@
FuncDef,
)
from mypy.sharedparse import argument_elide_name
from mypy.util import IdMapper, replace_object_state
from mypy.util import IdMapper
from mypy.bogus_type import Bogus


Expand Down Expand Up @@ -2077,13 +2078,7 @@ def copy_type(t: TP) -> TP:
"""
Build a copy of the type; used to mutate the copy with truthiness information
"""
# We'd like to just do a copy.copy(), but mypyc types aren't
# pickleable so we hack around it by manually creating a new type
# and copying everything in with replace_object_state.
typ = type(t)
nt = typ.__new__(typ)
replace_object_state(nt, t, copy_dict=True)
return nt
return copy.copy(t)


def function_type(func: mypy.nodes.FuncBase, fallback: Instance) -> FunctionLike:
Expand Down
8 changes: 8 additions & 0 deletions mypyc/emitclass.py
Expand Up @@ -550,6 +550,14 @@ def generate_methods_table(cl: ClassIR,
flags.append('METH_CLASS')

emitter.emit_line(' {}, NULL}},'.format(' | '.join(flags)))

# Provide a default __getstate__ and __setstate__
if not cl.has_method('__setstate__') and not cl.has_method('__getstate__'):
emitter.emit_lines(
'{"__setstate__", (PyCFunction)CPyPickle_SetState, METH_O, NULL},',
'{"__getstate__", (PyCFunction)CPyPickle_GetState, METH_NOARGS, NULL},',
)

emitter.emit_line('{NULL} /* Sentinel */')
emitter.emit_line('};')

Expand Down
24 changes: 19 additions & 5 deletions mypyc/genops.py
Expand Up @@ -531,9 +531,10 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef,
ir = mapper.type_to_ir[cdef.info]
info = cdef.info

for name, node in info.names.items():
# We sort the table for determinism here on Python 3.5
for name, node in sorted(info.names.items()):
if isinstance(node.node, Var):
assert node.node.type, "Class member missing type"
assert node.node.type, "Class member %s missing type" % name
if not node.node.is_classvar and name != '__slots__':
ir.attributes[name] = mapper.type_to_rtype(node.node.type)
elif isinstance(node.node, (FuncDef, Decorator)):
Expand Down Expand Up @@ -599,9 +600,8 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef,
base_mro.append(base_ir)
mro.append(base_ir)

# Generic and similar are python base classes
if cdef.removed_base_type_exprs:
ir.inherits_python = True
if cls.defn.removed_base_type_exprs or not base_ir.is_ext_class:
ir.inherits_python = True

base_idx = 1 if not ir.is_trait else 0
if len(base_mro) > base_idx:
Expand Down Expand Up @@ -1477,6 +1477,14 @@ def visit_class_def(self, cdef: ClassDef) -> None:
# Set this attribute back to None until the next non-extension class is visited.
self.non_ext_info = None

def create_mypyc_attrs_tuple(self, ir: ClassIR, line: int) -> Value:
attrs = [name for ancestor in ir.mro for name in ancestor.attributes]
if ir.inherits_python:
attrs.append('__dict__')
return self.primitive_op(new_tuple_op,
[self.load_static_unicode(attr) for attr in attrs],
line)

def allocate_class(self, cdef: ClassDef) -> None:
# OK AND NOW THE FUN PART
base_exprs = cdef.base_type_exprs + cdef.removed_base_type_exprs
Expand All @@ -1496,6 +1504,12 @@ def allocate_class(self, cdef: ClassDef) -> None:
FuncDecl(cdef.name + '_trait_vtable_setup',
None, self.module_name,
FuncSignature([], bool_rprimitive)), [], -1))
# Populate a '__mypyc_attrs__' field containing the list of attrs
self.primitive_op(py_setattr_op, [
tp, self.load_static_unicode('__mypyc_attrs__'),
self.create_mypyc_attrs_tuple(self.mapper.type_to_ir[cdef.info], cdef.line)],
cdef.line)

# Save the class
self.add(InitStatic(tp, cdef.name, self.module_name, NAMESPACE_TYPE))

Expand Down
62 changes: 62 additions & 0 deletions mypyc/lib-rt/CPy.h
Expand Up @@ -1333,6 +1333,68 @@ static int CPy_YieldFromErrorHandle(PyObject *iter, PyObject **outp)
return 2;
}

// Support for pickling; reusable getstate and setstate functions
static PyObject *
CPyPickle_SetState(PyObject *obj, PyObject *state)
{
Py_ssize_t pos = 0;
PyObject *key, *value;
while (PyDict_Next(state, &pos, &key, &value)) {
if (PyObject_SetAttr(obj, key, value) != 0) {
return NULL;
}
}
Py_RETURN_NONE;
}

static PyObject *
CPyPickle_GetState(PyObject *obj)
{
PyObject *attrs = NULL, *state = NULL;

attrs = PyObject_GetAttrString((PyObject *)Py_TYPE(obj), "__mypyc_attrs__");
if (!attrs) {
goto fail;
}
if (!PyTuple_Check(attrs)) {
PyErr_SetString(PyExc_TypeError, "__mypyc_attrs__ is not a tuple");
goto fail;
}
state = PyDict_New();
if (!state) {
goto fail;
}

// Collect all the values of attributes in __mypyc_attrs__
// Attributes that are missing we just ignore
int i;
for (i = 0; i < PyTuple_GET_SIZE(attrs); i++) {
PyObject *key = PyTuple_GET_ITEM(attrs, i);
PyObject *value = PyObject_GetAttr(obj, key);
if (!value) {
if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
PyErr_Clear();
continue;
}
goto fail;
}
int result = PyDict_SetItem(state, key, value);
Py_DECREF(value);
if (result != 0) {
goto fail;
}
}

Py_DECREF(attrs);

return state;
fail:
Py_XDECREF(attrs);
Py_XDECREF(state);
return NULL;
}


int CPyArg_ParseTupleAndKeywords(PyObject *, PyObject *,
const char *, char **, ...);

Expand Down
6 changes: 3 additions & 3 deletions mypyc/test-data/genops-basic.test
Expand Up @@ -1588,11 +1588,11 @@ def g(a):
r4 :: str
r5, r6 :: None
L0:
r0 = unicode_3 :: static ('a')
r0 = unicode_4 :: static ('a')
r1 = 0
r2 = a.f(r1, r0)
r3 = 1
r4 = unicode_4 :: static ('b')
r4 = unicode_5 :: static ('b')
r5 = a.f(r3, r4)
r6 = None
return r6
Expand Down Expand Up @@ -1828,7 +1828,7 @@ L1:
L2:
if is_error(z) goto L3 else goto L4
L3:
r1 = unicode_3 :: static ('test')
r1 = unicode_4 :: static ('test')
z = r1
L4:
r2 = None
Expand Down
134 changes: 76 additions & 58 deletions mypyc/test-data/genops-classes.test
Expand Up @@ -94,7 +94,7 @@ def g(a):
r3 :: None
L0:
r0 = 1
r1 = unicode_3 :: static ('hi')
r1 = unicode_4 :: static ('hi')
r2 = a.f(r0, r1)
r3 = None
return r3
Expand Down Expand Up @@ -360,31 +360,40 @@ def __top_level__():
r40 :: str
r41, r42 :: object
r43 :: bool
r44 :: dict
r45 :: str
r44 :: str
r45 :: tuple
r46 :: bool
r47 :: object
r47 :: dict
r48 :: str
r49, r50 :: object
r51 :: bool
r52 :: dict
r53 :: str
r49 :: bool
r50 :: object
r51 :: str
r52, r53 :: object
r54 :: bool
r55, r56 :: object
r57 :: dict
r58 :: str
r59 :: object
r60 :: dict
r61 :: str
r62, r63 :: object
r64 :: tuple
r65 :: str
r66, r67 :: object
r68 :: bool
r69 :: dict
r70 :: str
r71 :: bool
r72 :: None
r55 :: str
r56 :: tuple
r57 :: bool
r58 :: dict
r59 :: str
r60 :: bool
r61, r62 :: object
r63 :: dict
r64 :: str
r65 :: object
r66 :: dict
r67 :: str
r68, r69 :: object
r70 :: tuple
r71 :: str
r72, r73 :: object
r74 :: bool
r75, r76 :: str
r77 :: tuple
r78 :: bool
r79 :: dict
r80 :: str
r81 :: bool
r82 :: None
L0:
r0 = builtins.module :: static
r1 = builtins.None :: object
Expand Down Expand Up @@ -442,39 +451,49 @@ L6:
r41 = __main__.C_template :: type
r42 = pytype_from_template(r41, r39, r40)
r43 = C_trait_vtable_setup()
r44 = unicode_8 :: static ('__mypyc_attrs__')
r45 = () :: tuple
r46 = setattr r42, r44, r45
__main__.C = r42 :: type
r44 = __main__.globals :: static
r45 = unicode_8 :: static ('C')
r46 = r44.__setitem__(r45, r42) :: dict
r47 = <error> :: object
r48 = unicode_7 :: static ('__main__')
r49 = __main__.S_template :: type
r50 = pytype_from_template(r49, r47, r48)
r51 = S_trait_vtable_setup()
__main__.S = r50 :: type
r52 = __main__.globals :: static
r53 = unicode_9 :: static ('S')
r54 = r52.__setitem__(r53, r50) :: dict
r55 = __main__.C :: type
r56 = __main__.S :: type
r57 = __main__.globals :: static
r58 = unicode_3 :: static ('Generic')
r59 = r57[r58] :: dict
r60 = __main__.globals :: static
r61 = unicode_6 :: static ('T')
r62 = r60[r61] :: dict
r63 = r59[r62] :: object
r64 = (r55, r56, r63) :: tuple
r65 = unicode_7 :: static ('__main__')
r66 = __main__.D_template :: type
r67 = pytype_from_template(r66, r64, r65)
r68 = D_trait_vtable_setup()
__main__.D = r67 :: type
r69 = __main__.globals :: static
r70 = unicode_10 :: static ('D')
r71 = r69.__setitem__(r70, r67) :: dict
r72 = None
return r72
r47 = __main__.globals :: static
r48 = unicode_9 :: static ('C')
r49 = r47.__setitem__(r48, r42) :: dict
r50 = <error> :: object
r51 = unicode_7 :: static ('__main__')
r52 = __main__.S_template :: type
r53 = pytype_from_template(r52, r50, r51)
r54 = S_trait_vtable_setup()
r55 = unicode_8 :: static ('__mypyc_attrs__')
r56 = () :: tuple
r57 = setattr r53, r55, r56
__main__.S = r53 :: type
r58 = __main__.globals :: static
r59 = unicode_10 :: static ('S')
r60 = r58.__setitem__(r59, r53) :: dict
r61 = __main__.C :: type
r62 = __main__.S :: type
r63 = __main__.globals :: static
r64 = unicode_3 :: static ('Generic')
r65 = r63[r64] :: dict
r66 = __main__.globals :: static
r67 = unicode_6 :: static ('T')
r68 = r66[r67] :: dict
r69 = r65[r68] :: object
r70 = (r61, r62, r69) :: tuple
r71 = unicode_7 :: static ('__main__')
r72 = __main__.D_template :: type
r73 = pytype_from_template(r72, r70, r71)
r74 = D_trait_vtable_setup()
r75 = unicode_8 :: static ('__mypyc_attrs__')
r76 = unicode_11 :: static ('__dict__')
r77 = (r76) :: tuple
r78 = setattr r73, r75, r77
__main__.D = r73 :: type
r79 = __main__.globals :: static
r80 = unicode_12 :: static ('D')
r81 = r79.__setitem__(r80, r73) :: dict
r82 = None
return r82

[case testIsInstance]
class A: pass
Expand Down Expand Up @@ -785,12 +804,11 @@ def f():
r3 :: int
L0:
r0 = __main__.A :: type
r1 = unicode_5 :: static ('x')
r1 = unicode_6 :: static ('x')
r2 = getattr r0, r1
r3 = unbox(int, r2)
return r3


[case testNoEqDefined]
class A:
pass
Expand Down Expand Up @@ -1014,7 +1032,7 @@ L0:
r0 = 10
__mypyc_self__.x = r0; r1 = is_error
r2 = __main__.globals :: static
r3 = unicode_7 :: static ('LOL')
r3 = unicode_9 :: static ('LOL')
r4 = r2[r3] :: dict
r5 = cast(str, r4)
__mypyc_self__.y = r5; r6 = is_error
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/genops-optional.test
Expand Up @@ -384,7 +384,7 @@ def set(o, s):
r1 :: bool
r2 :: None
L0:
r0 = unicode_6 :: static ('a')
r0 = unicode_5 :: static ('a')
r1 = setattr o, r0, s
r2 = None
return r2
Expand Down

0 comments on commit 88e2b67

Please sign in to comment.