Skip to content

Commit

Permalink
Merge pull request #597 from Cuda-Chen/improve-mm-dp-ps
Browse files Browse the repository at this point in the history
Remove Kahan algorithm in `_mm_dp_ps`
  • Loading branch information
jserv committed May 18, 2023
2 parents b7417bc + 424279b commit 39d8540
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 47 deletions.
54 changes: 24 additions & 30 deletions sse2neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -779,16 +779,6 @@ FORCE_INLINE __m128 _mm_shuffle_ps_2032(__m128 a, __m128 b)
return vreinterpretq_m128_f32(vcombine_f32(a32, b20));
}

// Kahan summation for accurate summation of floating-point numbers.
// http://blog.zachbjornson.com/2019/08/11/fast-float-summation.html
FORCE_INLINE void _sse2neon_kadd_f32(float *sum, float *c, float y)
{
y -= *c;
float t = *sum + y;
*c = (t - *sum) - y;
*sum = t;
}

#if defined(__ARM_FEATURE_CRYPTO) && \
(defined(__aarch64__) || __has_builtin(__builtin_arm_crypto_vmullp64))
// Wraps vmull_p64
Expand Down Expand Up @@ -6894,40 +6884,44 @@ FORCE_INLINE __m128d _mm_dp_pd(__m128d a, __m128d b, const int imm)
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_dp_ps
FORCE_INLINE __m128 _mm_dp_ps(__m128 a, __m128 b, const int imm)
{
float32x4_t elementwise_prod = _mm_mul_ps(a, b);

#if defined(__aarch64__)
/* shortcuts */
if (imm == 0xFF) {
return _mm_set1_ps(vaddvq_f32(_mm_mul_ps(a, b)));
return _mm_set1_ps(vaddvq_f32(elementwise_prod));
}
if (imm == 0x7F) {
float32x4_t m = _mm_mul_ps(a, b);
m[3] = 0;
return _mm_set1_ps(vaddvq_f32(m));

if ((imm & 0x0F) == 0x0F) {
if (!(imm & (1 << 4)))
elementwise_prod[0] = 0.0f;
if (!(imm & (1 << 5)))
elementwise_prod[1] = 0.0f;
if (!(imm & (1 << 6)))
elementwise_prod[2] = 0.0f;
if (!(imm & (1 << 7)))
elementwise_prod[3] = 0.0f;

return _mm_set1_ps(vaddvq_f32(elementwise_prod));
}
#endif

float s = 0, c = 0;
float32x4_t f32a = vreinterpretq_f32_m128(a);
float32x4_t f32b = vreinterpretq_f32_m128(b);
float s = 0.0f;

/* To improve the accuracy of floating-point summation, Kahan algorithm
* is used for each operation.
*/
if (imm & (1 << 4))
_sse2neon_kadd_f32(&s, &c, f32a[0] * f32b[0]);
s += elementwise_prod[0];
if (imm & (1 << 5))
_sse2neon_kadd_f32(&s, &c, f32a[1] * f32b[1]);
s += elementwise_prod[1];
if (imm & (1 << 6))
_sse2neon_kadd_f32(&s, &c, f32a[2] * f32b[2]);
s += elementwise_prod[2];
if (imm & (1 << 7))
_sse2neon_kadd_f32(&s, &c, f32a[3] * f32b[3]);
s += c;
s += elementwise_prod[3];

float32x4_t res = {
(imm & 0x1) ? s : 0,
(imm & 0x2) ? s : 0,
(imm & 0x4) ? s : 0,
(imm & 0x8) ? s : 0,
(imm & 0x1) ? s : 0.0f,
(imm & 0x2) ? s : 0.0f,
(imm & 0x4) ? s : 0.0f,
(imm & 0x8) ? s : 0.0f,
};
return vreinterpretq_m128_f32(res);
}
Expand Down
48 changes: 31 additions & 17 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8347,25 +8347,39 @@ result_t test_mm_dp_pd(const SSE2NEONTestImpl &impl, uint32_t iter)
return TEST_SUCCESS;
}

result_t test_mm_dp_ps(const SSE2NEONTestImpl &impl, uint32_t iter)
{
const float *_a = impl.mTestFloatPointer1;
const float *_b = impl.mTestFloatPointer2;
const int imm = 0xFF;
__m128 a = load_m128(_a);
__m128 b = load_m128(_b);
__m128 out = _mm_dp_ps(a, b, imm);

float r[4]; /* the reference */
float sum = 0;
#define MM_DP_PS_TEST_CASE_WITH(IMM) \
do { \
const float *_a = impl.mTestFloatPointer1; \
const float *_b = impl.mTestFloatPointer2; \
const int imm = IMM; \
__m128 a = load_m128(_a); \
__m128 b = load_m128(_b); \
__m128 out = _mm_dp_ps(a, b, imm); \
float r[4]; /* the reference */ \
float sum = 0; \
for (size_t i = 0; i < 4; i++) \
sum += ((imm) & (1 << (i + 4))) ? _a[i] * _b[i] : 0; \
for (size_t i = 0; i < 4; i++) \
r[i] = (imm & (1 << i)) ? sum : 0; \
/* the epsilon has to be large enough, otherwise test suite fails. */ \
if (validateFloatEpsilon(out, r[0], r[1], r[2], r[3], 2050.0f) != \
TEST_SUCCESS) \
return TEST_FAIL; \
} while (0)

for (size_t i = 0; i < 4; i++)
sum += ((imm) & (1 << (i + 4))) ? _a[i] * _b[i] : 0;
for (size_t i = 0; i < 4; i++)
r[i] = (imm & (1 << i)) ? sum : 0;
#define GENERATE_MM_DP_PS_TEST_CASES \
MM_DP_PS_TEST_CASE_WITH(0xFF); \
MM_DP_PS_TEST_CASE_WITH(0x7F); \
MM_DP_PS_TEST_CASE_WITH(0x9F); \
MM_DP_PS_TEST_CASE_WITH(0x2F); \
MM_DP_PS_TEST_CASE_WITH(0x0F); \
MM_DP_PS_TEST_CASE_WITH(0x23); \
MM_DP_PS_TEST_CASE_WITH(0xB5);

/* the epsilon has to be large enough, otherwise test suite fails. */
return validateFloatEpsilon(out, r[0], r[1], r[2], r[3], 2050.0f);
result_t test_mm_dp_ps(const SSE2NEONTestImpl &impl, uint32_t iter)
{
GENERATE_MM_DP_PS_TEST_CASES
return TEST_SUCCESS;
}

result_t test_mm_extract_epi32(const SSE2NEONTestImpl &impl, uint32_t iter)
Expand Down

0 comments on commit 39d8540

Please sign in to comment.