Optimize FFT (#766)

* Optimize FFT for real transforms

* Throw error if power is not specified

https://github.com/huggingface/transformers/issues/27772
This commit is contained in:
Joshua Lochner
2024-05-23 01:43:05 +02:00
committed by GitHub
parent 8963720585
commit 8d166ca642
2 changed files with 60 additions and 42 deletions
+7 -2
View File
@@ -473,6 +473,13 @@ export function spectrogram(
throw new Error("hop_length must be greater than zero");
}
if (power === null && mel_filters !== null) {
throw new Error(
"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram. " +
"Specify `power` to fix this issue."
);
}
if (center) {
if (pad_mode !== 'reflect') {
throw new Error(`pad_mode="${pad_mode}" not implemented yet.`)
@@ -547,8 +554,6 @@ export function spectrogram(
magnitudes[i] = row;
}
// TODO what should happen if power is None?
// https://github.com/huggingface/transformers/issues/27772
if (power !== null && power !== 2) {
// slight optimization to not sqrt
const pow = 2 / power; // we use 2 since we already squared
+53 -40
View File
@@ -364,20 +364,6 @@ class P2FFT {
return res;
}
/**
* Completes the spectrum by adding its mirrored negative frequency components.
* @param {Float64Array} spectrum The input spectrum.
* @returns {void}
*/
completeSpectrum(spectrum) {
const size = this._csize;
const half = size >>> 1;
for (let i = 2; i < half; i += 2) {
spectrum[size - i] = spectrum[i];
spectrum[size - i + 1] = -spectrum[i + 1];
}
}
/**
* Performs a Fast Fourier Transform (FFT) on the given input data and stores the result in the output buffer.
*
@@ -466,6 +452,7 @@ class P2FFT {
}
// Loop through steps in decreasing order
const table = this.table;
for (step >>= 2; step >= 2; step >>= 2) {
len = (size / step) << 1;
const quarterLen = len >>> 2;
@@ -490,18 +477,18 @@ class P2FFT {
const Dr = out[D];
const Di = out[D + 1];
const tableBr = this.table[k];
const tableBi = inv * this.table[k + 1];
const tableBr = table[k];
const tableBi = inv * table[k + 1];
const MBr = Br * tableBr - Bi * tableBi;
const MBi = Br * tableBi + Bi * tableBr;
const tableCr = this.table[2 * k];
const tableCi = inv * this.table[2 * k + 1];
const tableCr = table[2 * k];
const tableCi = inv * table[2 * k + 1];
const MCr = Cr * tableCr - Ci * tableCi;
const MCi = Cr * tableCi + Ci * tableCr;
const tableDr = this.table[3 * k];
const tableDi = inv * this.table[3 * k + 1];
const tableDr = table[3 * k];
const tableDi = inv * table[3 * k + 1];
const MDr = Dr * tableDr - Di * tableDi;
const MDi = Dr * tableDi + Di * tableDr;
@@ -634,18 +621,18 @@ class P2FFT {
}
}
// TODO: Optimize once https://github.com/indutny/fft.js/issues/25 is fixed
// Loop through steps in decreasing order
const table = this.table;
for (step >>= 2; step >= 2; step >>= 2) {
len = (size / step) << 1;
const quarterLen = len >>> 2;
const halfLen = len >>> 1;
const quarterLen = halfLen >>> 1;
const hquarterLen = quarterLen >>> 1;
// Loop through offsets in the data
for (outOff = 0; outOff < size; outOff += len) {
// Full case
const limit = outOff + quarterLen - 1;
for (let i = outOff, k = 0; i < limit; i += 2, k += step) {
const A = i;
for (let i = 0, k = 0; i <= hquarterLen; i += 2, k += step) {
const A = outOff + i;
const B = A + quarterLen;
const C = B + quarterLen;
const D = C + quarterLen;
@@ -660,26 +647,30 @@ class P2FFT {
const Dr = out[D];
const Di = out[D + 1];
const tableBr = this.table[k];
const tableBi = inv * this.table[k + 1];
// Middle values
const MAr = Ar;
const MAi = Ai;
const tableBr = table[k];
const tableBi = inv * table[k + 1];
const MBr = Br * tableBr - Bi * tableBi;
const MBi = Br * tableBi + Bi * tableBr;
const tableCr = this.table[2 * k];
const tableCi = inv * this.table[2 * k + 1];
const tableCr = table[2 * k];
const tableCi = inv * table[2 * k + 1];
const MCr = Cr * tableCr - Ci * tableCi;
const MCi = Cr * tableCi + Ci * tableCr;
const tableDr = this.table[3 * k];
const tableDi = inv * this.table[3 * k + 1];
const tableDr = table[3 * k];
const tableDi = inv * table[3 * k + 1];
const MDr = Dr * tableDr - Di * tableDi;
const MDi = Dr * tableDi + Di * tableDr;
// Pre-Final values
const T0r = Ar + MCr;
const T0i = Ai + MCi;
const T1r = Ar - MCr;
const T1i = Ai - MCi;
const T0r = MAr + MCr;
const T0i = MAi + MCi;
const T1r = MAr - MCr;
const T1i = MAi - MCi;
const T2r = MBr + MDr;
const T2i = MBi + MDi;
const T3r = inv * (MBr - MDr);
@@ -690,13 +681,35 @@ class P2FFT {
out[A + 1] = T0i + T2i;
out[B] = T1r + T3i;
out[B + 1] = T1i - T3r;
out[C] = T0r - T2r;
out[C + 1] = T0i - T2i;
out[D] = T1r - T3i;
out[D + 1] = T1i + T3r;
// Output final middle point
if (i === 0) {
out[C] = T0r - T2r;
out[C + 1] = T0i - T2i;
continue;
}
// Do not overwrite ourselves
if (i === hquarterLen)
continue;
const SA = outOff + quarterLen - i;
const SB = outOff + halfLen - i;
out[SA] = T1r - inv * T3i;
out[SA + 1] = -T1i - inv * T3r;
out[SB] = T0r - inv * T2r;
out[SB + 1] = -T0i + inv * T2i;
}
}
}
// Complete the spectrum by adding its mirrored negative frequency components.
const half = size >>> 1;
for (let i = 2; i < half; i += 2) {
out[size - i] = out[i];
out[size - i + 1] = -out[i + 1];
}
}
/**