Skip to content

Commit

Permalink
AVX-512 hash_chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
oconnor663 committed Jul 17, 2023
1 parent 6052107 commit ce3b4a3
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 4 deletions.
220 changes: 217 additions & 3 deletions c/blake3_avx512_x86-64_unix.S
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
.global _blake3_guts_avx512_compress
.global blake3_guts_avx512_compress_xof
.global _blake3_guts_avx512_compress_xof
.global blake3_guts_avx512_hash_chunks_16_exact
.global _blake3_guts_avx512_hash_chunks_16_exact
.global blake3_guts_avx512_xof_16_exact
.global _blake3_guts_avx512_xof_16_exact
.global blake3_guts_avx512_xof_xor_16_exact
Expand Down Expand Up @@ -2746,8 +2748,8 @@ blake3_guts_avx512_compress_xof:
ret

.p2align 6
_blake3_guts_kernel_16_avx512_exact:
blake3_guts_kernel_16_avx512_exact:
_blake3_guts_avx512_kernel_16:
blake3_guts_avx512_kernel_16:
vpbroadcastd zmm8, dword ptr [BLAKE3_IV_0+rip]
vpbroadcastd zmm9, dword ptr [BLAKE3_IV_1+rip]
vpbroadcastd zmm10, dword ptr [BLAKE3_IV_2+rip]
Expand Down Expand Up @@ -3538,6 +3540,218 @@ blake3_guts_kernel_16_avx512_exact:
vprord zmm4, zmm4, 7
ret

