Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak on encoding errors when the buffer was resized #549

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/ultrajsonenc.c
Expand Up @@ -986,6 +986,11 @@ char *JSON_EncodeObject(JSOBJ obj, JSONObjectEncoder *enc, char *_buffer, size_t

if (enc->errorMsg)
{
if (enc->heap == 1)
{
// Buffer was realloc'd at some point, or no initial buffer was provided.
enc->free(enc->start);
}
return NULL;
}

Expand Down
8 changes: 4 additions & 4 deletions python/objToJSON.c
Expand Up @@ -876,19 +876,19 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)

dconv_d2s_free(&encoder.d2s);

if (PyErr_Occurred())
if (encoder.errorMsg && !PyErr_Occurred())
{
return NULL;
// If there is an error message and we don't already have a Python exception, set one.
PyErr_Format (PyExc_OverflowError, "%s", encoder.errorMsg);
}

if (encoder.errorMsg)
if (PyErr_Occurred())
{
if (ret != buffer)
{
encoder.free (ret);
}

PyErr_Format (PyExc_OverflowError, "%s", encoder.errorMsg);
return NULL;
}

Expand Down
45 changes: 45 additions & 0 deletions tests/memory.py
@@ -0,0 +1,45 @@
import gc
import sys
import tracemalloc

# exec the first argument to get func() and n
exec_globals = {}
exec(sys.argv[1], exec_globals)
func = exec_globals["func"]
n = int(sys.argv[2]) if sys.argv[2:] else 1

# Pre-run once
try:
func()
except Exception:
pass

# Create filter to only report leaks on the 'tracemalloc: measure' line below
filters = []
with open(__file__) as fp:
for i, line in enumerate(fp, start=1):
if "tracemalloc: measure" in line:
filters.append(tracemalloc.Filter(True, __file__, i))

# Clean up and take a snapshot
tracemalloc.start()
gc.collect()
before = tracemalloc.take_snapshot().filter_traces(filters)

# Run
for i in range(n):
try:
func() # tracemalloc: measure
except Exception:
pass

# Clean up and second snapshot
gc.collect()
after = tracemalloc.take_snapshot().filter_traces(filters)

# Check that nothing got leaked
diff = after.compare_to(before, "lineno")
if diff:
for stat in diff:
print(stat)
sys.exit(1)
18 changes: 18 additions & 0 deletions tests/test_ujson.py
Expand Up @@ -3,7 +3,9 @@
import io
import json
import math
import os.path
import re
import subprocess
import sys
import uuid
from collections import OrderedDict
Expand Down Expand Up @@ -1026,6 +1028,22 @@ def __str__(self):
ujson.dumps({Obj(): 1}, sort_keys=sort_keys)


def no_memory_leak(func_code, n=None):
code = f"import functools, ujson; func = {func_code}"
path = os.path.join(os.path.dirname(__file__), "memory.py")
n = [str(n)] if n is not None else []
p = subprocess.run([sys.executable, path, code] + n)
assert p.returncode == 0


@pytest.mark.skipif(
hasattr(sys, "pypy_version_info"), reason="PyPy uses incompatible GC"
)
@pytest.mark.parametrize("input", ['["a" * 11000, b""]'])
def test_no_memory_leak_encoding_errors(input):
no_memory_leak(f"functools.partial(ujson.dumps, {input})")


"""
def test_decode_numeric_int_frc_overflow():
input = "X.Y"
Expand Down