Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a default keyword argument to dumps #470

Merged
merged 1 commit into from Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 53 additions & 4 deletions python/objToJSON.c
Expand Up @@ -72,6 +72,12 @@ typedef struct __TypeContext

#define GET_TC(__ptrtc) ((TypeContext *)((__ptrtc)->prv))

// If newObj is set, we should use it rather than JSOBJ
#define GET_OBJ(__jsobj, __ptrtc) (GET_TC(__ptrtc)->newObj ? GET_TC(__ptrtc)->newObj : __jsobj)

// Avoid infinite loop caused by the default function
#define DEFAULT_FN_MAX_DEPTH 3

struct PyDictIterState
{
PyObject *keys;
Expand Down Expand Up @@ -432,7 +438,8 @@ static void SetupDictIter(PyObject *dictObj, TypeContext *pc, JSONObjectEncoder

static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObjectEncoder *enc)
{
PyObject *obj, *objRepr, *exc;
PyObject *obj, *objRepr, *exc, *defaultFn, *newObj;
int level = 0;
TypeContext *pc;
PRINTMARK();
if (!_obj)
Expand All @@ -442,6 +449,7 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
}

obj = (PyObject*) _obj;
defaultFn = (PyObject*) enc->prv;

tc->prv = PyObject_Malloc(sizeof(TypeContext));
pc = (TypeContext *) tc->prv;
Expand All @@ -462,6 +470,7 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
pc->longValue = 0;
pc->rawJSONValue = NULL;

BEGIN:
if (PyIter_Check(obj))
{
PRINTMARK();
Expand Down Expand Up @@ -553,7 +562,6 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
return;
}


ISITERABLE:
if (PyDict_Check(obj))
{
Expand Down Expand Up @@ -651,6 +659,31 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
return;
}

DEFAULT:
if (defaultFn)
{
// Break infinite loop
if (level >= DEFAULT_FN_MAX_DEPTH)
{
PRINTMARK();
PyErr_Format(PyExc_TypeError, "maximum recursion depth exceeded");
goto INVALID;
}

newObj = PyObject_CallFunctionObjArgs(defaultFn, obj, NULL);
if (newObj)
{
PRINTMARK();
obj = pc->newObj = newObj;
level += 1;
goto BEGIN;
}
else
{
goto INVALID;
}
}

PRINTMARK();
PyErr_Clear();

Expand Down Expand Up @@ -682,33 +715,38 @@ static void Object_endTypeContext(JSOBJ obj, JSONTypeContext *tc)

static const char *Object_getStringValue(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->PyTypeToJSON (obj, tc, NULL, _outLen);
}

static JSINT64 Object_getLongValue(JSOBJ obj, JSONTypeContext *tc)
{
JSINT64 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}

static JSUINT64 Object_getUnsignedLongValue(JSOBJ obj, JSONTypeContext *tc)
{
JSUINT64 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}

static JSINT32 Object_getIntValue(JSOBJ obj, JSONTypeContext *tc)
{
JSINT32 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}

static double Object_getDoubleValue(JSOBJ obj, JSONTypeContext *tc)
{
double ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}
Expand All @@ -720,27 +758,31 @@ static void Object_releaseObject(JSOBJ _obj)

static int Object_iterNext(JSOBJ obj, JSONTypeContext *tc)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterNext(obj, tc);
}

static void Object_iterEnd(JSOBJ obj, JSONTypeContext *tc)
{
obj = GET_OBJ(obj, tc);
GET_TC(tc)->iterEnd(obj, tc);
}

static JSOBJ Object_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterGetValue(obj, tc);
}

static char *Object_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
{
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterGetName(obj, tc, outLen);
}

PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
{
static char *kwlist[] = { "obj", "ensure_ascii", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", "allow_nan", "reject_bytes", NULL };
static char *kwlist[] = { "obj", "ensure_ascii", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", "allow_nan", "reject_bytes", "default", NULL };

char buffer[65536];
char *ret;
Expand All @@ -751,6 +793,7 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
PyObject *oencodeHTMLChars = NULL;
PyObject *oescapeForwardSlashes = NULL;
PyObject *osortKeys = NULL;
PyObject *odefaultFn = NULL;
int allowNan = -1;
int orejectBytes = -1;

Expand Down Expand Up @@ -785,7 +828,7 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)

PRINTMARK();

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOOOiii", kwlist, &oinput, &oensureAscii, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent, &allowNan, &orejectBytes))
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOOOiiiO", kwlist, &oinput, &oensureAscii, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent, &allowNan, &orejectBytes, &odefaultFn))
{
return NULL;
}
Expand Down Expand Up @@ -815,6 +858,12 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
encoder.allowNan = allowNan;
}

if (odefaultFn != NULL && odefaultFn != Py_None)
{
// Here use prv to store default function
encoder.prv = odefaultFn;
}

if (encoder.allowNan)
{
csInf = "Inf";
Expand Down
37 changes: 37 additions & 0 deletions tests/test_ujson.py
@@ -1,9 +1,11 @@
import datetime as dt
import decimal
import io
import json
import math
import re
import sys
import uuid
from collections import OrderedDict

import pytest
Expand Down Expand Up @@ -828,6 +830,41 @@ def test_encode_none_key():
assert ujson.dumps(data) == '{"null":null}'


def test_default_function():
iso8601_time_format = "%Y-%m-%dT%H:%M:%S.%f"

class CustomObject:
pass

class UnjsonableObject:
pass

def default(value):
if isinstance(value, dt.datetime):
return value.strftime(iso8601_time_format)
elif isinstance(value, uuid.UUID):
return value.hex
elif isinstance(value, CustomObject):
raise ValueError("invalid value")
return value

now = dt.datetime.now()
expected_output = '"%s"' % now.strftime(iso8601_time_format)
assert ujson.dumps(now, default=default) == expected_output

uuid4 = uuid.uuid4()
expected_output = '"%s"' % uuid4.hex
assert ujson.dumps(uuid4, default=default) == expected_output

custom_obj = CustomObject()
with pytest.raises(ValueError, match="invalid value"):
ujson.dumps(custom_obj, default=default)

unjsonable_obj = UnjsonableObject()
with pytest.raises(TypeError, match="maximum recursion depth exceeded"):
ujson.dumps(unjsonable_obj, default=default)


"""
def test_decode_numeric_int_frc_overflow():
input = "X.Y"
Expand Down