Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: support nan-like null strings in [l,r]strip #26392

Merged
merged 2 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 60 additions & 26 deletions numpy/_core/src/umath/stringdtype_ufuncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,7 @@ string_lrstrip_chars_strided_loop(
PyArray_StringDTypeObject *s1descr = (PyArray_StringDTypeObject *)context->descriptors[0];
int has_null = s1descr->na_object != NULL;
int has_string_na = s1descr->has_string_na;
int has_nan_na = s1descr->has_nan_na;

const npy_static_string *default_string = &s1descr->default_string;
npy_intp N = dimensions[0];
Expand All @@ -1072,28 +1073,46 @@ string_lrstrip_chars_strided_loop(
s2 = *default_string;
}
}
else if (has_nan_na) {
if (s2_isnull) {
npy_gil_error(PyExc_ValueError,
"Cannot use a null string that is not a "
"string as the %s delimiter", ufunc_name);
}
if (s1_isnull) {
if (NpyString_pack_null(oallocator, ops) < 0) {
npy_gil_error(PyExc_MemoryError,
"Failed to deallocate string in %s",
ufunc_name);
goto fail;
}
goto next_step;
}
}
else {
npy_gil_error(PyExc_ValueError,
"Cannot strip null values that are not strings");
"Can only strip null values that are strings "
"or NaN-like values");
goto fail;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One day might be nice to see if this check isn't common among a few ufuncs and just writing ufunc-name() cannot support ... with a single helper?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's definitely some code duplication. I added this as a task to #25693.

}
}
{
char *new_buf = (char *)PyMem_RawCalloc(s1.size, 1);
Buffer<ENCODING::UTF8> buf1((char *)s1.buf, s1.size);
Buffer<ENCODING::UTF8> buf2((char *)s2.buf, s2.size);
Buffer<ENCODING::UTF8> outbuf(new_buf, s1.size);
size_t new_buf_size = string_lrstrip_chars
(buf1, buf2, outbuf, striptype);

if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
ufunc_name);
goto fail;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, missing the PyMem_RawFree(new_buf);? I probably said that before, but it would seem nice to avoid reallocating a small buffer every time here (maybe directly allocating the result although that isn't possible if the operation is in-place).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this error path indicates there's memory corruption or a bug in the implementation so a memory leak is the least of your worries if you get here. That said, easy enough to add a PyMem_RawFree, thanks for the catch.

}

char *new_buf = (char *)PyMem_RawCalloc(s1.size, 1);
Buffer<ENCODING::UTF8> buf1((char *)s1.buf, s1.size);
Buffer<ENCODING::UTF8> buf2((char *)s2.buf, s2.size);
Buffer<ENCODING::UTF8> outbuf(new_buf, s1.size);
size_t new_buf_size = string_lrstrip_chars
(buf1, buf2, outbuf, striptype);

if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
ufunc_name);
goto fail;
PyMem_RawFree(new_buf);
}

PyMem_RawFree(new_buf);
next_step:

in1 += strides[0];
in2 += strides[1];
Expand Down Expand Up @@ -1150,8 +1169,9 @@ string_lrstrip_whitespace_strided_loop(
const char *ufunc_name = ((PyUFuncObject *)context->caller)->name;
STRIPTYPE striptype = *(STRIPTYPE *)context->method->static_data;
PyArray_StringDTypeObject *descr = (PyArray_StringDTypeObject *)context->descriptors[0];
int has_string_na = descr->has_string_na;
int has_null = descr->na_object != NULL;
int has_string_na = descr->has_string_na;
int has_nan_na = descr->has_nan_na;
const npy_static_string *default_string = &descr->default_string;

npy_string_allocator *allocators[2] = {};
Expand All @@ -1169,6 +1189,7 @@ string_lrstrip_whitespace_strided_loop(
npy_static_string s = {0, NULL};
int s_isnull = NpyString_load(allocator, ps, &s);


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

But doesn't matter.

if (s_isnull == -1) {
npy_gil_error(PyExc_MemoryError, "Failed to load string in %s",
ufunc_name);
Expand All @@ -1181,26 +1202,39 @@ string_lrstrip_whitespace_strided_loop(
if (has_string_na || !has_null) {
s = *default_string;
}
else if (has_nan_na) {
if (NpyString_pack_null(oallocator, ops) < 0) {
npy_gil_error(PyExc_MemoryError,
"Failed to deallocate string in %s",
ufunc_name);
goto fail;
}
goto next_step;
}
else {
npy_gil_error(PyExc_ValueError,
"Cannot strip null values that are not strings");
"Can only strip null values that are strings or "
"NaN-like values");
goto fail;
}
}
{
char *new_buf = (char *)PyMem_RawCalloc(s.size, 1);
Buffer<ENCODING::UTF8> buf((char *)s.buf, s.size);
Buffer<ENCODING::UTF8> outbuf(new_buf, s.size);
size_t new_buf_size = string_lrstrip_whitespace(
buf, outbuf, striptype);

char *new_buf = (char *)PyMem_RawCalloc(s.size, 1);
Buffer<ENCODING::UTF8> buf((char *)s.buf, s.size);
Buffer<ENCODING::UTF8> outbuf(new_buf, s.size);
size_t new_buf_size = string_lrstrip_whitespace(
buf, outbuf, striptype);
if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
ufunc_name);
goto fail;
}

if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in %s",
ufunc_name);
goto fail;
PyMem_RawFree(new_buf);
}

