Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a SIMD (AVX2) optimised vector distance function for int7 on x64 (#…
…108088) * Adding support for x64 to native vec library * Fix: aarch64 sqr7u dims * Fix: add symbol stripping (deb lintian) --------- Co-authored-by: Chris Hegarty <62058229+ChrisHegarty@users.noreply.github.com> Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
- Loading branch information
Showing
11 changed files
with
254 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 108088 | ||
summary: Add a SIMD (AVX2) optimised vector distance function for int7 on x64 | ||
area: "Search" | ||
type: enhancement | ||
issues: [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0 and the Server Side Public License, v 1; you may not use this file except | ||
* in compliance with, at your election, the Elastic License 2.0 or the Server | ||
* Side Public License, v 1. | ||
*/ | ||
|
||
#include <stddef.h> | ||
#include <stdint.h> | ||
#include "vec.h" | ||
|
||
#include <emmintrin.h> | ||
#include <immintrin.h> | ||
|
||
#ifndef DOT7U_STRIDE_BYTES_LEN | ||
#define DOT7U_STRIDE_BYTES_LEN 32 // Must be a power of 2 | ||
#endif | ||
|
||
#ifndef SQR7U_STRIDE_BYTES_LEN | ||
#define SQR7U_STRIDE_BYTES_LEN 32 // Must be a power of 2 | ||
#endif | ||
|
||
#ifdef _MSC_VER | ||
#include <intrin.h> | ||
#elif __GNUC__ | ||
#include <x86intrin.h> | ||
#elif __clang__ | ||
#include <x86intrin.h> | ||
#endif | ||
|
||
// Multi-platform CPUID "intrinsic"; it takes as input a "functionNumber" (or "leaf", the eax registry). "Subleaf" | ||
// is always 0. Output is stored in the passed output parameter: output[0] = eax, output[1] = ebx, output[2] = ecx, | ||
// output[3] = edx | ||
static inline void cpuid(int output[4], int functionNumber) { | ||
#if defined(__GNUC__) || defined(__clang__) | ||
// use inline assembly, Gnu/AT&T syntax | ||
int a, b, c, d; | ||
__asm("cpuid" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "a"(functionNumber), "c"(0) : ); | ||
output[0] = a; | ||
output[1] = b; | ||
output[2] = c; | ||
output[3] = d; | ||
|
||
#elif defined (_MSC_VER) | ||
__cpuidex(output, functionNumber, 0); | ||
#else | ||
#error Unsupported compiler | ||
#endif | ||
} | ||
|
||
// Utility function to horizontally add 8 32-bit integers | ||
static inline int hsum_i32_8(const __m256i a) { | ||
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); | ||
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); | ||
const __m128i sum64 = _mm_add_epi32(hi64, sum128); | ||
const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); | ||
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); | ||
} | ||
|
||
EXPORT int vec_caps() { | ||
int cpuInfo[4] = {-1}; | ||
// Calling __cpuid with 0x0 as the function_id argument | ||
// gets the number of the highest valid function ID. | ||
cpuid(cpuInfo, 0); | ||
int functionIds = cpuInfo[0]; | ||
if (functionIds >= 7) { | ||
cpuid(cpuInfo, 7); | ||
int ebx = cpuInfo[1]; | ||
// AVX2 flag is the 5th bit | ||
// We assume that all processors that have AVX2 also have FMA3 | ||
return (ebx & (1 << 5)) != 0; | ||
} | ||
return 0; | ||
} | ||
|
||
static inline int32_t dot7u_inner(int8_t* a, int8_t* b, size_t dims) { | ||
const __m256i ones = _mm256_set1_epi16(1); | ||
|
||
// Init accumulator(s) with 0 | ||
__m256i acc1 = _mm256_setzero_si256(); | ||
|
||
#pragma GCC unroll 4 | ||
for(int i = 0; i < dims; i += DOT7U_STRIDE_BYTES_LEN) { | ||
// Load packed 8-bit integers | ||
__m256i va1 = _mm256_loadu_si256(a + i); | ||
__m256i vb1 = _mm256_loadu_si256(b + i); | ||
|
||
// Perform multiplication and create 16-bit values | ||
// Vertically multiply each unsigned 8-bit integer from va with the corresponding | ||
// 8-bit integer from vb, producing intermediate signed 16-bit integers. | ||
const __m256i vab = _mm256_maddubs_epi16(va1, vb1); | ||
// Horizontally add adjacent pairs of intermediate signed 16-bit integers, and pack the results. | ||
acc1 = _mm256_add_epi32(_mm256_madd_epi16(ones, vab), acc1); | ||
} | ||
|
||
// reduce (horizontally add all) | ||
return hsum_i32_8(acc1); | ||
} | ||
|
||
EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) { | ||
int32_t res = 0; | ||
int i = 0; | ||
if (dims > DOT7U_STRIDE_BYTES_LEN) { | ||
i += dims & ~(DOT7U_STRIDE_BYTES_LEN - 1); | ||
res = dot7u_inner(a, b, i); | ||
} | ||
for (; i < dims; i++) { | ||
res += a[i] * b[i]; | ||
} | ||
return res; | ||
} | ||
|
||
static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, size_t dims) { | ||
// Init accumulator(s) with 0 | ||
__m256i acc1 = _mm256_setzero_si256(); | ||
|
||
const __m256i ones = _mm256_set1_epi16(1); | ||
|
||
#pragma GCC unroll 4 | ||
for(int i = 0; i < dims; i += SQR7U_STRIDE_BYTES_LEN) { | ||
// Load packed 8-bit integers | ||
__m256i va1 = _mm256_loadu_si256(a + i); | ||
__m256i vb1 = _mm256_loadu_si256(b + i); | ||
|
||
const __m256i dist1 = _mm256_sub_epi8(va1, vb1); | ||
const __m256i abs_dist1 = _mm256_sign_epi8(dist1, dist1); | ||
const __m256i sqr1 = _mm256_maddubs_epi16(abs_dist1, abs_dist1); | ||
|
||
acc1 = _mm256_add_epi32(_mm256_madd_epi16(ones, sqr1), acc1); | ||
} | ||
|
||
// reduce (accumulate all) | ||
return hsum_i32_8(acc1); | ||
} | ||
|
||
EXPORT int32_t sqr7u(int8_t* a, int8_t* b, size_t dims) { | ||
int32_t res = 0; | ||
int i = 0; | ||
if (dims > SQR7U_STRIDE_BYTES_LEN) { | ||
i += dims & ~(SQR7U_STRIDE_BYTES_LEN - 1); | ||
res = sqr7u_inner(a, b, i); | ||
} | ||
for (; i < dims; i++) { | ||
int32_t dist = a[i] - b[i]; | ||
res += dist * dist; | ||
} | ||
return res; | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters