Bug 1841230 - Update gemmology to latest revision r=yury

This brings in improvements for aarch64

Differential Revision: https://phabricator.services.mozilla.com/D182547
This commit is contained in:
serge-sans-paille 2023-07-06 07:19:52 +00:00
parent faf7e1d3b4
commit 77748547a3
2 changed files with 66 additions and 79 deletions

View File

@ -534,8 +534,7 @@ inline xsimd::batch<int32_t, Arch>
madd(xsimd::batch<int16_t, Arch> x, xsimd::batch<int16_t, Arch> y,
xsimd::kernel::requires_arch<xsimd::neon64>) {
int32x4_t low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
int32x4_t high = vmull_high_s16(x, y);
return vpaddq_s32(low, high);
return vmlal_high_s16(low, x, y);
}
template <class Arch>
@ -550,17 +549,37 @@ madd(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
return vqaddq_s16(vuzp1q_s16(tl, th), vuzp2q_s16(tl, th));
}
template <class Arch>
inline xsimd::batch<int32_t, Arch>
maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
xsimd::kernel::requires_arch<xsimd::neon64>) {
int16x8_t tl = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(x))),
vmovl_s8(vget_low_s8(y)));
int16x8_t th = vmulq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(x))),
vmovl_s8(vget_high_s8(y)));
int32x4_t pl = vpaddlq_s16(tl);
int32x4_t ph = vpaddlq_s16(th);
return vpaddq_s32(pl, ph);
}
template <class Arch>
inline xsimd::batch<int16_t, Arch>
madd(xsimd::batch<int8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
xsimd::kernel::requires_arch<xsimd::neon64>) {
int16x8_t low = vmull_s8(vget_low_s8(x), vget_low_s8(y));
int16x8_t high = vmull_high_s8(x, y);
return vpaddq_s16(low, high);
return vmlal_high_s8(low, x, y);
}
#endif
template <class Arch>
inline xsimd::batch<int32_t, Arch>
maddw(xsimd::batch<uint8_t, Arch> x, xsimd::batch<int8_t, Arch> y,
xsimd::kernel::requires_arch<xsimd::generic>) {
return madd(xsimd::batch<int16_t, Arch>(1), madd(x, y, Arch{}), Arch{});
}
} // namespace kernel
//
@ -599,6 +618,11 @@ inline xsimd::batch<int16_t, Arch> madd(xsimd::batch<uint8_t, Arch> x,
xsimd::batch<int8_t, Arch> y) {
return kernel::madd(x, y, Arch{});
}
template <class Arch>
inline xsimd::batch<int32_t, Arch> maddw(xsimd::batch<uint8_t, Arch> x,
xsimd::batch<int8_t, Arch> y) {
return kernel::maddw(x, y, Arch{});
}
template <class Arch>
inline auto PermuteSummer(xsimd::batch<int32_t, Arch> pack0123,
@ -1188,7 +1212,6 @@ void Engine<Arch>::Shift::Multiply(const uint8_t *A, const int8_t *B,
using batch8 = xsimd::batch<int8_t, Arch>;
using ubatch8 = xsimd::batch<uint8_t, Arch>;
using batch16 = xsimd::batch<int16_t, Arch>;
using batch32 = xsimd::batch<int32_t, Arch>;
const size_t simd_width = width / batch8::size;
@ -1202,47 +1225,30 @@ void Engine<Arch>::Shift::Multiply(const uint8_t *A, const int8_t *B,
reinterpret_cast<const ubatch8 *>(A + A_rowidx * width);
/* These will be packed 16-bit integers containing sums for each row of B
multiplied by the row of A. Iterate over shared (inner) dimension.*/
size_t k = 0;
ubatch8 a = *(A_row + k);
batch16 sum0 = madd(a, *(B0_col + k * 8));
batch16 sum1 = madd(a, *(B0_col + k * 8 + 1));
batch16 sum2 = madd(a, *(B0_col + k * 8 + 2));
batch16 sum3 = madd(a, *(B0_col + k * 8 + 3));
batch16 sum4 = madd(a, *(B0_col + k * 8 + 4));
batch16 sum5 = madd(a, *(B0_col + k * 8 + 5));
batch16 sum6 = madd(a, *(B0_col + k * 8 + 6));
batch16 sum7 = madd(a, *(B0_col + k * 8 + 7));
/* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
* declared here.*/
batch16 ones(1);
batch32 isum0 = madd(sum0, ones);
batch32 isum1 = madd(sum1, ones);
batch32 isum2 = madd(sum2, ones);
batch32 isum3 = madd(sum3, ones);
batch32 isum4 = madd(sum4, ones);
batch32 isum5 = madd(sum5, ones);
batch32 isum6 = madd(sum6, ones);
batch32 isum7 = madd(sum7, ones);
size_t k = 0;
ubatch8 a = *(A_row + k);
batch32 isum0 = maddw(a, *(B0_col + k * 8));
batch32 isum1 = maddw(a, *(B0_col + k * 8 + 1));
batch32 isum2 = maddw(a, *(B0_col + k * 8 + 2));
batch32 isum3 = maddw(a, *(B0_col + k * 8 + 3));
batch32 isum4 = maddw(a, *(B0_col + k * 8 + 4));
batch32 isum5 = maddw(a, *(B0_col + k * 8 + 5));
batch32 isum6 = maddw(a, *(B0_col + k * 8 + 6));
batch32 isum7 = maddw(a, *(B0_col + k * 8 + 7));
for (k = 1; k < simd_width; ++k) {
a = *(A_row + k);
/* Multiply 8-bit, horizontally add to packed 16-bit integers.*/
batch16 mult0 = madd(a, *(B0_col + k * 8));
batch16 mult1 = madd(a, *(B0_col + k * 8 + 1));
batch16 mult2 = madd(a, *(B0_col + k * 8 + 2));
batch16 mult3 = madd(a, *(B0_col + k * 8 + 3));
batch16 mult4 = madd(a, *(B0_col + k * 8 + 4));
batch16 mult5 = madd(a, *(B0_col + k * 8 + 5));
batch16 mult6 = madd(a, *(B0_col + k * 8 + 6));
batch16 mult7 = madd(a, *(B0_col + k * 8 + 7));
/* Upcast to 32-bit and horizontally add.*/
batch32 imult0 = madd(mult0, ones);
batch32 imult1 = madd(mult1, ones);
batch32 imult2 = madd(mult2, ones);
batch32 imult3 = madd(mult3, ones);
batch32 imult4 = madd(mult4, ones);
batch32 imult5 = madd(mult5, ones);
batch32 imult6 = madd(mult6, ones);
batch32 imult7 = madd(mult7, ones);
batch32 imult0 = maddw(a, *(B0_col + k * 8));
batch32 imult1 = maddw(a, *(B0_col + k * 8 + 1));
batch32 imult2 = maddw(a, *(B0_col + k * 8 + 2));
batch32 imult3 = maddw(a, *(B0_col + k * 8 + 3));
batch32 imult4 = maddw(a, *(B0_col + k * 8 + 4));
batch32 imult5 = maddw(a, *(B0_col + k * 8 + 5));
batch32 imult6 = maddw(a, *(B0_col + k * 8 + 6));
batch32 imult7 = maddw(a, *(B0_col + k * 8 + 7));
/*Add in 32bit*/
isum0 += imult0;
isum1 += imult1;
@ -1268,7 +1274,6 @@ template <class Callback>
void Engine<Arch>::Shift::PrepareBias(const int8_t *B, size_t width,
size_t B_cols, Callback C) {
using batch8 = xsimd::batch<int8_t, Arch>;
using batch16 = xsimd::batch<int16_t, Arch>;
const size_t simd_width = width / batch8::size;
xsimd::batch<uint8_t, Arch> a(1);
for (size_t j = 0; j < B_cols; j += 8) {
@ -1280,46 +1285,28 @@ void Engine<Arch>::Shift::PrepareBias(const int8_t *B, size_t width,
* first.*/
/* These will be packed 16-bit integers containing sums for each column of
* B multiplied by the row of A.*/
auto sum0 = madd(a, batch8::load_aligned(&B_j[0 * batch8::size]));
auto sum1 = madd(a, batch8::load_aligned(&B_j[1 * batch8::size]));
auto sum2 = madd(a, batch8::load_aligned(&B_j[2 * batch8::size]));
auto sum3 = madd(a, batch8::load_aligned(&B_j[3 * batch8::size]));
auto sum4 = madd(a, batch8::load_aligned(&B_j[4 * batch8::size]));
auto sum5 = madd(a, batch8::load_aligned(&B_j[5 * batch8::size]));
auto sum6 = madd(a, batch8::load_aligned(&B_j[6 * batch8::size]));
auto sum7 = madd(a, batch8::load_aligned(&B_j[7 * batch8::size]));
/* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
* declared here.*/
auto isum0 = maddw(a, batch8::load_aligned(&B_j[0 * batch8::size]));
auto isum1 = maddw(a, batch8::load_aligned(&B_j[1 * batch8::size]));
auto isum2 = maddw(a, batch8::load_aligned(&B_j[2 * batch8::size]));
auto isum3 = maddw(a, batch8::load_aligned(&B_j[3 * batch8::size]));
auto isum4 = maddw(a, batch8::load_aligned(&B_j[4 * batch8::size]));
auto isum5 = maddw(a, batch8::load_aligned(&B_j[5 * batch8::size]));
auto isum6 = maddw(a, batch8::load_aligned(&B_j[6 * batch8::size]));
auto isum7 = maddw(a, batch8::load_aligned(&B_j[7 * batch8::size]));
B_j += 8 * batch8::size;
/* Upcast to 32-bit and horizontally add. Seems a bit faster if this is
* declared here.*/
batch16 ones(1);
auto isum0 = madd(sum0, ones);
auto isum1 = madd(sum1, ones);
auto isum2 = madd(sum2, ones);
auto isum3 = madd(sum3, ones);
auto isum4 = madd(sum4, ones);
auto isum5 = madd(sum5, ones);
auto isum6 = madd(sum6, ones);
auto isum7 = madd(sum7, ones);
for (size_t k = 1; k < simd_width; ++k, B_j += 8 * batch8::size) {
isum0 +=
madd(madd(a, batch8::load_aligned(&B_j[0 * batch8::size])), ones);
isum1 +=
madd(madd(a, batch8::load_aligned(&B_j[1 * batch8::size])), ones);
isum2 +=
madd(madd(a, batch8::load_aligned(&B_j[2 * batch8::size])), ones);
isum3 +=
madd(madd(a, batch8::load_aligned(&B_j[3 * batch8::size])), ones);
isum4 +=
madd(madd(a, batch8::load_aligned(&B_j[4 * batch8::size])), ones);
isum5 +=
madd(madd(a, batch8::load_aligned(&B_j[5 * batch8::size])), ones);
isum6 +=
madd(madd(a, batch8::load_aligned(&B_j[6 * batch8::size])), ones);
isum7 +=
madd(madd(a, batch8::load_aligned(&B_j[7 * batch8::size])), ones);
isum0 += maddw(a, batch8::load_aligned(&B_j[0 * batch8::size]));
isum1 += maddw(a, batch8::load_aligned(&B_j[1 * batch8::size]));
isum2 += maddw(a, batch8::load_aligned(&B_j[2 * batch8::size]));
isum3 += maddw(a, batch8::load_aligned(&B_j[3 * batch8::size]));
isum4 += maddw(a, batch8::load_aligned(&B_j[4 * batch8::size]));
isum5 += maddw(a, batch8::load_aligned(&B_j[5 * batch8::size]));
isum6 += maddw(a, batch8::load_aligned(&B_j[6 * batch8::size]));
isum7 += maddw(a, batch8::load_aligned(&B_j[7 * batch8::size]));
}
auto pack0123 = Pack0123(isum0, isum1, isum2, isum3);

View File

@ -10,8 +10,8 @@ origin:
url: https://github.com/mozilla/gemmology
release: 3c20b313a6c49ca674a830630c0ef5ea5663b91e (2023-03-29T14:50:11Z).
revision: 3c20b313a6c49ca674a830630c0ef5ea5663b91e
release: 6497d9d3a50f3de62ec902efe0946d221f1f9095 (2023-05-05T08:43:01Z).
revision: 6497d9d3a50f3de62ec902efe0946d221f1f9095
license: MIT