From a8eca17bc5fee4772788c3b5dbbdc843b83d36fa Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Fri, 26 Apr 2024 20:13:32 -0600 Subject: [PATCH 1/4] ENH: add support for nan-like null strings in replace --- numpy/_core/src/umath/stringdtype_ufuncs.cpp | 96 ++++++++++++++------ numpy/_core/strings.py | 8 +- numpy/_core/tests/test_stringdtype.py | 2 +- 3 files changed, 73 insertions(+), 33 deletions(-) diff --git a/numpy/_core/src/umath/stringdtype_ufuncs.cpp b/numpy/_core/src/umath/stringdtype_ufuncs.cpp index e28f4100cadc..a0e2f3db5747 100644 --- a/numpy/_core/src/umath/stringdtype_ufuncs.cpp +++ b/numpy/_core/src/umath/stringdtype_ufuncs.cpp @@ -1300,7 +1300,9 @@ string_replace_strided_loop( PyArray_StringDTypeObject *descr0 = (PyArray_StringDTypeObject *)context->descriptors[0]; + int has_null = descr0->na_object != NULL; int has_string_na = descr0->has_string_na; + int has_nan_na = descr0->has_nan_na; const npy_static_string *default_string = &descr0->default_string; @@ -1330,11 +1332,29 @@ string_replace_strided_loop( goto fail; } else if (i1_isnull || i2_isnull || i3_isnull) { - if (!has_string_na) { - npy_gil_error(PyExc_ValueError, - "Null values are not supported as replacement arguments " - "for replace"); - goto fail; + if (has_null && !has_string_na) { + if (i2_isnull || i3_isnull) { + npy_gil_error(PyExc_ValueError, + "Null values are not supported as search " + "patterns or replacement strings for " + "replace"); + goto fail; + } + else if (i1_isnull) { + if (has_nan_na) { + if (NpyString_pack_null(oallocator, ops) < 0) { + npy_gil_error(PyExc_MemoryError, + "Failed to deallocate string in replace"); + goto fail; + } + goto next_step; + } + else { + npy_gil_error(PyExc_ValueError, + "Only nan-like null values are not supported " + "as search strings for replace"); + } + } } else { if (i1_isnull) { @@ -1349,32 +1369,52 @@ string_replace_strided_loop( } } - // conservatively overallocate - // TODO check overflow - size_t max_size; - if (i2s.size == 0) { - // interleaving - max_size = i1s.size + (i1s.size + 1)*(i3s.size); - } - else { - // replace i2 with i3 - max_size = i1s.size * (i3s.size/i2s.size + 1); - } - char *new_buf = (char *)PyMem_RawCalloc(max_size, 1); - Buffer buf1((char *)i1s.buf, i1s.size); - Buffer buf2((char *)i2s.buf, i2s.size); - Buffer buf3((char *)i3s.buf, i3s.size); - Buffer outbuf(new_buf, max_size); + { + Buffer buf1((char *)i1s.buf, i1s.size); + Buffer buf2((char *)i2s.buf, i2s.size); - size_t new_buf_size = string_replace( - buf1, buf2, buf3, *(npy_int64 *)in4, outbuf); + npy_int64 in_count = *(npy_int64*)in4; + if (in_count == -1) { + in_count = NPY_MAX_INT64; + } - if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) { - npy_gil_error(PyExc_MemoryError, "Failed to pack string in replace"); - goto fail; - } + npy_int64 start = 0; + npy_int64 end = NPY_MAX_INT64; - PyMem_RawFree(new_buf); + npy_int64 found_count = string_count(buf1, buf2, start, end); + if (found_count == -2) { + goto fail; + } + + npy_intp count = Py_MIN(in_count, found_count); + + Buffer buf3((char *)i3s.buf, i3s.size); + + // conservatively overallocate + // TODO check overflow + size_t max_size; + if (i2s.size == 0) { + // interleaving + max_size = i1s.size + (i1s.size + 1)*(i3s.size); + } + else { + // replace i2 with i3 + max_size = i1s.size * (i3s.size/i2s.size + 1); + } + char *new_buf = (char *)PyMem_RawCalloc(max_size, 1); + Buffer outbuf(new_buf, max_size); + + size_t new_buf_size = string_replace( + buf1, buf2, buf3, count, outbuf); + + if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) { + npy_gil_error(PyExc_MemoryError, "Failed to pack string in replace"); + goto fail; + } + + PyMem_RawFree(new_buf); + } + next_step: in1 += strides[0]; in2 += strides[1]; diff --git a/numpy/_core/strings.py b/numpy/_core/strings.py index d30c4be3d62e..83034705f525 100644 --- a/numpy/_core/strings.py +++ b/numpy/_core/strings.py @@ -1153,15 +1153,15 @@ def replace(a, old, new, count=-1): a_dt = arr.dtype old = np.asanyarray(old, dtype=getattr(old, 'dtype', a_dt)) new = np.asanyarray(new, dtype=getattr(new, 'dtype', a_dt)) + count = np.asanyarray(count) + + if arr.dtype.char == "T": + return _replace(arr, old, new, count) max_int64 = np.iinfo(np.int64).max counts = _count_ufunc(arr, old, 0, max_int64) - count = np.asanyarray(count) counts = np.where(count < 0, counts, np.minimum(counts, count)) - if arr.dtype.char == "T": - return _replace(arr, old, new, counts) - buffersizes = str_len(arr) + counts * (str_len(new) - str_len(old)) out_dtype = f"{arr.dtype.char}{buffersizes.max()}" out = np.empty_like(arr, shape=buffersizes.shape, dtype=out_dtype) diff --git a/numpy/_core/tests/test_stringdtype.py b/numpy/_core/tests/test_stringdtype.py index 8d3c0a381e5b..dd6ac36999e6 100644 --- a/numpy/_core/tests/test_stringdtype.py +++ b/numpy/_core/tests/test_stringdtype.py @@ -1218,6 +1218,7 @@ def test_unary(string_array, unicode_array, function_name): "strip", "lstrip", "rstrip", + "replace" "zfill", ] @@ -1230,7 +1231,6 @@ def test_unary(string_array, unicode_array, function_name): "count", "find", "rfind", - "replace", ] SUPPORTS_NULLS = ( From 4cc651b475d7868813de8a326c03406153ab3f7d Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Mon, 29 Apr 2024 13:06:51 -0600 Subject: [PATCH 2/4] MNT: respond to PR feedback --- numpy/_core/src/umath/stringdtype_ufuncs.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/numpy/_core/src/umath/stringdtype_ufuncs.cpp b/numpy/_core/src/umath/stringdtype_ufuncs.cpp index a0e2f3db5747..612799165df2 100644 --- a/numpy/_core/src/umath/stringdtype_ufuncs.cpp +++ b/numpy/_core/src/umath/stringdtype_ufuncs.cpp @@ -1350,9 +1350,9 @@ string_replace_strided_loop( goto next_step; } else { - npy_gil_error(PyExc_ValueError, - "Only nan-like null values are not supported " - "as search strings for replace"); + npy_gil_error(PyExc_ValueError, + "Only NaN-like null strings can be used " + "as as search strings for replace"); } } } @@ -1378,10 +1378,8 @@ string_replace_strided_loop( in_count = NPY_MAX_INT64; } - npy_int64 start = 0; - npy_int64 end = NPY_MAX_INT64; - - npy_int64 found_count = string_count(buf1, buf2, start, end); + npy_int64 found_count = string_count( + buf1, buf2, 0, NPY_MAX_INT64); if (found_count == -2) { goto fail; } From e85e7a5563aa614bb503c12407e323c607df9d87 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Mon, 29 Apr 2024 15:02:11 -0600 Subject: [PATCH 3/4] MNT: typo fix --- numpy/_core/src/umath/stringdtype_ufuncs.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy/_core/src/umath/stringdtype_ufuncs.cpp b/numpy/_core/src/umath/stringdtype_ufuncs.cpp index 612799165df2..15a9e136371a 100644 --- a/numpy/_core/src/umath/stringdtype_ufuncs.cpp +++ b/numpy/_core/src/umath/stringdtype_ufuncs.cpp @@ -1352,7 +1352,7 @@ string_replace_strided_loop( else { npy_gil_error(PyExc_ValueError, "Only NaN-like null strings can be used " - "as as search strings for replace"); + "as search strings for replace"); } } } From ffec40623be371f8f98d8fbc889c3067de3b115e Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Tue, 30 Apr 2024 13:25:52 -0600 Subject: [PATCH 4/4] MNT: respond to sebastian's comments --- numpy/_core/src/umath/stringdtype_ufuncs.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/numpy/_core/src/umath/stringdtype_ufuncs.cpp b/numpy/_core/src/umath/stringdtype_ufuncs.cpp index 15a9e136371a..052c4381a4b5 100644 --- a/numpy/_core/src/umath/stringdtype_ufuncs.cpp +++ b/numpy/_core/src/umath/stringdtype_ufuncs.cpp @@ -1351,8 +1351,8 @@ string_replace_strided_loop( } else { npy_gil_error(PyExc_ValueError, - "Only NaN-like null strings can be used " - "as search strings for replace"); + "Only string or NaN-like null strings can " + "be used as search strings for replace"); } } } @@ -1380,7 +1380,7 @@ string_replace_strided_loop( npy_int64 found_count = string_count( buf1, buf2, 0, NPY_MAX_INT64); - if (found_count == -2) { + if (found_count < 0) { goto fail; } @@ -1397,7 +1397,8 @@ string_replace_strided_loop( } else { // replace i2 with i3 - max_size = i1s.size * (i3s.size/i2s.size + 1); + size_t change = i2s.size >= i3s.size ? 0 : i3s.size - i2s.size; + max_size = i1s.size + count * change; } char *new_buf = (char *)PyMem_RawCalloc(max_size, 1); Buffer outbuf(new_buf, max_size);