Skip to content

Commit

Permalink
Merge pull request #22868 from charris/backport-22855
Browse files Browse the repository at this point in the history
BUG: Fortify string casts against floating point warnings
  • Loading branch information
charris committed Dec 22, 2022
2 parents 38046ce + 190fcb2 commit 3e2a0ba
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/convert_datatype.c
Expand Up @@ -2881,7 +2881,7 @@ add_other_to_and_from_string_cast(
.name = "legacy_cast_to_string",
.nin = 1,
.nout = 1,
.flags = NPY_METH_REQUIRES_PYAPI,
.flags = NPY_METH_REQUIRES_PYAPI | NPY_METH_NO_FLOATINGPOINT_ERRORS,
.dtypes = dtypes,
.slots = slots,
};
Expand Down
12 changes: 9 additions & 3 deletions numpy/core/src/multiarray/scalartypes.c.src
Expand Up @@ -896,15 +896,21 @@ static PyObject *
@name@type_@kind@_either(npy_@name@ val, TrimMode trim_pos, TrimMode trim_sci,
npy_bool sign)
{
npy_@name@ absval;

if (npy_legacy_print_mode <= 113) {
return legacy_@name@_format@kind@(val);
}

absval = val < 0 ? -val : val;
int use_positional;
if (npy_isnan(val) || val == 0) {
use_positional = 1;
}
else {
npy_@name@ absval = val < 0 ? -val : val;
use_positional = absval < 1.e16L && absval >= 1.e-4L;
}

if (absval == 0 || (absval < 1.e16L && absval >= 1.e-4L) ) {
if (use_positional) {
return format_@name@(val, 0, -1, sign, trim_pos, -1, -1, -1);
}
return format_@name@(val, 1, -1, sign, trim_sci, -1, -1, -1);
Expand Down
14 changes: 14 additions & 0 deletions numpy/core/tests/test_strings.py
Expand Up @@ -83,3 +83,17 @@ def test_string_comparisons_empty(op, ufunc, sym, dtypes):
assert_array_equal(op(arr, arr2), expected)
assert_array_equal(ufunc(arr, arr2), expected)
assert_array_equal(np.compare_chararrays(arr, arr2, sym, False), expected)


@pytest.mark.parametrize("str_dt", ["S", "U"])
@pytest.mark.parametrize("float_dt", np.typecodes["AllFloat"])
def test_float_to_string_cast(str_dt, float_dt):
float_dt = np.dtype(float_dt)
fi = np.finfo(float_dt)
arr = np.array([np.nan, np.inf, -np.inf, fi.max, fi.min], dtype=float_dt)
expected = ["nan", "inf", "-inf", repr(fi.max), repr(fi.min)]
if float_dt.kind == 'c':
expected = [f"({r}+0j)" for r in expected]

res = arr.astype(str_dt)
assert_array_equal(res, np.array(expected, dtype=str_dt))

0 comments on commit 3e2a0ba

Please sign in to comment.