Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Second Attempt of VAES (Alternative Proposal) #187

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

SchrodingerZhu
Copy link

@SchrodingerZhu SchrodingerZhu commented Nov 21, 2023

This is another proposal of VAES where the batched 128-byte function try its best to keep in the same form of 64-byte loops.

It starts up with mixed aesenc/aesdec:

image

Compared with #186, this is more similar to the original algorithm (it is just running the original algorithm in two interleaved blocks in parallel). Performance is similar to #186, but this can be slower when the size is only marginally larger than 128 bytes.

@SchrodingerZhu SchrodingerZhu marked this pull request as ready for review November 21, 2023 22:56
@SchrodingerZhu SchrodingerZhu changed the title Vaes 3rd Second Attempt of VAES (Alternative Proposal) Nov 21, 2023
@SchrodingerZhu
Copy link
Author

SchrodingerZhu commented Nov 21, 2023

without VAES
image


With VAES:
image

@SchrodingerZhu
Copy link
Author

From e62fa50c585c52da99cb69c1e615a5e8752fefae Mon Sep 17 00:00:00 2001
From: Schrodinger ZHU Yifan <yifanzhu@rochester.edu>
Date: Tue, 21 Nov 2023 17:51:46 -0500
Subject: [PATCH] make it similar to updated version

---
 src/aes_hash.rs   | 47 ++++++++++++++++++++++++----------
 src/operations.rs | 65 +++++------------------------------------------
 2 files changed, 40 insertions(+), 72 deletions(-)

diff --git a/src/aes_hash.rs b/src/aes_hash.rs
index ed7052d..01d68d0 100644
--- a/src/aes_hash.rs
+++ b/src/aes_hash.rs
@@ -28,17 +28,36 @@ pub struct AHasher {
     all(feature = "nightly-arm-aes", target_arch = "arm", target_feature = "aes", not(miri)),
 ))]
 fn hash_batch_128b(data: &mut &[u8], hasher: &mut AHasher) {
-    let tail = read_last4_vec256(data);
+    let tail = data.read_last_u128x8();
+    let current = [
+        aesenc(hasher.key, tail[0]),
+        aesdec(hasher.key, tail[1]),
+        aesenc(hasher.key, tail[2]),
+        aesdec(hasher.key, tail[3]),
+        aesenc(hasher.key, tail[4]),
+        aesdec(hasher.key, tail[5]),
+        aesenc(hasher.key, tail[6]),
+        aesdec(hasher.key, tail[7]),
+    ];
     let mut current = [
-        aesenc_vec256(convert_u128_to_vec256(hasher.key), tail[0]),
-        aesdec_vec256(convert_u128_to_vec256(hasher.key), tail[1]),
-        aesenc_vec256(convert_u128_to_vec256(hasher.key), tail[2]),
-        aesdec_vec256(convert_u128_to_vec256(hasher.key), tail[3]),
+        convert_u128_to_vec256(current[0], current[1]),
+        convert_u128_to_vec256(current[2], current[3]),
+        convert_u128_to_vec256(current[4], current[5]),
+        convert_u128_to_vec256(current[6], current[7]),
+    ];
+    let mut sum: [Vector256; 2] = [
+        convert_u128_to_vec256(hasher.key, !hasher.key), 
+        convert_u128_to_vec256(hasher.key, !hasher.key)
+    ];
+    let tail = [
+        convert_u128_to_vec256(tail[0], tail[1]),
+        convert_u128_to_vec256(tail[2], tail[3]),
+        convert_u128_to_vec256(tail[4], tail[5]),
+        convert_u128_to_vec256(tail[6], tail[7]),
     ];
-    let mut sum: [Vector256; 2] = [convert_u128_to_vec256(hasher.key), convert_u128_to_vec256(!hasher.key)];
     sum[0] = add_by_64s_vec256(sum[0], tail[0]);
-    sum[1] = add_by_64s_vec256(sum[1], tail[1]);
-    sum[0] = shuffle_and_add_vec256(sum[0], tail[2]);
+    sum[0] = add_by_64s_vec256(sum[0], tail[1]);
+    sum[1] = shuffle_and_add_vec256(sum[1], tail[2]);
     sum[1] = shuffle_and_add_vec256(sum[1], tail[3]);
     while data.len() > 128 {
         let (blocks, rest) = read4_vec256(data);
@@ -47,8 +66,8 @@ fn hash_batch_128b(data: &mut &[u8], hasher: &mut AHasher) {
         current[2] = aesdec_vec256(current[2], blocks[2]);
         current[3] = aesdec_vec256(current[3], blocks[3]);
         sum[0] = shuffle_and_add_vec256(sum[0], blocks[0]);
-        sum[1] = shuffle_and_add_vec256(sum[1], blocks[1]);
-        sum[0] = shuffle_and_add_vec256(sum[0], blocks[2]);
+        sum[0] = shuffle_and_add_vec256(sum[0], blocks[1]);
+        sum[1] = shuffle_and_add_vec256(sum[1], blocks[2]);
         sum[1] = shuffle_and_add_vec256(sum[1], blocks[3]);
         *data = rest;
     }
@@ -60,13 +79,13 @@ fn hash_batch_128b(data: &mut &[u8], hasher: &mut AHasher) {
     ];
     let sum = [convert_vec256_to_u128(sum[0]), convert_vec256_to_u128(sum[1])];
     hasher.hash_in_2(
-        aesenc(current[0][0], current[0][1]),
-        aesenc(current[1][0], current[1][1]),
+        aesdec(current[0][0], current[0][1]),
+        aesdec(current[1][0], current[1][1]),
     );
     hasher.hash_in(add_by_64s(sum[0][0].convert(), sum[0][1].convert()).convert());
     hasher.hash_in_2(
-        aesenc(current[2][0], current[2][1]),
-        aesenc(current[3][0], current[3][1]),
+        aesdec(current[2][0], current[2][1]),
+        aesdec(current[3][0], current[3][1]),
     );
     hasher.hash_in(add_by_64s(sum[1][0].convert(), sum[1][1].convert()).convert());
 }
diff --git a/src/operations.rs b/src/operations.rs
index 911202c..fe5c0f7 100644
--- a/src/operations.rs
+++ b/src/operations.rs
@@ -201,29 +201,6 @@ mod vaes {
         }
     }
 