// rdi: block pointer
// esi: [unused]
// rdx: cv
// rcx: counter
// r8d: flags
// r9: out pointer
.p2align 6
_blake3_guts_avx512_hash_blocks_16_exact:
blake3_guts_avx512_hash_blocks_16_exact:
// zmm0-zmm7 are already populated
// load the message words
vmovdqu32 ymm16, ymmword ptr [rdi+0x0*0x400]
vinserti64x4 zmm16, zmm16, ymmword ptr [rdi+0x8*0x400], 0x01
vmovdqu32 ymm17, ymmword ptr [rdi+0x1*0x400]
vinserti64x4 zmm17, zmm17, ymmword ptr [rdi+0x9*0x400], 0x01
vpunpcklqdq zmm8, zmm16, zmm17
vpunpckhqdq zmm9, zmm16, zmm17
vmovdqu32 ymm18, ymmword ptr [rdi+0x2*0x400]
vinserti64x4 zmm18, zmm18, ymmword ptr [rdi+0xa*0x400], 0x01
vmovdqu32 ymm19, ymmword ptr [rdi+0x3*0x400]
vinserti64x4 zmm19, zmm19, ymmword ptr [rdi+0xb*0x400], 0x01
vpunpcklqdq zmm10, zmm18, zmm19
vpunpckhqdq zmm11, zmm18, zmm19
vmovdqu32 ymm16, ymmword ptr [rdi+0x4*0x400]
vinserti64x4 zmm16, zmm16, ymmword ptr [rdi+0xc*0x400], 0x01
vmovdqu32 ymm17, ymmword ptr [rdi+0x5*0x400]
vinserti64x4 zmm17, zmm17, ymmword ptr [rdi+0xd*0x400], 0x01
vpunpcklqdq zmm12, zmm16, zmm17
vpunpckhqdq zmm13, zmm16, zmm17
vmovdqu32 ymm18, ymmword ptr [rdi+0x6*0x400]
vinserti64x4 zmm18, zmm18, ymmword ptr [rdi+0xe*0x400], 0x01
vmovdqu32 ymm19, ymmword ptr [rdi+0x7*0x400]
vinserti64x4 zmm19, zmm19, ymmword ptr [rdi+0xf*0x400], 0x01
vpunpcklqdq zmm14, zmm18, zmm19
vpunpckhqdq zmm15, zmm18, zmm19
vmovdqa32 zmm27, zmmword ptr [INDEX0+rip]
vmovdqa32 zmm31, zmmword ptr [INDEX1+rip]
vshufps zmm16, zmm8, zmm10, 136
vshufps zmm17, zmm12, zmm14, 136
vmovdqa32 zmm20, zmm16
vpermt2d zmm16, zmm27, zmm17
vpermt2d zmm20, zmm31, zmm17
vshufps zmm17, zmm8, zmm10, 221
vshufps zmm30, zmm12, zmm14, 221
vmovdqa32 zmm21, zmm17
vpermt2d zmm17, zmm27, zmm30
vpermt2d zmm21, zmm31, zmm30
vshufps zmm18, zmm9, zmm11, 136
vshufps zmm8, zmm13, zmm15, 136
vmovdqa32 zmm22, zmm18
vpermt2d zmm18, zmm27, zmm8
vpermt2d zmm22, zmm31, zmm8
vshufps zmm19, zmm9, zmm11, 221
vshufps zmm8, zmm13, zmm15, 221
vmovdqa32 zmm23, zmm19
vpermt2d zmm19, zmm27, zmm8
vpermt2d zmm23, zmm31, zmm8
vmovdqu32 ymm24, ymmword ptr [rdi+0x0*0x400+0x20]
vinserti64x4 zmm24, zmm24, ymmword ptr [rdi+0x8*0x400+0x20], 0x01
vmovdqu32 ymm25, ymmword ptr [rdi+0x1*0x400+0x20]
vinserti64x4 zmm25, zmm25, ymmword ptr [rdi+0x9*0x400+0x20], 0x01
vpunpcklqdq zmm8, zmm24, zmm25
vpunpckhqdq zmm9, zmm24, zmm25
vmovdqu32 ymm24, ymmword ptr [rdi+0x2*0x400+0x20]
vinserti64x4 zmm24, zmm24, ymmword ptr [rdi+0xa*0x400+0x20], 0x01
vmovdqu32 ymm25, ymmword ptr [rdi+0x3*0x400+0x20]
vinserti64x4 zmm25, zmm25, ymmword ptr [rdi+0xb*0x400+0x20], 0x01
vpunpcklqdq zmm10, zmm24, zmm25
vpunpckhqdq zmm11, zmm24, zmm25
prefetcht0 [rdi+0x0*0x400+0x80]
prefetcht0 [rdi+0x8*0x400+0x80]
prefetcht0 [rdi+0x1*0x400+0x80]
prefetcht0 [rdi+0x9*0x400+0x80]
prefetcht0 [rdi+0x2*0x400+0x80]
prefetcht0 [rdi+0xa*0x400+0x80]
prefetcht0 [rdi+0x3*0x400+0x80]
prefetcht0 [rdi+0xb*0x400+0x80]
vmovdqu32 ymm24, ymmword ptr [rdi+0x4*0x400+0x20]
vinserti64x4 zmm24, zmm24, ymmword ptr [rdi+0xc*0x400+0x20], 0x01
vmovdqu32 ymm25, ymmword ptr [rdi+0x5*0x400+0x20]
vinserti64x4 zmm25, zmm25, ymmword ptr [rdi+0xd*0x400+0x20], 0x01
vpunpcklqdq zmm12, zmm24, zmm25
vpunpckhqdq zmm13, zmm24, zmm25
vmovdqu32 ymm24, ymmword ptr [rdi+0x6*0x400+0x20]
vinserti64x4 zmm24, zmm24, ymmword ptr [rdi+0xe*0x400+0x20], 0x01
vmovdqu32 ymm25, ymmword ptr [rdi+0x7*0x400+0x20]
vinserti64x4 zmm25, zmm25, ymmword ptr [rdi+0xf*0x400+0x20], 0x01
vpunpcklqdq zmm14, zmm24, zmm25
vpunpckhqdq zmm15, zmm24, zmm25
prefetcht0 [rdi+0x4*0x400+0x80]
prefetcht0 [rdi+0xc*0x400+0x80]
prefetcht0 [rdi+0x5*0x400+0x80]
prefetcht0 [rdi+0xd*0x400+0x80]
prefetcht0 [rdi+0x6*0x400+0x80]
prefetcht0 [rdi+0xe*0x400+0x80]
prefetcht0 [rdi+0x7*0x400+0x80]
prefetcht0 [rdi+0xf*0x400+0x80]
vshufps zmm24, zmm8, zmm10, 136
vshufps zmm30, zmm12, zmm14, 136
vmovdqa32 zmm28, zmm24
vpermt2d zmm24, zmm27, zmm30
vpermt2d zmm28, zmm31, zmm30
vshufps zmm25, zmm8, zmm10, 221
vshufps zmm30, zmm12, zmm14, 221
vmovdqa32 zmm29, zmm25
vpermt2d zmm25, zmm27, zmm30
vpermt2d zmm29, zmm31, zmm30
vshufps zmm26, zmm9, zmm11, 136
vshufps zmm8, zmm13, zmm15, 136
vmovdqa32 zmm30, zmm26
vpermt2d zmm26, zmm27, zmm8
vpermt2d zmm30, zmm31, zmm8
vshufps zmm8, zmm9, zmm11, 221
vshufps zmm10, zmm13, zmm15, 221
vpermi2d zmm27, zmm8, zmm10
vpermi2d zmm31, zmm8, zmm10
// increment and broadcast the counter
vpbroadcastd zmm14,ecx
mov rax, rcx
shr rax,0x20
vpbroadcastd zmm13,eax
vpaddd zmm12,zmm14,ZMMWORD PTR [ADD0+rip]
vpcmpltud k1,zmm12,zmm14
vpaddd zmm13{k1},zmm13,DWORD PTR [ADD1+rip]{1to16}
// broadcast the block length
mov eax, 64
vpbroadcastd zmm14, eax
// broadcast the flags
vpbroadcastd zmm15, r8d

