diff --git a/chacha-simd.cpp b/chacha-simd.cpp index 6ec01834..dbaf1689 100644 --- a/chacha-simd.cpp +++ b/chacha-simd.cpp @@ -87,6 +87,7 @@ inline uint32x4_t RotateLeft<8>(const uint32x4_t& val) return vreinterpretq_u32_u8( vqtbl1q_u8(vreinterpretq_u8_u32(val), mask)); #else + // fallback to slower C++ rotation. return vorrq_u32(vshlq_n_u32(val, 8), vshrq_n_u32(val, 32 - 8)); #endif @@ -99,6 +100,7 @@ inline uint32x4_t RotateLeft<16>(const uint32x4_t& val) return vreinterpretq_u32_u16( vrev32q_u16(vreinterpretq_u16_u32(val))); #else + // fallback to slower C++ rotation. return vorrq_u32(vshlq_n_u32(val, 16), vshrq_n_u32(val, 32 - 16)); #endif @@ -114,6 +116,7 @@ inline uint32x4_t RotateRight<8>(const uint32x4_t& val) return vreinterpretq_u32_u8( vqtbl1q_u8(vreinterpretq_u8_u32(val), mask)); #else + // fallback to slower C++ rotation. return vorrq_u32(vshrq_n_u32(val, 8), vshlq_n_u32(val, 32 - 8)); #endif @@ -126,12 +129,14 @@ inline uint32x4_t RotateRight<16>(const uint32x4_t& val) return vreinterpretq_u32_u16( vrev32q_u16(vreinterpretq_u16_u32(val))); #else + // fallback to slower C++ rotation. return vorrq_u32(vshrq_n_u32(val, 16), vshlq_n_u32(val, 32 - 16)); #endif } -// ChaCha's use of shuffle is really a 4, 8, or 12 byte rotation: +// ChaCha's use of x86 shuffle is really a 4, 8, or 12 byte +// rotation on the 128-bit vector word: // * [3,2,1,0] => [0,3,2,1] is Extract<1>(x) // * [3,2,1,0] => [1,0,3,2] is Extract<2>(x) // * [3,2,1,0] => [2,1,0,3] is Extract<3>(x) @@ -141,6 +146,15 @@ inline uint32x4_t Extract(const uint32x4_t& val) return vextq_u32(val, val, S); } +// Helper to perform 64-bit addition across two elements of 32-bit vectors +inline uint32x4_t Add64(const uint32x4_t& a, const uint32x4_t& b) +{ + return vreinterpretq_u32_u64( + vaddq_u64( + vreinterpretq_u64_u32(a), + vreinterpretq_u64_u32(b))); +} + #endif // CRYPTOPP_ARM_NEON_AVAILABLE // ***************************** SSE2 ***************************** // @@ -200,8 +214,8 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte * const uint32x4_t state2 = vld1q_u32(state + 2*4); const uint32x4_t state3 = vld1q_u32(state + 3*4); - const uint64x2_t CTRS[3] = { - {1, 0}, {2, 0}, {3, 0} + const uint32x4_t CTRS[3] = { + {1,0,0,0}, {2,0,0,0}, {3,0,0,0} }; uint32x4_t r0_0 = state0; @@ -212,20 +226,17 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte * uint32x4_t r1_0 = state0; uint32x4_t r1_1 = state1; uint32x4_t r1_2 = state2; - uint32x4_t r1_3 = vreinterpretq_u32_u64(vaddq_u64( - vreinterpretq_u64_u32(r0_3), CTRS[0])); + uint32x4_t r1_3 = Add64(r0_3, CTRS[0]); uint32x4_t r2_0 = state0; uint32x4_t r2_1 = state1; uint32x4_t r2_2 = state2; - uint32x4_t r2_3 = vreinterpretq_u32_u64(vaddq_u64( - vreinterpretq_u64_u32(r0_3), CTRS[1])); + uint32x4_t r2_3 = Add64(r0_3, CTRS[1]); uint32x4_t r3_0 = state0; uint32x4_t r3_1 = state1; uint32x4_t r3_2 = state2; - uint32x4_t r3_3 = vreinterpretq_u32_u64(vaddq_u64( - vreinterpretq_u64_u32(r0_3), CTRS[2])); + uint32x4_t r3_3 = Add64(r0_3, CTRS[2]); for (int i = static_cast(rounds); i > 0; i -= 2) { @@ -391,22 +402,19 @@ void ChaCha_OperateKeystream_NEON(const word32 *state, const byte* input, byte * r1_1 = vaddq_u32(r1_1, state1); r1_2 = vaddq_u32(r1_2, state2); r1_3 = vaddq_u32(r1_3, state3); - r1_3 = vreinterpretq_u32_u64(vaddq_u64( - vreinterpretq_u64_u32(r1_3), CTRS[0])); + r1_3 = Add64(r1_3, CTRS[0]); r2_0 = vaddq_u32(r2_0, state0); r2_1 = vaddq_u32(r2_1, state1); r2_2 = vaddq_u32(r2_2, state2); r2_3 = vaddq_u32(r2_3, state3); - r2_3 = vreinterpretq_u32_u64(vaddq_u64( - vreinterpretq_u64_u32(r2_3), CTRS[1])); + r2_3 = Add64(r2_3, CTRS[1]); r3_0 = vaddq_u32(r3_0, state0); r3_1 = vaddq_u32(r3_1, state1); r3_2 = vaddq_u32(r3_2, state2); r3_3 = vaddq_u32(r3_3, state3); - r3_3 = vreinterpretq_u32_u64(vaddq_u64( - vreinterpretq_u64_u32(r3_3), CTRS[2])); + r3_3 = Add64(r3_3, CTRS[2]); if (input) {