Skip to content

Commit

Permalink
ENH: add support for nan-like null strings in string replace (#26355)
Browse files Browse the repository at this point in the history
This fixes an issue similar to the one fixed by #26353.

In particular, right now np.strings.replace calls the count ufunc to get the number of replacements. This is necessary for fixed-width strings, but it turns out to make it impossible to support null strings in replace.

I went ahead and instead found the replacement counts inline in the ufunc loop. This lets me add support for nan-like null strings, which it turns out pandas needs.
  • Loading branch information
ngoldbaum committed Apr 30, 2024
1 parent 05f8351 commit 4e6d2bf
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 33 deletions.
95 changes: 67 additions & 28 deletions numpy/_core/src/umath/stringdtype_ufuncs.cpp
Expand Up @@ -1300,7 +1300,9 @@ string_replace_strided_loop(

PyArray_StringDTypeObject *descr0 =
(PyArray_StringDTypeObject *)context->descriptors[0];
int has_null = descr0->na_object != NULL;
int has_string_na = descr0->has_string_na;
int has_nan_na = descr0->has_nan_na;
const npy_static_string *default_string = &descr0->default_string;


Expand Down Expand Up @@ -1330,11 +1332,29 @@ string_replace_strided_loop(
goto fail;
}
else if (i1_isnull || i2_isnull || i3_isnull) {
if (!has_string_na) {
npy_gil_error(PyExc_ValueError,
"Null values are not supported as replacement arguments "
"for replace");
goto fail;
if (has_null && !has_string_na) {
if (i2_isnull || i3_isnull) {
npy_gil_error(PyExc_ValueError,
"Null values are not supported as search "
"patterns or replacement strings for "
"replace");
goto fail;
}
else if (i1_isnull) {
if (has_nan_na) {
if (NpyString_pack_null(oallocator, ops) < 0) {
npy_gil_error(PyExc_MemoryError,
"Failed to deallocate string in replace");
goto fail;
}
goto next_step;
}
else {
npy_gil_error(PyExc_ValueError,
"Only string or NaN-like null strings can "
"be used as search strings for replace");
}
}
}
else {
if (i1_isnull) {
Expand All @@ -1349,32 +1369,51 @@ string_replace_strided_loop(
}
}

// conservatively overallocate
// TODO check overflow
size_t max_size;
if (i2s.size == 0) {
// interleaving
max_size = i1s.size + (i1s.size + 1)*(i3s.size);
}
else {
// replace i2 with i3
max_size = i1s.size * (i3s.size/i2s.size + 1);
}
char *new_buf = (char *)PyMem_RawCalloc(max_size, 1);
Buffer<ENCODING::UTF8> buf1((char *)i1s.buf, i1s.size);
Buffer<ENCODING::UTF8> buf2((char *)i2s.buf, i2s.size);
Buffer<ENCODING::UTF8> buf3((char *)i3s.buf, i3s.size);
Buffer<ENCODING::UTF8> outbuf(new_buf, max_size);
{
Buffer<ENCODING::UTF8> buf1((char *)i1s.buf, i1s.size);
Buffer<ENCODING::UTF8> buf2((char *)i2s.buf, i2s.size);

size_t new_buf_size = string_replace(
buf1, buf2, buf3, *(npy_int64 *)in4, outbuf);
npy_int64 in_count = *(npy_int64*)in4;
if (in_count == -1) {
in_count = NPY_MAX_INT64;
}

if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in replace");
goto fail;
}
npy_int64 found_count = string_count<ENCODING::UTF8>(
buf1, buf2, 0, NPY_MAX_INT64);
if (found_count < 0) {
goto fail;
}

PyMem_RawFree(new_buf);
npy_intp count = Py_MIN(in_count, found_count);

Buffer<ENCODING::UTF8> buf3((char *)i3s.buf, i3s.size);

// conservatively overallocate
// TODO check overflow
size_t max_size;
if (i2s.size == 0) {
// interleaving
max_size = i1s.size + (i1s.size + 1)*(i3s.size);
}
else {
// replace i2 with i3
size_t change = i2s.size >= i3s.size ? 0 : i3s.size - i2s.size;
max_size = i1s.size + count * change;
}
char *new_buf = (char *)PyMem_RawCalloc(max_size, 1);
Buffer<ENCODING::UTF8> outbuf(new_buf, max_size);

size_t new_buf_size = string_replace(
buf1, buf2, buf3, count, outbuf);

if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) {
npy_gil_error(PyExc_MemoryError, "Failed to pack string in replace");
goto fail;
}

PyMem_RawFree(new_buf);
}
next_step:

in1 += strides[0];
in2 += strides[1];
Expand Down
8 changes: 4 additions & 4 deletions numpy/_core/strings.py
Expand Up @@ -1153,15 +1153,15 @@ def replace(a, old, new, count=-1):
a_dt = arr.dtype
old = np.asanyarray(old, dtype=getattr(old, 'dtype', a_dt))
new = np.asanyarray(new, dtype=getattr(new, 'dtype', a_dt))
count = np.asanyarray(count)

if arr.dtype.char == "T":
return _replace(arr, old, new, count)

max_int64 = np.iinfo(np.int64).max
counts = _count_ufunc(arr, old, 0, max_int64)
count = np.asanyarray(count)
counts = np.where(count < 0, counts, np.minimum(counts, count))

if arr.dtype.char == "T":
return _replace(arr, old, new, counts)

buffersizes = str_len(arr) + counts * (str_len(new) - str_len(old))
out_dtype = f"{arr.dtype.char}{buffersizes.max()}"
out = np.empty_like(arr, shape=buffersizes.shape, dtype=out_dtype)
Expand Down
2 changes: 1 addition & 1 deletion numpy/_core/tests/test_stringdtype.py
Expand Up @@ -1218,6 +1218,7 @@ def test_unary(string_array, unicode_array, function_name):
"strip",
"lstrip",
"rstrip",
"replace"
"zfill",
]

Expand All @@ -1230,7 +1231,6 @@ def test_unary(string_array, unicode_array, function_name):
"count",
"find",
"rfind",
"replace",
]

SUPPORTS_NULLS = (
Expand Down

0 comments on commit 4e6d2bf

Please sign in to comment.