-    #[inline(always)]
-    pub(crate) fn aesenc_vec256(value: Vector256, xor: Vector256) -> Vector256 {
-        cfg_if::cfg_if! {
-            if #[cfg(all(
-                    any(target_arch = "x86", target_arch = "x86_64"),
-                    target_feature = "vaes",
-                    feature = "vaes",
-                    not(miri)
-                ))] {
-                use core::arch::x86_64::*;
-                unsafe {
-                    _mm256_aesenc_epi128(value, xor)
-                }
-            }
-            else {
-                    [
-                        aesenc(value[0], xor[0]),
-                        aesenc(value[1], xor[1]),
-                    ]
-            }
-        }
-    }
-
     #[inline(always)]
     pub(crate) fn aesdec_vec256(value: Vector256, xor: Vector256) -> Vector256 {
         cfg_if::cfg_if! {
@@ -279,7 +256,7 @@ mod vaes {
             ))] {
                 unsafe {
                     use core::arch::x86_64::*;
-                    let mask = convert_u128_to_vec256(SHUFFLE_MASK);
+                    let mask = convert_u128_to_vec256(SHUFFLE_MASK, SHUFFLE_MASK);
                     _mm256_shuffle_epi8(value, mask)
                 }
             }
@@ -327,35 +304,7 @@ mod vaes {
 
     // We specialize this routine because sometimes the compiler is not able to
     // optimize it properly.
-    pub(crate) fn read_last4_vec256(data: &[u8]) -> [Vector256; 4] {
-        cfg_if::cfg_if! {
-            if #[cfg(all(
-                any(target_arch = "x86", target_arch = "x86_64"),
-                target_feature = "vaes",
-                feature = "vaes",
-                not(miri)
-            ))] {
-                use core::arch::x86_64::*;
-                let (_, arr) = data.split_at(data.len() - 128);
-                let arr = unsafe {
-                   [ _mm256_loadu_si256(arr.as_ptr().cast::<__m256i>()),
-                     _mm256_loadu_si256(arr.as_ptr().add(32).cast::<__m256i>()),
-                     _mm256_loadu_si256(arr.as_ptr().add(64).cast::<__m256i>()),
-                     _mm256_loadu_si256(arr.as_ptr().add(96).cast::<__m256i>()),
-                   ]
-                };
-                arr
-            }
-            else {
-                let arr = data.read_last_u128x8();
-                transmute!(arr)
-            }
-        }
-    }
-
-    // We specialize this routine because sometimes the compiler is not able to
-    // optimize it properly.
-    pub(crate) fn convert_u128_to_vec256(x: u128) -> Vector256 {
+    pub(crate) fn convert_u128_to_vec256(low: u128, high: u128) -> Vector256 {
         cfg_if::cfg_if! {
             if #[cfg(all(
                 any(target_arch = "x86", target_arch = "x86_64"),
@@ -366,15 +315,15 @@ mod vaes {
                 use core::arch::x86_64::*;
                 unsafe {
                     _mm256_set_epi64x(
-                        (x >> 64) as i64,
-                        x as i64,
-                        (x >> 64) as i64,
-                        x as i64,
+                        (high >> 64) as i64,
+                        high as i64,
+                        (low >> 64) as i64,
+                        low as i64,
                     )
                 }
             }
             else {
-                transmute!([x, x])
+                transmute!([low, high])
             }
         }
     }
-- 
2.43.0

@SchrodingerZhu
Copy link
Author

@tkaitchuck a gentle ping~

sum[0] = add_by_64s_vec256(sum[0], tail[1]);
sum[1] = shuffle_and_add_vec256(sum[1], tail[2]);
sum[1] = shuffle_and_add_vec256(sum[1], tail[3]);
while data.len() > 128 {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could go even further to preserve the algorithm. In existing code it is doing 4 blocks of 128 bits in parallel, Here it is doing 4 blocks each of which is 2 inner blocks of 128 which are parallelized by the vaes instruction. But could instead make it 2 blocks each of which is 2 inner blocks parallelized by vaes which would make the output identical with and without the vaes instruction. This would make it much easier to audit. It would also eliminate the need for the additional if check which will improve the speed of variable length strings.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested before with aesenc, the AMD Zen3 CPU at least performs worse with vaes when processing same amount data with classical aes. The performance benefit does not appear unless with larger batch sizes. However, it was a long time ago, maybe I should rerun it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do see some speedup with

 if data.len() > 64 {
                    let tail = data.read_last_u128x4();
                    let mut current: [u128; 4] = [self.key; 4];
                    current[0] = aesenc(current[0], tail[0]);
                    current[1] = aesdec(current[1], tail[1]);
                    current[2] = aesenc(current[2], tail[2]);
                    current[3] = aesdec(current[3], tail[3]);
                    let mut sum: [u128; 2] = [self.key, !self.key];
                    sum[0] = add_by_64s(sum[0].convert(), tail[0].convert()).convert();
                    sum[1] = add_by_64s(sum[1].convert(), tail[1].convert()).convert();
                    sum[0] = shuffle_and_add(sum[0], tail[2]);
                    sum[1] = shuffle_and_add(sum[1], tail[3]);
                    let mut sum = convert_u128_to_vec256(sum[0], sum[1]);
                    let mut current = [
                        convert_u128_to_vec256(current[0], current[1]),
                        convert_u128_to_vec256(current[2], current[3]),
                    ];
                    while data.len() > 64 {
                        let (blocks, rest) = read2_vec256(data);
                        current[0] = aesdec_vec256(current[0], blocks[0]);
                        current[1] = aesdec_vec256(current[1], blocks[1]);
                        sum = shuffle_and_add_vec256(sum, blocks[0]);
                        sum = shuffle_and_add_vec256(sum, blocks[1]);
                        data = rest;
                    }
                    let current = [convert_vec256_to_u128(current[0]), convert_vec256_to_u128(current[1])];
                    let sum = convert_vec256_to_u128(sum);
                    self.hash_in_2(current[0][0], current[0][1]);
                    self.hash_in_2(current[1][0], current[1][1]);
                    self.hash_in_2(sum[0], sum[1]);
                } else {
                    //len 33-64
                    let (head, _) = data.read_u128x2();
                    let tail = data.read_last_u128x2();
                    self.hash_in_2(head[0], head[1]);
                    self.hash_in_2(tail[0], tail[1]);
                }

However, it is much slower (~65000MB/s).
It is worth noting that when compiled without vaes, there is also a slow down so I suppose there may be a room for improvement. Could you check it out on your side?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants