diff --git a/rijndael-simd.cpp b/rijndael-simd.cpp index 282f0ddb..e650de1c 100644 --- a/rijndael-simd.cpp +++ b/rijndael-simd.cpp @@ -168,167 +168,173 @@ const word32 s_one[] = {0, 0, 0, 1}; // uint32x4_t inline void ARMV8_Enc_Block(uint8x16_t &block, const word32 *subkeys, unsigned int rounds) { CRYPTOPP_ASSERT(subkeys); - CRYPTOPP_ASSERT(rounds >= 9); const byte *keys = reinterpret_cast(subkeys); - // Unroll the loop, profit 0.3 to 0.5 cpb. - block = vaeseq_u8(block, vld1q_u8(keys+0)); - block = vaesmcq_u8(block); - block = vaeseq_u8(block, vld1q_u8(keys+16)); - block = vaesmcq_u8(block); - block = vaeseq_u8(block, vld1q_u8(keys+32)); - block = vaesmcq_u8(block); - block = vaeseq_u8(block, vld1q_u8(keys+48)); - block = vaesmcq_u8(block); - block = vaeseq_u8(block, vld1q_u8(keys+64)); - block = vaesmcq_u8(block); - block = vaeseq_u8(block, vld1q_u8(keys+80)); - block = vaesmcq_u8(block); - block = vaeseq_u8(block, vld1q_u8(keys+96)); - block = vaesmcq_u8(block); - block = vaeseq_u8(block, vld1q_u8(keys+112)); - block = vaesmcq_u8(block); - block = vaeseq_u8(block, vld1q_u8(keys+128)); + // AES single round encryption + block = vaeseq_u8(block, vld1q_u8(keys+0*16)); + // AES mix columns block = vaesmcq_u8(block); - unsigned int i=9; - for ( ; i(subkeys); + uint8x16_t key; - unsigned int i=0; - for ( ; i= 9); const byte *keys = reinterpret_cast(subkeys); - // Unroll the loop, profit 0.3 to 0.5 cpb. - block = vaesdq_u8(block, vld1q_u8(keys+0)); - block = vaesimcq_u8(block); - block = vaesdq_u8(block, vld1q_u8(keys+16)); - block = vaesimcq_u8(block); - block = vaesdq_u8(block, vld1q_u8(keys+32)); - block = vaesimcq_u8(block); - block = vaesdq_u8(block, vld1q_u8(keys+48)); - block = vaesimcq_u8(block); - block = vaesdq_u8(block, vld1q_u8(keys+64)); - block = vaesimcq_u8(block); - block = vaesdq_u8(block, vld1q_u8(keys+80)); - block = vaesimcq_u8(block); - block = vaesdq_u8(block, vld1q_u8(keys+96)); - block = vaesimcq_u8(block); - block = vaesdq_u8(block, vld1q_u8(keys+112)); - block = vaesimcq_u8(block); - block = vaesdq_u8(block, vld1q_u8(keys+128)); + // AES single round decryption + block = vaesdq_u8(block, vld1q_u8(keys+0*16)); + // AES inverse mix columns block = vaesimcq_u8(block); - unsigned int i=9; - for ( ; i(subkeys); - unsigned int i=0; - for ( ; i -size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *subKeys, size_t rounds, +template +size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F6 func6, const word32 *subKeys, size_t rounds, const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags) { CRYPTOPP_ASSERT(subKeys); @@ -353,9 +359,9 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su if (flags & BlockTransformation::BT_AllowParallel) { - while (length >= 4*blockSize) + while (length >= 6*blockSize) { - uint8x16_t block0, block1, block2, block3, temp; + uint8x16_t block0, block1, block2, block3, block4, block5, temp; block0 = vld1q_u8(inBlocks); if (flags & BlockTransformation::BT_InBlockIsCounter) @@ -364,7 +370,9 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su block1 = vaddq_u8(block0, vreinterpretq_u8_u32(be)); block2 = vaddq_u8(block1, vreinterpretq_u8_u32(be)); block3 = vaddq_u8(block2, vreinterpretq_u8_u32(be)); - temp = vaddq_u8(block3, vreinterpretq_u8_u32(be)); + block4 = vaddq_u8(block3, vreinterpretq_u8_u32(be)); + block5 = vaddq_u8(block4, vreinterpretq_u8_u32(be)); + temp = vaddq_u8(block5, vreinterpretq_u8_u32(be)); vst1q_u8(const_cast(inBlocks), temp); } else @@ -376,6 +384,10 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su inBlocks += inIncrement; block3 = vld1q_u8(inBlocks); inBlocks += inIncrement; + block4 = vld1q_u8(inBlocks); + inBlocks += inIncrement; + block5 = vld1q_u8(inBlocks); + inBlocks += inIncrement; } if (flags & BlockTransformation::BT_XorInput) @@ -388,9 +400,13 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su xorBlocks += xorIncrement; block3 = veorq_u8(block3, vld1q_u8(xorBlocks)); xorBlocks += xorIncrement; + block4 = veorq_u8(block4, vld1q_u8(xorBlocks)); + xorBlocks += xorIncrement; + block5 = veorq_u8(block5, vld1q_u8(xorBlocks)); + xorBlocks += xorIncrement; } - func4(block0, block1, block2, block3, subKeys, rounds); + func6(block0, block1, block2, block3, block4, block5, subKeys, rounds); if (xorBlocks && !(flags & BlockTransformation::BT_XorInput)) { @@ -402,6 +418,10 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su xorBlocks += xorIncrement; block3 = veorq_u8(block3, vld1q_u8(xorBlocks)); xorBlocks += xorIncrement; + block4 = veorq_u8(block4, vld1q_u8(xorBlocks)); + xorBlocks += xorIncrement; + block5 = veorq_u8(block5, vld1q_u8(xorBlocks)); + xorBlocks += xorIncrement; } vst1q_u8(outBlocks, block0); @@ -412,8 +432,12 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su outBlocks += outIncrement; vst1q_u8(outBlocks, block3); outBlocks += outIncrement; + vst1q_u8(outBlocks, block4); + outBlocks += outIncrement; + vst1q_u8(outBlocks, block5); + outBlocks += outIncrement; - length -= 4*blockSize; + length -= 6*blockSize; } } @@ -446,14 +470,14 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su size_t Rijndael_Enc_AdvancedProcessBlocks_ARMV8(const word32 *subKeys, size_t rounds, const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags) { - return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Enc_Block, ARMV8_Enc_4_Blocks, + return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Enc_Block, ARMV8_Enc_6_Blocks, subKeys, rounds, inBlocks, xorBlocks, outBlocks, length, flags); } size_t Rijndael_Dec_AdvancedProcessBlocks_ARMV8(const word32 *subKeys, size_t rounds, const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags) { - return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Dec_Block, ARMV8_Dec_4_Blocks, + return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Dec_Block, ARMV8_Dec_6_Blocks, subKeys, rounds, inBlocks, xorBlocks, outBlocks, length, flags); }