PyMem_RawFree(new_buf);
next_step:

in += strides[0];
out += strides[1];
Expand Down
41 changes: 36 additions & 5 deletions numpy/_core/tests/test_stringdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,13 @@ def unicode_array():
"capitalize",
"expandtabs",
"lower",
"splitlines" "swapcase" "title" "upper",
"lstrip",
"rstrip",
"splitlines",
"strip",
"swapcase",
"title",
"upper",
]

BOOL_OUTPUT_FUNCTIONS = [
Expand All @@ -1107,7 +1113,10 @@ def unicode_array():
"istitle",
"isupper",
"lower",
"lstrip",
"rstrip",
"splitlines",
"strip",
"swapcase",
"title",
"upper",
Expand All @@ -1129,10 +1138,20 @@ def unicode_array():
"upper",
]

ONLY_IN_NP_CHAR = [
"join",
"split",
"rsplit",
"splitlines"
]


@pytest.mark.parametrize("function_name", UNARY_FUNCTIONS)
def test_unary(string_array, unicode_array, function_name):
func = getattr(np.char, function_name)
if function_name in ONLY_IN_NP_CHAR:
func = getattr(np.char, function_name)
else:
func = getattr(np.strings, function_name)
dtype = string_array.dtype
sres = func(string_array)
ures = func(unicode_array)
Expand Down Expand Up @@ -1173,6 +1192,10 @@ def test_unary(string_array, unicode_array, function_name):
with pytest.raises(ValueError):
func(na_arr)
return
if not (is_nan or is_str):
with pytest.raises(ValueError):
func(na_arr)
return
res = func(na_arr)
if is_nan and function_name in NAN_PRESERVING_FUNCTIONS:
assert res[0] is dtype.na_object
Expand All @@ -1197,13 +1220,17 @@ def test_unary(string_array, unicode_array, function_name):
("index", (None, "e")),
("join", ("-", None)),
("ljust", (None, 12)),
("lstrip", (None, "A")),
("partition", (None, "A")),
("replace", (None, "A", "B")),
("rfind", (None, "A")),
("rindex", (None, "e")),
("rjust", (None, 12)),
("rsplit", (None, "A")),
("rstrip", (None, "A")),
("rpartition", (None, "A")),
("split", (None, "A")),
("strip", (None, "A")),
("startswith", (None, "A")),
("zfill", (None, 12)),
]
Expand Down Expand Up @@ -1260,10 +1287,13 @@ def call_func(func, args, array, sanitize=True):

@pytest.mark.parametrize("function_name, args", BINARY_FUNCTIONS)
def test_binary(string_array, unicode_array, function_name, args):
func = getattr(np.char, function_name)
if function_name in ONLY_IN_NP_CHAR:
func = getattr(np.char, function_name)
else:
func = getattr(np.strings, function_name)
sres = call_func(func, args, string_array)
ures = call_func(func, args, unicode_array, sanitize=False)
if sres.dtype == StringDType():
if not isinstance(sres, tuple) and sres.dtype == StringDType():
ures = ures.astype(StringDType())
assert_array_equal(sres, ures)

Expand Down Expand Up @@ -1462,7 +1492,8 @@ def test_setup(self):
view = self.get_view(self.a)
sizes = np.where(is_short, view['size_and_flags'] & 0xf,
view['size'])
assert_array_equal(sizes, np.strings.str_len(self.a))
assert_array_equal(sizes, np.strings
.str_len(self.a))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert_array_equal(sizes, np.strings
.str_len(self.a))
assert_array_equal(sizes, np.strings.str_len(self.a))

assert_array_equal(view['xsiz'][2:],
np.void(b'\x00' * (self.sizeofstr // 4 - 1)))
# Check that the medium string uses only 1 byte for its length
Expand Down