Skip to content

Commit

Permalink
arm aes: add neon implementation using the crypto extension
Browse files Browse the repository at this point in the history
  • Loading branch information
mr-c committed Oct 16, 2023
1 parent 7198d6d commit fb3554f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 45 deletions.
3 changes: 3 additions & 0 deletions simde/simde-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@
#if defined(__ARM_FEATURE_FMA) && __ARM_FEATURE_FMA
# define SIMDE_ARCH_ARM_FMA
#endif
#if defined(__ARM_FEATURE_CRYPTO)
# define SIMDE_ARCH_ARM_CRYPTO
#endif

/* Blackfin
<https://en.wikipedia.org/wiki/Blackfin> */
Expand Down
122 changes: 77 additions & 45 deletions simde/x86/aes.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,16 @@ static uint8_t simde_x_aes_inv_s_box[256] = {
* Key length equals 128 bits/16 bytes).
*/
SIMDE_FUNCTION_ATTRIBUTES
void simde_x_aes_add_round_key(uint8_t *state, uint8_t *w, uint8_t r) {
void simde_x_aes_add_round_key(uint8_t *state, simde__m128i_private w, uint8_t r) {

int Nb = simde_x_aes_Nb;
uint8_t c;

for (c = 0; c < Nb; c++) {
state[Nb*0+c] = state[Nb*0+c]^w[4*Nb*r+4*c+0];
state[Nb*1+c] = state[Nb*1+c]^w[4*Nb*r+4*c+1];
state[Nb*2+c] = state[Nb*2+c]^w[4*Nb*r+4*c+2];
state[Nb*3+c] = state[Nb*3+c]^w[4*Nb*r+4*c+3];
state[Nb*0+c] = state[Nb*0+c]^w.u8[4*Nb*r+4*c+0];
state[Nb*1+c] = state[Nb*1+c]^w.u8[4*Nb*r+4*c+1];
state[Nb*2+c] = state[Nb*2+c]^w.u8[4*Nb*r+4*c+2];
state[Nb*3+c] = state[Nb*3+c]^w.u8[4*Nb*r+4*c+3];
}
}

Expand Down Expand Up @@ -453,15 +453,15 @@ void simde_x_aes_inv_sub_bytes(uint8_t *state) {
* Performs the AES cipher operation
*/
SIMDE_FUNCTION_ATTRIBUTES
void simde_x_aes_enc(uint8_t *in, uint8_t *out, uint8_t *w, int is_last) {
void simde_x_aes_enc(simde__m128i_private in, simde__m128i_private *out, simde__m128i_private w, int is_last) {

int Nb = simde_x_aes_Nb;
uint8_t state[4*simde_x_aes_Nb];
uint8_t r = 0, i, j;

for (i = 0; i < 4; i++) {
for (j = 0; j < Nb; j++) {
state[Nb*i+j] = in[i+4*j];
state[Nb*i+j] = in.u8[i+4*j];
}
}

Expand All @@ -475,7 +475,7 @@ void simde_x_aes_enc(uint8_t *in, uint8_t *out, uint8_t *w, int is_last) {

for (i = 0; i < 4; i++) {
for (j = 0; j < Nb; j++) {
out[i+4*j] = state[Nb*i+j];
out->u8[i+4*j] = state[Nb*i+j];
}
}
}
Expand All @@ -484,15 +484,15 @@ void simde_x_aes_enc(uint8_t *in, uint8_t *out, uint8_t *w, int is_last) {
* Performs the AES inverse cipher operation
*/
SIMDE_FUNCTION_ATTRIBUTES
void simde_x_aes_dec(uint8_t *in, uint8_t *out, uint8_t *w, int is_last) {
void simde_x_aes_dec(simde__m128i_private in, simde__m128i_private *out, simde__m128i_private w, int is_last) {

int Nb = simde_x_aes_Nb;
uint8_t state[4*simde_x_aes_Nb];
uint8_t r = 0, i, j;

for (i = 0; i < 4; i++) {
for (j = 0; j < Nb; j++) {
state[Nb*i+j] = in[i+4*j];
state[Nb*i+j] = in.u8[i+4*j];
}
}

Expand All @@ -506,7 +506,7 @@ void simde_x_aes_dec(uint8_t *in, uint8_t *out, uint8_t *w, int is_last) {

for (i = 0; i < 4; i++) {
for (j = 0; j < Nb; j++) {
out[i+4*j] = state[Nb*i+j];
out->u8[i+4*j] = state[Nb*i+j];
}
}
}
Expand All @@ -517,9 +517,17 @@ simde__m128i simde_mm_aesenc_si128(simde__m128i a, simde__m128i round_key) {
#if defined(SIMDE_X86_AES_NATIVE)
return _mm_aesenc_si128(a, round_key);
#else
simde__m128i result;
simde_x_aes_enc(HEDLEY_REINTERPRET_CAST(uint8_t *, &a), HEDLEY_REINTERPRET_CAST(uint8_t *, &result), HEDLEY_REINTERPRET_CAST(uint8_t *, &round_key), 0);
return result;
simde__m128i_private result_;
simde__m128i_private a_ = simde__m128i_to_private(a);
simde__m128i_private round_key_ = simde__m128i_to_private(round_key);
#if defined(SIMDE_ARM_NEON_A32V7_NATIVE) && defined(SIMDE_ARCH_ARM_CRYPTO)
result_.neon_u8 = veorq_u8(
vaesmcq_u8(vaeseq_u8(a_.neon_u8, vdupq_n_u8(0))),
round_key_.neon_u8);
#else
simde_x_aes_enc(a_, &result_, round_key_, 0);
#endif
return simde__m128i_from_private(result_);
#endif
}
#if defined(SIMDE_X86_AES_ENABLE_NATIVE_ALIASES)
Expand All @@ -531,9 +539,17 @@ simde__m128i simde_mm_aesdec_si128(simde__m128i a, simde__m128i round_key) {
#if defined(SIMDE_X86_AES_NATIVE)
return _mm_aesdec_si128(a, round_key);
#else
simde__m128i result;
simde_x_aes_dec(HEDLEY_REINTERPRET_CAST(uint8_t *, &a), HEDLEY_REINTERPRET_CAST(uint8_t *, &result), HEDLEY_REINTERPRET_CAST(uint8_t *, &round_key), 0);
return result;
simde__m128i_private result_;
simde__m128i_private a_ = simde__m128i_to_private(a);
simde__m128i_private round_key_ = simde__m128i_to_private(round_key);
#if defined(SIMDE_ARM_NEON_A32V7_NATIVE) && defined(SIMDE_ARCH_ARM_CRYPTO)
result_.neon_u8 = veorq_u8(
vaesimcq_u8(vaesdq_u8(a_.neon_u8, vdupq_n_u8(0))),
round_key_.neon_u8);
#else
simde_x_aes_dec(a_, &result_, round_key_, 0);
#endif
return simde__m128i_from_private(result_);
#endif
}
#if defined(SIMDE_X86_AES_ENABLE_NATIVE_ALIASES)
Expand All @@ -545,9 +561,16 @@ simde__m128i simde_mm_aesenclast_si128(simde__m128i a, simde__m128i round_key) {
#if defined(SIMDE_X86_AES_NATIVE)
return _mm_aesenclast_si128(a, round_key);
#else
simde__m128i result;
simde_x_aes_enc(HEDLEY_REINTERPRET_CAST(uint8_t *, &a), HEDLEY_REINTERPRET_CAST(uint8_t *, &result), HEDLEY_REINTERPRET_CAST(uint8_t *, &round_key), 1);
return result;
simde__m128i_private result_;
simde__m128i_private a_ = simde__m128i_to_private(a);
simde__m128i_private round_key_ = simde__m128i_to_private(round_key);
#if defined(SIMDE_ARM_NEON_A32V7_NATIVE) && defined(SIMDE_ARCH_ARM_CRYPTO)
result_.neon_u8 = vaeseq_u8(a_.neon_u8, vdupq_n_u8(0));
result_.neon_i32 = veorq_s32(result_.neon_i32, round_key_.neon_i32); // _mm_xor_si128
#else
simde_x_aes_enc(a_, &result_, round_key_, 1);
#endif
return simde__m128i_from_private(result_);
#endif
}
#if defined(SIMDE_X86_AES_ENABLE_NATIVE_ALIASES)
Expand All @@ -559,9 +582,17 @@ simde__m128i simde_mm_aesdeclast_si128(simde__m128i a, simde__m128i round_key) {
#if defined(SIMDE_X86_AES_NATIVE)
return _mm_aesdeclast_si128(a, round_key);
#else
simde__m128i result;
simde_x_aes_dec(HEDLEY_REINTERPRET_CAST(uint8_t *, &a), HEDLEY_REINTERPRET_CAST(uint8_t *, &result), HEDLEY_REINTERPRET_CAST(uint8_t *, &round_key), 1);
return result;
simde__m128i_private result_;
simde__m128i_private a_ = simde__m128i_to_private(a);
simde__m128i_private round_key_ = simde__m128i_to_private(round_key);
#if defined(SIMDE_ARM_NEON_A32V7_NATIVE) && defined(SIMDE_ARCH_ARM_CRYPTO)
result_.neon_u8 = veorq_u8(
vaesdq_u8(a_.neon_u8, vdupq_n_u8(0)),
round_key_.neon_u8);
#else
simde_x_aes_dec(a_, &result_, round_key_, 1);
#endif
return simde__m128i_from_private(result_);
#endif
}
#if defined(SIMDE_X86_AES_ENABLE_NATIVE_ALIASES)
Expand All @@ -573,29 +604,30 @@ simde__m128i simde_mm_aesimc_si128(simde__m128i a) {
#if defined(SIMDE_X86_AES_NATIVE)
return _mm_aesimc_si128(a);
#else
simde__m128i result;

uint8_t *in = HEDLEY_REINTERPRET_CAST(uint8_t *, &a);
uint8_t *out = HEDLEY_REINTERPRET_CAST(uint8_t *, &result);

int Nb = simde_x_aes_Nb;
// uint8_t k[] = {0x0e, 0x09, 0x0d, 0x0b}; // a(x) = {0e} + {09}x + {0d}x2 + {0b}x3
uint8_t i, j, col[4], res[4];

for (j = 0; j < Nb; j++) {
for (i = 0; i < 4; i++) {
col[i] = in[Nb*j+i];
simde__m128i_private result_ = simde__m128i_to_private(simde_mm_setzero_si128());
simde__m128i_private a_ = simde__m128i_to_private(a);

#if defined(SIMDE_ARM_NEON_A32V7_NATIVE) && defined(SIMDE_ARCH_ARM_CRYPTO)
result_.neon_u8 = vaesimcq_u8(a_.neon_u8);
#else
int Nb = simde_x_aes_Nb;
// uint8_t k[] = {0x0e, 0x09, 0x0d, 0x0b}; // a(x) = {0e} + {09}x + {0d}x2 + {0b}x3
uint8_t i, j, col[4], res[4];

for (j = 0; j < Nb; j++) {
for (i = 0; i < 4; i++) {
col[i] = a_.u8[Nb*j+i];
}

//coef_mult(k, col, res);
simde_x_aes_coef_mult_lookup(4, col, res);

for (i = 0; i < 4; i++) {
result_.u8[Nb*j+i] = res[i];
}
}

//coef_mult(k, col, res);
simde_x_aes_coef_mult_lookup(4, col, res);

for (i = 0; i < 4; i++) {
out[Nb*j+i] = res[i];
}
}

return result;
#endif
return simde__m128i_from_private(result_);
#endif
}
#if defined(SIMDE_X86_AES_ENABLE_NATIVE_ALIASES)
Expand Down

0 comments on commit fb3554f

Please sign in to comment.