Skip to content

Commit

Permalink
Finish the urlsafe coding engine
Browse files Browse the repository at this point in the history
  • Loading branch information
dequbed committed Sep 16, 2021
1 parent 1046070 commit 7039695
Showing 1 changed file with 135 additions and 94 deletions.
229 changes: 135 additions & 94 deletions src/engine/avx2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pub struct AVX2Encoder {
// For STANDARD these are '+' and '/' and the engine matches against '/' i.e. 0x2F
// For URL_SAFE these are '-' and '_' and the engine matches against '_' i.e. 0x5F
singleton_mask: __m256i,
hi_witnesses: __m256i,
lo_witnesses: __m256i,
}

impl AVX2Encoder {
Expand All @@ -55,15 +57,64 @@ impl AVX2Encoder {
)
};

// These decode offsets are accessed by the high nibble of the ASCII character being
// decoded so for example 'A' (0x41) is offset -65 since it encodes 0b000000.
// The one exception to that is the value '/' (0x2F) which has to be handled specifically.
let decode_offsets = unsafe {
_mm256_setr_epi8(
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
0, 16, 19, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0,
0, 16, 19, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
0, 0, 19, 4, -65, -65, -71, -71, 16, 0, 0, 0, 0, 0, 0, 0,
0, 0, 19, 4, -65, -65, -71, -71, 16, 0, 0, 0, 0, 0, 0, 0
)
};

let singleton_mask = unsafe { _mm256_set1_epi8(0x2F) };
// Witnesses for the high nibbles:
// 0x0 and 0x1 are never valid, no matter what the low nibble is.
// 0x2 is valid for the characters '+' (0x2B), '/' (0x2F) and '-' (0x2D), depending on the
// alphabet.
// 0x3 contains numerals but the only valid inputs are 0x30 to 0x39, so we need to make
// sure that everything from 0xA to 0xF is rejected.
// 0x4 and 0x5 contain [A-Z] and also the special character '_' (0x5F) from the urlsafe
// alphabet.
// 0x6 and 0x7 contain [a-z].
// 0x7 and 0x8 are never valid; 0x8 or higher especially means invalid ASCII.
//
// We use -0x1 as "always invalid" value so that the low witness has to only return
// something != 0 for the invalid test to trip.
let hi_witnesses = unsafe {
_mm256_setr_epi8(
// 0 1 2 3 4 5 6 7
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
// 8 9 10 11 12 13 14 15
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1,
// 0 1 2 3 4 5 6 7
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
// 8 9 10 11 12 13 14 15
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1
)
};
// Witnesses for the low nibbles.
// ASCII has the advantage that A-Z and a-z are 0x20 away from each other so you can use
// the same lo witnesses for both of those ranges.
// The easiest way to create these witness tables and what is done here is to use the hi
// witness to select a bit to probe and set the bit in the low witness for invalid nibbles
// in that range. E.g. the hi witness sets bit 1 for high nibble 0x2, bit 2 for 0x3, and
// bit 3 for 0x4 and 0x6. The lo witness then sets bit 2 for 0xA..0xF (since those are
// invalids in the numeric range), bit 1 for everything invalid in the special bytes range
// (i.e. everything but 0x2F, 0x2B etc.), bit 3 for 0x1 and bit 4 for 0xB..0xF.
let lo_witnesses = unsafe {
_mm256_setr_epi8(
// 0 1 2 3 4 5 6 7
0x75, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71,
// 8 9 80 11 12 13 14 15
0x71, 0x71, 0x73, 0x7A, 0x7B, 0x7B, 0x7B, 0x7A,
// 0 1 2 3 4 5 6 7
0x75, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71,
// 8 9 80 11 12 13 14 15
0x71, 0x71, 0x73, 0x7A, 0x7B, 0x7B, 0x7B, 0x7A,
)
};

