Use 6x blocks for ARMv8 AES rather than 4x

We gain 0.1 to 0.3 cpb, depending on the mode
This commit is contained in:
Jeffrey Walton 2017-09-14 20:32:06 -04:00
parent 51752cb91a
commit 25efb7a140
No known key found for this signature in database
GPG Key ID: B36AB348921B1838

View File

@ -168,167 +168,173 @@ const word32 s_one[] = {0, 0, 0, 1}; // uint32x4_t
inline void ARMV8_Enc_Block(uint8x16_t &block, const word32 *subkeys, unsigned int rounds)
{
CRYPTOPP_ASSERT(subkeys);
CRYPTOPP_ASSERT(rounds >= 9);
const byte *keys = reinterpret_cast<const byte*>(subkeys);
// Unroll the loop, profit 0.3 to 0.5 cpb.
block = vaeseq_u8(block, vld1q_u8(keys+0));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+16));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+32));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+48));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+64));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+80));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+96));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+112));
block = vaesmcq_u8(block);
block = vaeseq_u8(block, vld1q_u8(keys+128));
// AES single round encryption
block = vaeseq_u8(block, vld1q_u8(keys+0*16));
// AES mix columns
block = vaesmcq_u8(block);
unsigned int i=9;
for ( ; i<rounds-1; ++i)
for (unsigned int i=1; i<rounds-1; i+=2)
{
// AES single round encryption
block = vaeseq_u8(block, vld1q_u8(keys+i*16));
// AES mix columns
block = vaesmcq_u8(block);
// AES single round encryption
block = vaeseq_u8(block, vld1q_u8(keys+(i+1)*16));
// AES mix columns
block = vaesmcq_u8(block);
}
// AES single round encryption
block = vaeseq_u8(block, vld1q_u8(keys+i*16));
block = vaeseq_u8(block, vld1q_u8(keys+(rounds-1)*16));
// Final Add (bitwise Xor)
block = veorq_u8(block, vld1q_u8(keys+(i+1)*16));
block = veorq_u8(block, vld1q_u8(keys+rounds*16));
}
inline void ARMV8_Enc_4_Blocks(uint8x16_t &block0, uint8x16_t &block1, uint8x16_t &block2,
uint8x16_t &block3, const word32 *subkeys, unsigned int rounds)
inline void ARMV8_Enc_6_Blocks(uint8x16_t &block0, uint8x16_t &block1, uint8x16_t &block2,
uint8x16_t &block3, uint8x16_t &block4, uint8x16_t &block5,
const word32 *subkeys, unsigned int rounds)
{
CRYPTOPP_ASSERT(subkeys);
const byte *keys = reinterpret_cast<const byte*>(subkeys);
uint8x16_t key;
unsigned int i=0;
for ( ; i<rounds-1; ++i)
for (unsigned int i=0; i<rounds-1; ++i)
{
uint8x16_t key = vld1q_u8(keys+i*16);
// AES single round encryption
block0 = vaeseq_u8(block0, vld1q_u8(keys+i*16));
block0 = vaeseq_u8(block0, key);
// AES mix columns
block0 = vaesmcq_u8(block0);
// AES single round encryption
block1 = vaeseq_u8(block1, vld1q_u8(keys+i*16));
block1 = vaeseq_u8(block1, key);
// AES mix columns
block1 = vaesmcq_u8(block1);
// AES single round encryption
block2 = vaeseq_u8(block2, vld1q_u8(keys+i*16));
block2 = vaeseq_u8(block2, key);
// AES mix columns
block2 = vaesmcq_u8(block2);
// AES single round encryption
block3 = vaeseq_u8(block3, vld1q_u8(keys+i*16));
block3 = vaeseq_u8(block3, key);
// AES mix columns
block3 = vaesmcq_u8(block3);
// AES single round encryption
block4 = vaeseq_u8(block4, key);
// AES mix columns
block4 = vaesmcq_u8(block4);
// AES single round encryption
block5 = vaeseq_u8(block5, key);
// AES mix columns
block5 = vaesmcq_u8(block5);
}
// AES single round encryption
block0 = vaeseq_u8(block0, vld1q_u8(keys+i*16));
block1 = vaeseq_u8(block1, vld1q_u8(keys+i*16));
block2 = vaeseq_u8(block2, vld1q_u8(keys+i*16));
block3 = vaeseq_u8(block3, vld1q_u8(keys+i*16));
key = vld1q_u8(keys+(rounds-1)*16);
block0 = vaeseq_u8(block0, key);
block1 = vaeseq_u8(block1, key);
block2 = vaeseq_u8(block2, key);
block3 = vaeseq_u8(block3, key);
block4 = vaeseq_u8(block4, key);
block5 = vaeseq_u8(block5, key);
// Final Add (bitwise Xor)
block0 = veorq_u8(block0, vld1q_u8(keys+(i+1)*16));
block1 = veorq_u8(block1, vld1q_u8(keys+(i+1)*16));
block2 = veorq_u8(block2, vld1q_u8(keys+(i+1)*16));
block3 = veorq_u8(block3, vld1q_u8(keys+(i+1)*16));
key = vld1q_u8(keys+rounds*16);
block0 = veorq_u8(block0, key);
block1 = veorq_u8(block1, key);
block2 = veorq_u8(block2, key);
block3 = veorq_u8(block3, key);
block4 = veorq_u8(block4, key);
block5 = veorq_u8(block5, key);
}
inline void ARMV8_Dec_Block(uint8x16_t &block, const word32 *subkeys, unsigned int rounds)
{
CRYPTOPP_ASSERT(subkeys);
CRYPTOPP_ASSERT(rounds >= 9);
const byte *keys = reinterpret_cast<const byte*>(subkeys);
// Unroll the loop, profit 0.3 to 0.5 cpb.
block = vaesdq_u8(block, vld1q_u8(keys+0));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+16));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+32));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+48));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+64));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+80));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+96));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+112));
block = vaesimcq_u8(block);
block = vaesdq_u8(block, vld1q_u8(keys+128));
// AES single round decryption
block = vaesdq_u8(block, vld1q_u8(keys+0*16));
// AES inverse mix columns
block = vaesimcq_u8(block);
unsigned int i=9;
for ( ; i<rounds-1; ++i)
for (unsigned int i=1; i<rounds-1; i+=2)
{
// AES single round decryption
block = vaesdq_u8(block, vld1q_u8(keys+i*16));
// AES inverse mix columns
block = vaesimcq_u8(block);
// AES single round decryption
block = vaesdq_u8(block, vld1q_u8(keys+(i+1)*16));
// AES inverse mix columns
block = vaesimcq_u8(block);
}
// AES single round decryption
block = vaesdq_u8(block, vld1q_u8(keys+i*16));
block = vaesdq_u8(block, vld1q_u8(keys+(rounds-1)*16));
// Final Add (bitwise Xor)
block = veorq_u8(block, vld1q_u8(keys+(i+1)*16));
block = veorq_u8(block, vld1q_u8(keys+rounds*16));
}
inline void ARMV8_Dec_4_Blocks(uint8x16_t &block0, uint8x16_t &block1, uint8x16_t &block2,
uint8x16_t &block3, const word32 *subkeys, unsigned int rounds)
inline void ARMV8_Dec_6_Blocks(uint8x16_t &block0, uint8x16_t &block1, uint8x16_t &block2,
uint8x16_t &block3, uint8x16_t &block4, uint8x16_t &block5,
const word32 *subkeys, unsigned int rounds)
{
CRYPTOPP_ASSERT(subkeys);
const byte *keys = reinterpret_cast<const byte*>(subkeys);
unsigned int i=0;
for ( ; i<rounds-1; ++i)
uint8x16_t key;
for (unsigned int i=0; i<rounds-1; ++i)
{
key = vld1q_u8(keys+i*16);
// AES single round decryption
block0 = vaesdq_u8(block0, vld1q_u8(keys+i*16));
block0 = vaesdq_u8(block0, key);
// AES inverse mix columns
block0 = vaesimcq_u8(block0);
// AES single round decryption
block1 = vaesdq_u8(block1, vld1q_u8(keys+i*16));
block1 = vaesdq_u8(block1, key);
// AES inverse mix columns
block1 = vaesimcq_u8(block1);
// AES single round decryption
block2 = vaesdq_u8(block2, vld1q_u8(keys+i*16));
block2 = vaesdq_u8(block2, key);
// AES inverse mix columns
block2 = vaesimcq_u8(block2);
// AES single round decryption
block3 = vaesdq_u8(block3, vld1q_u8(keys+i*16));
block3 = vaesdq_u8(block3, key);
// AES inverse mix columns
block3 = vaesimcq_u8(block3);
// AES single round decryption
block4 = vaesdq_u8(block4, key);
// AES inverse mix columns
block4 = vaesimcq_u8(block4);
// AES single round decryption
block5 = vaesdq_u8(block5, key);
// AES inverse mix columns
block5 = vaesimcq_u8(block5);
}
// AES single round decryption
block0 = vaesdq_u8(block0, vld1q_u8(keys+i*16));
block1 = vaesdq_u8(block1, vld1q_u8(keys+i*16));
block2 = vaesdq_u8(block2, vld1q_u8(keys+i*16));
block3 = vaesdq_u8(block3, vld1q_u8(keys+i*16));
key = vld1q_u8(keys+(rounds-1)*16);
block0 = vaesdq_u8(block0, key);
block1 = vaesdq_u8(block1, key);
block2 = vaesdq_u8(block2, key);
block3 = vaesdq_u8(block3, key);
block4 = vaesdq_u8(block4, key);
block5 = vaesdq_u8(block5, key);
// Final Add (bitwise Xor)
block0 = veorq_u8(block0, vld1q_u8(keys+(i+1)*16));
block1 = veorq_u8(block1, vld1q_u8(keys+(i+1)*16));
block2 = veorq_u8(block2, vld1q_u8(keys+(i+1)*16));
block3 = veorq_u8(block3, vld1q_u8(keys+(i+1)*16));
key = vld1q_u8(keys+rounds*16);
block0 = veorq_u8(block0, key);
block1 = veorq_u8(block1, key);
block2 = veorq_u8(block2, key);
block3 = veorq_u8(block3, key);
block4 = veorq_u8(block4, key);
block5 = veorq_u8(block5, key);
}
template <typename F1, typename F4>
size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *subKeys, size_t rounds,
template <typename F1, typename F6>
size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F6 func6, const word32 *subKeys, size_t rounds,
const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags)
{
CRYPTOPP_ASSERT(subKeys);
@ -353,9 +359,9 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
if (flags & BlockTransformation::BT_AllowParallel)
{
while (length >= 4*blockSize)
while (length >= 6*blockSize)
{
uint8x16_t block0, block1, block2, block3, temp;
uint8x16_t block0, block1, block2, block3, block4, block5, temp;
block0 = vld1q_u8(inBlocks);
if (flags & BlockTransformation::BT_InBlockIsCounter)
@ -364,7 +370,9 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
block1 = vaddq_u8(block0, vreinterpretq_u8_u32(be));
block2 = vaddq_u8(block1, vreinterpretq_u8_u32(be));
block3 = vaddq_u8(block2, vreinterpretq_u8_u32(be));
temp = vaddq_u8(block3, vreinterpretq_u8_u32(be));
block4 = vaddq_u8(block3, vreinterpretq_u8_u32(be));
block5 = vaddq_u8(block4, vreinterpretq_u8_u32(be));
temp = vaddq_u8(block5, vreinterpretq_u8_u32(be));
vst1q_u8(const_cast<byte*>(inBlocks), temp);
}
else
@ -376,6 +384,10 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
inBlocks += inIncrement;
block3 = vld1q_u8(inBlocks);
inBlocks += inIncrement;
block4 = vld1q_u8(inBlocks);
inBlocks += inIncrement;
block5 = vld1q_u8(inBlocks);
inBlocks += inIncrement;
}
if (flags & BlockTransformation::BT_XorInput)
@ -388,9 +400,13 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
xorBlocks += xorIncrement;
block3 = veorq_u8(block3, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement;
block4 = veorq_u8(block4, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement;
block5 = veorq_u8(block5, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement;
}
func4(block0, block1, block2, block3, subKeys, rounds);
func6(block0, block1, block2, block3, block4, block5, subKeys, rounds);
if (xorBlocks && !(flags & BlockTransformation::BT_XorInput))
{
@ -402,6 +418,10 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
xorBlocks += xorIncrement;
block3 = veorq_u8(block3, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement;
block4 = veorq_u8(block4, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement;
block5 = veorq_u8(block5, vld1q_u8(xorBlocks));
xorBlocks += xorIncrement;
}
vst1q_u8(outBlocks, block0);
@ -412,8 +432,12 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
outBlocks += outIncrement;
vst1q_u8(outBlocks, block3);
outBlocks += outIncrement;
vst1q_u8(outBlocks, block4);
outBlocks += outIncrement;
vst1q_u8(outBlocks, block5);
outBlocks += outIncrement;
length -= 4*blockSize;
length -= 6*blockSize;
}
}
@ -446,14 +470,14 @@ size_t Rijndael_AdvancedProcessBlocks_ARMV8(F1 func1, F4 func4, const word32 *su
size_t Rijndael_Enc_AdvancedProcessBlocks_ARMV8(const word32 *subKeys, size_t rounds,
const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags)
{
return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Enc_Block, ARMV8_Enc_4_Blocks,
return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Enc_Block, ARMV8_Enc_6_Blocks,
subKeys, rounds, inBlocks, xorBlocks, outBlocks, length, flags);
}
size_t Rijndael_Dec_AdvancedProcessBlocks_ARMV8(const word32 *subKeys, size_t rounds,
const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags)
{
return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Dec_Block, ARMV8_Dec_4_Blocks,
return Rijndael_AdvancedProcessBlocks_ARMV8(ARMV8_Dec_Block, ARMV8_Dec_6_Blocks,
subKeys, rounds, inBlocks, xorBlocks, outBlocks, length, flags);
}