Skip to content

Commit

Permalink
try doing 512-bit loads
Browse files Browse the repository at this point in the history
  • Loading branch information
oconnor663 committed Dec 23, 2022
1 parent 0c80427 commit 93fd593
Showing 1 changed file with 60 additions and 47 deletions.
107 changes: 60 additions & 47 deletions src/kernel2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,55 +827,39 @@ global_asm!(
"ret",
);

#[inline]
#[target_feature(enable = "avx512f,avx512vl")]
#[inline(always)]
unsafe fn load_transposed_16(input: *const u8) -> [__m512i; 16] {
// We're going to load 16 vectors, each containing 16 words (64 bytes). We assume that these
// vectors are coming from contiguous chunks, so each is offset by CHUNK_LEN (1024 bytes) from
// the last. We'll name the input vectors a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, and p.
// Well denote the words of the input vectors:
//
// a_0, a_1, a_2, a_3, a_4, a_5, a_6, a_7, a_8, a_9, a_a, a_b, a_c, a_d, a_e, a_f
// b_0, b_1, b_2, b_3, b_4, b_5, b_6, b_7, b_8, b_9, b_a, b_b, b_c, b_d, b_e, b_f
// etc.
//
// Our goal is to load and transpose these into output vectors that look like:
//
// a_0, b_0, c_0, d_0, e_0, f_0, g_0, h_0, i_0, j_0, k_0, l_0, m_0, n_0, o_0, p_0
// a_1, b_1, c_1, d_1, e_1, f_1, g_1, h_1, i_1, j_1, k_1, l_1, m_1, n_1, o_1, p_1
// etc.
// lane selectors for _mm512_permutex2var_epi64
let lower_256 = _mm512_setr_epi64(0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xa, 0xb);
let upper_256 = _mm512_setr_epi64(0x4, 0x5, 0x6, 0x7, 0xc, 0xd, 0xe, 0xf);
let lower_128 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xc, 0xd);
let upper_128 = _mm512_setr_epi64(0x2, 0x3, 0xa, 0xb, 0x6, 0x7, 0xe, 0xf);

// Because operations that cross 128-bit lanes are relatively expensive, we split each 512-bit
// load into four 128-bit loads. This results in vectors like:
// a0, a1, a2, a3, e0, e1, e2, e3, i0, i1, i2, i3, m0, m1, m2, m3
#[inline(always)]
unsafe fn load_4_lanes(input: *const u8) -> __m512i {
let lane0 = _mm_loadu_epi32(input.add(0 * CHUNK_LEN) as *const i32);
let lane1 = _mm_loadu_epi32(input.add(4 * CHUNK_LEN) as *const i32);
let lane2 = _mm_loadu_epi32(input.add(8 * CHUNK_LEN) as *const i32);
let lane3 = _mm_loadu_epi32(input.add(12 * CHUNK_LEN) as *const i32);
let ret = _mm512_castsi128_si512(lane0);
let ret = _mm512_inserti32x4::<1>(ret, lane1);
let ret = _mm512_inserti32x4::<2>(ret, lane2);
let ret = _mm512_inserti32x4::<3>(ret, lane3);
ret
}
let aeim_0123 = load_4_lanes(input.add(0 * CHUNK_LEN + 0 * 16));
let aeim_4567 = load_4_lanes(input.add(0 * CHUNK_LEN + 1 * 16));
let aeim_89ab = load_4_lanes(input.add(0 * CHUNK_LEN + 2 * 16));
let aeim_cdef = load_4_lanes(input.add(0 * CHUNK_LEN + 3 * 16));
let bfjn_0123 = load_4_lanes(input.add(1 * CHUNK_LEN + 0 * 16));
let bfjn_4567 = load_4_lanes(input.add(1 * CHUNK_LEN + 1 * 16));
let bfjn_89ab = load_4_lanes(input.add(1 * CHUNK_LEN + 2 * 16));
let bfjn_cdef = load_4_lanes(input.add(1 * CHUNK_LEN + 3 * 16));
let cgko_0123 = load_4_lanes(input.add(2 * CHUNK_LEN + 0 * 16));
let cgko_4567 = load_4_lanes(input.add(2 * CHUNK_LEN + 1 * 16));
let cgko_89ab = load_4_lanes(input.add(2 * CHUNK_LEN + 2 * 16));
let cgko_cdef = load_4_lanes(input.add(2 * CHUNK_LEN + 3 * 16));
let dhlp_0123 = load_4_lanes(input.add(3 * CHUNK_LEN + 0 * 16));
let dhlp_4567 = load_4_lanes(input.add(3 * CHUNK_LEN + 1 * 16));
let dhlp_89ab = load_4_lanes(input.add(3 * CHUNK_LEN + 2 * 16));
let dhlp_cdef = load_4_lanes(input.add(3 * CHUNK_LEN + 3 * 16));
let a = _mm512_loadu_si512(input.add(0x0 * CHUNK_LEN) as *const i32);
let i = _mm512_loadu_si512(input.add(0x8 * CHUNK_LEN) as *const i32);
let ai_01234567 = _mm512_permutex2var_epi64(a, lower_256, i);
let ai_89abcdef = _mm512_permutex2var_epi64(a, upper_256, i);
let e = _mm512_loadu_si512(input.add(0x4 * CHUNK_LEN) as *const i32);
let m = _mm512_loadu_si512(input.add(0xc * CHUNK_LEN) as *const i32);
let em_01234567 = _mm512_permutex2var_epi64(e, lower_256, m);
let em_89abcdef = _mm512_permutex2var_epi64(e, upper_256, m);
let aeim_0123 = _mm512_permutex2var_epi64(ai_01234567, lower_128, em_01234567);
let aeim_4567 = _mm512_permutex2var_epi64(ai_01234567, upper_128, em_01234567);
let aeim_89ab = _mm512_permutex2var_epi64(ai_89abcdef, lower_128, em_89abcdef);
let aeim_cdef = _mm512_permutex2var_epi64(ai_89abcdef, upper_128, em_89abcdef);