// execute the kernel
call blake3_guts_avx512_kernel_16

// xor the two halves of the state
vpxord zmm0, zmm0, zmm8
vpxord zmm1, zmm1, zmm9
vpxord zmm2, zmm2, zmm10
vpxord zmm3, zmm3, zmm11
vpxord zmm4, zmm4, zmm12
vpxord zmm5, zmm5, zmm13
vpxord zmm6, zmm6, zmm14
vpxord zmm7, zmm7, zmm15
ret

// rdi: block pointer
// esi: [unused]
// rdx: cv
// rcx: counter
// r8d: flags
// r9: out pointer
.p2align 6
_blake3_guts_avx512_hash_chunks_16_exact:
blake3_guts_avx512_hash_chunks_16_exact:
// broadcast the key
vpbroadcastd zmm0,DWORD PTR [rdx]
vpbroadcastd zmm1,DWORD PTR [rdx+0x4]
vpbroadcastd zmm2,DWORD PTR [rdx+0x8]
vpbroadcastd zmm3,DWORD PTR [rdx+0xc]
vpbroadcastd zmm4,DWORD PTR [rdx+0x10]
vpbroadcastd zmm5,DWORD PTR [rdx+0x14]
vpbroadcastd zmm6,DWORD PTR [rdx+0x18]
vpbroadcastd zmm7,DWORD PTR [rdx+0x1c]
// hash sixteen blocks, managing the flags and block pointer
// set CHUNK_START
or r8d, 0x1
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
// clear CHUNK_START
and r8d, 0xFFFFFFFE
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
call blake3_guts_avx512_hash_blocks_16_exact
add rdi, 0x40
// set CHUNK_END
or r8d, 0x2
call blake3_guts_avx512_hash_blocks_16_exact

// write aligned, transposed outputs with a stride of 2*MAX_SIMD_DEGREE words
vmovdqa32 ZMMWORD PTR [r9+0x0*0x80],zmm0
vmovdqa32 ZMMWORD PTR [r9+0x1*0x80],zmm1
vmovdqa32 ZMMWORD PTR [r9+0x2*0x80],zmm2
vmovdqa32 ZMMWORD PTR [r9+0x3*0x80],zmm3
vmovdqa32 ZMMWORD PTR [r9+0x4*0x80],zmm4
vmovdqa32 ZMMWORD PTR [r9+0x5*0x80],zmm5
vmovdqa32 ZMMWORD PTR [r9+0x6*0x80],zmm6
vmovdqa32 ZMMWORD PTR [r9+0x7*0x80],zmm7
ret

// rdi: block pointer
// esi: block_len
// rdx: cv
Expand Down Expand Up @@ -3586,7 +3800,7 @@ blake3_guts_avx512_xof_inner_16_exact:
vpbroadcastd zmm15, r8d

// execute the kernel
call blake3_guts_kernel_16_avx512_exact
call blake3_guts_avx512_kernel_16

// xor the two halves of the state
vpxord zmm0, zmm0, zmm8
Expand Down
22 changes: 21 additions & 1 deletion rust/guts/src/avx512.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{BlockBytes, CVBytes, Implementation, BLOCK_LEN};
use crate::{BlockBytes, CVBytes, Implementation, BLOCK_LEN, CHUNK_LEN};

const DEGREE: usize = 16;

Expand All @@ -19,6 +19,14 @@ extern "C" {
flags: u32,
out: *mut BlockBytes,
);
fn blake3_guts_avx512_hash_chunks_16_exact(
input: *const u8,
input_len: usize,
key: *const CVBytes,
counter: u64,
flags: u32,
transposed_output: *mut u32,
);
fn blake3_guts_avx512_xof_16_exact(
block: *const BlockBytes,
block_len: u32,
Expand All @@ -45,6 +53,18 @@ unsafe extern "C" fn hash_chunks(
flags: u32,
transposed_output: *mut u32,
) {
debug_assert!(input_len <= 16 * CHUNK_LEN);
if input_len == 16 * CHUNK_LEN {
blake3_guts_avx512_hash_chunks_16_exact(
input,
0, // unused
key,
counter,
flags,
transposed_output,
);
return;
}
crate::hash_chunks_using_compress(
blake3_guts_avx512_compress,
input,
Expand Down
1 change: 1 addition & 0 deletions rust/guts/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ pub fn test_hash_chunks_vs_portable(test_impl: &Implementation) {
let input1 = &input[..test_impl.degree() * CHUNK_LEN];
let input2 = &input[test_impl.degree() * CHUNK_LEN..][..input_2_len];
for initial_counter in INITIAL_COUNTERS {
dbg!(initial_counter);
// Make two calls, to test the output_column parameter.
let mut portable_output = TransposedVectors::new();
let (portable_left, portable_right) =
Expand Down

0 comments on commit ce3b4a3

Please sign in to comment.