mirror of
https://github.com/Mintplex-Labs/transformersjs-electron.git
synced 2026-07-01 14:10:59 -04:00
Implement numerically stable log_softmax() (#812)
* Implement numerically stable log_softmax() * Add unit tests * Update src/utils/maths.js --------- Co-authored-by: Joshua Lochner <admin@xenova.com>
This commit is contained in:
+13
-4
@@ -158,11 +158,20 @@ export function softmax(arr) {
|
||||
* @returns {T} The resulting log_softmax array.
|
||||
*/
|
||||
export function log_softmax(arr) {
|
||||
// Compute the softmax values
|
||||
const softmaxArr = softmax(arr);
|
||||
// Compute the maximum value in the array
|
||||
const maxVal = max(arr)[0];
|
||||
|
||||
// Apply log formula to each element
|
||||
const logSoftmaxArr = softmaxArr.map(x => Math.log(x));
|
||||
// Compute the sum of the exponentials
|
||||
let sumExps = 0;
|
||||
for(let i = 0; i < arr.length; ++i) {
|
||||
sumExps += Math.exp(arr[i] - maxVal);
|
||||
}
|
||||
|
||||
// Compute the log of the sum
|
||||
const logSum = Math.log(sumExps);
|
||||
|
||||
// Compute the softmax values
|
||||
const logSoftmaxArr = arr.map(x => x - maxVal - logSum);
|
||||
|
||||
return /** @type {T} */(logSoftmaxArr);
|
||||
}
|
||||
|
||||
+18
-1
@@ -2,7 +2,7 @@
|
||||
import { compare } from './test_utils.js';
|
||||
|
||||
import { getFile } from '../src/utils/hub.js';
|
||||
import { FFT, medianFilter, bankers_round } from '../src/utils/maths.js';
|
||||
import { FFT, medianFilter, bankers_round, log_softmax } from '../src/utils/maths.js';
|
||||
|
||||
|
||||
const fft = (arr, complex = false) => {
|
||||
@@ -136,4 +136,21 @@ describe('Mathematical operations', () => {
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe('log softmax', () => {
|
||||
// Should match output of scipy log_softmax
|
||||
it('should compute log softmax correctly for usual values', () => {
|
||||
const input = [0, 1, 2, 3];
|
||||
const expected = [-3.4401896985611953, -2.4401896985611953, -1.4401896985611953, -0.44018969856119533];
|
||||
const output = log_softmax(input);
|
||||
compare(output, expected, 1e-13);
|
||||
});
|
||||
|
||||
it('should compute log softmax correctly for values with large differences', () => {
|
||||
const input = [1000, 1];
|
||||
const expected = [0, -999];
|
||||
const output = log_softmax(input);
|
||||
compare(output, expected, 1e-13);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user