From b2066a79a997e41c6d7eeb7f371d13d11e687413 Mon Sep 17 00:00:00 2001 From: garenchan Date: Tue, 31 Aug 2021 17:11:25 +0800 Subject: [PATCH] Add a default keyword argument to dumps dump and dumps functions in python json stdlib have a default keyword argument. It's useful for serializing complex objects. Supporting this argument will improve compatibility and flexibility of ujson. --- python/objToJSON.c | 57 +++++++++++++++++++++++++++++++++++++++++---- tests/test_ujson.py | 35 ++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 4 deletions(-) diff --git a/python/objToJSON.c b/python/objToJSON.c index 1297856d..3aae152a 100644 --- a/python/objToJSON.c +++ b/python/objToJSON.c @@ -72,6 +72,12 @@ typedef struct __TypeContext #define GET_TC(__ptrtc) ((TypeContext *)((__ptrtc)->prv)) +// If newObj is set, we should use it rather than JSOBJ +#define GET_OBJ(__jsobj, __ptrtc) (GET_TC(__ptrtc)->newObj ? GET_TC(__ptrtc)->newObj : __jsobj) + +// Avoid infinite loop caused by the default function +#define DEFAULT_FN_MAX_DEPTH 3 + struct PyDictIterState { PyObject *keys; @@ -432,7 +438,8 @@ static void SetupDictIter(PyObject *dictObj, TypeContext *pc, JSONObjectEncoder static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObjectEncoder *enc) { - PyObject *obj, *objRepr, *exc; + PyObject *obj, *objRepr, *exc, *defaultFn, *newObj; + int level = 0; TypeContext *pc; PRINTMARK(); if (!_obj) @@ -442,6 +449,7 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject } obj = (PyObject*) _obj; + defaultFn = (PyObject*) enc->prv; tc->prv = PyObject_Malloc(sizeof(TypeContext)); pc = (TypeContext *) tc->prv; @@ -462,6 +470,7 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject pc->longValue = 0; pc->rawJSONValue = NULL; +BEGIN: if (PyIter_Check(obj)) { PRINTMARK(); @@ -553,7 +562,6 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject return; } - ISITERABLE: if (PyDict_Check(obj)) { @@ -651,6 +659,31 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject return; } +DEFAULT: + if (defaultFn) + { + // Break infinite loop + if (level >= DEFAULT_FN_MAX_DEPTH) + { + PRINTMARK(); + PyErr_Format(PyExc_TypeError, "maximum recursion depth exceeded"); + goto INVALID; + } + + newObj = PyObject_CallFunctionObjArgs(defaultFn, obj, NULL); + if (newObj) + { + PRINTMARK(); + obj = pc->newObj = newObj; + level += 1; + goto BEGIN; + } + else + { + goto INVALID; + } + } + PRINTMARK(); PyErr_Clear(); @@ -682,12 +715,14 @@ static void Object_endTypeContext(JSOBJ obj, JSONTypeContext *tc) static const char *Object_getStringValue(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen) { + obj = GET_OBJ(obj, tc); return GET_TC(tc)->PyTypeToJSON (obj, tc, NULL, _outLen); } static JSINT64 Object_getLongValue(JSOBJ obj, JSONTypeContext *tc) { JSINT64 ret; + obj = GET_OBJ(obj, tc); GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); return ret; } @@ -695,6 +730,7 @@ static JSINT64 Object_getLongValue(JSOBJ obj, JSONTypeContext *tc) static JSUINT64 Object_getUnsignedLongValue(JSOBJ obj, JSONTypeContext *tc) { JSUINT64 ret; + obj = GET_OBJ(obj, tc); GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); return ret; } @@ -702,6 +738,7 @@ static JSUINT64 Object_getUnsignedLongValue(JSOBJ obj, JSONTypeContext *tc) static JSINT32 Object_getIntValue(JSOBJ obj, JSONTypeContext *tc) { JSINT32 ret; + obj = GET_OBJ(obj, tc); GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); return ret; } @@ -709,6 +746,7 @@ static JSINT32 Object_getIntValue(JSOBJ obj, JSONTypeContext *tc) static double Object_getDoubleValue(JSOBJ obj, JSONTypeContext *tc) { double ret; + obj = GET_OBJ(obj, tc); GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); return ret; } @@ -720,27 +758,31 @@ static void Object_releaseObject(JSOBJ _obj) static int Object_iterNext(JSOBJ obj, JSONTypeContext *tc) { + obj = GET_OBJ(obj, tc); return GET_TC(tc)->iterNext(obj, tc); } static void Object_iterEnd(JSOBJ obj, JSONTypeContext *tc) { + obj = GET_OBJ(obj, tc); GET_TC(tc)->iterEnd(obj, tc); } static JSOBJ Object_iterGetValue(JSOBJ obj, JSONTypeContext *tc) { + obj = GET_OBJ(obj, tc); return GET_TC(tc)->iterGetValue(obj, tc); } static char *Object_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen) { + obj = GET_OBJ(obj, tc); return GET_TC(tc)->iterGetName(obj, tc, outLen); } PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs) { - static char *kwlist[] = { "obj", "ensure_ascii", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", "allow_nan", "reject_bytes", NULL }; + static char *kwlist[] = { "obj", "ensure_ascii", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", "allow_nan", "reject_bytes", "default", NULL }; char buffer[65536]; char *ret; @@ -751,6 +793,7 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs) PyObject *oencodeHTMLChars = NULL; PyObject *oescapeForwardSlashes = NULL; PyObject *osortKeys = NULL; + PyObject *odefaultFn = NULL; int allowNan = -1; int orejectBytes = -1; @@ -785,7 +828,7 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs) PRINTMARK(); - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOOOiii", kwlist, &oinput, &oensureAscii, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent, &allowNan, &orejectBytes)) + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOOOiiiO", kwlist, &oinput, &oensureAscii, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent, &allowNan, &orejectBytes, &odefaultFn)) { return NULL; } @@ -815,6 +858,12 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs) encoder.allowNan = allowNan; } + if (odefaultFn != NULL && odefaultFn != Py_None) + { + // Here use prv to store default function + encoder.prv = odefaultFn; + } + if (encoder.allowNan) { csInf = "Inf"; diff --git a/tests/test_ujson.py b/tests/test_ujson.py index 2a0d5af7..920fa43e 100644 --- a/tests/test_ujson.py +++ b/tests/test_ujson.py @@ -1,9 +1,11 @@ +import datetime import decimal import io import json import math import re import sys +import uuid from collections import OrderedDict import pytest @@ -828,6 +830,39 @@ def test_encode_none_key(): assert ujson.dumps(data) == '{"null":null}' +def test_default_function(): + iso8601_time_format = "%Y-%m-%dT%H:%M:%S.%f" + + class CustomObject: + pass + + class UnjsonableObject: + pass + + def default(value): + if isinstance(value, datetime.datetime): + return value.strftime(iso8601_time_format) + elif isinstance(value, uuid.UUID): + return value.hex + elif isinstance(value, CustomObject): + raise ValueError("invalid value") + return value + + dt = datetime.datetime.now() + assert ujson.dumps(dt, default=default) == '"%s"' % dt.strftime(iso8601_time_format) + + uuid4 = uuid.uuid4() + assert ujson.dumps(uuid4, default=default) == '"%s"' % uuid4.hex + + custom_obj = CustomObject() + with pytest.raises(ValueError, match="invalid value"): + ujson.dumps(custom_obj, default=default) + + unjsonable_obj = UnjsonableObject() + with pytest.raises(TypeError, match="maximum recursion depth exceeded"): + ujson.dumps(unjsonable_obj, default=default) + + """ def test_decode_numeric_int_frc_overflow(): input = "X.Y"