Self {
config,
Expand All @@ -74,6 +125,8 @@ impl AVX2Encoder {
encode_offsets,
decode_offsets,
singleton_mask,
hi_witnesses,
lo_witnesses,
}
}
/// Create an AVX2Encoder for the urlsafe alphabet with the given config.
Expand All @@ -89,13 +142,39 @@ impl AVX2Encoder {

let decode_offsets = unsafe {
_mm256_setr_epi8(
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
0, -32, 17, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0,
0, -32, 17, 4, -65, -65, -71, -71, 0, 0, 0, 0, 0, 0, 0, 0
// 00 01 02 03 04 05 06 07 08 09 10 11 12 13 14 15
0, 0, 17, 4, -65, -65, -71, -71, 0, 0, 0,-32, 0, 0, 0, 0,
0, 0, 17, 4, -65, -65, -71, -71, 0, 0, 0,-32, 0, 0, 0, 0
)
};

let singleton_mask = unsafe { _mm256_set1_epi8(0x2B) };
let singleton_mask = unsafe { _mm256_set1_epi8(0x5F) };
let hi_witnesses = unsafe {
_mm256_setr_epi8(
// 0 1 2 3 4 5 6 7
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
// 8 9 10 11 12 13 14 15
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1,
// 0 1 2 3 4 5 6 7
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
// 8 9 10 11 12 13 14 15
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1
)
};
// Lo witnesses for url-safe are slightly different than for standard:
// Inputs 0x5F ('_') and 0x2D are valid, inputs 0x2F ('/') and 0x2B ('+') are not.
let lo_witnesses = unsafe {
_mm256_setr_epi8(
// 0 1 2 3 4 5 6 7
0x75, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71,
// 8 9 A B C D E F
0x71, 0x71, 0x73, 0x7B, 0x7B, 0x7A, 0x7B, 0x73,
// 0 1 2 3 4 5 6 7
0x75, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71, 0x71,
// 8 9 A B C D E F
0x71, 0x71, 0x73, 0x7B, 0x7B, 0x7A, 0x7B, 0x73,
)
};

Self {
config,
Expand All @@ -106,6 +185,8 @@ impl AVX2Encoder {
encode_offsets,
decode_offsets,
singleton_mask,
hi_witnesses,
lo_witnesses,
}
}
}
Expand Down Expand Up @@ -191,7 +272,7 @@ unsafe fn decode(
// The main optimization this algorithm makes and the source for it's assumptions is that it
// relies on the fact that the alphabet used has continous ordered ranges of inputs that thus
// share an offset, and that these ranges are distinguishable by their upper nibble.
// In other words because for [A-Z] substracting 65 gets you to the correct value and for [a-z]
// In other words for [A-Z] substracting 65 gets you to the correct value and for [a-z]
// substracting 71 does as well. While decoding we just have to figure out which range an input
// belongs to and directly know what offset to apply.
// However, we need to check for invalid inputs. The algorithm again optimizes that by using
Expand All @@ -214,10 +295,21 @@ unsafe fn decode(
return _mm256_and_si256(witness_hi, witness_lo);
}

// Next we check for one of the singleton bytes. Since in neither standard nor url-safe
// alphabet they both have the same offset to their encoded value and also can't be
// distinguished from other offset values by their high nibble alone ('_' has high nibble 5
// like a-z, '/' and '+' both have 2) we need to explicitly match against one of them.
let eq_singleton = _mm256_cmpeq_epi8(block, mask_singleton);
let roll = _mm256_shuffle_epi8(offsets, _mm256_add_epi8(eq_singleton, hi_nibbles));
let shuffeled = _mm256_add_epi8(block, roll);

// In the last decoding step we do two things: Add 0x6 to all hi nibbles where we found our
// singleton. This makes input 0x2F check for offset in offsets[8] and 0x5F in offsets[11].
// Then, get the actual offset amount from `offsets` and add it to our input.
let offsetidxs = _mm256_add_epi8(hi_nibbles, _mm256_and_si256(eq_singleton, _mm256_set1_epi8(0x6)));
let offsetvals = _mm256_shuffle_epi8(offsets, offsetidxs);
let shuffeled = _mm256_add_epi8(block, offsetvals);

// This merges the 16, 6 bit wide but byte aligned, values in each half into a packed 12 byte
// block of data each.
let merge_ab_and_bc = _mm256_maddubs_epi16(shuffeled,
_mm256_set1_epi32(0x01400140));
let madd = _mm256_madd_epi16(merge_ab_and_bc, _mm256_set1_epi32(0x00011000));
Expand All @@ -243,27 +335,29 @@ unsafe fn decode_masked(
invalid: &mut bool,
lo_witness_lut: __m256i,
hi_witness_lut: __m256i,
lut_roll: __m256i,
offsets: __m256i,
mask_singleton: __m256i,
mask_input: __m256i,
block: __m256i
) -> __m256i {
let hi_nibbles = _mm256_srli_epi32(block, 4);
let lo_nibbles = _mm256_and_si256(block, mask_singleton);
let eq_singleton = _mm256_cmpeq_epi8(block, mask_singleton);
let hi_nibbles = _mm256_and_si256(hi_nibbles, mask_singleton);

let lo = _mm256_shuffle_epi8(lo_witness_lut, lo_nibbles);
let hi = _mm256_shuffle_epi8(hi_witness_lut, hi_nibbles);
// Special case: If we have a masked input we need to forward this mask here to not
// trip the test below
let hi = _mm256_and_si256(hi, mask_input);
if _mm256_testz_si256(lo, hi) == 0 {
let mask_nib = _mm256_set1_epi8(0b00001111);
let block_shifted = _mm256_srli_epi32(block, 4);
let hi_nibbles = _mm256_and_si256(block_shifted, mask_nib);
let lo_nibbles = _mm256_and_si256(block, mask_nib);

let witness_lo = _mm256_shuffle_epi8(lo_witness_lut, lo_nibbles);
let witness_hi = _mm256_shuffle_epi8(hi_witness_lut, hi_nibbles);

let witness_hi = _mm256_and_si256(witness_hi, mask_input);
if _mm256_testz_si256(witness_lo, witness_hi) == 0 {
*invalid = true;
return _mm256_and_si256(witness_hi, witness_lo);
}

let roll = _mm256_shuffle_epi8(lut_roll, _mm256_add_epi8(eq_singleton, hi_nibbles));
let shuffeled = _mm256_add_epi8(block, roll);
let eq_singleton = _mm256_cmpeq_epi8(block, mask_singleton);
let offsetidxs = _mm256_add_epi8(hi_nibbles, _mm256_and_si256(eq_singleton, _mm256_set1_epi8(0x6)));
let offsetvals = _mm256_shuffle_epi8(offsets, offsetidxs);
let shuffeled = _mm256_add_epi8(block, offsetvals);

let merge_ab_and_bc = _mm256_maddubs_epi16(shuffeled,
_mm256_set1_epi32(0x01400140));
Expand Down Expand Up @@ -513,70 +607,6 @@ impl super::Engine for AVX2Encoder {
let mut block: __m256i;
let mut invalid: bool = false;

// Witnesses for the high nibbles:
// 0x0 and 0x1 are never valid, no matter what the low nibble is.
// 0x2 is valid for the characters '+' (0x2B), '/' (0x2F) and '-' (0x2D), depending on the
// alphabet.
// 0x3 contains numerals but the only valid inputs are 0x30 to 0x39, so we need to make
// sure that everything from 0xA to 0xF is rejected.
// 0x4 and 0x5 contain [A-Z] and also the special character '_' (0x5F) from the urlsafe
// alphabet.
// 0x6 and 0x7 contain [a-z].
// 0x7 and 0x8 are never valid; 0x8 or higher especially means invalid ASCII.
//
// We use -0x1 as "always invalid" value so that the low witness has to only return
// something != 0 for the invalid test to trip.
let hi_witness_lut = unsafe { _mm256_setr_epi8(
// 0 1 2 3 4 5 6 7
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
// 8 9 10 11 12 13 14 15
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1,
// 0 1 2 3 4 5 6 7
-0x1, -0x1, 0x01, 0x02, 0x04, 0x08, 0x04, 0x08,
// 8 9 10 11 12 13 14 15
-0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1, -0x1
)};
// Witnesses for the low nibbles. The requirements for the given hi witnesses are then:
// // Be invalid if hi is.
// - lo[..] & -0x1 == 1
// // Numerals
// - lo[0..9] & 0x2 == 0
// - lo[10..15] & 0x2 == 1
// // Capitals
// - lo[0] & 0x4 == 1
// - lo[1..] & 0x4 == 0
// - lo[0..10] & 0x8 == 0
// - lo[11..15] & 0x8 == 1
// // Miniscules
// - lo[0] & 0x4 == 1
// - lo[1..] & 0x4 == 1
// - lo[..10] & 0x8 == 1
// - lo[11..15] & 0x8 == 1
// // Special, depending on the alphabet
// // standard
// - lo[15] & 0x1 == 1
// - lo[11] & 0x1 == 1
// // urlsafe
// - lo[13] & 0x1 == 1
// - lo[15] & 0x8
// ASCII has the advantage that A-Z and a-z are 0x20 away from each other so you can use
// the same lo witnesses.
// The easiest way to create these witness tables and what is done here is to use the hi
// witness to select a bit to probe and set the bit in the low witness for valid nibbles in
// that range. E.g. the hi witness sets bit 1 for high nibble 0x2 and bit 3 for 0x4 and
// 0x6, and the lo witness only sets bit 1 for valid inputs with high nibble 0x2 (like
// 0x2F, 0x2B etc.) and bit 3 for valid letters [A-Za-z].
let lo_witness_lut = unsafe { _mm256_setr_epi8(
// 0 1 2 3 4 5 6 7
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
// 8 9 10 11 12 13 14 15
0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B, 0x1B, 0x1A,
// 0 1 2 3 4 5 6 7
0x15, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
// 8 9 10 11 12 13 14 15
0x11, 0x11, 0x13, 0x1A, 0x1B, 0x1B, 0x1B, 0x1A
)};

// This will only evaluate to true if we have an input of 33 bytes or more;
// skip_final_bytes is at least input.len() otherwise.
if last_fast_index > 0 {
Expand All @@ -599,8 +629,8 @@ impl super::Engine for AVX2Encoder {
unsafe {
block = _mm256_loadu_si256(input_chunk.as_ptr().cast());
block = decode(&mut invalid,
lo_witness_lut,
hi_witness_lut,
self.lo_witnesses,
self.hi_witnesses,
self.decode_offsets,
self.singleton_mask,
block);
Expand Down Expand Up @@ -640,7 +670,12 @@ impl super::Engine for AVX2Encoder {
let mask_output = _mm256_loadu_si256(MASKLOAD[2..10].as_ptr().cast());

block = _mm256_loadu_si256(input_chunk.as_ptr().cast());
block = decode(&mut invalid, lo_witness_lut, hi_witness_lut, self.decode_offsets, self.singleton_mask, block);
block = decode(&mut invalid,
self.lo_witnesses,
self.hi_witnesses,
self.decode_offsets,
self.singleton_mask,
block);

_mm256_maskstore_epi32(output_chunk.as_mut_ptr().cast(), mask_output, block);
}
Expand Down Expand Up @@ -691,9 +726,15 @@ impl super::Engine for AVX2Encoder {
let mask_output = _mm256_loadu_si256(output_mask.as_ptr().cast());

block = _mm256_maskload_epi32(input_chunk.as_ptr().cast(), mask_input);
let outblock
= decode_masked(&mut invalid,
lo_witness_lut, hi_witness_lut, self.decode_offsets, self.singleton_mask, mask_input, block);
let outblock = decode_masked(
&mut invalid,
self.lo_witnesses,
self.hi_witnesses,
self.decode_offsets,
self.singleton_mask,
mask_input,
block
);

if invalid {
return Err(find_invalid_input(input_index, input_chunk, &self.decode_table));
Expand Down

0 comments on commit 7039695

Please sign in to comment.