Skip to content

Commit

Permalink
Update logical_and to avoid extra mask ops in AVX512
Browse files Browse the repository at this point in the history
  • Loading branch information
Developer-Ecosystem-Engineering committed Dec 22, 2022
1 parent b0919fe commit b7ca02f
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions numpy/core/src/umath/loops_logical.dispatch.c.src
Expand Up @@ -44,6 +44,30 @@ NPY_FINLINE npyv_u8 mask_to_true(npyv_b8 v)
const npyv_u8 truemask = npyv_setall_u8(1 == 1);
return npyv_and_u8(truemask, npyv_cvt_u8_b8(v));
}
/*
* For logical_and, we have to be careful to handle non-bool inputs where
* bits of each operand might not overlap. Example: a = 0x01, b = 0x80
* Both evaluate to boolean true, however, a & b is false. Return value
* should be consistent with byte_to_true().
*/
NPY_FINLINE npyv_u8 simd_logical_and_u8(npyv_u8 a, npyv_u8 b)
{
const npyv_u8 zero = npyv_zero_u8();
const npyv_u8 truemask = npyv_setall_u8(1 == 1);
npyv_b8 ma = npyv_cmpeq_u8(a, zero);
npyv_b8 mb = npyv_cmpeq_u8(b, zero);
npyv_u8 r = npyv_cvt_u8_b8(npyv_or_b8(ma, mb));
return npyv_andc_u8(truemask, r);
}
/*
* We don't really need the following, but it simplifies the templating code
* below since it is paired with simd_logical_and_u8() above.
*/
NPY_FINLINE npyv_u8 simd_logical_or_u8(npyv_u8 a, npyv_u8 b)
{
npyv_u8 r = npyv_or_u8(a, b);
return byte_to_true(r);
}


/**begin repeat
Expand All @@ -63,10 +87,6 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp
const int vstep = npyv_nlanes_u8;
const int wstep = vstep * UNROLL;

#if @and@
const npyv_u8 zero = npyv_zero_u8();
#endif
// Unrolled vectors loop
for (; len >= wstep; len -= wstep, ip1 += wstep, ip2 += wstep, op += wstep) {
/**begin repeat1
Expand All @@ -75,12 +95,8 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp
#if UNROLL > @unroll@
npyv_u8 a@unroll@ = npyv_load_u8(ip1 + vstep * @unroll@);
npyv_u8 b@unroll@ = npyv_load_u8(ip2 + vstep * @unroll@);
#if @and@
// a = 0x00/0xff if any bit is set. ensures non-bool inputs are handled properly.
a@unroll@ = npyv_cvt_u8_b8(npyv_cmpgt_u8(a@unroll@, zero));
#endif
npyv_u8 r@unroll@ = npyv_@intrin@_u8(a@unroll@, b@unroll@);
npyv_store_u8(op + vstep * @unroll@, byte_to_true(r@unroll@));
npyv_u8 r@unroll@ = simd_logical_@intrin@_u8(a@unroll@, b@unroll@);
npyv_store_u8(op + vstep * @unroll@, r@unroll@);
#endif
/**end repeat1**/
}
Expand All @@ -90,12 +106,8 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp
for (; len >= vstep; len -= vstep, ip1 += vstep, ip2 += vstep, op += vstep) {
npyv_u8 a = npyv_load_u8(ip1);
npyv_u8 b = npyv_load_u8(ip2);
#if @and@
// a = 0x00/0xff if any bit is set. ensures non-bool inputs are handled properly.
a = npyv_cvt_u8_b8(npyv_cmpgt_u8(a, zero));
#endif
npyv_u8 r = npyv_@intrin@_u8(a, b);
npyv_store_u8(op, byte_to_true(r));
npyv_u8 r = simd_logical_@intrin@_u8(a, b);
npyv_store_u8(op, r);
}

// Scalar loop to finish off
Expand Down

0 comments on commit b7ca02f

Please sign in to comment.