let b = _mm512_loadu_si512(input.add(0x1 * CHUNK_LEN) as *const i32);
let j = _mm512_loadu_si512(input.add(0x9 * CHUNK_LEN) as *const i32);
let bj_01234567 = _mm512_permutex2var_epi64(b, lower_256, j);
let bj_89abcdef = _mm512_permutex2var_epi64(b, upper_256, j);
let f = _mm512_loadu_si512(input.add(0x5 * CHUNK_LEN) as *const i32);
let n = _mm512_loadu_si512(input.add(0xd * CHUNK_LEN) as *const i32);
let fn_01234567 = _mm512_permutex2var_epi64(f, lower_256, n);
let fn_89abcdef = _mm512_permutex2var_epi64(f, upper_256, n);
let bfjn_0123 = _mm512_permutex2var_epi64(bj_01234567, lower_128, fn_01234567);
let bfjn_4567 = _mm512_permutex2var_epi64(bj_01234567, upper_128, fn_01234567);
let bfjn_89ab = _mm512_permutex2var_epi64(bj_89abcdef, lower_128, fn_89abcdef);
let bfjn_cdef = _mm512_permutex2var_epi64(bj_89abcdef, upper_128, fn_89abcdef);

// Interleave 32-bit words. This results in vectors like:
// a0, b0, a1, b1, e0, f0, e1, f1, i0, j0, i1, j1, m0, n0, m1, n1
Expand All @@ -887,6 +871,35 @@ unsafe fn load_transposed_16(input: *const u8) -> [__m512i; 16] {
let abefijmn_ab = _mm512_unpackhi_epi32(aeim_89ab, bfjn_89ab);
let abefijmn_cd = _mm512_unpacklo_epi32(aeim_cdef, bfjn_cdef);
let abefijmn_ef = _mm512_unpackhi_epi32(aeim_cdef, bfjn_cdef);

let c = _mm512_loadu_si512(input.add(0x2 * CHUNK_LEN) as *const i32);
let k = _mm512_loadu_si512(input.add(0xa * CHUNK_LEN) as *const i32);
let ck_01234567 = _mm512_permutex2var_epi64(c, lower_256, k);
let ck_89abcdef = _mm512_permutex2var_epi64(c, upper_256, k);
let g = _mm512_loadu_si512(input.add(0x6 * CHUNK_LEN) as *const i32);
let o = _mm512_loadu_si512(input.add(0xe * CHUNK_LEN) as *const i32);
let go_01234567 = _mm512_permutex2var_epi64(g, lower_256, o);
let go_89abcdef = _mm512_permutex2var_epi64(g, upper_256, o);
let cgko_0123 = _mm512_permutex2var_epi64(ck_01234567, lower_128, go_01234567);
let cgko_4567 = _mm512_permutex2var_epi64(ck_01234567, upper_128, go_01234567);
let cgko_89ab = _mm512_permutex2var_epi64(ck_89abcdef, lower_128, go_89abcdef);
let cgko_cdef = _mm512_permutex2var_epi64(ck_89abcdef, upper_128, go_89abcdef);

let d = _mm512_loadu_si512(input.add(0x3 * CHUNK_LEN) as *const i32);
let l = _mm512_loadu_si512(input.add(0xb * CHUNK_LEN) as *const i32);
let dl_01234567 = _mm512_permutex2var_epi64(d, lower_256, l);
let dl_89abcdef = _mm512_permutex2var_epi64(d, upper_256, l);
let h = _mm512_loadu_si512(input.add(0x7 * CHUNK_LEN) as *const i32);
let p = _mm512_loadu_si512(input.add(0xf * CHUNK_LEN) as *const i32);
let hp_01234567 = _mm512_permutex2var_epi64(h, lower_256, p);
let hp_89abcdef = _mm512_permutex2var_epi64(h, upper_256, p);
let dhlp_0123 = _mm512_permutex2var_epi64(dl_01234567, lower_128, hp_01234567);
let dhlp_4567 = _mm512_permutex2var_epi64(dl_01234567, upper_128, hp_01234567);
let dhlp_89ab = _mm512_permutex2var_epi64(dl_89abcdef, lower_128, hp_89abcdef);
let dhlp_cdef = _mm512_permutex2var_epi64(dl_89abcdef, upper_128, hp_89abcdef);

// Interleave 32-bit words. This results in vectors like:
// a0, b0, a1, b1, e0, f0, e1, f1, i0, j0, i1, j1, m0, n0, m1, n1
let cdghklop_01 = _mm512_unpacklo_epi32(cgko_0123, dhlp_0123);
let cdghklop_23 = _mm512_unpackhi_epi32(cgko_0123, dhlp_0123);
let cdghklop_45 = _mm512_unpacklo_epi32(cgko_4567, dhlp_4567);
Expand Down

0 comments on commit 93fd593

Please sign in to comment.