mirror of
https://github.com/mozilla/gecko-dev.git
synced 2024-11-23 12:51:06 +00:00
Bug 1905384 Implement Download Progress in Firefox's Local Inference Engine r=calixte,tarek
Differential Revision: https://phabricator.services.mozilla.com/D215635
This commit is contained in:
parent
0150dcbf24
commit
36a89c5f0f
@ -10,6 +10,7 @@ import { XPCOMUtils } from "resource://gre/modules/XPCOMUtils.sys.mjs";
|
||||
|
||||
/**
|
||||
* @typedef {object} Lazy
|
||||
* @typedef {import("../content/Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams
|
||||
* @property {typeof import("../../promiseworker/PromiseWorker.sys.mjs").BasePromiseWorker} BasePromiseWorker
|
||||
* @property {typeof setTimeout} setTimeout
|
||||
* @property {typeof clearTimeout} clearTimeout
|
||||
@ -160,9 +161,10 @@ class EngineDispatcher {
|
||||
* Any exception here will be bubbled up for the constructor to log.
|
||||
*
|
||||
* @param {PipelineOptions} pipelineOptions
|
||||
* @param {?function(ProgressAndStatusCallbackParams):void} notificationsCallback The callback to call for updating about notifications such as dowload progress status.
|
||||
* @returns {Promise<Engine>}
|
||||
*/
|
||||
async initializeInferenceEngine(pipelineOptions) {
|
||||
async initializeInferenceEngine(pipelineOptions, notificationsCallback) {
|
||||
// Create the inference engine given the wasm runtime and the options.
|
||||
const wasm = await this.mlEngineChild.getWasmArrayBuffer();
|
||||
const inferenceOptions = await this.mlEngineChild.getInferenceOptions(
|
||||
@ -171,7 +173,11 @@ class EngineDispatcher {
|
||||
lazy.console.debug("Inference engine options:", inferenceOptions);
|
||||
pipelineOptions.updateOptions(inferenceOptions);
|
||||
|
||||
return InferenceEngine.create(wasm, pipelineOptions);
|
||||
return InferenceEngine.create({
|
||||
wasm,
|
||||
pipelineOptions,
|
||||
notificationsCallback,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
@ -184,7 +190,12 @@ class EngineDispatcher {
|
||||
this.#taskName = pipelineOptions.taskName;
|
||||
this.timeoutMS = pipelineOptions.timeoutMS;
|
||||
|
||||
this.#engine = this.initializeInferenceEngine(pipelineOptions);
|
||||
this.#engine = this.initializeInferenceEngine(
|
||||
pipelineOptions,
|
||||
notificationsData => {
|
||||
this.handleInitProgressStatus(port, notificationsData);
|
||||
}
|
||||
);
|
||||
|
||||
// Trigger the keep alive timer.
|
||||
this.#engine
|
||||
@ -201,6 +212,13 @@ class EngineDispatcher {
|
||||
this.setupMessageHandler(port);
|
||||
}
|
||||
|
||||
handleInitProgressStatus(port, notificationsData) {
|
||||
port.postMessage({
|
||||
type: "EnginePort:InitProgress",
|
||||
statusResponse: notificationsData,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* The worker needs to be shutdown after some amount of time of not being used.
|
||||
*/
|
||||
@ -338,12 +356,14 @@ let modelHub = null; // This will hold the ModelHub instance to reuse it.
|
||||
* then fetches the model file using the ModelHub API. The `modelHub` instance is created
|
||||
* only once and reused for subsequent calls to optimize performance.
|
||||
*
|
||||
* @param {string} url - The URL of the model file to fetch. Can be a path relative to
|
||||
* @param {object} config
|
||||
* @param {string} config.url - The URL of the model file to fetch. Can be a path relative to
|
||||
* the model hub root or an absolute URL.
|
||||
* @param {?function(ProgressAndStatusCallbackParams):void} config.progressCallback The callback to call for notifying about download progress status.
|
||||
* @returns {Promise} A promise that resolves to a Meta object containing the URL, response headers,
|
||||
* and data as an ArrayBuffer. The data is marked for transfer to avoid cloning.
|
||||
*/
|
||||
async function getModelFile(url) {
|
||||
async function getModelFile({ url, progressCallback }) {
|
||||
// Create the model hub instance if needed
|
||||
if (!modelHub) {
|
||||
lazy.console.debug("Creating model hub instance");
|
||||
@ -365,7 +385,10 @@ async function getModelFile(url) {
|
||||
// if this errors out, it will be caught in the worker
|
||||
const parsedUrl = modelHub.parseUrl(url);
|
||||
|
||||
let [data, headers] = await modelHub.getModelFileAsArrayBuffer(parsedUrl);
|
||||
let [data, headers] = await modelHub.getModelFileAsArrayBuffer({
|
||||
...parsedUrl,
|
||||
progressCallback,
|
||||
});
|
||||
return new lazy.BasePromiseWorker.Meta([url, headers, data], {
|
||||
transfers: [data],
|
||||
});
|
||||
@ -381,16 +404,22 @@ class InferenceEngine {
|
||||
/**
|
||||
* Initialize the worker.
|
||||
*
|
||||
* @param {ArrayBuffer} wasm
|
||||
* @param {PipelineOptions} pipelineOptions
|
||||
* @param {object} config
|
||||
* @param {ArrayBuffer} config.wasm
|
||||
* @param {PipelineOptions} config.pipelineOptions
|
||||
* @param {?function(ProgressAndStatusCallbackParams):void} config.notificationsCallback The callback to call for updating about notifications such as dowload progress status.
|
||||
* @returns {InferenceEngine}
|
||||
*/
|
||||
static async create(wasm, pipelineOptions) {
|
||||
static async create({ wasm, pipelineOptions, notificationsCallback }) {
|
||||
/** @type {BasePromiseWorker} */
|
||||
const worker = new lazy.BasePromiseWorker(
|
||||
"chrome://global/content/ml/MLEngine.worker.mjs",
|
||||
{ type: "module" },
|
||||
{ getModelFile }
|
||||
{
|
||||
getModelFile: async url => {
|
||||
return getModelFile({ url, progressCallback: notificationsCallback });
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const args = [wasm, pipelineOptions];
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
/**
|
||||
* @typedef {object} Lazy
|
||||
* @typedef {import("../content/Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams
|
||||
* @property {typeof console} console
|
||||
* @property {typeof import("../content/Utils.sys.mjs").getRuntimeWasmFilename} getRuntimeWasmFilename
|
||||
* @property {typeof import("../content/EngineProcess.sys.mjs").EngineProcess} EngineProcess
|
||||
@ -80,10 +81,15 @@ export class MLEngineParent extends JSWindowActorParent {
|
||||
/** Creates a new MLEngine.
|
||||
*
|
||||
* @param {PipelineOptions} pipelineOptions
|
||||
* @param {?function(ProgressAndStatusCallbackParams):void} notificationsCallback A function to call to indicate progress status.
|
||||
* @returns {MLEngine}
|
||||
*/
|
||||
getEngine(pipelineOptions) {
|
||||
return new MLEngine({ mlEngineParent: this, pipelineOptions });
|
||||
getEngine(pipelineOptions, notificationsCallback = null) {
|
||||
return new MLEngine({
|
||||
mlEngineParent: this,
|
||||
pipelineOptions,
|
||||
notificationsCallback,
|
||||
});
|
||||
}
|
||||
|
||||
/** Extracts the task name from the name and validates it.
|
||||
@ -315,14 +321,23 @@ class MLEngine {
|
||||
*/
|
||||
engineStatus = "uninitialized";
|
||||
|
||||
/**
|
||||
* Callback to call when receiving an initializing progress status.
|
||||
*
|
||||
* @type {?function(ProgressAndStatusCallbackParams):void}
|
||||
*/
|
||||
notificationsCallback = null;
|
||||
|
||||
/**
|
||||
* @param {object} config - The configuration object for the instance.
|
||||
* @param {object} config.mlEngineParent - The parent machine learning engine associated with this instance.
|
||||
* @param {object} config.pipelineOptions - The options for configuring the pipeline associated with this instance.
|
||||
* @param {?function(ProgressAndStatusCallbackParams):void} config.notificationsCallback - The initialization progress callback function to call.
|
||||
*/
|
||||
constructor({ mlEngineParent, pipelineOptions }) {
|
||||
constructor({ mlEngineParent, pipelineOptions, notificationsCallback }) {
|
||||
this.mlEngineParent = mlEngineParent;
|
||||
this.pipelineOptions = pipelineOptions;
|
||||
this.notificationsCallback = notificationsCallback;
|
||||
this.#setupPortCommunication();
|
||||
}
|
||||
|
||||
@ -402,6 +417,10 @@ class MLEngine {
|
||||
this.discardPort();
|
||||
break;
|
||||
}
|
||||
case "EnginePort:InitProgress": {
|
||||
this.notificationsCallback?.(data.statusResponse);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
lazy.console.error("Unknown port message from engine", data);
|
||||
break;
|
||||
|
@ -16,6 +16,7 @@ ChromeUtils.defineESModuleGetters(
|
||||
|
||||
/**
|
||||
* @typedef {import("../actors/MLEngineParent.sys.mjs").MLEngineParent} MLEngineParent
|
||||
* @typedef {import("../content/Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams
|
||||
*/
|
||||
|
||||
/**
|
||||
@ -428,11 +429,12 @@ export class EngineProcess {
|
||||
* Creates a new ML engine instance with the provided options.
|
||||
*
|
||||
* @param {object} options - Configuration options for the ML engine.
|
||||
* @param {?function(ProgressAndStatusCallbackParams):void} notificationsCallback A function to call to indicate notifications.
|
||||
* @returns {Promise<MLEngine>} - A promise that resolves to the ML engine instance.
|
||||
*
|
||||
*/
|
||||
export async function createEngine(options) {
|
||||
export async function createEngine(options, notificationsCallback = null) {
|
||||
const pipelineOptions = new PipelineOptions(options);
|
||||
const engineParent = await EngineProcess.getMLEngineParent();
|
||||
return engineParent.getEngine(pipelineOptions);
|
||||
return engineParent.getEngine(pipelineOptions, notificationsCallback);
|
||||
}
|
||||
|
@ -1,11 +1,17 @@
|
||||
/* This Source Code Form is subject to the terms of the Mozilla Public
|
||||
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
|
||||
|
||||
/**
|
||||
* @typedef {import("./Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams
|
||||
*/
|
||||
|
||||
const lazy = {};
|
||||
|
||||
ChromeUtils.defineESModuleGetters(lazy, {
|
||||
clearTimeout: "resource://gre/modules/Timer.sys.mjs",
|
||||
setTimeout: "resource://gre/modules/Timer.sys.mjs",
|
||||
Progress: "chrome://global/content/ml/Utils.sys.mjs",
|
||||
});
|
||||
|
||||
ChromeUtils.defineLazyGetter(lazy, "console", () => {
|
||||
@ -22,7 +28,6 @@ const ALLOWED_HUBS = [
|
||||
"https://localhost",
|
||||
"https://model-hub.mozilla.org",
|
||||
];
|
||||
|
||||
const ALLOWED_HEADERS_KEYS = ["Content-Type", "ETag", "status"];
|
||||
const DEFAULT_URL_TEMPLATE = "{model}/resolve/{revision}";
|
||||
|
||||
@ -640,25 +645,7 @@ export class ModelHub {
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an organization, model, and version, fetch a model file in the hub as an ArrayBuffer.
|
||||
*
|
||||
* @param {object} config
|
||||
* @param {string} config.model
|
||||
* @param {string} config.revision
|
||||
* @param {string} config.file
|
||||
* @returns {Promise<[ArrayBuffer, headers]>} The file content
|
||||
*/
|
||||
async getModelFileAsArrayBuffer({ model, revision, file }) {
|
||||
const [blob, headers] = await this.getModelFileAsBlob({
|
||||
model,
|
||||
revision,
|
||||
file,
|
||||
});
|
||||
return [await blob.arrayBuffer(), headers];
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an organization, model, and version, fetch a model file in the hub as blob.
|
||||
* Given an organization, model, and version, fetch a model file in the hub as an blob.
|
||||
*
|
||||
* @param {object} config
|
||||
* @param {string} config.model
|
||||
@ -667,6 +654,26 @@ export class ModelHub {
|
||||
* @returns {Promise<[Blob, object]>} The file content
|
||||
*/
|
||||
async getModelFileAsBlob({ model, revision, file }) {
|
||||
const [buffer, headers] = await this.getModelFileAsArrayBuffer({
|
||||
model,
|
||||
revision,
|
||||
file,
|
||||
});
|
||||
return [new Blob([buffer]), headers];
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an organization, model, and version, fetch a model file in the hub as an ArrayBuffer
|
||||
* while supporting status callback.
|
||||
*
|
||||
* @param {object} config
|
||||
* @param {string} config.model
|
||||
* @param {string} config.revision
|
||||
* @param {string} config.file
|
||||
* @param {?function(ProgressAndStatusCallbackParams):void} config.progressCallback A function to call to indicate progress status.
|
||||
* @returns {Promise<[ArrayBuffer, headers]>} The file content
|
||||
*/
|
||||
async getModelFileAsArrayBuffer({ model, revision, file, progressCallback }) {
|
||||
// Make sure inputs are clean. We don't sanitize them but throw an exception
|
||||
let checkError = this.#checkInput(model, revision, file);
|
||||
if (checkError) {
|
||||
@ -696,16 +703,77 @@ export class ModelHub {
|
||||
useCached = await this.cache.fileExists(model, revision, file);
|
||||
}
|
||||
|
||||
const progressInfo = {
|
||||
progress: null,
|
||||
totalLoaded: null,
|
||||
currentLoaded: null,
|
||||
total: null,
|
||||
};
|
||||
|
||||
const statusInfo = {
|
||||
metadata: { model, revision, file, url },
|
||||
ok: true,
|
||||
id: url,
|
||||
};
|
||||
|
||||
if (useCached) {
|
||||
lazy.console.debug(`Cache Hit for ${url}`);
|
||||
return await this.cache.getFile(model, revision, file);
|
||||
progressCallback?.(
|
||||
new lazy.Progress.ProgressAndStatusCallbackParams({
|
||||
...statusInfo,
|
||||
...progressInfo,
|
||||
type: lazy.Progress.ProgressType.LOAD_FROM_CACHE,
|
||||
statusText: lazy.Progress.ProgressStatusText.INITIATE,
|
||||
})
|
||||
);
|
||||
const [blob, headers] = await this.cache.getFile(model, revision, file);
|
||||
progressCallback?.(
|
||||
new lazy.Progress.ProgressAndStatusCallbackParams({
|
||||
...statusInfo,
|
||||
...progressInfo,
|
||||
type: lazy.Progress.ProgressType.LOAD_FROM_CACHE,
|
||||
statusText: lazy.Progress.ProgressStatusText.DONE,
|
||||
})
|
||||
);
|
||||
return [await blob.arrayBuffer(), headers];
|
||||
}
|
||||
|
||||
progressCallback?.(
|
||||
new lazy.Progress.ProgressAndStatusCallbackParams({
|
||||
...statusInfo,
|
||||
...progressInfo,
|
||||
type: lazy.Progress.ProgressType.DOWNLOAD,
|
||||
statusText: lazy.Progress.ProgressStatusText.INITIATE,
|
||||
})
|
||||
);
|
||||
|
||||
lazy.console.debug(`Fetching ${url}`);
|
||||
try {
|
||||
const response = await fetch(url);
|
||||
let response = await fetch(url);
|
||||
let isFirstCall = true;
|
||||
let responseContentArray = await lazy.Progress.readResponse(
|
||||
response,
|
||||
progressData => {
|
||||
progressCallback?.(
|
||||
new lazy.Progress.ProgressAndStatusCallbackParams({
|
||||
...progressInfo,
|
||||
...progressData,
|
||||
statusText: isFirstCall
|
||||
? lazy.Progress.ProgressStatusText.SIZE_ESTIMATE
|
||||
: lazy.Progress.ProgressStatusText.IN_PROGRESS,
|
||||
type: lazy.Progress.ProgressType.DOWNLOAD,
|
||||
...statusInfo,
|
||||
})
|
||||
);
|
||||
isFirstCall = false;
|
||||
}
|
||||
);
|
||||
let responseContent = responseContentArray.buffer.slice(
|
||||
responseContentArray.byteOffset,
|
||||
responseContentArray.byteLength + responseContentArray.byteOffset
|
||||
);
|
||||
|
||||
if (response.ok) {
|
||||
const clone = response.clone();
|
||||
const headers = {
|
||||
// We don't store the boundary or the charset, just the content type,
|
||||
// so we drop what's after the semicolon.
|
||||
@ -717,15 +785,36 @@ export class ModelHub {
|
||||
model,
|
||||
revision,
|
||||
file,
|
||||
await clone.blob(),
|
||||
new Blob([responseContent]),
|
||||
headers
|
||||
);
|
||||
return [await response.blob(), headers];
|
||||
|
||||
progressCallback?.(
|
||||
new lazy.Progress.ProgressAndStatusCallbackParams({
|
||||
...statusInfo,
|
||||
...progressInfo,
|
||||
type: lazy.Progress.ProgressType.DOWNLOAD,
|
||||
statusText: lazy.Progress.ProgressStatusText.DONE,
|
||||
})
|
||||
);
|
||||
|
||||
return [responseContent, headers];
|
||||
}
|
||||
} catch (error) {
|
||||
lazy.console.error(`Failed to fetch ${url}:`, error);
|
||||
}
|
||||
|
||||
// Indicate there is an error
|
||||
progressCallback?.(
|
||||
new lazy.Progress.ProgressAndStatusCallbackParams({
|
||||
...statusInfo,
|
||||
...progressInfo,
|
||||
type: lazy.Progress.ProgressType.DOWNLOAD,
|
||||
statusText: lazy.Progress.ProgressStatusText.DONE,
|
||||
ok: false,
|
||||
})
|
||||
);
|
||||
|
||||
throw new Error(`Failed to fetch the model file: ${url}`);
|
||||
}
|
||||
}
|
||||
|
@ -75,3 +75,321 @@ export function getRuntimeWasmFilename(browsingContext = null) {
|
||||
cachedRuntimeWasmFilename = res;
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enumeration for the progress status text.
|
||||
*/
|
||||
export const ProgressStatusText = Object.freeze({
|
||||
// The value of the status text indicating that an operation is started.
|
||||
INITIATE: "initiate",
|
||||
// The value of the status text indicating an estimate for the size of the operation.
|
||||
SIZE_ESTIMATE: "size_estimate",
|
||||
// The value of the status text indicating that an operation is in progress.
|
||||
IN_PROGRESS: "in_progress",
|
||||
// The value of the status text indicating that an operation has completed.
|
||||
DONE: "done",
|
||||
});
|
||||
|
||||
/**
|
||||
* Enumeration for type of progress operations.
|
||||
*/
|
||||
export const ProgressType = Object.freeze({
|
||||
// The value of the operation type for a remote downloading.
|
||||
DOWNLOAD: "downloading",
|
||||
// The value of the operation type when loading from cache
|
||||
LOAD_FROM_CACHE: "loading_from_cache",
|
||||
});
|
||||
|
||||
/**
|
||||
* This class encapsulates the parameters supported by a progress and status callback.
|
||||
*/
|
||||
export class ProgressAndStatusCallbackParams {
|
||||
// Params for progress callback
|
||||
|
||||
/**
|
||||
* A float indicating the percentage of data loaded. Note that
|
||||
* 100% does not necessarily mean the operation is complete.
|
||||
*
|
||||
* @type {?float}
|
||||
*/
|
||||
progress = null;
|
||||
|
||||
/**
|
||||
* A float indicating the total amount of data loaded so far.
|
||||
* In particular, this is the sum of currentLoaded across all call of the callback.
|
||||
*
|
||||
* @type {?float}
|
||||
*/
|
||||
totalLoaded = null;
|
||||
|
||||
/**
|
||||
* The amount of data loaded in the current callback call.
|
||||
*
|
||||
* @type {?float}
|
||||
*/
|
||||
currentLoaded = null;
|
||||
|
||||
/**
|
||||
* A float indicating an estimate of the total amount of data to be loaded.
|
||||
* Do not rely on this number as this is an estimate and the true total could be
|
||||
* either lower or higher.
|
||||
*
|
||||
* @type {?float}
|
||||
*/
|
||||
total = null;
|
||||
|
||||
/**
|
||||
* The units in which the amounts are reported.
|
||||
*
|
||||
* @type {?string}
|
||||
*/
|
||||
units = null;
|
||||
|
||||
// Params for status callback
|
||||
/**
|
||||
* The name of the operation being tracked.
|
||||
*
|
||||
* @type {?string}
|
||||
*/
|
||||
type = null;
|
||||
|
||||
/**
|
||||
* A message indicating the status of the tracked operation.
|
||||
*
|
||||
* @type {?string}
|
||||
*/
|
||||
statusText = null;
|
||||
|
||||
/**
|
||||
* An ID uniquely identifying the object/file being tracked.
|
||||
*
|
||||
* @type {?string}
|
||||
*/
|
||||
id = null;
|
||||
|
||||
/**
|
||||
* A boolean indicating if the operation was successful.
|
||||
* true means we have a successful operation.
|
||||
*
|
||||
* @type {?boolean}
|
||||
*/
|
||||
ok = null;
|
||||
|
||||
/**
|
||||
* Any additional metadata for the operation being tracked.
|
||||
*
|
||||
* @type {?object}
|
||||
*/
|
||||
metadata = null;
|
||||
|
||||
constructor(params = {}) {
|
||||
this.update(params);
|
||||
}
|
||||
|
||||
update(params = {}) {
|
||||
const allowedKeys = new Set(Object.keys(this));
|
||||
const invalidKeys = Object.keys(params).filter(x => !allowedKeys.has(x));
|
||||
if (invalidKeys.length) {
|
||||
throw new Error(`Received Invalid option: ${invalidKeys}`);
|
||||
}
|
||||
for (const key of allowedKeys) {
|
||||
if (key in params) {
|
||||
this[key] = params[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read and track progress when reading a Response object
|
||||
*
|
||||
* @param {any} response The Response object to read
|
||||
* @param {?function(ProgressAndStatusCallbackParams):void} progressCallback The function to call with progress updates
|
||||
*
|
||||
* @returns {Promise<Uint8Array>} A Promise that resolves with the Uint8Array buffer
|
||||
*/
|
||||
export async function readResponse(response, progressCallback) {
|
||||
const contentLength = response.headers.get("Content-Length");
|
||||
if (!contentLength) {
|
||||
console.warn(
|
||||
"Unable to determine content-length from response headers. Will expand buffer when needed."
|
||||
);
|
||||
}
|
||||
let total = parseInt(contentLength ?? "0");
|
||||
|
||||
progressCallback?.(
|
||||
new ProgressAndStatusCallbackParams({
|
||||
progress: 0,
|
||||
totalLoaded: 0,
|
||||
currentLoaded: 0,
|
||||
total,
|
||||
units: "bytes",
|
||||
})
|
||||
);
|
||||
|
||||
let buffer = new Uint8Array(total);
|
||||
let loaded = 0;
|
||||
|
||||
for await (const value of response.body) {
|
||||
let newLoaded = loaded + value.length;
|
||||
if (newLoaded > total) {
|
||||
total = newLoaded;
|
||||
|
||||
// Adding the new data will overflow buffer.
|
||||
// In this case, we extend the buffer
|
||||
// Happened when the content-length is lower than the actual lenght
|
||||
let newBuffer = new Uint8Array(total);
|
||||
|
||||
// copy contents
|
||||
newBuffer.set(buffer);
|
||||
|
||||
buffer = newBuffer;
|
||||
}
|
||||
buffer.set(value, loaded);
|
||||
loaded = newLoaded;
|
||||
|
||||
const progress = (loaded / total) * 100;
|
||||
|
||||
progressCallback?.(
|
||||
new ProgressAndStatusCallbackParams({
|
||||
progress,
|
||||
totalLoaded: loaded,
|
||||
currentLoaded: value.length,
|
||||
total,
|
||||
units: "bytes",
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Ensure that buffer is not bigger than loaded
|
||||
// Sometimes content length is larger than the actual size
|
||||
buffer = buffer.slice(0, loaded);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Class for watching the progress bar of multiple events and combining
|
||||
* then into a single progress bar.
|
||||
*/
|
||||
export class MultiProgressAggregator {
|
||||
/**
|
||||
* A function to call with the aggregated statistics.
|
||||
*
|
||||
* @type {?function(ProgressAndStatusCallbackParams):void}
|
||||
*/
|
||||
progressCallback = null;
|
||||
|
||||
/**
|
||||
* The name of the key that contains status information.
|
||||
*
|
||||
* @type {Set<string>}
|
||||
*/
|
||||
watchedTypes;
|
||||
|
||||
/**
|
||||
* The total amount of information loaded so far.
|
||||
*
|
||||
* @type {float}
|
||||
*/
|
||||
#combinedLoaded = 0;
|
||||
|
||||
/**
|
||||
* The total amount of information to be loaded.
|
||||
*
|
||||
* @type {float}
|
||||
*/
|
||||
#combinedTotal = 0;
|
||||
|
||||
/**
|
||||
* The number of operations that are yet to be completed.
|
||||
*
|
||||
* @type {float}
|
||||
*/
|
||||
#remainingEvents = 0;
|
||||
|
||||
/**
|
||||
* The type of operation seen so far.
|
||||
*
|
||||
* @type {Set<string>}
|
||||
*/
|
||||
#seenTypes;
|
||||
|
||||
/**
|
||||
* The status of text seen so far.
|
||||
*
|
||||
* @type {Set<string>}
|
||||
*/
|
||||
#seenStatus;
|
||||
|
||||
/**
|
||||
* @param {object} config
|
||||
* @param {?function(ProgressAndStatusCallbackParams):void} config.progressCallback - A function to call with the aggregated statistics.
|
||||
* @param {Iterable<string>} config.watchedTypes - The types to watch for aggregation
|
||||
*/
|
||||
constructor({ progressCallback, watchedTypes = [ProgressType.DOWNLOAD] }) {
|
||||
this.progressCallback = progressCallback;
|
||||
this.watchedTypes = new Set(watchedTypes);
|
||||
|
||||
this.#seenTypes = new Set();
|
||||
this.#seenStatus = new Set();
|
||||
}
|
||||
|
||||
/**
|
||||
* Callback function that will combined data from different objects/files.
|
||||
*
|
||||
* @param {ProgressAndStatusCallbackParams} data - object containing the data
|
||||
*/
|
||||
aggregateCallback(data) {
|
||||
if (this.watchedTypes.has(data.type)) {
|
||||
this.#seenTypes.add(data.type);
|
||||
this.#seenStatus.add(data.statusText);
|
||||
if (data.statusText == ProgressStatusText.INITIATE) {
|
||||
this.#remainingEvents += 1;
|
||||
}
|
||||
|
||||
if (data.statusText == ProgressStatusText.SIZE_ESTIMATE) {
|
||||
this.#combinedTotal += data.total ?? 0;
|
||||
}
|
||||
|
||||
if (data.statusText == ProgressStatusText.DONE) {
|
||||
this.#remainingEvents -= 1;
|
||||
}
|
||||
|
||||
this.#combinedLoaded += data.currentLoaded ?? 0;
|
||||
|
||||
if (this.progressCallback) {
|
||||
let statusText = data.statusText;
|
||||
if (this.#seenStatus.has(ProgressStatusText.IN_PROGRESS)) {
|
||||
statusText = ProgressStatusText.IN_PROGRESS;
|
||||
}
|
||||
|
||||
if (this.#remainingEvents == 0) {
|
||||
statusText = ProgressStatusText.DONE;
|
||||
}
|
||||
|
||||
this.progressCallback(
|
||||
new ProgressAndStatusCallbackParams({
|
||||
type: data.type,
|
||||
statusText,
|
||||
id: data.id,
|
||||
total: this.#combinedTotal,
|
||||
currentLoaded: data.currentLoaded,
|
||||
totalLoaded: this.#combinedLoaded,
|
||||
progress: (this.#combinedLoaded / this.#combinedTotal) * 100,
|
||||
ok: data.ok,
|
||||
units: data.units,
|
||||
metadata: data,
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a "namespace" to make it easier to import multiple names.
|
||||
export var Progress = Progress || {};
|
||||
Progress.ProgressAndStatusCallbackParams = ProgressAndStatusCallbackParams;
|
||||
Progress.ProgressStatusText = ProgressStatusText;
|
||||
Progress.ProgressType = ProgressType;
|
||||
Progress.readResponse = readResponse;
|
||||
|
@ -129,3 +129,50 @@ In the example below, an image is converted to text using the `moz-image-to-text
|
||||
The following internal tasks are supported by the machine learning engine:
|
||||
|
||||
.. js:autofunction:: imageToText
|
||||
|
||||
|
||||
Notifications callback
|
||||
::::::::::::::::::::::
|
||||
|
||||
When initializing or running the engine, certain operations may take considerable time to complete.
|
||||
You can receive progress notifications for these operations using a callback function.
|
||||
|
||||
Currently, progress notifications are supported only for model downloads.
|
||||
When the engine is created, it will download any model not already in the cache.
|
||||
|
||||
Below is an example of using the callback function with the image-to-text model:
|
||||
|
||||
.. code-block:: javascript
|
||||
|
||||
const { createEngine } = ChromeUtils.importESModule("chrome://global/content/ml/EngineProcess.sys.mjs");
|
||||
|
||||
// options needed for the task
|
||||
const options = {taskName: "moz-image-to-text" };
|
||||
|
||||
// We create the engine object, using options and a callback
|
||||
const engine = await createEngine(options, progressData => {
|
||||
console.log("Received progress data", progressData);
|
||||
});
|
||||
|
||||
|
||||
In the code above, **progressData** is an object of type `ProgressAndStatusCallbackParams` containing the following fields:
|
||||
|
||||
- **progress**: A float indicating the percentage of data loaded. Note that 100% does not necessarily mean the operation is complete.
|
||||
- **totalLoaded**: A float indicating the total amount of data loaded so far.
|
||||
- **currentLoaded**: The amount of data loaded in the current callback call.
|
||||
- **total**: A float indicating an estimate of the total amount of data to be loaded.
|
||||
- **units**: The units in which the amounts are reported.
|
||||
- **type**: The name of the operation being tracked. It will be one of `ProgressType.DOWNLOAD`, `ProgressType.LOAD_FROM_CACHE`.
|
||||
- **statusText**: A message indicating the status of the tracked operation, which can be:
|
||||
|
||||
- `ProgressStatusText.INITIATE` Indicates that an operation has started. This will be used exactly once for each operation uniquely identified by `id` and `type`.
|
||||
|
||||
- `ProgressStatusText.SIZE_ESTIMATE` Indicates an estimate for the size of the operation. This will be used exactly once for each operation uniquely identified by `id` and `type`, updating the `total`` field.
|
||||
|
||||
- `ProgressStatusText.IN_PROGRESS` Indicates that an operation is in progress. This will be used each time progress occurs, updating the `totalLoaded`` and `currentLoaded`` fields.
|
||||
|
||||
- `ProgressStatusText.DONE` indicating that an operation has completed.
|
||||
|
||||
- **id**: An ID uniquely identifying the object/file being tracked.
|
||||
- **ok**: A boolean indicating if the operation was succesfull.
|
||||
- **metadata**: Any additional metadata for the operation being tracked.
|
||||
|
@ -6,6 +6,10 @@ const { sinon } = ChromeUtils.importESModule(
|
||||
"resource://testing-common/Sinon.sys.mjs"
|
||||
);
|
||||
|
||||
const { ProgressStatusText, ProgressType } = ChromeUtils.importESModule(
|
||||
"chrome://global/content/ml/Utils.sys.mjs"
|
||||
);
|
||||
|
||||
// Root URL of the fake hub, see the `data` dir in the tests.
|
||||
const FAKE_HUB =
|
||||
"chrome://mochitests/content/browser/toolkit/components/ml/tests/browser/data";
|
||||
@ -270,6 +274,157 @@ add_task(async function test_getting_file_from_cache() {
|
||||
Assert.deepEqual(array, array2);
|
||||
});
|
||||
|
||||
/**
|
||||
* Test that the callback is appropriately called when the data is retrieved from the server
|
||||
* or from the cache.
|
||||
*/
|
||||
add_task(async function test_getting_file_from_url_cache_with_callback() {
|
||||
const hub = new ModelHub({ rootUrl: FAKE_HUB });
|
||||
|
||||
hub.cache = await initializeCache();
|
||||
|
||||
let numCalls = 0;
|
||||
let currentData = null;
|
||||
let array = await hub.getModelFileAsArrayBuffer({
|
||||
...FAKE_MODEL_ARGS,
|
||||
progressCallback: data => {
|
||||
// expecting initiate status and download
|
||||
currentData = data;
|
||||
if (numCalls == 0) {
|
||||
Assert.deepEqual(
|
||||
{
|
||||
type: data.type,
|
||||
statusText: data.statusText,
|
||||
ok: data.ok,
|
||||
model: currentData?.metadata?.model,
|
||||
file: currentData?.metadata?.file,
|
||||
revision: currentData?.metadata?.revision,
|
||||
},
|
||||
{
|
||||
type: ProgressType.DOWNLOAD,
|
||||
statusText: ProgressStatusText.INITIATE,
|
||||
ok: true,
|
||||
...FAKE_MODEL_ARGS,
|
||||
},
|
||||
"Initiate Data from server should be correct"
|
||||
);
|
||||
}
|
||||
|
||||
if (numCalls == 1) {
|
||||
Assert.deepEqual(
|
||||
{
|
||||
type: data.type,
|
||||
statusText: data.statusText,
|
||||
ok: data.ok,
|
||||
model: currentData?.metadata?.model,
|
||||
file: currentData?.metadata?.file,
|
||||
revision: currentData?.metadata?.revision,
|
||||
},
|
||||
{
|
||||
type: ProgressType.DOWNLOAD,
|
||||
statusText: ProgressStatusText.SIZE_ESTIMATE,
|
||||
ok: true,
|
||||
...FAKE_MODEL_ARGS,
|
||||
},
|
||||
"size estimate Data from server should be correct"
|
||||
);
|
||||
}
|
||||
|
||||
numCalls += 1;
|
||||
},
|
||||
});
|
||||
|
||||
Assert.greaterOrEqual(numCalls, 3);
|
||||
|
||||
// last received message is DONE
|
||||
Assert.deepEqual(
|
||||
{
|
||||
type: currentData?.type,
|
||||
statusText: currentData?.statusText,
|
||||
ok: currentData?.ok,
|
||||
model: currentData?.metadata?.model,
|
||||
file: currentData?.metadata?.file,
|
||||
revision: currentData?.metadata?.revision,
|
||||
},
|
||||
{
|
||||
type: ProgressType.DOWNLOAD,
|
||||
statusText: ProgressStatusText.DONE,
|
||||
ok: true,
|
||||
...FAKE_MODEL_ARGS,
|
||||
},
|
||||
"Done Data from server should be correct"
|
||||
);
|
||||
|
||||
// stub to verify that the data was retrieved from IndexDB
|
||||
let matchMethod = hub.cache._testGetData;
|
||||
|
||||
sinon.stub(hub.cache, "_testGetData").callsFake(function () {
|
||||
return matchMethod.apply(this, arguments).then(result => {
|
||||
Assert.notEqual(result, null);
|
||||
return result;
|
||||
});
|
||||
});
|
||||
|
||||
numCalls = 0;
|
||||
currentData = null;
|
||||
|
||||
// Now we expect the callback to indicate cache usage.
|
||||
let array2 = await hub.getModelFileAsArrayBuffer({
|
||||
...FAKE_MODEL_ARGS,
|
||||
progressCallback: data => {
|
||||
// expecting initiate status and download
|
||||
|
||||
currentData = data;
|
||||
|
||||
if (numCalls == 0) {
|
||||
Assert.deepEqual(
|
||||
{
|
||||
type: data.type,
|
||||
statusText: data.statusText,
|
||||
ok: data.ok,
|
||||
model: currentData?.metadata?.model,
|
||||
file: currentData?.metadata?.file,
|
||||
revision: currentData?.metadata?.revision,
|
||||
},
|
||||
{
|
||||
type: ProgressType.LOAD_FROM_CACHE,
|
||||
statusText: ProgressStatusText.INITIATE,
|
||||
ok: true,
|
||||
...FAKE_MODEL_ARGS,
|
||||
},
|
||||
"Initiate Data from cache should be correct"
|
||||
);
|
||||
}
|
||||
|
||||
numCalls += 1;
|
||||
},
|
||||
});
|
||||
hub.cache._testGetData.restore();
|
||||
|
||||
Assert.deepEqual(array, array2);
|
||||
|
||||
// last received message is DONE
|
||||
Assert.deepEqual(
|
||||
{
|
||||
type: currentData?.type,
|
||||
statusText: currentData?.statusText,
|
||||
ok: currentData?.ok,
|
||||
model: currentData?.metadata?.model,
|
||||
file: currentData?.metadata?.file,
|
||||
revision: currentData?.metadata?.revision,
|
||||
},
|
||||
{
|
||||
type: ProgressType.LOAD_FROM_CACHE,
|
||||
statusText: ProgressStatusText.DONE,
|
||||
ok: true,
|
||||
...FAKE_MODEL_ARGS,
|
||||
},
|
||||
"Done Data from cache should be correct"
|
||||
);
|
||||
|
||||
await deleteCache(hub.cache);
|
||||
});
|
||||
|
||||
/**
|
||||
* Test parsing of a well-formed full URL, including protocol and path.
|
||||
*/
|
||||
|
@ -2,9 +2,13 @@
|
||||
http://creativecommons.org/publicdomain/zero/1.0/ */
|
||||
"use strict";
|
||||
|
||||
const { arrayBufferToBlobURL } = ChromeUtils.importESModule(
|
||||
"chrome://global/content/ml/Utils.sys.mjs"
|
||||
);
|
||||
const {
|
||||
arrayBufferToBlobURL,
|
||||
MultiProgressAggregator,
|
||||
ProgressAndStatusCallbackParams,
|
||||
ProgressStatusText,
|
||||
readResponse,
|
||||
} = ChromeUtils.importESModule("chrome://global/content/ml/Utils.sys.mjs");
|
||||
|
||||
/**
|
||||
* Test arrayBufferToBlobURL function.
|
||||
@ -24,3 +28,392 @@ add_task(async function test_ml_utils_array_buffer_to_blob_url() {
|
||||
"The returned string should be a Blob URL"
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Test that we can retrieve the correct content without a callback.
|
||||
*/
|
||||
add_task(async function test_correct_response_no_callback() {
|
||||
const content = "This is the expected response.";
|
||||
const blob = new Blob([content]);
|
||||
const response = new Response(blob, {
|
||||
headers: new Headers({ "Content-Length": blob.size }),
|
||||
});
|
||||
|
||||
const responseArray = await readResponse(response);
|
||||
|
||||
const responseContent = new TextDecoder().decode(responseArray);
|
||||
|
||||
Assert.equal(content, responseContent, "The response content should match.");
|
||||
});
|
||||
|
||||
/**
|
||||
* Test that we can retrieve the correct content with a callback.
|
||||
*/
|
||||
add_task(async function test_correct_response_callback() {
|
||||
const content = "This is the expected response.";
|
||||
const blob = new Blob([content]);
|
||||
const response = new Response(blob, {
|
||||
headers: new Headers({ "Content-Length": blob.size }),
|
||||
});
|
||||
|
||||
const responseArray = await readResponse(response, data => {
|
||||
data;
|
||||
});
|
||||
|
||||
const responseContent = new TextDecoder().decode(responseArray);
|
||||
|
||||
Assert.equal(content, responseContent, "The response content should match.");
|
||||
});
|
||||
|
||||
/**
|
||||
* Test that we can retrieve the correct content with a content-lenght lower than the actual len
|
||||
*/
|
||||
add_task(async function test_correct_response_content_length_under_reported() {
|
||||
const content = "This is the expected response.";
|
||||
const blob = new Blob([content]);
|
||||
const response = new Response(blob, {
|
||||
headers: new Headers({
|
||||
"Content-Length": 1,
|
||||
}),
|
||||
});
|
||||
|
||||
const responseArray = await readResponse(response, data => {
|
||||
data;
|
||||
});
|
||||
|
||||
const responseContent = new TextDecoder().decode(responseArray);
|
||||
|
||||
Assert.equal(content, responseContent, "The response content should match.");
|
||||
});
|
||||
|
||||
/**
|
||||
* Test that we can retrieve the correct content with a content-lenght larger than the actual len
|
||||
*/
|
||||
add_task(async function test_correct_response_content_length_over_reported() {
|
||||
const content = "This is the expected response.";
|
||||
const blob = new Blob([content]);
|
||||
const response = new Response(blob, {
|
||||
headers: new Headers({
|
||||
"Content-Length": 2 * blob.size + 20,
|
||||
}),
|
||||
});
|
||||
|
||||
const responseArray = await readResponse(response, data => {
|
||||
data;
|
||||
});
|
||||
|
||||
const responseContent = new TextDecoder().decode(responseArray);
|
||||
|
||||
Assert.equal(content, responseContent, "The response content should match.");
|
||||
});
|
||||
|
||||
/**
|
||||
* Test that we can retrieve and the callback provide correct information
|
||||
*/
|
||||
add_task(async function test_correct_response_callback_correct() {
|
||||
const contents = ["Carrot", "Broccoli", "Tomato", "Spinach"];
|
||||
|
||||
let contentSizes = [];
|
||||
|
||||
let totalSize = 0;
|
||||
|
||||
for (const value of contents) {
|
||||
contentSizes.push(new Blob([value]).size);
|
||||
|
||||
totalSize += contentSizes[contentSizes.length - 1];
|
||||
}
|
||||
|
||||
const numChunks = contents.length;
|
||||
|
||||
let encoder = new TextEncoder();
|
||||
|
||||
// const stream = ReadableStream.from(contents);
|
||||
|
||||
let streamId = -1;
|
||||
|
||||
const stream = new ReadableStream({
|
||||
pull(controller) {
|
||||
streamId += 1;
|
||||
|
||||
if (streamId < numChunks) {
|
||||
controller.enqueue(encoder.encode(contents[streamId]));
|
||||
} else {
|
||||
controller.close();
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
const response = new Response(stream, {
|
||||
headers: new Headers({
|
||||
"Content-Length": totalSize,
|
||||
}),
|
||||
});
|
||||
|
||||
let chunkId = -1;
|
||||
let expectedTotalLoaded = 0;
|
||||
|
||||
const responseArray = await readResponse(response, data => {
|
||||
chunkId += 1;
|
||||
// The callback is called on time with no data loaded and just the total
|
||||
if (chunkId == 0) {
|
||||
Assert.deepEqual(
|
||||
{
|
||||
total: data.total,
|
||||
currentLoaded: data.currentLoaded,
|
||||
totalLoaded: data.totalLoaded,
|
||||
},
|
||||
{
|
||||
total: totalSize,
|
||||
currentLoaded: 0,
|
||||
totalLoaded: 0,
|
||||
},
|
||||
"The callback should be called on time with an estimate of the total size and no data read. "
|
||||
);
|
||||
} else {
|
||||
Assert.less(
|
||||
chunkId - 1,
|
||||
numChunks,
|
||||
"The number of times the callback is called should be lower than the number of chunks"
|
||||
);
|
||||
|
||||
expectedTotalLoaded += contentSizes[chunkId - 1];
|
||||
|
||||
Assert.deepEqual(
|
||||
{
|
||||
total: data.total,
|
||||
currentLoaded: data.currentLoaded,
|
||||
totalLoaded: data.totalLoaded,
|
||||
},
|
||||
{
|
||||
total: totalSize,
|
||||
currentLoaded: contentSizes[chunkId - 1],
|
||||
totalLoaded: expectedTotalLoaded,
|
||||
},
|
||||
"The reported value by the callback should match the correct values"
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
Assert.equal(
|
||||
chunkId,
|
||||
numChunks,
|
||||
"The callback should be called exactly as many times as the number of chunks."
|
||||
);
|
||||
|
||||
const responseContent = new TextDecoder().decode(
|
||||
responseArray.buffer.slice(
|
||||
responseArray.byteOffset,
|
||||
responseArray.byteLength + responseArray.byteOffset
|
||||
)
|
||||
);
|
||||
|
||||
Assert.equal(
|
||||
contents.join(""),
|
||||
responseContent,
|
||||
"The response content should match."
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Test that multi-aggregator only call the callback for the provided types.
|
||||
*/
|
||||
add_task(async function test_multi_aggregator_watchtypes() {
|
||||
let numCalls = 0;
|
||||
let aggregator = new MultiProgressAggregator({
|
||||
progressCallback: _ => {
|
||||
numCalls += 1;
|
||||
},
|
||||
watchedTypes: ["t1"],
|
||||
});
|
||||
|
||||
aggregator.aggregateCallback(
|
||||
new ProgressAndStatusCallbackParams({
|
||||
type: "download",
|
||||
})
|
||||
);
|
||||
|
||||
Assert.equal(numCalls, 0);
|
||||
|
||||
aggregator.aggregateCallback(
|
||||
new ProgressAndStatusCallbackParams({
|
||||
type: "t1",
|
||||
})
|
||||
);
|
||||
|
||||
Assert.equal(numCalls, 1);
|
||||
});
|
||||
|
||||
/**
|
||||
* Test that multi-aggregator aggregate correctly.
|
||||
*/
|
||||
add_task(async function test_multi_aggregator() {
|
||||
// Ids for all available tasks. Should be unique per element.
|
||||
const taskIds = ["A", "B", "C", "D", "E", "F"];
|
||||
|
||||
// The type for each available tasks.
|
||||
const taskTypes = ["t1", "t1", "t2", "t2", "t3", "t3"];
|
||||
|
||||
// The total size available for each task
|
||||
const taskSizes = [5, 11, 13, 17, 19, 23];
|
||||
|
||||
// The chunk sizes. The sum for indices with same chunk task index (according to chunkTaskIndex)
|
||||
// should be equal to the corresponding size in taskSizes
|
||||
const chunkSizes = [2, 3, 5, 6, 11, 7, 12, 6, 8, 9, 9, 10];
|
||||
|
||||
// Task index for each chunk. Index in the array taskIds. Order was chosen so that we can simulate
|
||||
// overlaps in tasks.
|
||||
const chunkTaskIndex = [0, 0, 1, 2, 5, 2, 5, 1, 3, 4, 3, 4];
|
||||
|
||||
// Indicating how much has been loaded for the task so far.
|
||||
const chunkTaskLoaded = [2, 5, 5, 6, 11, 13, 23, 11, 8, 9, 17, 19];
|
||||
|
||||
// Whether the
|
||||
const chunkIsFinal = [0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1];
|
||||
|
||||
let numDone = 0;
|
||||
|
||||
let currentData = null;
|
||||
|
||||
let expectedTotalToLoad = 0;
|
||||
|
||||
let numCalls = 0;
|
||||
|
||||
let expectedNumCalls = 0;
|
||||
|
||||
let expectedTotalLoaded = 0;
|
||||
|
||||
const aggregator = new MultiProgressAggregator({
|
||||
progressCallback: data => {
|
||||
currentData = data;
|
||||
|
||||
numCalls += 1;
|
||||
|
||||
if (data.statusText == ProgressStatusText.DONE) {
|
||||
numDone += 1;
|
||||
}
|
||||
},
|
||||
watchedTypes: ["t1", "t2", "t3"],
|
||||
});
|
||||
|
||||
// Initiate and advertise the size for each task
|
||||
for (const i in taskTypes) {
|
||||
currentData = null;
|
||||
expectedNumCalls += 1;
|
||||
aggregator.aggregateCallback(
|
||||
new ProgressAndStatusCallbackParams({
|
||||
type: taskTypes[i],
|
||||
statusText: ProgressStatusText.INITIATE,
|
||||
id: taskIds[i],
|
||||
total: taskSizes[i],
|
||||
})
|
||||
);
|
||||
|
||||
Assert.ok(currentData, "Received data should be defined");
|
||||
Assert.deepEqual(
|
||||
{
|
||||
statusText: currentData?.statusText,
|
||||
type: currentData?.type,
|
||||
id: currentData?.id,
|
||||
numDone,
|
||||
numCalls,
|
||||
},
|
||||
{
|
||||
statusText: ProgressStatusText.INITIATE,
|
||||
type: taskTypes[i],
|
||||
id: taskIds[i],
|
||||
numDone: 0,
|
||||
numCalls: expectedNumCalls,
|
||||
},
|
||||
"Data received after initiate should be correct"
|
||||
);
|
||||
currentData = null;
|
||||
expectedNumCalls += 1;
|
||||
aggregator.aggregateCallback(
|
||||
new ProgressAndStatusCallbackParams({
|
||||
type: taskTypes[i],
|
||||
statusText: ProgressStatusText.SIZE_ESTIMATE,
|
||||
id: taskIds[i],
|
||||
total: taskSizes[i],
|
||||
})
|
||||
);
|
||||
Assert.ok(currentData, "Received data should be defined");
|
||||
|
||||
expectedTotalToLoad += taskSizes[i];
|
||||
|
||||
Assert.deepEqual(
|
||||
{
|
||||
numDone,
|
||||
numCalls,
|
||||
total: currentData.total,
|
||||
},
|
||||
{
|
||||
numDone: 0,
|
||||
total: expectedTotalToLoad,
|
||||
numCalls: expectedNumCalls,
|
||||
},
|
||||
"Data received after size estimate should be correct."
|
||||
);
|
||||
}
|
||||
|
||||
// Send progress status for each chunk.
|
||||
for (const chunkIndex in chunkTaskIndex) {
|
||||
let taskIndex = chunkTaskIndex[chunkIndex];
|
||||
currentData = null;
|
||||
expectedNumCalls += 1;
|
||||
expectedTotalLoaded += chunkSizes[chunkIndex];
|
||||
aggregator.aggregateCallback(
|
||||
new ProgressAndStatusCallbackParams({
|
||||
type: taskTypes[taskIndex],
|
||||
statusText: ProgressStatusText.IN_PROGRESS,
|
||||
id: taskIds[taskIndex],
|
||||
total: taskSizes[taskIndex],
|
||||
currentLoaded: chunkSizes[chunkIndex],
|
||||
totalLoaded: chunkTaskLoaded[chunkIndex],
|
||||
})
|
||||
);
|
||||
|
||||
Assert.ok(currentData, "Received data should be defined");
|
||||
|
||||
Assert.deepEqual(
|
||||
{
|
||||
numDone,
|
||||
numCalls,
|
||||
total: currentData?.total,
|
||||
currentLoaded: currentData?.currentLoaded,
|
||||
totalLoaded: currentData?.totalLoaded,
|
||||
},
|
||||
{
|
||||
numDone: 0,
|
||||
numCalls: expectedNumCalls,
|
||||
total: expectedTotalToLoad,
|
||||
currentLoaded: chunkSizes[chunkIndex],
|
||||
totalLoaded: expectedTotalLoaded,
|
||||
},
|
||||
"Data received after in progress should be correct"
|
||||
);
|
||||
|
||||
// Notify of task is done
|
||||
if (chunkIsFinal[chunkIndex]) {
|
||||
currentData = null;
|
||||
expectedNumCalls += 1;
|
||||
aggregator.aggregateCallback(
|
||||
new ProgressAndStatusCallbackParams({
|
||||
type: taskTypes[taskIndex],
|
||||
statusText: ProgressStatusText.DONE,
|
||||
id: taskIds[taskIndex],
|
||||
total: taskSizes[chunkIndex],
|
||||
})
|
||||
);
|
||||
|
||||
Assert.ok(currentData, "Received data should be defined");
|
||||
|
||||
Assert.deepEqual(
|
||||
{ total: currentData.total, numCalls },
|
||||
{ total: expectedTotalToLoad, numCalls: expectedNumCalls },
|
||||
"Data received after completed tasks should be correct"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Assert.equal(numDone, 1, "Done status should be received");
|
||||
});
|
||||
|
@ -19,12 +19,20 @@ const lazy = {};
|
||||
|
||||
ChromeUtils.defineESModuleGetters(lazy, {
|
||||
createEngine: "chrome://global/content/ml/EngineProcess.sys.mjs",
|
||||
MultiProgressAggregator: "chrome://global/content/ml/Utils.sys.mjs",
|
||||
NimbusFeatures: "resource://nimbus/ExperimentAPI.sys.mjs",
|
||||
PdfJsTelemetry: "resource://pdf.js/PdfJsTelemetry.sys.mjs",
|
||||
PrivateBrowsingUtils: "resource://gre/modules/PrivateBrowsingUtils.sys.mjs",
|
||||
SetClipboardSearchString: "resource://gre/modules/Finder.sys.mjs",
|
||||
});
|
||||
|
||||
ChromeUtils.defineLazyGetter(lazy, "console", () => {
|
||||
return console.createInstance({
|
||||
maxLogLevelPref: "browser.ml.logLevel",
|
||||
prefix: "PDF_JS",
|
||||
});
|
||||
});
|
||||
|
||||
var Svc = {};
|
||||
XPCOMUtils.defineLazyServiceGetter(
|
||||
Svc,
|
||||
@ -95,6 +103,10 @@ export class PdfjsParent extends JSWindowActorParent {
|
||||
lazy.PdfJsTelemetry.report(aMsg.data);
|
||||
}
|
||||
|
||||
_initProgressBar(progressData) {
|
||||
lazy.console.debug("progess_from_pdfjs", progressData);
|
||||
}
|
||||
|
||||
async _mlGuess({ data: { service, request } }) {
|
||||
if (!lazy.createEngine) {
|
||||
return null;
|
||||
@ -102,10 +114,20 @@ export class PdfjsParent extends JSWindowActorParent {
|
||||
if (service !== "image-to-text") {
|
||||
throw new Error("Invalid service");
|
||||
}
|
||||
// We are using the internal task name prefixed with moz-
|
||||
const engine = await lazy.createEngine({
|
||||
taskName: "moz-image-to-text",
|
||||
|
||||
let aggregator = new lazy.MultiProgressAggregator({
|
||||
progressCallback: this._initProgressBar,
|
||||
});
|
||||
|
||||
const callback = aggregator.aggregateCallback.bind(aggregator);
|
||||
|
||||
// We are using the internal task name prefixed with moz-
|
||||
const engine = await lazy.createEngine(
|
||||
{
|
||||
taskName: "moz-image-to-text",
|
||||
},
|
||||
callback
|
||||
);
|
||||
return engine.run(request);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user