diff --git a/Cargo.lock b/Cargo.lock index 24fa1e55..84cdfad3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -352,7 +352,7 @@ dependencies = [ [[package]] name = "mgm" -version = "0.4.3" +version = "0.4.4" dependencies = [ "aead", "cipher", diff --git a/mgm/CHANGELOG.md b/mgm/CHANGELOG.md index 27241509..a928bd59 100644 --- a/mgm/CHANGELOG.md +++ b/mgm/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## 0.4.4 (2021-08-24) +### Changed +- Decrypt ciphertext only after tag verification ([#356]) + +[#356]: https://github.com/RustCrypto/AEADs/pull/356 + ## 0.4.3 (2021-07-20) ### Changed - Pin `subtle` dependency to v2.4 ([#349]) diff --git a/mgm/Cargo.toml b/mgm/Cargo.toml index 3671c792..42658f7b 100644 --- a/mgm/Cargo.toml +++ b/mgm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mgm" -version = "0.4.3" +version = "0.4.4" description = "Generic implementation of the Multilinear Galois Mode (MGM) cipher" authors = ["RustCrypto Developers"] edition = "2018" diff --git a/mgm/src/lib.rs b/mgm/src/lib.rs index 4e1ff00a..e33cee98 100644 --- a/mgm/src/lib.rs +++ b/mgm/src/lib.rs @@ -39,6 +39,7 @@ use aead::{ AeadCore, AeadInPlace, Error, Key, NewAead, }; use cipher::{BlockCipher, BlockEncrypt, NewBlockCipher}; +use subtle::ConstantTimeEq; pub use aead; @@ -55,6 +56,7 @@ pub type Nonce = GenericArray; pub type Tag = GenericArray; type Block = GenericArray::BlockSize>; +type Element = <::BlockSize as Sealed>::Element; /// Trait implemented for block cipher sizes usable with MGM. pub trait MgmBlockSize: sealed::Sealed {} @@ -76,16 +78,20 @@ where C: BlockEncrypt, C::BlockSize: MgmBlockSize, { - fn get_h(&self, counter: &Counter) -> Block { + fn apply_ks_block(&self, counter: &mut Counter, buf: &mut [u8]) { let mut block = C::BlockSize::ctr2block(counter); self.cipher.encrypt_block(&mut block); - block + for i in 0..core::cmp::min(block.len(), buf.len()) { + buf[i] ^= block[i]; + } + C::BlockSize::incr_r(counter); } - fn encrypt_counter(&self, counter: &Counter) -> Block { - let mut block = C::BlockSize::ctr2block(counter); - self.cipher.encrypt_block(&mut block); - block + fn update_tag(&self, tag: &mut Element, tag_ctr: &mut Counter, block: &Block) { + let mut h = C::BlockSize::ctr2block(tag_ctr); + self.cipher.encrypt_block(&mut h); + tag.mul_sum(&h, block); + C::BlockSize::incr_l(tag_ctr); } } @@ -136,57 +142,49 @@ where if nonce[0] >> 7 != 0 { return Err(Error); } - let final_block = C::BlockSize::lengths2block(adata.len(), buffer.len())?; - let mut enc_counter = nonce.clone(); - let mut tag_counter = nonce.clone(); - enc_counter[0] &= 0b0111_1111; - tag_counter[0] |= 0b1000_0000; - - self.cipher.encrypt_block(&mut enc_counter); - self.cipher.encrypt_block(&mut tag_counter); - - let mut enc_counter = C::BlockSize::block2ctr(&enc_counter); - let mut tag_counter = C::BlockSize::block2ctr(&tag_counter); + let mut tag_ctr = nonce.clone(); + tag_ctr[0] |= 0b1000_0000; + self.cipher.encrypt_block(&mut tag_ctr); + let mut tag_ctr = C::BlockSize::block2ctr(&tag_ctr); let mut tag = ::Element::new(); // process adata let mut iter = adata.chunks_exact(C::BlockSize::USIZE); - for chunk in (&mut iter).map(Block::::from_slice) { - tag.mul_sum(&self.get_h(&tag_counter), chunk); - C::BlockSize::incr_l(&mut tag_counter); + for block in (&mut iter).map(Block::::from_slice) { + self.update_tag(&mut tag, &mut tag_ctr, block); } let rem = iter.remainder(); if !rem.is_empty() { - let mut chunk: Block = Default::default(); - chunk[..rem.len()].copy_from_slice(rem); - tag.mul_sum(&self.get_h(&tag_counter), &chunk); - C::BlockSize::incr_l(&mut tag_counter); + let mut block: Block = Default::default(); + block[..rem.len()].copy_from_slice(rem); + self.update_tag(&mut tag, &mut tag_ctr, &block); } + let mut enc_ctr = nonce.clone(); + enc_ctr[0] &= 0b0111_1111; + self.cipher.encrypt_block(&mut enc_ctr); + let mut enc_ctr = C::BlockSize::block2ctr(&enc_ctr); + // process plaintext let mut iter = buffer.chunks_exact_mut(C::BlockSize::USIZE); - for chunk in (&mut iter).map(Block::::from_mut_slice) { - xor(chunk, &self.encrypt_counter(&enc_counter)); - tag.mul_sum(&self.get_h(&tag_counter), chunk); - C::BlockSize::incr_r(&mut enc_counter); - C::BlockSize::incr_l(&mut tag_counter); + for block in (&mut iter).map(Block::::from_mut_slice) { + self.apply_ks_block(&mut enc_ctr, block); + self.update_tag(&mut tag, &mut tag_ctr, block); } let rem = iter.into_remainder(); if !rem.is_empty() { - let n = rem.len(); - let e = self.encrypt_counter(&enc_counter); - xor(rem, &e[..n]); - - let mut ct = Block::::default(); - ct[..n].copy_from_slice(rem); + self.apply_ks_block(&mut enc_ctr, rem); - tag.mul_sum(&self.get_h(&tag_counter), &ct); - C::BlockSize::incr_l(&mut tag_counter); + let mut block = Block::::default(); + let n = rem.len(); + block[..n].copy_from_slice(rem); + self.update_tag(&mut tag, &mut tag_ctr, &block); } - tag.mul_sum(&self.get_h(&tag_counter), &final_block); + let block = C::BlockSize::lengths2block(adata.len(), buffer.len())?; + self.update_tag(&mut tag, &mut tag_ctr, &block); let mut tag = tag.into_bytes(); self.cipher.encrypt_block(&mut tag); @@ -205,73 +203,65 @@ where if nonce[0] >> 7 != 0 { return Err(Error); } - let final_block = C::BlockSize::lengths2block(adata.len(), buffer.len())?; - - let mut enc_counter = nonce.clone(); - let mut tag_counter = nonce.clone(); - enc_counter[0] &= 0b0111_1111; - tag_counter[0] |= 0b1000_0000; - - self.cipher.encrypt_block(&mut enc_counter); - self.cipher.encrypt_block(&mut tag_counter); - let mut dec_counter = C::BlockSize::block2ctr(&enc_counter); - let mut tag_counter = C::BlockSize::block2ctr(&tag_counter); + let mut tag_ctr = nonce.clone(); + tag_ctr[0] |= 0b1000_0000; + self.cipher.encrypt_block(&mut tag_ctr); + let mut tag_ctr = C::BlockSize::block2ctr(&tag_ctr); let mut tag = ::Element::new(); // process adata let mut iter = adata.chunks_exact(C::BlockSize::USIZE); - for chunk in (&mut iter).map(Block::::from_slice) { - tag.mul_sum(&self.get_h(&tag_counter), chunk); - C::BlockSize::incr_l(&mut tag_counter); + for block in (&mut iter).map(Block::::from_slice) { + self.update_tag(&mut tag, &mut tag_ctr, block); } let rem = iter.remainder(); if !rem.is_empty() { - let mut chunk: Block = Default::default(); - chunk[..rem.len()].copy_from_slice(rem); - tag.mul_sum(&self.get_h(&tag_counter), &chunk); - C::BlockSize::incr_l(&mut tag_counter); + let mut block: Block = Default::default(); + block[..rem.len()].copy_from_slice(rem); + self.update_tag(&mut tag, &mut tag_ctr, &block); } - // process ciphertext let mut iter = buffer.chunks_exact_mut(C::BlockSize::USIZE); - for chunk in (&mut iter).map(Block::::from_mut_slice) { - tag.mul_sum(&self.get_h(&tag_counter), chunk); - xor(chunk, &self.encrypt_counter(&dec_counter)); - C::BlockSize::incr_r(&mut dec_counter); - C::BlockSize::incr_l(&mut tag_counter); + for block in (&mut iter).map(Block::::from_mut_slice) { + self.update_tag(&mut tag, &mut tag_ctr, block); } let rem = iter.into_remainder(); if !rem.is_empty() { let n = rem.len(); - let e = self.encrypt_counter(&dec_counter); - let mut ct = Block::::default(); - ct[..n].copy_from_slice(rem); + let mut block = Block::::default(); + block[..n].copy_from_slice(rem); - tag.mul_sum(&self.get_h(&tag_counter), &ct); - xor(rem, &e[..n]); - C::BlockSize::incr_l(&mut tag_counter); + self.update_tag(&mut tag, &mut tag_ctr, &block); } - tag.mul_sum(&self.get_h(&tag_counter), &final_block); + let block = C::BlockSize::lengths2block(adata.len(), buffer.len())?; + self.update_tag(&mut tag, &mut tag_ctr, &block); let mut tag = tag.into_bytes(); self.cipher.encrypt_block(&mut tag); - use subtle::ConstantTimeEq; - if expected_tag.ct_eq(&tag).unwrap_u8() == 1 { - Ok(()) - } else { - Err(Error) + if expected_tag.ct_eq(&tag).unwrap_u8() == 0 { + return Err(Error); + } + + // decrypt ciphertext + let mut dec_ctr = nonce.clone(); + dec_ctr[0] &= 0b0111_1111; + self.cipher.encrypt_block(&mut dec_ctr); + let mut dec_ctr = C::BlockSize::block2ctr(&dec_ctr); + + let mut iter = buffer.chunks_exact_mut(C::BlockSize::USIZE); + for block in (&mut iter).map(Block::::from_mut_slice) { + self.apply_ks_block(&mut dec_ctr, block); + } + let rem = iter.into_remainder(); + if !rem.is_empty() { + self.apply_ks_block(&mut dec_ctr, rem); } - } -} -fn xor(buf: &mut [u8], val: &[u8]) { - debug_assert_eq!(buf.len(), val.len()); - for (a, b) in buf.iter_mut().zip(val.iter()) { - *a ^= *b; + Ok(()) } }