Skip to content

Commit

Permalink
Add a default keyword argument to dumps
Browse files Browse the repository at this point in the history
dump and dumps functions in python json stdlib have a default keyword argument.
It's useful for serializing complex objects. Supporting this argument will improve compatibility and flexibility of ujson.
  • Loading branch information
garenchan committed Sep 1, 2021
1 parent 8f5ad61 commit b2066a7
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 4 deletions.
57 changes: 53 additions & 4 deletions python/objToJSON.c
Expand Up @@ -72,6 +72,12 @@ typedef struct __TypeContext

#define GET_TC(__ptrtc) ((TypeContext *)((__ptrtc)->prv))

// If newObj is set, we should use it rather than JSOBJ
#define GET_OBJ(__jsobj, __ptrtc) (GET_TC(__ptrtc)->newObj ? GET_TC(__ptrtc)->newObj : __jsobj)

// Avoid infinite loop caused by the default function
#define DEFAULT_FN_MAX_DEPTH 3

struct PyDictIterState
{
PyObject *keys;
Expand Down Expand Up @@ -432,7 +438,8 @@ static void SetupDictIter(PyObject *dictObj, TypeContext *pc, JSONObjectEncoder

static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObjectEncoder *enc)
{
PyObject *obj, *objRepr, *exc;
PyObject *obj, *objRepr, *exc, *defaultFn, *newObj;
int level = 0;
TypeContext *pc;
PRINTMARK();
if (!_obj)
Expand All @@ -442,6 +449,7 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
}

obj = (PyObject*) _obj;
defaultFn = (PyObject*) enc->prv;

tc->prv = PyObject_Malloc(sizeof(TypeContext));
pc = (TypeContext *) tc->prv;
Expand All @@ -462,6 +470,7 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
pc->longValue = 0;
pc->rawJSONValue = NULL;

BEGIN:
if (PyIter_Check(obj))
{
PRINTMARK();
Expand Down Expand Up @@ -553,7 +562,6 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
return;
}


ISITERABLE:
if (PyDict_Check(obj))
{
Expand Down Expand Up @@ -651,6 +659,31 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
return;
}

DEFAULT:
if (defaultFn)
{
// Break infinite loop
if (level >= DEFAULT_FN_MAX_DEPTH)
{
PRINTMARK();
PyErr_Format(PyExc_TypeError, "maximum recursion depth exceeded");
goto INVALID;
}

newObj = PyObject_CallFunctionObjArgs(defaultFn, obj, NULL);
if (newObj)
{
PRINTMARK();
obj = pc->newObj = newObj;
level += 1;
goto BEGIN;
}
else
{
goto INVALID;
}
}

PRINTMARK();
PyErr_Clear();

Expand Down Expand Up @@ -682,33 +715,38 @@ static void Object_endTypeContext(JSOBJ obj, JSONTypeContext *tc)

static const char *Object_getStringValue(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->PyTypeToJSON (obj, tc, NULL, _outLen);
}

static JSINT64 Object_getLongValue(JSOBJ obj, JSONTypeContext *tc)
{
JSINT64 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}

static JSUINT64 Object_getUnsignedLongValue(JSOBJ obj, JSONTypeContext *tc)
{
JSUINT64 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}

static JSINT32 Object_getIntValue(JSOBJ obj, JSONTypeContext *tc)
{
JSINT32 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}

static double Object_getDoubleValue(JSOBJ obj, JSONTypeContext *tc)
{
double ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}
Expand All @@ -720,27 +758,31 @@ static void Object_releaseObject(JSOBJ _obj)

static int Object_iterNext(JSOBJ obj, JSONTypeContext *tc)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterNext(obj, tc);
}

static void Object_iterEnd(JSOBJ obj, JSONTypeContext *tc)
{
obj = GET_OBJ(obj, tc);
GET_TC(tc)->iterEnd(obj, tc);
}

static JSOBJ Object_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterGetValue(obj, tc);
}

static char *Object_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterGetName(obj, tc, outLen);
}

PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
{
static char *kwlist[] = { "obj", "ensure_ascii", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", "allow_nan", "reject_bytes", NULL };
static char *kwlist[] = { "obj", "ensure_ascii", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", "allow_nan", "reject_bytes", "default", NULL };

char buffer[65536];
char *ret;
Expand All @@ -751,6 +793,7 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
PyObject *oencodeHTMLChars = NULL;
PyObject *oescapeForwardSlashes = NULL;
PyObject *osortKeys = NULL;
PyObject *odefaultFn = NULL;
int allowNan = -1;
int orejectBytes = -1;

Expand Down Expand Up @@ -785,7 +828,7 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)

PRINTMARK();

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOOOiii", kwlist, &oinput, &oensureAscii, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent, &allowNan, &orejectBytes))
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOOOiiiO", kwlist, &oinput, &oensureAscii, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent, &allowNan, &orejectBytes, &odefaultFn))
{
return NULL;
}
Expand Down Expand Up @@ -815,6 +858,12 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
encoder.allowNan = allowNan;
}

if (odefaultFn != NULL && odefaultFn != Py_None)
{
// Here use prv to store default function
encoder.prv = odefaultFn;
}

if (encoder.allowNan)
{
csInf = "Inf";
Expand Down
35 changes: 35 additions & 0 deletions tests/test_ujson.py
@@ -1,9 +1,11 @@
import datetime
import decimal
import io
import json
import math
import re
import sys
import uuid
from collections import OrderedDict

import pytest
Expand Down Expand Up @@ -828,6 +830,39 @@ def test_encode_none_key():
assert ujson.dumps(data) == '{"null":null}'


def test_default_function():
iso8601_time_format = "%Y-%m-%dT%H:%M:%S.%f"

class CustomObject:
pass

class UnjsonableObject:
pass

def default(value):
if isinstance(value, datetime.datetime):
return value.strftime(iso8601_time_format)
elif isinstance(value, uuid.UUID):
return value.hex
elif isinstance(value, CustomObject):
raise ValueError("invalid value")
return value

dt = datetime.datetime.now()
assert ujson.dumps(dt, default=default) == '"%s"' % dt.strftime(iso8601_time_format)

uuid4 = uuid.uuid4()
assert ujson.dumps(uuid4, default=default) == '"%s"' % uuid4.hex

custom_obj = CustomObject()
with pytest.raises(ValueError, match="invalid value"):
ujson.dumps(custom_obj, default=default)

unjsonable_obj = UnjsonableObject()
with pytest.raises(TypeError, match="maximum recursion depth exceeded"):
ujson.dumps(unjsonable_obj, default=default)


"""
def test_decode_numeric_int_frc_overflow():
input = "X.Y"
Expand Down

0 comments on commit b2066a7

Please sign in to comment.