Bug 1905384 Implement Download Progress in Firefox's Local Inference Engine r=calixte,tarek

Differential Revision: https://phabricator.services.mozilla.com/D215635
This commit is contained in:
Aristide Tossou 2024-07-17 19:45:24 +00:00
parent 0150dcbf24
commit 36a89c5f0f
9 changed files with 1120 additions and 46 deletions

View File

@ -10,6 +10,7 @@ import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs";
/**
* @typedef {object} Lazy
* @typedef {import("../content/Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams
* @property {typeof import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker
* @property {typeof setTimeout} setTimeout
* @property {typeof clearTimeout} clearTimeout
@ -160,9 +161,10 @@ class EngineDispatcher {
* Any exception here will be bubbled up for the constructor to log.
*
* @param {PipelineOptions} pipelineOptions
* @param {?function(ProgressAndStatusCallbackParams):void} notificationsCallback The callback to call for updating about notifications such as dowload progress status.
* @returns {Promise<Engine>}
*/
async initializeInferenceEngine(pipelineOptions) {
async initializeInferenceEngine(pipelineOptions, notificationsCallback) {
// Create the inference engine given the wasm runtime and the options.
const wasm = await this.mlEngineChild.getWasmArrayBuffer();
const inferenceOptions = await this.mlEngineChild.getInferenceOptions(
@ -171,7 +173,11 @@ class EngineDispatcher {
lazy.console.debug("Inference engine options:", inferenceOptions);
pipelineOptions.updateOptions(inferenceOptions);
return InferenceEngine.create(wasm, pipelineOptions);
return InferenceEngine.create({
wasm,
pipelineOptions,
notificationsCallback,
});
}
/**
@ -184,7 +190,12 @@ class EngineDispatcher {
this.#taskName = pipelineOptions.taskName;
this.timeoutMS = pipelineOptions.timeoutMS;
this.#engine = this.initializeInferenceEngine(pipelineOptions);
this.#engine = this.initializeInferenceEngine(
pipelineOptions,
notificationsData => {
this.handleInitProgressStatus(port, notificationsData);
}
);
// Trigger the keep alive timer.
this.#engine
@ -201,6 +212,13 @@ class EngineDispatcher {
this.setupMessageHandler(port);
}
handleInitProgressStatus(port, notificationsData) {
port.postMessage({
type: "EnginePort:InitProgress",
statusResponse: notificationsData,
});
}
/**
* The worker needs to be shutdown after some amount of time of not being used.
*/
@ -338,12 +356,14 @@ let modelHub = null; // This will hold the ModelHub instance to reuse it.
* then fetches the model file using the ModelHub API. The `modelHub` instance is created
* only once and reused for subsequent calls to optimize performance.
*
* @param {string} url - The URL of the model file to fetch. Can be a path relative to
* @param {object} config
* @param {string} config.url - The URL of the model file to fetch. Can be a path relative to
* the model hub root or an absolute URL.
* @param {?function(ProgressAndStatusCallbackParams):void} config.progressCallback The callback to call for notifying about download progress status.
* @returns {Promise} A promise that resolves to a Meta object containing the URL, response headers,
* and data as an ArrayBuffer. The data is marked for transfer to avoid cloning.
*/
async function getModelFile(url) {
async function getModelFile({ url, progressCallback }) {
// Create the model hub instance if needed
if (!modelHub) {
lazy.console.debug("Creating model hub instance");
@ -365,7 +385,10 @@ async function getModelFile(url) {
// if this errors out, it will be caught in the worker
const parsedUrl = modelHub.parseUrl(url);
let [data, headers] = await modelHub.getModelFileAsArrayBuffer(parsedUrl);
let [data, headers] = await modelHub.getModelFileAsArrayBuffer({
...parsedUrl,
progressCallback,
});
return new lazy.BasePromiseWorker.Meta([url, headers, data], {
transfers: [data],
});
@ -381,16 +404,22 @@ class InferenceEngine {
/**
* Initialize the worker.
*
* @param {ArrayBuffer} wasm
* @param {PipelineOptions} pipelineOptions
* @param {object} config
* @param {ArrayBuffer} config.wasm
* @param {PipelineOptions} config.pipelineOptions
* @param {?function(ProgressAndStatusCallbackParams):void} config.notificationsCallback The callback to call for updating about notifications such as dowload progress status.
* @returns {InferenceEngine}
*/
static async create(wasm, pipelineOptions) {
static async create({ wasm, pipelineOptions, notificationsCallback }) {
/** @type {BasePromiseWorker} */
const worker = new lazy.BasePromiseWorker(
"chrome://global/content/ml/MLEngine.worker.mjs",
{ type: "module" },
{ getModelFile }
{
getModelFile: async url => {
return getModelFile({ url, progressCallback: notificationsCallback });
},
}
);
const args = [wasm, pipelineOptions];

View File

@ -4,6 +4,7 @@
/**
* @typedef {object} Lazy
* @typedef {import("../content/Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams
* @property {typeof console} console
* @property {typeof import("../content/Utils.sys.mjs").getRuntimeWasmFilename} getRuntimeWasmFilename
* @property {typeof import("../content/EngineProcess.sys.mjs").EngineProcess} EngineProcess
@ -80,10 +81,15 @@ export class MLEngineParent extends JSWindowActorParent {
/** Creates a new MLEngine.
*
* @param {PipelineOptions} pipelineOptions
* @param {?function(ProgressAndStatusCallbackParams):void} notificationsCallback A function to call to indicate progress status.
* @returns {MLEngine}
*/
getEngine(pipelineOptions) {
return new MLEngine({ mlEngineParent: this, pipelineOptions });
getEngine(pipelineOptions, notificationsCallback = null) {
return new MLEngine({
mlEngineParent: this,
pipelineOptions,
notificationsCallback,
});
}
/** Extracts the task name from the name and validates it.
@ -315,14 +321,23 @@ class MLEngine {
*/
engineStatus = "uninitialized";
/**
* Callback to call when receiving an initializing progress status.
*
* @type {?function(ProgressAndStatusCallbackParams):void}
*/
notificationsCallback = null;
/**
* @param {object} config - The configuration object for the instance.
* @param {object} config.mlEngineParent - The parent machine learning engine associated with this instance.
* @param {object} config.pipelineOptions - The options for configuring the pipeline associated with this instance.
* @param {?function(ProgressAndStatusCallbackParams):void} config.notificationsCallback - The initialization progress callback function to call.
*/
constructor({ mlEngineParent, pipelineOptions }) {
constructor({ mlEngineParent, pipelineOptions, notificationsCallback }) {
this.mlEngineParent = mlEngineParent;
this.pipelineOptions = pipelineOptions;
this.notificationsCallback = notificationsCallback;
this.#setupPortCommunication();
}
@ -402,6 +417,10 @@ class MLEngine {
this.discardPort();
break;
}
case "EnginePort:InitProgress": {
this.notificationsCallback?.(data.statusResponse);
break;
}
default:
lazy.console.error("Unknown port message from engine", data);
break;

View File

@ -16,6 +16,7 @@ ChromeUtils.defineESModuleGetters(
/**
* @typedef {import("../actors/MLEngineParent.sys.mjs").MLEngineParent} MLEngineParent
* @typedef {import("../content/Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams
*/
/**
@ -428,11 +429,12 @@ export class EngineProcess {
* Creates a new ML engine instance with the provided options.
*
* @param {object} options - Configuration options for the ML engine.
* @param {?function(ProgressAndStatusCallbackParams):void} notificationsCallback A function to call to indicate notifications.
* @returns {Promise<MLEngine>} - A promise that resolves to the ML engine instance.
*
*/
export async function createEngine(options) {
export async function createEngine(options, notificationsCallback = null) {
const pipelineOptions = new PipelineOptions(options);
const engineParent = await EngineProcess.getMLEngineParent();
return engineParent.getEngine(pipelineOptions);
return engineParent.getEngine(pipelineOptions, notificationsCallback);
}

View File

@ -1,11 +1,17 @@
/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
/**
* @typedef {import("./Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams
*/
const lazy = {};
ChromeUtils.defineESModuleGetters(lazy, {
clearTimeout: "resource://gre/modules/Timer.sys.mjs",
setTimeout: "resource://gre/modules/Timer.sys.mjs",
Progress: "chrome://global/content/ml/Utils.sys.mjs",
});
ChromeUtils.defineLazyGetter(lazy, "console", () => {
@ -22,7 +28,6 @@ const ALLOWED_HUBS = [
"https://localhost",
"https://model-hub.mozilla.org",
];
const ALLOWED_HEADERS_KEYS = ["Content-Type", "ETag", "status"];
const DEFAULT_URL_TEMPLATE = "{model}/resolve/{revision}";
@ -640,25 +645,7 @@ export class ModelHub {
}
/**
* Given an organization, model, and version, fetch a model file in the hub as an ArrayBuffer.
*
* @param {object} config
* @param {string} config.model
* @param {string} config.revision
* @param {string} config.file
* @returns {Promise<[ArrayBuffer, headers]>} The file content
*/
async getModelFileAsArrayBuffer({ model, revision, file }) {
const [blob, headers] = await this.getModelFileAsBlob({
model,
revision,
file,
});
return [await blob.arrayBuffer(), headers];
}
/**
* Given an organization, model, and version, fetch a model file in the hub as blob.
* Given an organization, model, and version, fetch a model file in the hub as an blob.
*
* @param {object} config
* @param {string} config.model
@ -667,6 +654,26 @@ export class ModelHub {
* @returns {Promise<[Blob, object]>} The file content
*/
async getModelFileAsBlob({ model, revision, file }) {
const [buffer, headers] = await this.getModelFileAsArrayBuffer({
model,
revision,
file,
});
return [new Blob([buffer]), headers];
}
/**
* Given an organization, model, and version, fetch a model file in the hub as an ArrayBuffer
* while supporting status callback.
*
* @param {object} config
* @param {string} config.model
* @param {string} config.revision
* @param {string} config.file
* @param {?function(ProgressAndStatusCallbackParams):void} config.progressCallback A function to call to indicate progress status.
* @returns {Promise<[ArrayBuffer, headers]>} The file content
*/
async getModelFileAsArrayBuffer({ model, revision, file, progressCallback }) {
// Make sure inputs are clean. We don't sanitize them but throw an exception
let checkError = this.#checkInput(model, revision, file);
if (checkError) {
@ -696,16 +703,77 @@ export class ModelHub {
useCached = await this.cache.fileExists(model, revision, file);
}
const progressInfo = {
progress: null,
totalLoaded: null,
currentLoaded: null,
total: null,
};
const statusInfo = {
metadata: { model, revision, file, url },
ok: true,
id: url,
};
if (useCached) {
lazy.console.debug(`Cache Hit for ${url}`);
return await this.cache.getFile(model, revision, file);
progressCallback?.(
new lazy.Progress.ProgressAndStatusCallbackParams({
...statusInfo,
...progressInfo,
type: lazy.Progress.ProgressType.LOAD_FROM_CACHE,
statusText: lazy.Progress.ProgressStatusText.INITIATE,
})
);
const [blob, headers] = await this.cache.getFile(model, revision, file);
progressCallback?.(
new lazy.Progress.ProgressAndStatusCallbackParams({
...statusInfo,
...progressInfo,
type: lazy.Progress.ProgressType.LOAD_FROM_CACHE,
statusText: lazy.Progress.ProgressStatusText.DONE,
})
);
return [await blob.arrayBuffer(), headers];
}
progressCallback?.(
new lazy.Progress.ProgressAndStatusCallbackParams({
...statusInfo,
...progressInfo,
type: lazy.Progress.ProgressType.DOWNLOAD,
statusText: lazy.Progress.ProgressStatusText.INITIATE,
})
);
lazy.console.debug(`Fetching ${url}`);
try {
const response = await fetch(url);
let response = await fetch(url);
let isFirstCall = true;
let responseContentArray = await lazy.Progress.readResponse(
response,
progressData => {
progressCallback?.(
new lazy.Progress.ProgressAndStatusCallbackParams({
...progressInfo,
...progressData,
statusText: isFirstCall
? lazy.Progress.ProgressStatusText.SIZE_ESTIMATE
: lazy.Progress.ProgressStatusText.IN_PROGRESS,
type: lazy.Progress.ProgressType.DOWNLOAD,
...statusInfo,
})
);
isFirstCall = false;
}
);
let responseContent = responseContentArray.buffer.slice(
responseContentArray.byteOffset,
responseContentArray.byteLength + responseContentArray.byteOffset
);
if (response.ok) {
const clone = response.clone();
const headers = {
// We don't store the boundary or the charset, just the content type,
// so we drop what's after the semicolon.
@ -717,15 +785,36 @@ export class ModelHub {
model,
revision,
file,
await clone.blob(),
new Blob([responseContent]),
headers
);
return [await response.blob(), headers];
progressCallback?.(
new lazy.Progress.ProgressAndStatusCallbackParams({
...statusInfo,
...progressInfo,
type: lazy.Progress.ProgressType.DOWNLOAD,
statusText: lazy.Progress.ProgressStatusText.DONE,
})
);
return [responseContent, headers];
}
} catch (error) {
lazy.console.error(`Failed to fetch ${url}:`, error);
}
// Indicate there is an error
progressCallback?.(
new lazy.Progress.ProgressAndStatusCallbackParams({
...statusInfo,
...progressInfo,
type: lazy.Progress.ProgressType.DOWNLOAD,
statusText: lazy.Progress.ProgressStatusText.DONE,
ok: false,
})
);
throw new Error(`Failed to fetch the model file: ${url}`);
}
}

View File

@ -75,3 +75,321 @@ export function getRuntimeWasmFilename(browsingContext = null) {
cachedRuntimeWasmFilename = res;
return res;
}
/**
* Enumeration for the progress status text.
*/
export const ProgressStatusText = Object.freeze({
// The value of the status text indicating that an operation is started.
INITIATE: "initiate",
// The value of the status text indicating an estimate for the size of the operation.
SIZE_ESTIMATE: "size_estimate",
// The value of the status text indicating that an operation is in progress.
IN_PROGRESS: "in_progress",
// The value of the status text indicating that an operation has completed.
DONE: "done",
});
/**
* Enumeration for type of progress operations.
*/
export const ProgressType = Object.freeze({
// The value of the operation type for a remote downloading.
DOWNLOAD: "downloading",
// The value of the operation type when loading from cache
LOAD_FROM_CACHE: "loading_from_cache",
});
/**
* This class encapsulates the parameters supported by a progress and status callback.
*/
export class ProgressAndStatusCallbackParams {
// Params for progress callback
/**
* A float indicating the percentage of data loaded. Note that
* 100% does not necessarily mean the operation is complete.
*
* @type {?float}
*/
progress = null;
/**
* A float indicating the total amount of data loaded so far.
* In particular, this is the sum of currentLoaded across all call of the callback.
*
* @type {?float}
*/
totalLoaded = null;
/**
* The amount of data loaded in the current callback call.
*
* @type {?float}
*/
currentLoaded = null;
/**
* A float indicating an estimate of the total amount of data to be loaded.
* Do not rely on this number as this is an estimate and the true total could be
* either lower or higher.
*
* @type {?float}
*/
total = null;
/**
* The units in which the amounts are reported.
*
* @type {?string}
*/
units = null;
// Params for status callback
/**
* The name of the operation being tracked.
*
* @type {?string}
*/
type = null;
/**
* A message indicating the status of the tracked operation.
*
* @type {?string}
*/
statusText = null;
/**
* An ID uniquely identifying the object/file being tracked.
*
* @type {?string}
*/
id = null;
/**
* A boolean indicating if the operation was successful.
* true means we have a successful operation.
*
* @type {?boolean}
*/
ok = null;
/**
* Any additional metadata for the operation being tracked.
*
* @type {?object}
*/
metadata = null;
constructor(params = {}) {
this.update(params);
}
update(params = {}) {
const allowedKeys = new Set(Object.keys(this));
const invalidKeys = Object.keys(params).filter(x => !allowedKeys.has(x));
if (invalidKeys.length) {
throw new Error(`Received Invalid option: ${invalidKeys}`);
}
for (const key of allowedKeys) {
if (key in params) {
this[key] = params[key];
}
}
}
}
/**
* Read and track progress when reading a Response object
*
* @param {any} response The Response object to read
* @param {?function(ProgressAndStatusCallbackParams):void} progressCallback The function to call with progress updates
*
* @returns {Promise<Uint8Array>} A Promise that resolves with the Uint8Array buffer
*/
export async function readResponse(response, progressCallback) {
const contentLength = response.headers.get("Content-Length");
if (!contentLength) {
console.warn(
"Unable to determine content-length from response headers. Will expand buffer when needed."
);
}
let total = parseInt(contentLength ?? "0");
progressCallback?.(
new ProgressAndStatusCallbackParams({
progress: 0,
totalLoaded: 0,
currentLoaded: 0,
total,
units: "bytes",
})
);
let buffer = new Uint8Array(total);
let loaded = 0;
for await (const value of response.body) {
let newLoaded = loaded + value.length;
if (newLoaded > total) {
total = newLoaded;
// Adding the new data will overflow buffer.
// In this case, we extend the buffer
// Happened when the content-length is lower than the actual lenght
let newBuffer = new Uint8Array(total);
// copy contents
newBuffer.set(buffer);
buffer = newBuffer;
}
buffer.set(value, loaded);
loaded = newLoaded;
const progress = (loaded / total) * 100;
progressCallback?.(
new ProgressAndStatusCallbackParams({
progress,
totalLoaded: loaded,
currentLoaded: value.length,
total,
units: "bytes",
})
);
}
// Ensure that buffer is not bigger than loaded
// Sometimes content length is larger than the actual size
buffer = buffer.slice(0, loaded);
return buffer;
}
/**
* Class for watching the progress bar of multiple events and combining
* then into a single progress bar.
*/
export class MultiProgressAggregator {
/**
* A function to call with the aggregated statistics.
*
* @type {?function(ProgressAndStatusCallbackParams):void}
*/
progressCallback = null;
/**
* The name of the key that contains status information.
*
* @type {Set<string>}
*/
watchedTypes;
/**
* The total amount of information loaded so far.
*
* @type {float}
*/
#combinedLoaded = 0;
/**
* The total amount of information to be loaded.
*
* @type {float}
*/
#combinedTotal = 0;
/**
* The number of operations that are yet to be completed.
*
* @type {float}
*/
#remainingEvents = 0;
/**
* The type of operation seen so far.
*
* @type {Set<string>}
*/
#seenTypes;
/**
* The status of text seen so far.
*
* @type {Set<string>}
*/
#seenStatus;
/**
* @param {object} config
* @param {?function(ProgressAndStatusCallbackParams):void} config.progressCallback - A function to call with the aggregated statistics.
* @param {Iterable<string>} config.watchedTypes - The types to watch for aggregation
*/
constructor({ progressCallback, watchedTypes = [ProgressType.DOWNLOAD] }) {
this.progressCallback = progressCallback;
this.watchedTypes = new Set(watchedTypes);
this.#seenTypes = new Set();
this.#seenStatus = new Set();
}
/**
* Callback function that will combined data from different objects/files.
*
* @param {ProgressAndStatusCallbackParams} data - object containing the data
*/
aggregateCallback(data) {
if (this.watchedTypes.has(data.type)) {
this.#seenTypes.add(data.type);
this.#seenStatus.add(data.statusText);
if (data.statusText == ProgressStatusText.INITIATE) {
this.#remainingEvents += 1;
}
if (data.statusText == ProgressStatusText.SIZE_ESTIMATE) {
this.#combinedTotal += data.total ?? 0;
}
if (data.statusText == ProgressStatusText.DONE) {
this.#remainingEvents -= 1;
}
this.#combinedLoaded += data.currentLoaded ?? 0;
if (this.progressCallback) {
let statusText = data.statusText;
if (this.#seenStatus.has(ProgressStatusText.IN_PROGRESS)) {
statusText = ProgressStatusText.IN_PROGRESS;
}
if (this.#remainingEvents == 0) {
statusText = ProgressStatusText.DONE;
}
this.progressCallback(
new ProgressAndStatusCallbackParams({
type: data.type,
statusText,
id: data.id,
total: this.#combinedTotal,
currentLoaded: data.currentLoaded,
totalLoaded: this.#combinedLoaded,
progress: (this.#combinedLoaded / this.#combinedTotal) * 100,
ok: data.ok,
units: data.units,
metadata: data,
})
);
}
}
}
}
// Create a "namespace" to make it easier to import multiple names.
export var Progress = Progress || {};
Progress.ProgressAndStatusCallbackParams = ProgressAndStatusCallbackParams;
Progress.ProgressStatusText = ProgressStatusText;
Progress.ProgressType = ProgressType;
Progress.readResponse = readResponse;

View File

@ -129,3 +129,50 @@ In the example below, an image is converted to text using the `moz-image-to-text
The following internal tasks are supported by the machine learning engine:
.. js:autofunction:: imageToText
Notifications callback
::::::::::::::::::::::
When initializing or running the engine, certain operations may take considerable time to complete.
You can receive progress notifications for these operations using a callback function.
Currently, progress notifications are supported only for model downloads.
When the engine is created, it will download any model not already in the cache.
Below is an example of using the callback function with the image-to-text model:
.. code-block:: javascript
const { createEngine } = ChromeUtils.importESModule("chrome://global/content/ml/EngineProcess.sys.mjs");
// options needed for the task
const options = {taskName: "moz-image-to-text" };
// We create the engine object, using options and a callback
const engine = await createEngine(options, progressData => {
console.log("Received progress data", progressData);
});
In the code above, **progressData** is an object of type `ProgressAndStatusCallbackParams` containing the following fields:
- **progress**: A float indicating the percentage of data loaded. Note that 100% does not necessarily mean the operation is complete.
- **totalLoaded**: A float indicating the total amount of data loaded so far.
- **currentLoaded**: The amount of data loaded in the current callback call.
- **total**: A float indicating an estimate of the total amount of data to be loaded.
- **units**: The units in which the amounts are reported.
- **type**: The name of the operation being tracked. It will be one of `ProgressType.DOWNLOAD`, `ProgressType.LOAD_FROM_CACHE`.
- **statusText**: A message indicating the status of the tracked operation, which can be:
- `ProgressStatusText.INITIATE` Indicates that an operation has started. This will be used exactly once for each operation uniquely identified by `id` and `type`.
- `ProgressStatusText.SIZE_ESTIMATE` Indicates an estimate for the size of the operation. This will be used exactly once for each operation uniquely identified by `id` and `type`, updating the `total`` field.
- `ProgressStatusText.IN_PROGRESS` Indicates that an operation is in progress. This will be used each time progress occurs, updating the `totalLoaded`` and `currentLoaded`` fields.
- `ProgressStatusText.DONE` indicating that an operation has completed.
- **id**: An ID uniquely identifying the object/file being tracked.
- **ok**: A boolean indicating if the operation was succesfull.
- **metadata**: Any additional metadata for the operation being tracked.

View File

@ -6,6 +6,10 @@ const { sinon } = ChromeUtils.importESModule(
"resource://testing-common/Sinon.sys.mjs"
);
const { ProgressStatusText, ProgressType } = ChromeUtils.importESModule(
"chrome://global/content/ml/Utils.sys.mjs"
);
// Root URL of the fake hub, see the `data` dir in the tests.
const FAKE_HUB =
"chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data";
@ -270,6 +274,157 @@ add_task(async function test_getting_file_from_cache() {
Assert.deepEqual(array, array2);
});
/**
* Test that the callback is appropriately called when the data is retrieved from the server
* or from the cache.
*/
add_task(async function test_getting_file_from_url_cache_with_callback() {
const hub = new ModelHub({ rootUrl: FAKE_HUB });
hub.cache = await initializeCache();
let numCalls = 0;
let currentData = null;
let array = await hub.getModelFileAsArrayBuffer({
...FAKE_MODEL_ARGS,
progressCallback: data => {
// expecting initiate status and download
currentData = data;
if (numCalls == 0) {
Assert.deepEqual(
{
type: data.type,
statusText: data.statusText,
ok: data.ok,
model: currentData?.metadata?.model,
file: currentData?.metadata?.file,
revision: currentData?.metadata?.revision,
},
{
type: ProgressType.DOWNLOAD,
statusText: ProgressStatusText.INITIATE,
ok: true,
...FAKE_MODEL_ARGS,
},
"Initiate Data from server should be correct"
);
}
if (numCalls == 1) {
Assert.deepEqual(
{
type: data.type,
statusText: data.statusText,
ok: data.ok,
model: currentData?.metadata?.model,
file: currentData?.metadata?.file,
revision: currentData?.metadata?.revision,
},
{
type: ProgressType.DOWNLOAD,
statusText: ProgressStatusText.SIZE_ESTIMATE,
ok: true,
...FAKE_MODEL_ARGS,
},
"size estimate Data from server should be correct"
);
}
numCalls += 1;
},
});
Assert.greaterOrEqual(numCalls, 3);
// last received message is DONE
Assert.deepEqual(
{
type: currentData?.type,
statusText: currentData?.statusText,
ok: currentData?.ok,
model: currentData?.metadata?.model,
file: currentData?.metadata?.file,
revision: currentData?.metadata?.revision,
},
{
type: ProgressType.DOWNLOAD,
statusText: ProgressStatusText.DONE,
ok: true,
...FAKE_MODEL_ARGS,
},
"Done Data from server should be correct"
);
// stub to verify that the data was retrieved from IndexDB
let matchMethod = hub.cache._testGetData;
sinon.stub(hub.cache, "_testGetData").callsFake(function () {
return matchMethod.apply(this, arguments).then(result => {
Assert.notEqual(result, null);
return result;
});
});
numCalls = 0;
currentData = null;
// Now we expect the callback to indicate cache usage.
let array2 = await hub.getModelFileAsArrayBuffer({
...FAKE_MODEL_ARGS,
progressCallback: data => {
// expecting initiate status and download
currentData = data;
if (numCalls == 0) {
Assert.deepEqual(
{
type: data.type,
statusText: data.statusText,
ok: data.ok,
model: currentData?.metadata?.model,
file: currentData?.metadata?.file,
revision: currentData?.metadata?.revision,
},
{
type: ProgressType.LOAD_FROM_CACHE,
statusText: ProgressStatusText.INITIATE,
ok: true,
...FAKE_MODEL_ARGS,
},
"Initiate Data from cache should be correct"
);
}
numCalls += 1;
},
});
hub.cache._testGetData.restore();
Assert.deepEqual(array, array2);
// last received message is DONE
Assert.deepEqual(
{
type: currentData?.type,
statusText: currentData?.statusText,
ok: currentData?.ok,
model: currentData?.metadata?.model,
file: currentData?.metadata?.file,
revision: currentData?.metadata?.revision,
},
{
type: ProgressType.LOAD_FROM_CACHE,
statusText: ProgressStatusText.DONE,
ok: true,
...FAKE_MODEL_ARGS,
},
"Done Data from cache should be correct"
);
await deleteCache(hub.cache);
});
/**
* Test parsing of a well-formed full URL, including protocol and path.
*/

View File

@ -2,9 +2,13 @@
http://creativecommons.org/publicdomain/zero/1.0/ */
"use strict";
const { arrayBufferToBlobURL } = ChromeUtils.importESModule(
"chrome://global/content/ml/Utils.sys.mjs"
);
const {
arrayBufferToBlobURL,
MultiProgressAggregator,
ProgressAndStatusCallbackParams,
ProgressStatusText,
readResponse,
} = ChromeUtils.importESModule("chrome://global/content/ml/Utils.sys.mjs");
/**
* Test arrayBufferToBlobURL function.
@ -24,3 +28,392 @@ add_task(async function test_ml_utils_array_buffer_to_blob_url() {
"The returned string should be a Blob URL"
);
});
/**
* Test that we can retrieve the correct content without a callback.
*/
add_task(async function test_correct_response_no_callback() {
const content = "This is the expected response.";
const blob = new Blob([content]);
const response = new Response(blob, {
headers: new Headers({ "Content-Length": blob.size }),
});
const responseArray = await readResponse(response);
const responseContent = new TextDecoder().decode(responseArray);
Assert.equal(content, responseContent, "The response content should match.");
});
/**
* Test that we can retrieve the correct content with a callback.
*/
add_task(async function test_correct_response_callback() {
const content = "This is the expected response.";
const blob = new Blob([content]);
const response = new Response(blob, {
headers: new Headers({ "Content-Length": blob.size }),
});
const responseArray = await readResponse(response, data => {
data;
});
const responseContent = new TextDecoder().decode(responseArray);
Assert.equal(content, responseContent, "The response content should match.");
});
/**
* Test that we can retrieve the correct content with a content-lenght lower than the actual len
*/
add_task(async function test_correct_response_content_length_under_reported() {
const content = "This is the expected response.";
const blob = new Blob([content]);
const response = new Response(blob, {
headers: new Headers({
"Content-Length": 1,
}),
});
const responseArray = await readResponse(response, data => {
data;
});
const responseContent = new TextDecoder().decode(responseArray);
Assert.equal(content, responseContent, "The response content should match.");
});
/**
* Test that we can retrieve the correct content with a content-lenght larger than the actual len
*/
add_task(async function test_correct_response_content_length_over_reported() {
const content = "This is the expected response.";
const blob = new Blob([content]);
const response = new Response(blob, {
headers: new Headers({
"Content-Length": 2 * blob.size + 20,
}),
});
const responseArray = await readResponse(response, data => {
data;
});
const responseContent = new TextDecoder().decode(responseArray);
Assert.equal(content, responseContent, "The response content should match.");
});
/**
* Test that we can retrieve and the callback provide correct information
*/
add_task(async function test_correct_response_callback_correct() {
const contents = ["Carrot", "Broccoli", "Tomato", "Spinach"];
let contentSizes = [];
let totalSize = 0;
for (const value of contents) {
contentSizes.push(new Blob([value]).size);
totalSize += contentSizes[contentSizes.length - 1];
}
const numChunks = contents.length;
let encoder = new TextEncoder();
// const stream = ReadableStream.from(contents);
let streamId = -1;
const stream = new ReadableStream({
pull(controller) {
streamId += 1;
if (streamId < numChunks) {
controller.enqueue(encoder.encode(contents[streamId]));
} else {
controller.close();
}
},
});
const response = new Response(stream, {
headers: new Headers({
"Content-Length": totalSize,
}),
});
let chunkId = -1;
let expectedTotalLoaded = 0;
const responseArray = await readResponse(response, data => {
chunkId += 1;
// The callback is called on time with no data loaded and just the total
if (chunkId == 0) {
Assert.deepEqual(
{
total: data.total,
currentLoaded: data.currentLoaded,
totalLoaded: data.totalLoaded,
},
{
total: totalSize,
currentLoaded: 0,
totalLoaded: 0,
},
"The callback should be called on time with an estimate of the total size and no data read. "
);
} else {
Assert.less(
chunkId - 1,
numChunks,
"The number of times the callback is called should be lower than the number of chunks"
);
expectedTotalLoaded += contentSizes[chunkId - 1];
Assert.deepEqual(
{
total: data.total,
currentLoaded: data.currentLoaded,
totalLoaded: data.totalLoaded,
},
{
total: totalSize,
currentLoaded: contentSizes[chunkId - 1],
totalLoaded: expectedTotalLoaded,
},
"The reported value by the callback should match the correct values"
);
}
});
Assert.equal(
chunkId,
numChunks,
"The callback should be called exactly as many times as the number of chunks."
);
const responseContent = new TextDecoder().decode(
responseArray.buffer.slice(
responseArray.byteOffset,
responseArray.byteLength + responseArray.byteOffset
)
);
Assert.equal(
contents.join(""),
responseContent,
"The response content should match."
);
});
/**
* Test that multi-aggregator only call the callback for the provided types.
*/
add_task(async function test_multi_aggregator_watchtypes() {
let numCalls = 0;
let aggregator = new MultiProgressAggregator({
progressCallback: _ => {
numCalls += 1;
},
watchedTypes: ["t1"],
});
aggregator.aggregateCallback(
new ProgressAndStatusCallbackParams({
type: "download",
})
);
Assert.equal(numCalls, 0);
aggregator.aggregateCallback(
new ProgressAndStatusCallbackParams({
type: "t1",
})
);
Assert.equal(numCalls, 1);
});
/**
* Test that multi-aggregator aggregate correctly.
*/
add_task(async function test_multi_aggregator() {
// Ids for all available tasks. Should be unique per element.
const taskIds = ["A", "B", "C", "D", "E", "F"];
// The type for each available tasks.
const taskTypes = ["t1", "t1", "t2", "t2", "t3", "t3"];
// The total size available for each task
const taskSizes = [5, 11, 13, 17, 19, 23];
// The chunk sizes. The sum for indices with same chunk task index (according to chunkTaskIndex)
// should be equal to the corresponding size in taskSizes
const chunkSizes = [2, 3, 5, 6, 11, 7, 12, 6, 8, 9, 9, 10];
// Task index for each chunk. Index in the array taskIds. Order was chosen so that we can simulate
// overlaps in tasks.
const chunkTaskIndex = [0, 0, 1, 2, 5, 2, 5, 1, 3, 4, 3, 4];
// Indicating how much has been loaded for the task so far.
const chunkTaskLoaded = [2, 5, 5, 6, 11, 13, 23, 11, 8, 9, 17, 19];
// Whether the
const chunkIsFinal = [0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1];
let numDone = 0;
let currentData = null;
let expectedTotalToLoad = 0;
let numCalls = 0;
let expectedNumCalls = 0;
let expectedTotalLoaded = 0;
const aggregator = new MultiProgressAggregator({
progressCallback: data => {
currentData = data;
numCalls += 1;
if (data.statusText == ProgressStatusText.DONE) {
numDone += 1;
}
},
watchedTypes: ["t1", "t2", "t3"],
});
// Initiate and advertise the size for each task
for (const i in taskTypes) {
currentData = null;
expectedNumCalls += 1;
aggregator.aggregateCallback(
new ProgressAndStatusCallbackParams({
type: taskTypes[i],
statusText: ProgressStatusText.INITIATE,
id: taskIds[i],
total: taskSizes[i],
})
);
Assert.ok(currentData, "Received data should be defined");
Assert.deepEqual(
{
statusText: currentData?.statusText,
type: currentData?.type,
id: currentData?.id,
numDone,
numCalls,
},
{
statusText: ProgressStatusText.INITIATE,
type: taskTypes[i],
id: taskIds[i],
numDone: 0,
numCalls: expectedNumCalls,
},
"Data received after initiate should be correct"
);
currentData = null;
expectedNumCalls += 1;
aggregator.aggregateCallback(
new ProgressAndStatusCallbackParams({
type: taskTypes[i],
statusText: ProgressStatusText.SIZE_ESTIMATE,
id: taskIds[i],
total: taskSizes[i],
})
);
Assert.ok(currentData, "Received data should be defined");
expectedTotalToLoad += taskSizes[i];
Assert.deepEqual(
{
numDone,
numCalls,
total: currentData.total,
},
{
numDone: 0,
total: expectedTotalToLoad,
numCalls: expectedNumCalls,
},
"Data received after size estimate should be correct."
);
}
// Send progress status for each chunk.
for (const chunkIndex in chunkTaskIndex) {
let taskIndex = chunkTaskIndex[chunkIndex];
currentData = null;
expectedNumCalls += 1;
expectedTotalLoaded += chunkSizes[chunkIndex];
aggregator.aggregateCallback(
new ProgressAndStatusCallbackParams({
type: taskTypes[taskIndex],
statusText: ProgressStatusText.IN_PROGRESS,
id: taskIds[taskIndex],
total: taskSizes[taskIndex],
currentLoaded: chunkSizes[chunkIndex],
totalLoaded: chunkTaskLoaded[chunkIndex],
})
);
Assert.ok(currentData, "Received data should be defined");
Assert.deepEqual(
{
numDone,
numCalls,
total: currentData?.total,
currentLoaded: currentData?.currentLoaded,
totalLoaded: currentData?.totalLoaded,
},
{
numDone: 0,
numCalls: expectedNumCalls,
total: expectedTotalToLoad,
currentLoaded: chunkSizes[chunkIndex],
totalLoaded: expectedTotalLoaded,
},
"Data received after in progress should be correct"
);
// Notify of task is done
if (chunkIsFinal[chunkIndex]) {
currentData = null;
expectedNumCalls += 1;
aggregator.aggregateCallback(
new ProgressAndStatusCallbackParams({
type: taskTypes[taskIndex],
statusText: ProgressStatusText.DONE,
id: taskIds[taskIndex],
total: taskSizes[chunkIndex],
})
);
Assert.ok(currentData, "Received data should be defined");
Assert.deepEqual(
{ total: currentData.total, numCalls },
{ total: expectedTotalToLoad, numCalls: expectedNumCalls },
"Data received after completed tasks should be correct"
);
}
}
Assert.equal(numDone, 1, "Done status should be received");
});

View File

@ -19,12 +19,20 @@ const lazy = {};
ChromeUtils.defineESModuleGetters(lazy, {
createEngine: "chrome://global/content/ml/EngineProcess.sys.mjs",
MultiProgressAggregator: "chrome://global/content/ml/Utils.sys.mjs",
NimbusFeatures: "resource://nimbus/ExperimentAPI.sys.mjs",
PdfJsTelemetry: "resource://pdf.js/PdfJsTelemetry.sys.mjs",
PrivateBrowsingUtils: "resource://gre/modules/PrivateBrowsingUtils.sys.mjs",
SetClipboardSearchString: "resource://gre/modules/Finder.sys.mjs",
});
ChromeUtils.defineLazyGetter(lazy, "console", () => {
return console.createInstance({
maxLogLevelPref: "browser.ml.logLevel",
prefix: "PDF_JS",
});
});
var Svc = {};
XPCOMUtils.defineLazyServiceGetter(
Svc,
@ -95,6 +103,10 @@ export class PdfjsParent extends JSWindowActorParent {
lazy.PdfJsTelemetry.report(aMsg.data);
}
_initProgressBar(progressData) {
lazy.console.debug("progess_from_pdfjs", progressData);
}
async _mlGuess({ data: { service, request } }) {
if (!lazy.createEngine) {
return null;
@ -102,10 +114,20 @@ export class PdfjsParent extends JSWindowActorParent {
if (service !== "image-to-text") {
throw new Error("Invalid service");
}
// We are using the internal task name prefixed with moz-
const engine = await lazy.createEngine({
taskName: "moz-image-to-text",
let aggregator = new lazy.MultiProgressAggregator({
progressCallback: this._initProgressBar,
});
const callback = aggregator.aggregateCallback.bind(aggregator);
// We are using the internal task name prefixed with moz-
const engine = await lazy.createEngine(
{
taskName: "moz-image-to-text",
},
callback
);
return engine.run(request);
}