Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pickling of extension classes #7481

Merged
merged 3 commits into from Sep 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 3 additions & 8 deletions mypy/types.py
@@ -1,5 +1,6 @@
"""Classes for representing mypy types."""

import copy
import sys
from abc import abstractmethod
from collections import OrderedDict
Expand All @@ -17,7 +18,7 @@
FuncDef,
)
from mypy.sharedparse import argument_elide_name
from mypy.util import IdMapper, replace_object_state
from mypy.util import IdMapper
from mypy.bogus_type import Bogus


Expand Down Expand Up @@ -2077,13 +2078,7 @@ def copy_type(t: TP) -> TP:
"""
Build a copy of the type; used to mutate the copy with truthiness information
"""
# We'd like to just do a copy.copy(), but mypyc types aren't
# pickleable so we hack around it by manually creating a new type
# and copying everything in with replace_object_state.
typ = type(t)
nt = typ.__new__(typ)
replace_object_state(nt, t, copy_dict=True)
return nt
return copy.copy(t)


def function_type(func: mypy.nodes.FuncBase, fallback: Instance) -> FunctionLike:
Expand Down
8 changes: 8 additions & 0 deletions mypyc/emitclass.py
Expand Up @@ -550,6 +550,14 @@ def generate_methods_table(cl: ClassIR,
flags.append('METH_CLASS')

emitter.emit_line(' {}, NULL}},'.format(' | '.join(flags)))

# Provide a default __getstate__ and __setstate__
if not cl.has_method('__setstate__') and not cl.has_method('__getstate__'):
emitter.emit_lines(
'{"__setstate__", (PyCFunction)CPyPickle_SetState, METH_O, NULL},',
'{"__getstate__", (PyCFunction)CPyPickle_GetState, METH_NOARGS, NULL},',
)

emitter.emit_line('{NULL} /* Sentinel */')
emitter.emit_line('};')

Expand Down
24 changes: 19 additions & 5 deletions mypyc/genops.py
Expand Up @@ -531,9 +531,10 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef,
ir = mapper.type_to_ir[cdef.info]
info = cdef.info

for name, node in info.names.items():
# We sort the table for determinism here on Python 3.5
for name, node in sorted(info.names.items()):
if isinstance(node.node, Var):
assert node.node.type, "Class member missing type"
assert node.node.type, "Class member %s missing type" % name
if not node.node.is_classvar and name != '__slots__':
ir.attributes[name] = mapper.type_to_rtype(node.node.type)
elif isinstance(node.node, (FuncDef, Decorator)):
Expand Down Expand Up @@ -599,9 +600,8 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef,
base_mro.append(base_ir)
mro.append(base_ir)

# Generic and similar are python base classes
if cdef.removed_base_type_exprs:
ir.inherits_python = True
if cls.defn.removed_base_type_exprs or not base_ir.is_ext_class:
ir.inherits_python = True

base_idx = 1 if not ir.is_trait else 0
if len(base_mro) > base_idx:
Expand Down Expand Up @@ -1477,6 +1477,14 @@ def visit_class_def(self, cdef: ClassDef) -> None:
# Set this attribute back to None until the next non-extension class is visited.
self.non_ext_info = None

def create_mypyc_attrs_tuple(self, ir: ClassIR, line: int) -> Value:
attrs = [name for ancestor in ir.mro for name in ancestor.attributes]
if ir.inherits_python:
attrs.append('__dict__')
return self.primitive_op(new_tuple_op,
[self.load_static_unicode(attr) for attr in attrs],
line)

def allocate_class(self, cdef: ClassDef) -> None:
# OK AND NOW THE FUN PART
base_exprs = cdef.base_type_exprs + cdef.removed_base_type_exprs
Expand All @@ -1496,6 +1504,12 @@ def allocate_class(self, cdef: ClassDef) -> None:
FuncDecl(cdef.name + '_trait_vtable_setup',
None, self.module_name,
FuncSignature([], bool_rprimitive)), [], -1))
# Populate a '__mypyc_attrs__' field containing the list of attrs
self.primitive_op(py_setattr_op, [
tp, self.load_static_unicode('__mypyc_attrs__'),
self.create_mypyc_attrs_tuple(self.mapper.type_to_ir[cdef.info], cdef.line)],
cdef.line)

# Save the class
self.add(InitStatic(tp, cdef.name, self.module_name, NAMESPACE_TYPE))

