Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pickling Record-s #1000

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion asyncpg/protocol/protocol.pyx
Expand Up @@ -1024,4 +1024,4 @@ def _create_record(object mapping, tuple elems):
return rec


Record = <object>record.ApgRecord_InitTypes()
Record, RecordDescriptor = record.ApgRecord_InitTypes()
2 changes: 1 addition & 1 deletion asyncpg/protocol/record/__init__.pxd
Expand Up @@ -10,7 +10,7 @@ cimport cpython

cdef extern from "record/recordobj.h":

cpython.PyTypeObject *ApgRecord_InitTypes() except NULL
tuple ApgRecord_InitTypes()

int ApgRecord_CheckExact(object)
object ApgRecord_New(type, object, int)
Expand Down
169 changes: 142 additions & 27 deletions asyncpg/protocol/record/recordobj.c
Expand Up @@ -20,6 +20,9 @@ static PyObject * record_new_items_iter(PyObject *);
static ApgRecordObject *free_list[ApgRecord_MAXSAVESIZE];
static int numfree[ApgRecord_MAXSAVESIZE];

static PyObject *record_reconstruct_obj;
static PyObject *record_desc_reconstruct_obj;

static size_t MAX_RECORD_SIZE = (
((size_t)PY_SSIZE_T_MAX - sizeof(ApgRecordObject) - sizeof(PyObject *))
/ sizeof(PyObject *)
Expand Down Expand Up @@ -575,14 +578,14 @@ record_repr(ApgRecordObject *v)


static PyObject *
record_values(PyObject *o, PyObject *args)
record_values(PyObject *o, PyObject *Py_UNUSED(unused))
{
return record_iter(o);
}


static PyObject *
record_keys(PyObject *o, PyObject *args)
record_keys(PyObject *o, PyObject *Py_UNUSED(unused))
{
if (!ApgRecord_Check(o)) {
PyErr_BadInternalCall();
Expand All @@ -594,7 +597,7 @@ record_keys(PyObject *o, PyObject *args)


static PyObject *
record_items(PyObject *o, PyObject *args)
record_items(PyObject *o, PyObject *Py_UNUSED(unused))
{
if (!ApgRecord_Check(o)) {
PyErr_BadInternalCall();
Expand Down Expand Up @@ -658,11 +661,69 @@ static PyMappingMethods record_as_mapping = {
};


static PyObject *
record_reduce(ApgRecordObject *o, PyObject *Py_UNUSED(unused))
{
PyObject *value = PyTuple_New(2);
if (value == NULL) {
return NULL;
}
Py_ssize_t len = Py_SIZE(o);
PyObject *state = PyTuple_New(1 + len);
if (state == NULL) {
Py_DECREF(value);
return NULL;
}
PyTuple_SET_ITEM(value, 0, record_reconstruct_obj);
Py_INCREF(record_reconstruct_obj);
PyTuple_SET_ITEM(value, 1, state);
PyTuple_SET_ITEM(state, 0, (PyObject *)o->desc);
Py_INCREF(o->desc);
for (Py_ssize_t i = 0; i < len; i++) {
PyObject *item = ApgRecord_GET_ITEM(o, i);
PyTuple_SET_ITEM(state, i + 1, item);
Py_INCREF(item);
}
return value;
}

static PyObject *
record_reconstruct(PyObject *Py_UNUSED(unused), PyObject *args)
{
if (!PyTuple_CheckExact(args)) {
return NULL;
}
Py_ssize_t len = PyTuple_GET_SIZE(args);
if (len < 2) {
return NULL;
}
len--;
ApgRecordDescObject *desc = (ApgRecordDescObject *)PyTuple_GET_ITEM(args, 0);
if (!ApgRecordDesc_CheckExact(desc)) {
return NULL;
}
if (PyObject_Length(desc->mapping) != len) {
return NULL;
}
PyObject *record = ApgRecord_New(&ApgRecord_Type, (PyObject *)desc, len);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would lose the original record subtype if one was used, so you need to pickle the class reference also.

if (record == NULL) {
return NULL;
}
for (Py_ssize_t i = 0; i < len; i++) {
PyObject *item = PyTuple_GET_ITEM(args, i + 1);
ApgRecord_SET_ITEM(record, i, item);
Py_INCREF(item);
}
return record;
}

static PyMethodDef record_methods[] = {
{"values", (PyCFunction)record_values, METH_NOARGS},
{"keys", (PyCFunction)record_keys, METH_NOARGS},
{"items", (PyCFunction)record_items, METH_NOARGS},
{"get", (PyCFunction)record_get, METH_VARARGS},
{"__reduce__", (PyCFunction)record_reduce, METH_NOARGS},
{"__reconstruct__", (PyCFunction)record_reconstruct, METH_VARARGS | METH_STATIC},
{NULL, NULL} /* sentinel */
};

Expand Down Expand Up @@ -942,29 +1003,6 @@ record_new_items_iter(PyObject *seq)
}


PyTypeObject *
ApgRecord_InitTypes(void)
{
if (PyType_Ready(&ApgRecord_Type) < 0) {
return NULL;
}

if (PyType_Ready(&ApgRecordDesc_Type) < 0) {
return NULL;
}

if (PyType_Ready(&ApgRecordIter_Type) < 0) {
return NULL;
}

if (PyType_Ready(&ApgRecordItems_Type) < 0) {
return NULL;
}

return &ApgRecord_Type;
}


/* ----------------- */


Expand All @@ -987,15 +1025,54 @@ record_desc_traverse(ApgRecordDescObject *o, visitproc visit, void *arg)
}


static PyObject *record_desc_reduce(ApgRecordDescObject *o, PyObject *Py_UNUSED(unused))
{
PyObject *value = PyTuple_New(2);
if (value == NULL) {
return NULL;
}
PyObject *state = PyTuple_New(2);
if (state == NULL) {
Py_DECREF(value);
return NULL;
}
PyTuple_SET_ITEM(value, 0, record_desc_reconstruct_obj);
Py_INCREF(record_desc_reconstruct_obj);
PyTuple_SET_ITEM(value, 1, state);
PyTuple_SET_ITEM(state, 0, o->mapping);
Py_INCREF(o->mapping);
PyTuple_SET_ITEM(state, 1, o->keys);
Py_INCREF(o->keys);
return value;
}


static PyObject *record_desc_reconstruct(PyObject *Py_UNUSED(unused), PyObject *args)
{
if (PyTuple_GET_SIZE(args) != 2) {
return NULL;
}
return ApgRecordDesc_New(PyTuple_GET_ITEM(args, 0), PyTuple_GET_ITEM(args, 1));
}


static PyMethodDef record_desc_methods[] = {
{"__reduce__", (PyCFunction)record_desc_reduce, METH_NOARGS},
{"__reconstruct__", (PyCFunction)record_desc_reconstruct, METH_VARARGS | METH_STATIC},
{NULL, NULL} /* sentinel */
};


PyTypeObject ApgRecordDesc_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "RecordDescriptor",
.tp_name = "asyncpg.protocol.protocol.RecordDescriptor",
.tp_basicsize = sizeof(ApgRecordDescObject),
.tp_dealloc = (destructor)record_desc_dealloc,
.tp_getattro = PyObject_GenericGetAttr,
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
.tp_traverse = (traverseproc)record_desc_traverse,
.tp_iter = PyObject_SelfIter,
.tp_methods = record_desc_methods,
};


Expand Down Expand Up @@ -1023,3 +1100,41 @@ ApgRecordDesc_New(PyObject *mapping, PyObject *keys)
PyObject_GC_Track(o);
return (PyObject *) o;
}


PyObject *
ApgRecord_InitTypes(void)
{
if (PyType_Ready(&ApgRecord_Type) < 0) {
return NULL;
}

if (PyType_Ready(&ApgRecordDesc_Type) < 0) {
return NULL;
}

if (PyType_Ready(&ApgRecordIter_Type) < 0) {
return NULL;
}

if (PyType_Ready(&ApgRecordItems_Type) < 0) {
return NULL;
}

record_reconstruct_obj = PyCFunction_New(
&record_methods[5], (PyObject *)&ApgRecord_Type
);
record_desc_reconstruct_obj = PyCFunction_New(
&record_desc_methods[1], (PyObject *)&ApgRecordDesc_Type
);

PyObject *types = PyTuple_New(2);
if (types == NULL) {
return NULL;
}
PyTuple_SET_ITEM(types, 0, (PyObject *)&ApgRecord_Type);
Py_INCREF(&ApgRecord_Type);
PyTuple_SET_ITEM(types, 1, (PyObject *)&ApgRecordDesc_Type);
Py_INCREF(&ApgRecordDesc_Type);
return types;
}
2 changes: 1 addition & 1 deletion asyncpg/protocol/record/recordobj.h
Expand Up @@ -46,7 +46,7 @@ extern PyTypeObject ApgRecordDesc_Type;
#define ApgRecord_GET_ITEM(op, i) \
(((ApgRecordObject *)(op))->ob_item[i])

PyTypeObject *ApgRecord_InitTypes(void);
PyObject *ApgRecord_InitTypes(void);
PyObject *ApgRecord_New(PyTypeObject *, PyObject *, Py_ssize_t);
PyObject *ApgRecordDesc_New(PyObject *, PyObject *);

Expand Down
13 changes: 8 additions & 5 deletions tests/test_record.py
Expand Up @@ -287,11 +287,6 @@ def test_record_get(self):
self.assertEqual(r.get('nonexistent'), None)
self.assertEqual(r.get('nonexistent', 'default'), 'default')

def test_record_not_pickleable(self):
r = Record(R_A, (42,))
with self.assertRaises(Exception):
pickle.dumps(r)

def test_record_empty(self):
r = Record(None, ())
self.assertEqual(r, ())
Expand Down Expand Up @@ -575,3 +570,11 @@ class MyRecordBad:
'record_class is expected to be a subclass of asyncpg.Record',
):
await self.connect(record_class=MyRecordBad)

def test_record_pickle(self):
Copy link

@frake23 frake23 Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i suggest to write a test of pickling nested records

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Postgres doesn't support returning nested records. asyncpg doesn't support nested records anywhere. I don't think it's a good idea.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nested composite types are returned as nested Record instances, e.g:

import asyncio
import asyncpg

async def main():
    conn = await asyncpg.connect()
    await conn.execute('CREATE TYPE complex AS (r float, imag float)')
    print(await conn.fetchrow("SELECT 1, '2', (3, 4)::complex"))

asyncio.run(main())

r = pickle.loads(pickle.dumps(Record(R_AB, (42, 43))))
self.assertEqual(len(r), 2)
self.assertEqual(r[0], 42)
self.assertEqual(r[1], 43)
self.assertEqual(r['a'], 42)
self.assertEqual(r['b'], 43)