Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Second Attempt of VAES (Alternative Proposal) #187

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea/
.vscode/
Cargo.lock
target
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ atomic-polyfill = [ "dep:atomic-polyfill", "once_cell/atomic-polyfill"]
# Nightly-only support for AES intrinsics on 32-bit ARM
nightly-arm-aes = []

# Nightly-only support for VAES intrinsics with 256 SIMD registers
vaes = []

[[bench]]
name = "ahash"
path = "tests/bench.rs"
Expand Down
5 changes: 4 additions & 1 deletion smhasher/ahash-cbindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,7 @@ lto = 'fat'
debug-assertions = false

[dependencies]
ahash = { path = "../../", default-features = false }
ahash = { path = "../../", default-features = false }

[features]
vaes = ["ahash/vaes"]
10 changes: 9 additions & 1 deletion smhasher/ahash-cbindings/install.sh
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
RUSTFLAGS="-C opt-level=3 -C target-cpu=native -C codegen-units=1" cargo build --release && sudo cp target/release/libahash_c.a /usr/local/lib/

# check if args contains vaes
if [[ $* == *vaes* ]]; then
export CARGO_OPTS="--features=vaes"
else
export CARGO_OPTS=""
fi

RUSTFLAGS="-C opt-level=3 -C target-cpu=native -C codegen-units=1" cargo build ${CARGO_OPTS} --release && sudo cp target/release/libahash_c.a /usr/local/lib/
1 change: 0 additions & 1 deletion smhasher/ahash-cbindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use ahash::*;
use core::slice;
use std::hash::{BuildHasher};

#[no_mangle]
pub extern "C" fn ahash64(buf: *const (), len: usize, seed: u64) -> u64 {
Expand Down
72 changes: 72 additions & 0 deletions src/aes_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,74 @@ pub struct AHasher {
key: u128,
}

#[cfg(any(
all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "aes", not(miri)),
all(target_arch = "aarch64", target_feature = "aes", not(miri)),
all(feature = "nightly-arm-aes", target_arch = "arm", target_feature = "aes", not(miri)),
))]
fn hash_batch_128b(data: &mut &[u8], hasher: &mut AHasher) {
let tail = data.read_last_u128x8();
let current = [
aesenc(hasher.key, tail[0]),
aesdec(hasher.key, tail[1]),
aesenc(hasher.key, tail[2]),
aesdec(hasher.key, tail[3]),
aesenc(hasher.key, tail[4]),
aesdec(hasher.key, tail[5]),
aesenc(hasher.key, tail[6]),
aesdec(hasher.key, tail[7]),
];
let mut current = [
convert_u128_to_vec256(current[0], current[1]),
convert_u128_to_vec256(current[2], current[3]),
convert_u128_to_vec256(current[4], current[5]),
convert_u128_to_vec256(current[6], current[7]),
];
let mut sum: [Vector256; 2] = [
convert_u128_to_vec256(hasher.key, !hasher.key),
convert_u128_to_vec256(hasher.key, !hasher.key)
];
let tail = [
convert_u128_to_vec256(tail[0], tail[1]),
convert_u128_to_vec256(tail[2], tail[3]),
convert_u128_to_vec256(tail[4], tail[5]),
convert_u128_to_vec256(tail[6], tail[7]),
];
sum[0] = add_by_64s_vec256(sum[0], tail[0]);
sum[0] = add_by_64s_vec256(sum[0], tail[1]);
sum[1] = shuffle_and_add_vec256(sum[1], tail[2]);
sum[1] = shuffle_and_add_vec256(sum[1], tail[3]);
while data.len() > 128 {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could go even further to preserve the algorithm. In existing code it is doing 4 blocks of 128 bits in parallel, Here it is doing 4 blocks each of which is 2 inner blocks of 128 which are parallelized by the vaes instruction. But could instead make it 2 blocks each of which is 2 inner blocks parallelized by vaes which would make the output identical with and without the vaes instruction. This would make it much easier to audit. It would also eliminate the need for the additional if check which will improve the speed of variable length strings.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested before with aesenc, the AMD Zen3 CPU at least performs worse with vaes when processing same amount data with classical aes. The performance benefit does not appear unless with larger batch sizes. However, it was a long time ago, maybe I should rerun it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do see some speedup with

 if data.len() > 64 {
                    let tail = data.read_last_u128x4();
                    let mut current: [u128; 4] = [self.key; 4];
                    current[0] = aesenc(current[0], tail[0]);
                    current[1] = aesdec(current[1], tail[1]);
                    current[2] = aesenc(current[2], tail[2]);
                    current[3] = aesdec(current[3], tail[3]);
                    let mut sum: [u128; 2] = [self.key, !self.key];
                    sum[0] = add_by_64s(sum[0].convert(), tail[0].convert()).convert();
                    sum[1] = add_by_64s(sum[1].convert(), tail[1].convert()).convert();
                    sum[0] = shuffle_and_add(sum[0], tail[2]);
                    sum[1] = shuffle_and_add(sum[1], tail[3]);
                    let mut sum = convert_u128_to_vec256(sum[0], sum[1]);
                    let mut current = [
                        convert_u128_to_vec256(current[0], current[1]),
                        convert_u128_to_vec256(current[2], current[3]),
                    ];
                    while data.len() > 64 {
                        let (blocks, rest) = read2_vec256(data);
                        current[0] = aesdec_vec256(current[0], blocks[0]);
                        current[1] = aesdec_vec256(current[1], blocks[1]);
                        sum = shuffle_and_add_vec256(sum, blocks[0]);
                        sum = shuffle_and_add_vec256(sum, blocks[1]);
                        data = rest;
                    }
                    let current = [convert_vec256_to_u128(current[0]), convert_vec256_to_u128(current[1])];
                    let sum = convert_vec256_to_u128(sum);
                    self.hash_in_2(current[0][0], current[0][1]);
                    self.hash_in_2(current[1][0], current[1][1]);
                    self.hash_in_2(sum[0], sum[1]);
                } else {
                    //len 33-64
                    let (head, _) = data.read_u128x2();
                    let tail = data.read_last_u128x2();
                    self.hash_in_2(head[0], head[1]);
                    self.hash_in_2(tail[0], tail[1]);
                }

However, it is much slower (~65000MB/s).
It is worth noting that when compiled without vaes, there is also a slow down so I suppose there may be a room for improvement. Could you check it out on your side?

let (blocks, rest) = read4_vec256(data);
current[0] = aesdec_vec256(current[0], blocks[0]);
current[1] = aesdec_vec256(current[1], blocks[1]);
current[2] = aesdec_vec256(current[2], blocks[2]);
current[3] = aesdec_vec256(current[3], blocks[3]);
sum[0] = shuffle_and_add_vec256(sum[0], blocks[0]);
sum[0] = shuffle_and_add_vec256(sum[0], blocks[1]);
sum[1] = shuffle_and_add_vec256(sum[1], blocks[2]);
sum[1] = shuffle_and_add_vec256(sum[1], blocks[3]);
*data = rest;
}
let current = [
convert_vec256_to_u128(current[0]),
convert_vec256_to_u128(current[1]),
convert_vec256_to_u128(current[2]),
convert_vec256_to_u128(current[3]),
];
let sum = [convert_vec256_to_u128(sum[0]), convert_vec256_to_u128(sum[1])];
hasher.hash_in_2(
aesdec(current[0][0], current[0][1]),
aesdec(current[1][0], current[1][1]),
);
hasher.hash_in(add_by_64s(sum[0][0].convert(), sum[0][1].convert()).convert());
hasher.hash_in_2(
aesdec(current[2][0], current[2][1]),
aesdec(current[3][0], current[3][1]),
);
hasher.hash_in(add_by_64s(sum[1][0].convert(), sum[1][1].convert()).convert());
}

