Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG, SIMD: Fix 64-bit/8-bit integer division by a scalar #20297

Merged
merged 3 commits into from Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions numpy/core/src/common/simd/intdiv.h
Expand Up @@ -162,11 +162,12 @@ NPY_FINLINE npy_uint64 npyv__divh128_u64(npy_uint64 high, npy_uint64 divisor)
npy_uint32 divisor_hi = divisor >> 32;
npy_uint32 divisor_lo = divisor & 0xFFFFFFFF;
// compute high quotient digit
npy_uint32 quotient_hi = (npy_uint32)(high / divisor_hi);
npy_uint64 quotient_hi = high / divisor_hi;
npy_uint64 remainder = high - divisor_hi * quotient_hi;
npy_uint64 base32 = 1ULL << 32;
while (quotient_hi >= base32 || quotient_hi*divisor_lo > base32*remainder) {
remainder += --divisor_hi;
--quotient_hi;
remainder += divisor_hi;
if (remainder >= base32) {
break;
}
Expand Down Expand Up @@ -200,7 +201,7 @@ NPY_FINLINE npyv_u8x3 npyv_divisor_u8(npy_uint8 d)
default:
l = npyv__bitscan_revnz_u32(d - 1) + 1; // ceil(log2(d))
l2 = (npy_uint8)(1 << l); // 2^l, overflow to 0 if l = 8
m = ((l2 - d) << 8) / d + 1; // multiplier
m = ((npy_uint16)((l2 - d) << 8)) / d + 1; // multiplier
sh1 = 1; sh2 = l - 1; // shift counts
}
npyv_u8x3 divisor;
Expand Down
75 changes: 25 additions & 50 deletions numpy/core/tests/test_simd.py
Expand Up @@ -329,7 +329,7 @@ def test_square(self):
data_square = [x*x for x in data]
square = self.square(vdata)
assert square == data_square

def test_max(self):
"""
Test intrinsics:
Expand Down Expand Up @@ -818,6 +818,7 @@ def test_arithmetic_intdiv(self):
if self._is_fp():
return

int_min = self._int_min()
def trunc_div(a, d):
"""
Divide towards zero works with large integers > 2^53,
Expand All @@ -830,57 +831,31 @@ def trunc_div(a, d):
return a // d
return (a + sign_d - sign_a) // d + 1

int_min = self._int_min() if self._is_signed() else 1
int_max = self._int_max()
rdata = (
0, 1, self.nlanes, int_max-self.nlanes,
int_min, int_min//2 + 1
)
divisors = (1, 2, 9, 13, self.nlanes, int_min, int_max, int_max//2)

for x, d in itertools.product(rdata, divisors):
data = self._data(x)
vdata = self.load(data)
data_divc = [trunc_div(a, d) for a in data]
divisor = self.divisor(d)
divc = self.divc(vdata, divisor)
assert divc == data_divc

if not self._is_signed():
return

safe_neg = lambda x: -x-1 if -x > int_max else -x
# test round divison for signed integers
for x, d in itertools.product(rdata, divisors):
d_neg = safe_neg(d)
data = self._data(x)
data_neg = [safe_neg(a) for a in data]
vdata = self.load(data)
vdata_neg = self.load(data_neg)
divisor = self.divisor(d)
divisor_neg = self.divisor(d_neg)

# round towards zero
data_divc = [trunc_div(a, d_neg) for a in data]
divc = self.divc(vdata, divisor_neg)
assert divc == data_divc
data_divc = [trunc_div(a, d) for a in data_neg]
divc = self.divc(vdata_neg, divisor)
data = [1, -int_min] # to test overflow
data += range(0, 2**8, 2**5)
data += range(0, 2**8, 2**5-1)
bsize = self._scalar_size()
if bsize > 8:
data += range(2**8, 2**16, 2**13)
data += range(2**8, 2**16, 2**13-1)
if bsize > 16:
data += range(2**16, 2**32, 2**29)
data += range(2**16, 2**32, 2**29-1)
if bsize > 32:
data += range(2**32, 2**64, 2**61)
data += range(2**32, 2**64, 2**61-1)
# negate
data += [-x for x in data]
for dividend, divisor in itertools.product(data, data):
divisor = self.setall(divisor)[0] # cast
if divisor == 0:
continue
dividend = self.load(self._data(dividend))
data_divc = [trunc_div(a, divisor) for a in dividend]
divisor_parms = self.divisor(divisor)
divc = self.divc(dividend, divisor_parms)
assert divc == data_divc

# test truncate sign if the dividend is zero
vzero = self.zero()
for d in (-1, -10, -100, int_min//2, int_min):
divisor = self.divisor(d)
divc = self.divc(vzero, divisor)
assert divc == vzero

# test overflow
vmin = self.setall(int_min)
divisor = self.divisor(-1)
divc = self.divc(vmin, divisor)
assert divc == vmin

def test_arithmetic_reduce_sum(self):
"""
Test reduce sum intrinsics:
Expand Down