From 7c57e7d94cb84544f97472761b2d95d99291a564 Mon Sep 17 00:00:00 2001 From: Helder Eijs Date: Fri, 17 Aug 2018 17:39:56 +0200 Subject: [PATCH] Fix issue #198: AESNI breaks with messages shorter than 16 bytes --- lib/Crypto/SelfTest/Cipher/test_AES.py | 23 +++++++++++++++++++++++ src/AESNI.c | 4 ++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/lib/Crypto/SelfTest/Cipher/test_AES.py b/lib/Crypto/SelfTest/Cipher/test_AES.py index e6dd69f68..3733fb981 100644 --- a/lib/Crypto/SelfTest/Cipher/test_AES.py +++ b/lib/Crypto/SelfTest/Cipher/test_AES.py @@ -1265,16 +1265,39 @@ def runTest(self): self.assertEqual(SHA256.new(ct).hexdigest(), expected) +class TestIncompleteBlocks(unittest.TestCase): + + def __init__(self, use_aesni): + unittest.TestCase.__init__(self) + self.use_aesni = use_aesni + + def runTest(self): + # Encrypt data with length not multiple of 16 bytes + + cipher = AES.new(b'4'*16, AES.MODE_ECB, use_aesni=self.use_aesni) + + for msg_len in range(1, 16): + self.assertRaises(ValueError, cipher.encrypt, b'1' * msg_len) + self.assertRaises(ValueError, cipher.encrypt, b'1' * (msg_len+16)) + self.assertRaises(ValueError, cipher.decrypt, b'1' * msg_len) + self.assertRaises(ValueError, cipher.decrypt, b'1' * (msg_len+16)) + + self.assertEqual(cipher.encrypt(b''), b'') + self.assertEqual(cipher.decrypt(b''), b'') + + def get_tests(config={}): from Crypto.Util import _cpu_features from common import make_block_tests tests = make_block_tests(AES, "AES", test_data, {'use_aesni': False}) tests += [ TestMultipleBlocks(False) ] + tests += [ TestIncompleteBlocks(False) ] if _cpu_features.have_aes_ni(): # Run tests with AES-NI instructions if they are available. tests += make_block_tests(AES, "AESNI", test_data, {'use_aesni': True}) tests += [ TestMultipleBlocks(True) ] + tests += [ TestIncompleteBlocks(True) ] else: print "Skipping AESNI tests" return tests diff --git a/src/AESNI.c b/src/AESNI.c index 7c1a92321..38acfa0c5 100644 --- a/src/AESNI.c +++ b/src/AESNI.c @@ -222,7 +222,7 @@ static int AESNI_encrypt(const BlockBase *bb, const uint8_t *in, uint8_t *out, s } /** There are 7 blocks or fewer left **/ - for (;data_len>0; data_len-=16, in+=16, out+=16) { + for (;data_len>=BLOCK_SIZE; data_len-=BLOCK_SIZE, in+=BLOCK_SIZE, out+=BLOCK_SIZE) { __m128i pt, data; unsigned i; @@ -331,7 +331,7 @@ static int AESNI_decrypt(const BlockBase *bb, const uint8_t *in, uint8_t *out, s } /** There are 7 blocks or fewer left **/ - for (;data_len>0; data_len-=16, in+=16, out+=16) { + for (;data_len>=BLOCK_SIZE; data_len-=BLOCK_SIZE, in+=BLOCK_SIZE, out+=BLOCK_SIZE) { __m128i ct, data; unsigned i;