From b7ca02f27146591fde90c917184a563b647fe6db Mon Sep 17 00:00:00 2001 From: Developer-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com> Date: Wed, 21 Dec 2022 19:20:42 -0800 Subject: [PATCH] Update logical_and to avoid extra mask ops in AVX512 --- .../src/umath/loops_logical.dispatch.c.src | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/numpy/core/src/umath/loops_logical.dispatch.c.src b/numpy/core/src/umath/loops_logical.dispatch.c.src index 4a021d50b0e3..c07525be402a 100644 --- a/numpy/core/src/umath/loops_logical.dispatch.c.src +++ b/numpy/core/src/umath/loops_logical.dispatch.c.src @@ -44,6 +44,30 @@ NPY_FINLINE npyv_u8 mask_to_true(npyv_b8 v) const npyv_u8 truemask = npyv_setall_u8(1 == 1); return npyv_and_u8(truemask, npyv_cvt_u8_b8(v)); } +/* + * For logical_and, we have to be careful to handle non-bool inputs where + * bits of each operand might not overlap. Example: a = 0x01, b = 0x80 + * Both evaluate to boolean true, however, a & b is false. Return value + * should be consistent with byte_to_true(). + */ +NPY_FINLINE npyv_u8 simd_logical_and_u8(npyv_u8 a, npyv_u8 b) +{ + const npyv_u8 zero = npyv_zero_u8(); + const npyv_u8 truemask = npyv_setall_u8(1 == 1); + npyv_b8 ma = npyv_cmpeq_u8(a, zero); + npyv_b8 mb = npyv_cmpeq_u8(b, zero); + npyv_u8 r = npyv_cvt_u8_b8(npyv_or_b8(ma, mb)); + return npyv_andc_u8(truemask, r); +} +/* + * We don't really need the following, but it simplifies the templating code + * below since it is paired with simd_logical_and_u8() above. + */ +NPY_FINLINE npyv_u8 simd_logical_or_u8(npyv_u8 a, npyv_u8 b) +{ + npyv_u8 r = npyv_or_u8(a, b); + return byte_to_true(r); +} /**begin repeat @@ -63,10 +87,6 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp const int vstep = npyv_nlanes_u8; const int wstep = vstep * UNROLL; -#if @and@ - const npyv_u8 zero = npyv_zero_u8(); -#endif - // Unrolled vectors loop for (; len >= wstep; len -= wstep, ip1 += wstep, ip2 += wstep, op += wstep) { /**begin repeat1 @@ -75,12 +95,8 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp #if UNROLL > @unroll@ npyv_u8 a@unroll@ = npyv_load_u8(ip1 + vstep * @unroll@); npyv_u8 b@unroll@ = npyv_load_u8(ip2 + vstep * @unroll@); -#if @and@ - // a = 0x00/0xff if any bit is set. ensures non-bool inputs are handled properly. - a@unroll@ = npyv_cvt_u8_b8(npyv_cmpgt_u8(a@unroll@, zero)); -#endif - npyv_u8 r@unroll@ = npyv_@intrin@_u8(a@unroll@, b@unroll@); - npyv_store_u8(op + vstep * @unroll@, byte_to_true(r@unroll@)); + npyv_u8 r@unroll@ = simd_logical_@intrin@_u8(a@unroll@, b@unroll@); + npyv_store_u8(op + vstep * @unroll@, r@unroll@); #endif /**end repeat1**/ } @@ -90,12 +106,8 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp for (; len >= vstep; len -= vstep, ip1 += vstep, ip2 += vstep, op += vstep) { npyv_u8 a = npyv_load_u8(ip1); npyv_u8 b = npyv_load_u8(ip2); -#if @and@ - // a = 0x00/0xff if any bit is set. ensures non-bool inputs are handled properly. - a = npyv_cvt_u8_b8(npyv_cmpgt_u8(a, zero)); -#endif - npyv_u8 r = npyv_@intrin@_u8(a, b); - npyv_store_u8(op, byte_to_true(r)); + npyv_u8 r = simd_logical_@intrin@_u8(a, b); + npyv_store_u8(op, r); } // Scalar loop to finish off