Implement numerically stable log_softmax() (#812)

* Implement numerically stable log_softmax()

* Add unit tests

* Update src/utils/maths.js

---------

Co-authored-by: Joshua Lochner <admin@xenova.com>
This commit is contained in:
Taha Yassine
2024-07-01 14:25:18 +01:00
committed by GitHub
parent fc34517091
commit 75f557b505
2 changed files with 31 additions and 5 deletions
+13 -4
View File
@@ -158,11 +158,20 @@ export function softmax(arr) {
* @returns {T} The resulting log_softmax array.
*/
export function log_softmax(arr) {
// Compute the softmax values
const softmaxArr = softmax(arr);
// Compute the maximum value in the array
const maxVal = max(arr)[0];
// Apply log formula to each element
const logSoftmaxArr = softmaxArr.map(x => Math.log(x));
// Compute the sum of the exponentials
let sumExps = 0;
for(let i = 0; i < arr.length; ++i) {
sumExps += Math.exp(arr[i] - maxVal);
}
// Compute the log of the sum
const logSum = Math.log(sumExps);
// Compute the softmax values
const logSoftmaxArr = arr.map(x => x - maxVal - logSum);
return /** @type {T} */(logSoftmaxArr);
}
+18 -1
View File
@@ -2,7 +2,7 @@
import { compare } from './test_utils.js';
import { getFile } from '../src/utils/hub.js';
import { FFT, medianFilter, bankers_round } from '../src/utils/maths.js';
import { FFT, medianFilter, bankers_round, log_softmax } from '../src/utils/maths.js';
const fft = (arr, complex = false) => {
@@ -136,4 +136,21 @@ describe('Mathematical operations', () => {
});
}
});
describe('log softmax', () => {
// Should match output of scipy log_softmax
it('should compute log softmax correctly for usual values', () => {
const input = [0, 1, 2, 3];
const expected = [-3.4401896985611953, -2.4401896985611953, -1.4401896985611953, -0.44018969856119533];
const output = log_softmax(input);
compare(output, expected, 1e-13);
});
it('should compute log softmax correctly for values with large differences', () => {
const input = [1000, 1];
const expected = [0, -999];
const output = log_softmax(input);
compare(output, expected, 1e-13);
});
});
});