diff --git a/numpy/_core/src/umath/stringdtype_ufuncs.cpp b/numpy/_core/src/umath/stringdtype_ufuncs.cpp index e28f4100cadc..052c4381a4b5 100644 --- a/numpy/_core/src/umath/stringdtype_ufuncs.cpp +++ b/numpy/_core/src/umath/stringdtype_ufuncs.cpp @@ -1300,7 +1300,9 @@ string_replace_strided_loop( PyArray_StringDTypeObject *descr0 = (PyArray_StringDTypeObject *)context->descriptors[0]; + int has_null = descr0->na_object != NULL; int has_string_na = descr0->has_string_na; + int has_nan_na = descr0->has_nan_na; const npy_static_string *default_string = &descr0->default_string; @@ -1330,11 +1332,29 @@ string_replace_strided_loop( goto fail; } else if (i1_isnull || i2_isnull || i3_isnull) { - if (!has_string_na) { - npy_gil_error(PyExc_ValueError, - "Null values are not supported as replacement arguments " - "for replace"); - goto fail; + if (has_null && !has_string_na) { + if (i2_isnull || i3_isnull) { + npy_gil_error(PyExc_ValueError, + "Null values are not supported as search " + "patterns or replacement strings for " + "replace"); + goto fail; + } + else if (i1_isnull) { + if (has_nan_na) { + if (NpyString_pack_null(oallocator, ops) < 0) { + npy_gil_error(PyExc_MemoryError, + "Failed to deallocate string in replace"); + goto fail; + } + goto next_step; + } + else { + npy_gil_error(PyExc_ValueError, + "Only string or NaN-like null strings can " + "be used as search strings for replace"); + } + } } else { if (i1_isnull) { @@ -1349,32 +1369,51 @@ string_replace_strided_loop( } } - // conservatively overallocate - // TODO check overflow - size_t max_size; - if (i2s.size == 0) { - // interleaving - max_size = i1s.size + (i1s.size + 1)*(i3s.size); - } - else { - // replace i2 with i3 - max_size = i1s.size * (i3s.size/i2s.size + 1); - } - char *new_buf = (char *)PyMem_RawCalloc(max_size, 1); - Buffer buf1((char *)i1s.buf, i1s.size); - Buffer buf2((char *)i2s.buf, i2s.size); - Buffer buf3((char *)i3s.buf, i3s.size); - Buffer outbuf(new_buf, max_size); + { + Buffer buf1((char *)i1s.buf, i1s.size); + Buffer buf2((char *)i2s.buf, i2s.size); - size_t new_buf_size = string_replace( - buf1, buf2, buf3, *(npy_int64 *)in4, outbuf); + npy_int64 in_count = *(npy_int64*)in4; + if (in_count == -1) { + in_count = NPY_MAX_INT64; + } - if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) { - npy_gil_error(PyExc_MemoryError, "Failed to pack string in replace"); - goto fail; - } + npy_int64 found_count = string_count( + buf1, buf2, 0, NPY_MAX_INT64); + if (found_count < 0) { + goto fail; + } - PyMem_RawFree(new_buf); + npy_intp count = Py_MIN(in_count, found_count); + + Buffer buf3((char *)i3s.buf, i3s.size); + + // conservatively overallocate + // TODO check overflow + size_t max_size; + if (i2s.size == 0) { + // interleaving + max_size = i1s.size + (i1s.size + 1)*(i3s.size); + } + else { + // replace i2 with i3 + size_t change = i2s.size >= i3s.size ? 0 : i3s.size - i2s.size; + max_size = i1s.size + count * change; + } + char *new_buf = (char *)PyMem_RawCalloc(max_size, 1); + Buffer outbuf(new_buf, max_size); + + size_t new_buf_size = string_replace( + buf1, buf2, buf3, count, outbuf); + + if (NpyString_pack(oallocator, ops, new_buf, new_buf_size) < 0) { + npy_gil_error(PyExc_MemoryError, "Failed to pack string in replace"); + goto fail; + } + + PyMem_RawFree(new_buf); + } + next_step: in1 += strides[0]; in2 += strides[1]; diff --git a/numpy/_core/strings.py b/numpy/_core/strings.py index d30c4be3d62e..83034705f525 100644 --- a/numpy/_core/strings.py +++ b/numpy/_core/strings.py @@ -1153,15 +1153,15 @@ def replace(a, old, new, count=-1): a_dt = arr.dtype old = np.asanyarray(old, dtype=getattr(old, 'dtype', a_dt)) new = np.asanyarray(new, dtype=getattr(new, 'dtype', a_dt)) + count = np.asanyarray(count) + + if arr.dtype.char == "T": + return _replace(arr, old, new, count) max_int64 = np.iinfo(np.int64).max counts = _count_ufunc(arr, old, 0, max_int64) - count = np.asanyarray(count) counts = np.where(count < 0, counts, np.minimum(counts, count)) - if arr.dtype.char == "T": - return _replace(arr, old, new, counts) - buffersizes = str_len(arr) + counts * (str_len(new) - str_len(old)) out_dtype = f"{arr.dtype.char}{buffersizes.max()}" out = np.empty_like(arr, shape=buffersizes.shape, dtype=out_dtype) diff --git a/numpy/_core/tests/test_stringdtype.py b/numpy/_core/tests/test_stringdtype.py index 8d3c0a381e5b..dd6ac36999e6 100644 --- a/numpy/_core/tests/test_stringdtype.py +++ b/numpy/_core/tests/test_stringdtype.py @@ -1218,6 +1218,7 @@ def test_unary(string_array, unicode_array, function_name): "strip", "lstrip", "rstrip", + "replace" "zfill", ] @@ -1230,7 +1231,6 @@ def test_unary(string_array, unicode_array, function_name): "count", "find", "rfind", - "replace", ] SUPPORTS_NULLS = (