Skip to content

Commit

Permalink
BUG: ensure text padding ufuncs handle stringdtype nulls
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoldbaum committed Apr 26, 2024
1 parent 764f17c commit 197c915
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
12 changes: 11 additions & 1 deletion numpy/_core/src/umath/stringdtype_ufuncs.cpp
Expand Up @@ -1625,12 +1625,19 @@ center_ljust_rjust_strided_loop(PyArrayMethod_Context *context,
Buffer<ENCODING::UTF8> inbuf((char *)s1.buf, s1.size);
Buffer<ENCODING::UTF8> fill((char *)s2.buf, s2.size);

size_t num_codepoints = inbuf.num_codepoints();
npy_intp width = (npy_intp)*(npy_int64*)in2;

if (num_codepoints > (size_t)width) {
width = num_codepoints;
}

char *buf = NULL;
npy_intp newsize;
int overflowed = npy_mul_sizes_with_overflow(
&(newsize),
(npy_intp)num_bytes_for_utf8_character((unsigned char *)s2.buf),
(npy_intp)*(npy_int64*)in2 - inbuf.num_codepoints());
width - num_codepoints);
newsize += s1.size;

if (overflowed) {
Expand Down Expand Up @@ -1752,6 +1759,9 @@ zfill_strided_loop(PyArrayMethod_Context *context,
Buffer<ENCODING::UTF8> inbuf((char *)is.buf, is.size);
size_t in_codepoints = inbuf.num_codepoints();
size_t width = (size_t)*(npy_int64 *)in2;
if (in_codepoints > width) {
width = in_codepoints;
}
// number of leading one-byte characters plus the size of the
// original string
size_t outsize = (width - in_codepoints) + is.size;
Expand Down
8 changes: 4 additions & 4 deletions numpy/_core/strings.py
Expand Up @@ -626,7 +626,6 @@ def center(a, width, fillchar=' '):
"""
a = np.asanyarray(a)
width = np.maximum(str_len(a), width)
fillchar = np.asanyarray(fillchar, dtype=a.dtype)

if np.any(str_len(fillchar) != 1):
Expand All @@ -636,6 +635,7 @@ def center(a, width, fillchar=' '):
if a.dtype.char == "T":
return _center(a, width, fillchar)

width = np.maximum(str_len(a), width)
out_dtype = f"{a.dtype.char}{width.max()}"
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
out = np.empty_like(a, shape=shape, dtype=out_dtype)
Expand Down Expand Up @@ -682,7 +682,6 @@ def ljust(a, width, fillchar=' '):
"""
a = np.asanyarray(a)
width = np.maximum(str_len(a), width)
fillchar = np.asanyarray(fillchar, dtype=a.dtype)

if np.any(str_len(fillchar) != 1):
Expand All @@ -692,6 +691,7 @@ def ljust(a, width, fillchar=' '):
if a.dtype.char == "T":
return _ljust(a, width, fillchar)

width = np.maximum(str_len(a), width)
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
out_dtype = f"{a.dtype.char}{width.max()}"
out = np.empty_like(a, shape=shape, dtype=out_dtype)
Expand Down Expand Up @@ -738,7 +738,6 @@ def rjust(a, width, fillchar=' '):
"""
a = np.asanyarray(a)
width = np.maximum(str_len(a), width)
fillchar = np.asanyarray(fillchar, dtype=a.dtype)

if np.any(str_len(fillchar) != 1):
Expand All @@ -748,6 +747,7 @@ def rjust(a, width, fillchar=' '):
if a.dtype.char == "T":
return _rjust(a, width, fillchar)

width = np.maximum(str_len(a), width)
shape = np.broadcast_shapes(a.shape, width.shape, fillchar.shape)
out_dtype = f"{a.dtype.char}{width.max()}"
out = np.empty_like(a, shape=shape, dtype=out_dtype)
Expand Down Expand Up @@ -784,11 +784,11 @@ def zfill(a, width):
"""
a = np.asanyarray(a)
width = np.maximum(str_len(a), width)

if a.dtype.char == "T":
return _zfill(a, width)

width = np.maximum(str_len(a), width)
shape = np.broadcast_shapes(a.shape, width.shape)
out_dtype = f"{a.dtype.char}{width.max()}"
out = np.empty_like(a, shape=shape, dtype=out_dtype)
Expand Down
4 changes: 4 additions & 0 deletions numpy/_core/tests/test_stringdtype.py
Expand Up @@ -1210,11 +1210,15 @@ def test_unary(string_array, unicode_array, function_name):

PASSES_THROUGH_NAN_NULLS = [
"add",
"center",
"ljust",
"multiply",
"replace",
"rjust",
"strip",
"lstrip",
"rstrip",
"zfill",
]

NULLS_ARE_FALSEY = [
Expand Down

0 comments on commit 197c915

Please sign in to comment.