impl AHasher {
/// Creates a new hasher keyed to the provided keys.
///
Expand All @@ -47,6 +115,7 @@ impl AHasher {
///
/// println!("Hash is {:x}!", hasher.finish());
/// ```
#[allow(unused)]
#[inline]
pub(crate) fn new_with_keys(key1: u128, key2: u128) -> Self {
let pi: [u128; 2] = PI.convert();
Expand Down Expand Up @@ -160,6 +229,9 @@ impl Hasher for AHasher {
self.hash_in(value.convert());
} else {
if data.len() > 32 {
if data.len() > 128 {
return hash_batch_128b(&mut data, self);
}
if data.len() > 64 {
let tail = data.read_last_u128x4();
let mut current: [u128; 4] = [self.key; 4];
Expand Down
14 changes: 14 additions & 0 deletions src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ macro_rules! convert {
};
}

convert!([u128; 8], [u8; 128]);
convert!([u128; 4], [u64; 8]);
convert!([u128; 4], [u32; 16]);
convert!([u128; 4], [u16; 32]);
Expand Down Expand Up @@ -79,12 +80,14 @@ pub(crate) trait ReadFromSlice {
fn read_u128(&self) -> (u128, &[u8]);
fn read_u128x2(&self) -> ([u128; 2], &[u8]);
fn read_u128x4(&self) -> ([u128; 4], &[u8]);
fn read_u128x8(&self) -> ([u128; 8], &[u8]);
fn read_last_u16(&self) -> u16;
fn read_last_u32(&self) -> u32;
fn read_last_u64(&self) -> u64;
fn read_last_u128(&self) -> u128;
fn read_last_u128x2(&self) -> [u128; 2];
fn read_last_u128x4(&self) -> [u128; 4];
fn read_last_u128x8(&self) -> [u128; 8];
}

impl ReadFromSlice for [u8] {
Expand Down Expand Up @@ -124,6 +127,12 @@ impl ReadFromSlice for [u8] {
(as_array!(value, 64).convert(), rest)
}

#[inline(always)]
fn read_u128x8(&self) -> ([u128; 8], &[u8]) {
let (value, rest) = self.split_at(128);
(as_array!(value, 128).convert(), rest)
}

#[inline(always)]
fn read_last_u16(&self) -> u16 {
let (_, value) = self.split_at(self.len() - 2);
Expand Down Expand Up @@ -159,4 +168,9 @@ impl ReadFromSlice for [u8] {
let (_, value) = self.split_at(self.len() - 64);
as_array!(value, 64).convert()
}

fn read_last_u128x8(&self) -> [u128; 8] {
let (_, value) = self.split_at(self.len() - 128);
as_array!(value, 128).convert()
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Note the import of [HashMapExt]. This is needed for the constructor.
#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
#![cfg_attr(feature = "specialize", feature(min_specialization))]
#![cfg_attr(feature = "nightly-arm-aes", feature(stdarch_arm_neon_intrinsics))]
#![cfg_attr(feature = "vaes", feature(stdsimd))]

#[macro_use]
mod convert;
Expand Down
180 changes: 180 additions & 0 deletions src/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,186 @@ pub(crate) fn add_in_length(enc: &mut u128, len: u64) {
}
}

#[cfg(any(
all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "aes", not(miri)),
all(target_arch = "aarch64", target_feature = "aes", not(miri)),
all(feature = "nightly-arm-aes", target_arch = "arm", target_feature = "aes", not(miri)),
))]
mod vaes {
use super::*;
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes",
feature = "vaes",
not(miri)
))] {
pub type Vector256 = core::arch::x86_64::__m256i;
}
else {
pub type Vector256 = [u128;2];
}
}