Expand Down
62 changes: 62 additions & 0 deletions mypyc/lib-rt/CPy.h
Expand Up @@ -1333,6 +1333,68 @@ static int CPy_YieldFromErrorHandle(PyObject *iter, PyObject **outp)
return 2;
}

// Support for pickling; reusable getstate and setstate functions
static PyObject *
CPyPickle_SetState(PyObject *obj, PyObject *state)
{
Py_ssize_t pos = 0;
PyObject *key, *value;
while (PyDict_Next(state, &pos, &key, &value)) {
if (PyObject_SetAttr(obj, key, value) != 0) {
return NULL;
}
}
Py_RETURN_NONE;
}

static PyObject *
CPyPickle_GetState(PyObject *obj)
{
PyObject *attrs = NULL, *state = NULL;

attrs = PyObject_GetAttrString((PyObject *)Py_TYPE(obj), "__mypyc_attrs__");
if (!attrs) {
goto fail;
}
if (!PyTuple_Check(attrs)) {
PyErr_SetString(PyExc_TypeError, "__mypyc_attrs__ is not a tuple");
goto fail;
}
state = PyDict_New();
if (!state) {
goto fail;
}

// Collect all the values of attributes in __mypyc_attrs__
// Attributes that are missing we just ignore
int i;
for (i = 0; i < PyTuple_GET_SIZE(attrs); i++) {
PyObject *key = PyTuple_GET_ITEM(attrs, i);
PyObject *value = PyObject_GetAttr(obj, key);
if (!value) {
if (PyErr_ExceptionMatches(PyExc_AttributeError)) {
PyErr_Clear();
continue;
}
goto fail;
}
int result = PyDict_SetItem(state, key, value);
Py_DECREF(value);
if (result != 0) {
goto fail;
}
}

Py_DECREF(attrs);

return state;
fail:
Py_XDECREF(attrs);
Py_XDECREF(state);
return NULL;
}


int CPyArg_ParseTupleAndKeywords(PyObject *, PyObject *,
const char *, char **, ...);

