Skip to content

Commit

Permalink
Add an default keyword argument to dumps
Browse files Browse the repository at this point in the history
dump and dumps functions in python json stdlib have a default keyword argument.
It's useful for serializing complex objects. Supporting this argument will improve compatibility and flexibility of ujson.
  • Loading branch information
garenchan committed Aug 31, 2021
1 parent 8f5ad61 commit ff9d82e
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 4 deletions.
56 changes: 52 additions & 4 deletions python/objToJSON.c
Expand Up @@ -72,6 +72,12 @@ typedef struct __TypeContext

#define GET_TC(__ptrtc) ((TypeContext *)((__ptrtc)->prv))

// If newObj is set, we should use it rather than JSOBJ
#define GET_OBJ(__jsobj, __ptrtc) (GET_TC(__ptrtc)->newObj ? GET_TC(__ptrtc)->newObj : __jsobj)

// Avoid infinite loop caused by the default function
#define DEFAULT_FN_MAX_DEPTH 3

struct PyDictIterState
{
PyObject *keys;
Expand Down Expand Up @@ -432,7 +438,8 @@ static void SetupDictIter(PyObject *dictObj, TypeContext *pc, JSONObjectEncoder

static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObjectEncoder *enc)
{
PyObject *obj, *objRepr, *exc;
PyObject *obj, *objRepr, *exc, *defaultFn, *newObj;
int level = 0;
TypeContext *pc;
PRINTMARK();
if (!_obj)
Expand All @@ -442,6 +449,7 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
}

obj = (PyObject*) _obj;
defaultFn = (PyObject*) enc->prv;

tc->prv = PyObject_Malloc(sizeof(TypeContext));
pc = (TypeContext *) tc->prv;
Expand All @@ -462,6 +470,7 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
pc->longValue = 0;
pc->rawJSONValue = NULL;

BEGIN:
if (PyIter_Check(obj))
{
PRINTMARK();
Expand Down Expand Up @@ -553,7 +562,6 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
return;
}


ISITERABLE:
if (PyDict_Check(obj))
{
Expand Down Expand Up @@ -651,6 +659,30 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
return;
}

DEFAULT:
if (defaultFn)
{
// Break infinite loop
if (level >= DEFAULT_FN_MAX_DEPTH)
{
PRINTMARK();
PyErr_Format(PyExc_TypeError, "maximum recursion depth exceeded");
goto INVALID;
}

newObj = PyObject_CallFunctionObjArgs(defaultFn, obj, NULL);
if (newObj)
{
PRINTMARK();
obj = pc->newObj = newObj;
level += 1;
goto BEGIN;
}
else{
goto INVALID;
}
}

PRINTMARK();
PyErr_Clear();

Expand Down Expand Up @@ -682,33 +714,38 @@ static void Object_endTypeContext(JSOBJ obj, JSONTypeContext *tc)

static const char *Object_getStringValue(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->PyTypeToJSON (obj, tc, NULL, _outLen);
}

static JSINT64 Object_getLongValue(JSOBJ obj, JSONTypeContext *tc)
{
JSINT64 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}

static JSUINT64 Object_getUnsignedLongValue(JSOBJ obj, JSONTypeContext *tc)
{
JSUINT64 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}

static JSINT32 Object_getIntValue(JSOBJ obj, JSONTypeContext *tc)
{
JSINT32 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}

static double Object_getDoubleValue(JSOBJ obj, JSONTypeContext *tc)
{
double ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}
Expand All @@ -720,27 +757,31 @@ static void Object_releaseObject(JSOBJ _obj)

static int Object_iterNext(JSOBJ obj, JSONTypeContext *tc)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterNext(obj, tc);
}

static void Object_iterEnd(JSOBJ obj, JSONTypeContext *tc)
{
obj = GET_OBJ(obj, tc);
GET_TC(tc)->iterEnd(obj, tc);
}

static JSOBJ Object_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterGetValue(obj, tc);
}

static char *Object_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterGetName(obj, tc, outLen);
}

PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
{
static char *kwlist[] = { "obj", "ensure_ascii", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", "allow_nan", "reject_bytes", NULL };
static char *kwlist[] = { "obj", "ensure_ascii", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", "allow_nan", "reject_bytes", "default", NULL };

char buffer[65536];
char *ret;
Expand All @@ -751,6 +792,7 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
PyObject *oencodeHTMLChars = NULL;
PyObject *oescapeForwardSlashes = NULL;
PyObject *osortKeys = NULL;
PyObject *odefaultFn = NULL;
int allowNan = -1;
int orejectBytes = -1;

Expand Down Expand Up @@ -785,7 +827,7 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)

PRINTMARK();

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOOOiii", kwlist, &oinput, &oensureAscii, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent, &allowNan, &orejectBytes))
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOOOiiiO", kwlist, &oinput, &oensureAscii, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent, &allowNan, &orejectBytes, &odefaultFn))
{
return NULL;
}
Expand Down Expand Up @@ -815,6 +857,12 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
encoder.allowNan = allowNan;
}

if (odefaultFn != NULL && odefaultFn != Py_None)
{
// Here use prv to store default function
encoder.prv = odefaultFn;
}

if (encoder.allowNan)
{
csInf = "Inf";
Expand Down
23 changes: 23 additions & 0 deletions tests/test_ujson.py
@@ -1,9 +1,11 @@
import datetime
import decimal
import io
import json
import math
import re
import sys
import uuid
from collections import OrderedDict

import pytest
Expand Down Expand Up @@ -828,6 +830,27 @@ def test_encode_none_key():
assert ujson.dumps(data) == '{"null":null}'


def test_default_function():
iso8601_time_format = '%Y-%m-%dT%H:%M:%S.%f'

def default(value):
if isinstance(value, datetime.datetime):
return value.strftime(iso8601_time_format)
elif isinstance(value, uuid.UUID):
return value.hex
return value

dt = datetime.datetime.now()
assert ujson.dumps(dt, default=default) == '"%s"' % dt.strftime(iso8601_time_format)

uuid4 = uuid.uuid4()
assert ujson.dumps(uuid4, default=default) == '"%s"' % uuid4.hex

some_obj = SomeObject()
with pytest.raises(TypeError, match="maximum recursion depth exceeded"):
ujson.dumps(some_obj, default=default)


"""
def test_decode_numeric_int_frc_overflow():
input = "X.Y"
Expand Down

0 comments on commit ff9d82e

Please sign in to comment.