#[inline(always)]
pub(crate) fn aesdec_vec256(value: Vector256, xor: Vector256) -> Vector256 {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes",
feature = "vaes",
not(miri)
))] {
use core::arch::x86_64::*;
unsafe {
_mm256_aesdec_epi128(value, xor)
}
}
else {
[
aesdec(value[0], xor[0]),
aesdec(value[1], xor[1]),
]
}
}
}

#[inline(always)]
pub(crate) fn add_by_64s_vec256(a: Vector256, b: Vector256) -> Vector256 {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes",
feature = "vaes",
not(miri)
))] {
use core::arch::x86_64::*;
unsafe { _mm256_add_epi64(a, b) }
}
else {
[
transmute!(add_by_64s(transmute!(a[0]), transmute!(b[0]))),
transmute!(add_by_64s(transmute!(a[1]), transmute!(b[1]))),
]
}
}
}

#[inline(always)]
pub(crate) fn shuffle_vec256(value: Vector256) -> Vector256 {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes",
feature = "vaes",
not(miri)
))] {
unsafe {
use core::arch::x86_64::*;
let mask = convert_u128_to_vec256(SHUFFLE_MASK, SHUFFLE_MASK);
_mm256_shuffle_epi8(value, mask)
}
}
else {

[
shuffle(value[0]),
shuffle(value[1]),
]
}
}
}

pub(crate) fn shuffle_and_add_vec256(base: Vector256, to_add: Vector256) -> Vector256 {
add_by_64s_vec256(shuffle_vec256(base), to_add)
}

// We specialize this routine because sometimes the compiler is not able to
// optimize it properly.
pub(crate) fn read4_vec256(data: &[u8]) -> ([Vector256; 4], &[u8]) {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes",
feature = "vaes",
not(miri)
))] {
use core::arch::x86_64::*;
let (arr, rem) = data.split_at(128);
let arr = unsafe {
[ _mm256_loadu_si256(arr.as_ptr().cast::<__m256i>()),
_mm256_loadu_si256(arr.as_ptr().add(32).cast::<__m256i>()),
_mm256_loadu_si256(arr.as_ptr().add(64).cast::<__m256i>()),
_mm256_loadu_si256(arr.as_ptr().add(96).cast::<__m256i>()),
]
};
(arr, rem)
}
else {
let (arr, slice) = data.read_u128x8();
(transmute!(arr), slice)
}
}
}

// We specialize this routine because sometimes the compiler is not able to
// optimize it properly.
pub(crate) fn convert_u128_to_vec256(low: u128, high: u128) -> Vector256 {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes",
feature = "vaes",
not(miri)
))] {
use core::arch::x86_64::*;
unsafe {
_mm256_set_epi64x(
(high >> 64) as i64,
high as i64,
(low >> 64) as i64,
low as i64,
)
}
}
else {
transmute!([low, high])
}
}
}

// We specialize this routine because sometimes the compiler is not able to
// optimize it properly.
pub(crate) fn convert_vec256_to_u128(x: Vector256) -> [u128; 2] {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes",
feature = "vaes",
not(miri)
))] {
use core::arch::x86_64::*;
unsafe {
[
transmute!(_mm256_extracti128_si256(x, 0)),
transmute!(_mm256_extracti128_si256(x, 1)),
]
}
}
else {
x
}
}
}
}

#[cfg(any(
all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "aes", not(miri)),
all(target_arch = "aarch64", target_feature = "aes", not(miri)),
all(feature = "nightly-arm-aes", target_arch = "arm", target_feature = "aes", not(miri)),
))]
pub(crate) use vaes::*;

#[cfg(test)]
mod test {
use super::*;
Expand Down