Expand Down
6 changes: 3 additions & 3 deletions mypyc/test-data/genops-basic.test
Expand Up @@ -1588,11 +1588,11 @@ def g(a):
r4 :: str
r5, r6 :: None
L0:
r0 = unicode_3 :: static ('a')
r0 = unicode_4 :: static ('a')
r1 = 0
r2 = a.f(r1, r0)
r3 = 1
r4 = unicode_4 :: static ('b')
r4 = unicode_5 :: static ('b')
r5 = a.f(r3, r4)
r6 = None
return r6
Expand Down Expand Up @@ -1828,7 +1828,7 @@ L1:
L2:
if is_error(z) goto L3 else goto L4
L3:
r1 = unicode_3 :: static ('test')
r1 = unicode_4 :: static ('test')
z = r1
L4:
r2 = None
Expand Down
134 changes: 76 additions & 58 deletions mypyc/test-data/genops-classes.test
Expand Up @@ -94,7 +94,7 @@ def g(a):
r3 :: None
L0:
r0 = 1
r1 = unicode_3 :: static ('hi')
r1 = unicode_4 :: static ('hi')
r2 = a.f(r0, r1)
r3 = None
return r3
Expand Down Expand Up @@ -360,31 +360,40 @@ def __top_level__():
r40 :: str
r41, r42 :: object
r43 :: bool
r44 :: dict
r45 :: str
r44 :: str
r45 :: tuple
r46 :: bool
r47 :: object
r47 :: dict
r48 :: str
r49, r50 :: object
r51 :: bool
r52 :: dict
r53 :: str
r49 :: bool
r50 :: object
r51 :: str
r52, r53 :: object
r54 :: bool
r55, r56 :: object
r57 :: dict
r58 :: str
r59 :: object
r60 :: dict
r61 :: str
r62, r63 :: object
r64 :: tuple
r65 :: str
r66, r67 :: object
r68 :: bool
r69 :: dict
r70 :: str
r71 :: bool
r72 :: None
r55 :: str
r56 :: tuple
r57 :: bool
r58 :: dict
r59 :: str
r60 :: bool
r61, r62 :: object
r63 :: dict
r64 :: str
r65 :: object
r66 :: dict
r67 :: str
r68, r69 :: object
r70 :: tuple
r71 :: str
r72, r73 :: object
r74 :: bool
r75, r76 :: str
r77 :: tuple
r78 :: bool
r79 :: dict
r80 :: str
r81 :: bool
r82 :: None
L0:
r0 = builtins.module :: static
r1 = builtins.None :: object
Expand Down Expand Up @@ -442,39 +451,49 @@ L6:
r41 = __main__.C_template :: type
r42 = pytype_from_template(r41, r39, r40)
r43 = C_trait_vtable_setup()
r44 = unicode_8 :: static ('__mypyc_attrs__')
r45 = () :: tuple
r46 = setattr r42, r44, r45
__main__.C = r42 :: type
r44 = __main__.globals :: static
r45 = unicode_8 :: static ('C')
r46 = r44.__setitem__(r45, r42) :: dict
r47 = <error> :: object
r48 = unicode_7 :: static ('__main__')
r49 = __main__.S_template :: type
r50 = pytype_from_template(r49, r47, r48)
r51 = S_trait_vtable_setup()
__main__.S = r50 :: type
r52 = __main__.globals :: static
r53 = unicode_9 :: static ('S')
r54 = r52.__setitem__(r53, r50) :: dict
r55 = __main__.C :: type
r56 = __main__.S :: type
r57 = __main__.globals :: static
r58 = unicode_3 :: static ('Generic')
r59 = r57[r58] :: dict
r60 = __main__.globals :: static
r61 = unicode_6 :: static ('T')
r62 = r60[r61] :: dict
r63 = r59[r62] :: object
r64 = (r55, r56, r63) :: tuple
r65 = unicode_7 :: static ('__main__')
r66 = __main__.D_template :: type
r67 = pytype_from_template(r66, r64, r65)
r68 = D_trait_vtable_setup()
__main__.D = r67 :: type
r69 = __main__.globals :: static
r70 = unicode_10 :: static ('D')
r71 = r69.__setitem__(r70, r67) :: dict
r72 = None
return r72
r47 = __main__.globals :: static
r48 = unicode_9 :: static ('C')
r49 = r47.__setitem__(r48, r42) :: dict
r50 = <error> :: object
r51 = unicode_7 :: static ('__main__')
r52 = __main__.S_template :: type
r53 = pytype_from_template(r52, r50, r51)
r54 = S_trait_vtable_setup()
r55 = unicode_8 :: static ('__mypyc_attrs__')
r56 = () :: tuple
r57 = setattr r53, r55, r56
__main__.S = r53 :: type
r58 = __main__.globals :: static
r59 = unicode_10 :: static ('S')
r60 = r58.__setitem__(r59, r53) :: dict
r61 = __main__.C :: type
r62 = __main__.S :: type
r63 = __main__.globals :: static
r64 = unicode_3 :: static ('Generic')
r65 = r63[r64] :: dict
r66 = __main__.globals :: static
r67 = unicode_6 :: static ('T')
r68 = r66[r67] :: dict
r69 = r65[r68] :: object
r70 = (r61, r62, r69) :: tuple
r71 = unicode_7 :: static ('__main__')
r72 = __main__.D_template :: type
r73 = pytype_from_template(r72, r70, r71)
r74 = D_trait_vtable_setup()
r75 = unicode_8 :: static ('__mypyc_attrs__')
r76 = unicode_11 :: static ('__dict__')
r77 = (r76) :: tuple
r78 = setattr r73, r75, r77
__main__.D = r73 :: type
r79 = __main__.globals :: static
r80 = unicode_12 :: static ('D')
r81 = r79.__setitem__(r80, r73) :: dict
r82 = None
return r82

[case testIsInstance]
class A: pass
Expand Down Expand Up @@ -785,12 +804,11 @@ def f():
r3 :: int
L0:
r0 = __main__.A :: type
r1 = unicode_5 :: static ('x')
r1 = unicode_6 :: static ('x')
r2 = getattr r0, r1
r3 = unbox(int, r2)
return r3


[case testNoEqDefined]
class A:
pass
Expand Down Expand Up @@ -1014,7 +1032,7 @@ L0:
r0 = 10
__mypyc_self__.x = r0; r1 = is_error
r2 = __main__.globals :: static
r3 = unicode_7 :: static ('LOL')
r3 = unicode_9 :: static ('LOL')
r4 = r2[r3] :: dict
r5 = cast(str, r4)
__mypyc_self__.y = r5; r6 = is_error
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/genops-optional.test
Expand Up @@ -384,7 +384,7 @@ def set(o, s):
r1 :: bool
r2 :: None
L0:
r0 = unicode_6 :: static ('a')
r0 = unicode_5 :: static ('a')
r1 = setattr o, r0, s
r2 = None
return r2
Expand Down