mirror of
https://github.com/mozilla/gecko-dev.git
synced 2024-12-11 16:32:59 +00:00
Bug 1925471 - Check for values in PipelineOptions r=vazish,urlbar-reviewers,adw, a=dmeehan
Differential Revision: https://phabricator.services.mozilla.com/D226224
This commit is contained in:
parent
e5bebc56cd
commit
f679596bf1
@ -19,16 +19,16 @@ ChromeUtils.defineESModuleGetters(lazy, {
|
||||
*/
|
||||
const INTENT_OPTIONS = {
|
||||
taskName: "text-classification",
|
||||
modelId: "mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier",
|
||||
modelRevision: "v0.1.0",
|
||||
dtype: "q8",
|
||||
featureId: "suggest-intent-classification",
|
||||
engineId: "ml-suggest-intent",
|
||||
timeoutMS: 36000000,
|
||||
};
|
||||
|
||||
const NER_OPTIONS = {
|
||||
taskName: "token-classification",
|
||||
modelId: "mozilla/distilbert-uncased-NER-LoRA",
|
||||
modelRevision: "v0.1.1",
|
||||
dtype: "q8",
|
||||
featureId: "suggest-NER",
|
||||
engineId: "ml-suggest-ner",
|
||||
timeoutMS: 36000000,
|
||||
};
|
||||
|
||||
// List of prepositions used in subject cleaning.
|
||||
@ -115,20 +115,8 @@ class _MLSuggest {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to generate a unique key for model engines.
|
||||
*
|
||||
* @param {object} options
|
||||
* The options object containing taskName and modelId.
|
||||
* @returns {string}
|
||||
* The key for the model engine.
|
||||
*/
|
||||
#getmodelEnginesKey(options) {
|
||||
return `${options.taskName}-${options.modelId}`;
|
||||
}
|
||||
|
||||
async #initializeModelEngine(options) {
|
||||
const engineId = this.#getmodelEnginesKey(options);
|
||||
const engineId = options.engineId;
|
||||
|
||||
// uses cache if engine was used
|
||||
if (this.#modelEngines[engineId]) {
|
||||
@ -153,9 +141,7 @@ class _MLSuggest {
|
||||
* The predicted intent label or null if the model is not initialized.
|
||||
*/
|
||||
async _findIntent(query, options = {}) {
|
||||
const engineIntentClassifier =
|
||||
this.#modelEngines[this.#getmodelEnginesKey(INTENT_OPTIONS)];
|
||||
|
||||
const engineIntentClassifier = this.#modelEngines[INTENT_OPTIONS.engineId];
|
||||
if (!engineIntentClassifier) {
|
||||
return null;
|
||||
}
|
||||
@ -180,7 +166,7 @@ class _MLSuggest {
|
||||
* The NER results or null if the model is not initialized.
|
||||
*/
|
||||
async _findNER(query, options = {}) {
|
||||
const engineNER = this.#modelEngines[this.#getmodelEnginesKey(NER_OPTIONS)];
|
||||
const engineNER = this.#modelEngines[NER_OPTIONS.engineId];
|
||||
return engineNER?.run({ args: [query], options });
|
||||
}
|
||||
|
||||
|
@ -580,7 +580,7 @@ async function runInference() {
|
||||
tokenizerId: modelId,
|
||||
processorId: modelId,
|
||||
taskName,
|
||||
engineId: "about:inference",
|
||||
engineId: "about-inference",
|
||||
modelHubRootUrl,
|
||||
modelHubUrlTemplate,
|
||||
device,
|
||||
|
@ -16,6 +16,42 @@ ChromeUtils.defineESModuleGetters(
|
||||
* @typedef {import("../content/Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams
|
||||
*/
|
||||
|
||||
/**
|
||||
* Lists Firefox internal features
|
||||
*/
|
||||
const FEATURES = [
|
||||
"pdfjs-alt-text", // see toolkit/components/pdfjs/content/PdfjsParent.sys.mjs
|
||||
"suggest-intent-classification", // see browser/components/urlbar/private/MLSuggest.sys.mjs
|
||||
"suggest-NER", // see browser/components/urlbar/private/MLSuggest.sys.mjs
|
||||
];
|
||||
|
||||
/**
|
||||
* Custom error class for validation errors.
|
||||
*
|
||||
* This error is thrown when a field fails validation, providing additional context such as
|
||||
* the name of the field that caused the error.
|
||||
*
|
||||
* @augments Error
|
||||
*/
|
||||
class PipelineOptionsValidationError extends Error {
|
||||
/**
|
||||
* Create a PipelineOptionsValidationError.
|
||||
*
|
||||
* @param {string} field - The name of the field that caused the validation error.
|
||||
* @param {any} value - The invalid value provided for the field.
|
||||
* @param {string} [tips=null] - Optional tips or suggestions for valid values.
|
||||
*/
|
||||
constructor(field, value, tips = null) {
|
||||
const baseMessage = `Invalid value "${value}" for field "${field}".`;
|
||||
const message = tips ? `${baseMessage} ${tips}` : baseMessage;
|
||||
super(message);
|
||||
|
||||
this.name = this.constructor.name;
|
||||
this.field = field;
|
||||
this.value = value;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Enum for execution priority.
|
||||
*
|
||||
@ -29,16 +65,66 @@ ChromeUtils.defineESModuleGetters(
|
||||
* @enum {string}
|
||||
*/
|
||||
export const ExecutionPriority = {
|
||||
/** High priority, needed for Firefox */
|
||||
HIGH: "HIGH",
|
||||
|
||||
/** Normal priority, default */
|
||||
NORMAL: "NORMAL",
|
||||
|
||||
/** Low priority, used for 3rd party calls */
|
||||
LOW: "LOW",
|
||||
};
|
||||
|
||||
/**
|
||||
* Enum for quantization levels.
|
||||
*
|
||||
* Defines the quantization level of the task:
|
||||
*
|
||||
* - 'fp32': Full precision 32-bit floating point (`''`)
|
||||
* - 'fp16': Half precision 16-bit floating point (`'_fp16'`)
|
||||
* - 'q8': Quantized 8-bit (`'_quantized'`)
|
||||
* - 'int8': Integer 8-bit quantization (`'_int8'`)
|
||||
* - 'uint8': Unsigned integer 8-bit quantization (`'_uint8'`)
|
||||
* - 'q4': Quantized 4-bit (`'_q4'`)
|
||||
* - 'bnb4': Binary/Boolean 4-bit quantization (`'_bnb4'`)
|
||||
* - 'q4f16': 16-bit floating point model with 4-bit block weight quantization (`'_q4f16'`)
|
||||
*
|
||||
* @readonly
|
||||
* @enum {string}
|
||||
*/
|
||||
export const QuantizationLevel = {
|
||||
FP32: "fp32",
|
||||
FP16: "fp16",
|
||||
Q8: "q8",
|
||||
INT8: "int8",
|
||||
UINT8: "uint8",
|
||||
Q4: "q4",
|
||||
BNB4: "bnb4",
|
||||
Q4F16: "q4f16",
|
||||
};
|
||||
|
||||
/**
|
||||
* Enum for the device used for inference.
|
||||
*
|
||||
* @readonly
|
||||
* @enum {string}
|
||||
*/
|
||||
export const InferenceDevice = {
|
||||
GPU: "gpu",
|
||||
WASM: "wasm",
|
||||
};
|
||||
|
||||
/**
|
||||
* Enum for log levels.
|
||||
*
|
||||
* @readonly
|
||||
* @enum {string}
|
||||
*/
|
||||
export const LogLevel = {
|
||||
TRACE: "Trace",
|
||||
INFO: "Info",
|
||||
DEBUG: "Debug",
|
||||
WARN: "Warn",
|
||||
ERROR: "Error",
|
||||
CRITICAL: "Critical",
|
||||
ALL: "All",
|
||||
};
|
||||
|
||||
/**
|
||||
* @typedef {import("../../translations/actors/TranslationsEngineParent.sys.mjs").TranslationsEngineParent} TranslationsEngineParent
|
||||
*/
|
||||
@ -138,7 +224,7 @@ export class PipelineOptions {
|
||||
/**
|
||||
* The log level used in the worker
|
||||
*
|
||||
* @type {?string}
|
||||
* @type {LogLevel | null}
|
||||
*/
|
||||
logLevel = null;
|
||||
|
||||
@ -152,24 +238,14 @@ export class PipelineOptions {
|
||||
/**
|
||||
* Device used for inference
|
||||
*
|
||||
* @type {"gpu" | "wasm" | null}
|
||||
* @type {InferenceDevice | null}
|
||||
*/
|
||||
device = null;
|
||||
|
||||
/**
|
||||
* Quantization level
|
||||
*
|
||||
* - name : description (onnx file suffix)
|
||||
* - 'fp32': Full precision 32-bit floating point (`''`)
|
||||
* - 'fp16': Half precision 16-bit floating point (`'_fp16'`)
|
||||
* - 'q8': Quantized 8-bit (`'_quantized'`)
|
||||
* - 'int8': Integer 8-bit quantization (`'_int8'`)
|
||||
* - 'uint8': Unsigned integer 8-bit quantization (`'_uint8'`)
|
||||
* - 'q4': Quantized 4-bit (`'_q4'`)
|
||||
* - 'bnb4': Binary/Boolean 4-bit quantization (`'_bnb4'`)
|
||||
* - 'q4f16': 16-bit floating point model with 4-bit block weight quantization (`'_q4f16'`)
|
||||
*
|
||||
* @type {"fp32" | "fp16" | "q8" | "int8" | "uint8" | "q4" | "bnb4" | "q4f16" | null}
|
||||
* @type {QuantizationLevel | null}
|
||||
*/
|
||||
dtype = null;
|
||||
|
||||
@ -187,7 +263,7 @@ export class PipelineOptions {
|
||||
*
|
||||
* @type {ExecutionPriority}
|
||||
*/
|
||||
executionPriority = ExecutionPriority.NORMAL;
|
||||
executionPriority = null;
|
||||
|
||||
/**
|
||||
* Create a PipelineOptions instance.
|
||||
@ -198,6 +274,122 @@ export class PipelineOptions {
|
||||
this.updateOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
* Private method to validate enum fields.
|
||||
*
|
||||
* @param {string} field - The field being validated (e.g., 'dtype', 'device', 'executionPriority').
|
||||
* @param {*} value - The value being checked against the enum.
|
||||
* @throws {Error} Throws an error if the value is not valid.
|
||||
* @private
|
||||
*/
|
||||
#validateEnum(field, value) {
|
||||
const enums = {
|
||||
dtype: QuantizationLevel,
|
||||
device: InferenceDevice,
|
||||
executionPriority: ExecutionPriority,
|
||||
logLevel: LogLevel,
|
||||
};
|
||||
// Check if the value is part of the enum or null
|
||||
if (!Object.values(enums[field]).includes(value)) {
|
||||
throw new PipelineOptionsValidationError(field, value);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the taskName field, ensuring it contains only alphanumeric characters, underscores, and dashes.
|
||||
* Slashes are not allowed in the taskName.
|
||||
*
|
||||
* @param {string} field - The name of the field being validated (e.g., taskName).
|
||||
* @param {string} value - The value of the field to validate.
|
||||
* @throws {Error} Throws an error if the taskName contains invalid characters.
|
||||
* @private
|
||||
*/
|
||||
#validateTaskName(field, value) {
|
||||
// Define a regular expression to verify taskName pattern (alphanumeric, underscores, and dashes, no slashes)
|
||||
const validTaskNamePattern = /^[a-zA-Z0-9_\-]+$/;
|
||||
|
||||
// Check if the value matches the pattern
|
||||
if (!validTaskNamePattern.test(value)) {
|
||||
throw new PipelineOptionsValidationError(
|
||||
field,
|
||||
value,
|
||||
"Should contain only alphanumeric characters, underscores, or dashes."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates a taskName or ID.
|
||||
*
|
||||
* The ID can optionally be in the form `organization/name`, where both `organization` and `name`
|
||||
* follow the `taskName` pattern (alphanumeric characters, underscores, and dashes).
|
||||
*
|
||||
* Throws an exception if the name or ID is invalid.
|
||||
*
|
||||
* @param {string} field - The name of the field being validated (e.g., taskName, engineId).
|
||||
* @param {string} value - The value of the field to validate.
|
||||
* @throws {PipelineOptionsValidationError} Throws a validation error if the ID is invalid.
|
||||
* @private
|
||||
*/
|
||||
#validateId(field, value) {
|
||||
// Define a regular expression to match the optional organization and required name
|
||||
// `organization/` part is optional, and both parts should follow the taskName pattern.
|
||||
const validPattern = /^(?:[a-zA-Z0-9_\-]+\/)?[a-zA-Z0-9_\-]+$/;
|
||||
|
||||
// Check if the value matches the pattern
|
||||
if (!validPattern.test(value)) {
|
||||
throw new PipelineOptionsValidationError(
|
||||
field,
|
||||
value,
|
||||
"Should follow the format 'organization/name' or 'name', where both parts contain only alphanumeric characters, underscores, or dashes."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generic method to validate an integer within a specified range.
|
||||
*
|
||||
* @param {string} field - The name of the field being validated.
|
||||
* @param {number} value - The integer value to validate.
|
||||
* @param {number} min - The minimum allowed value (inclusive).
|
||||
* @param {number} max - The maximum allowed value (inclusive).
|
||||
* @throws {Error} Throws an error if the value is not a valid integer within the range.
|
||||
* @private
|
||||
*/
|
||||
#validateIntegerRange(field, value, min, max) {
|
||||
if (!Number.isInteger(value) || value < min || value > max) {
|
||||
throw new PipelineOptionsValidationError(
|
||||
field,
|
||||
value,
|
||||
`Should be an integer between ${min} and ${max}.`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the revision field.
|
||||
* The revision can be `main` or a version following a pattern like `v1.0.0`, `1.0.0-beta1`, `1.0.0.alpha2`, `1.0.0.rc1`, etc.
|
||||
*
|
||||
* @param {string} field - The name of the field being validated (e.g., modelRevision, tokenizerRevision).
|
||||
* @param {string} value - The value of the revision field to validate.
|
||||
* @throws {Error} Throws an error if the revision does not follow the expected pattern.
|
||||
* @private
|
||||
*/
|
||||
#validateRevision(field, value) {
|
||||
// Regular expression to match `main` or a version like `v1`, `v1.0.0`, `1.0.0-alpha1`, `1.0.0.alpha2`, `1.0.0.rc1`, etc.
|
||||
const revisionPattern =
|
||||
/^v?(\d+(\.\d+){0,2})([-\.](alpha\d*|beta\d*|pre\d*|post\d*|rc\d*))?$|^main$/;
|
||||
|
||||
// Check if the value matches the pattern
|
||||
if (!revisionPattern.test(value)) {
|
||||
throw new PipelineOptionsValidationError(
|
||||
field,
|
||||
value,
|
||||
`Should be 'main' or follow a versioning pattern like 'v1.0.0', '1.0.0-beta1', '1.0.0.alpha2', '1.0.0.rc1', etc.`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates multiple options at once.
|
||||
*
|
||||
@ -238,13 +430,49 @@ export class PipelineOptions {
|
||||
if (!optionsKeys.includes(key) || options[key] == null) {
|
||||
return;
|
||||
}
|
||||
if (key === "featureId" && !FEATURES.includes(options[key])) {
|
||||
throw new PipelineOptionsValidationError(
|
||||
key,
|
||||
options[key],
|
||||
`Should be one of ${FEATURES.join(", ")}`
|
||||
);
|
||||
}
|
||||
// Validating values.
|
||||
if (["taskName", "engineId"].includes(key)) {
|
||||
this.#validateTaskName(key, options[key]);
|
||||
}
|
||||
|
||||
if (["modelId", "tokenizerId", "processorId"].includes(key)) {
|
||||
this.#validateId(key, options[key]);
|
||||
}
|
||||
|
||||
if (
|
||||
["modelRevision", "tokenizerRevision", "processorRevision"].includes(
|
||||
key
|
||||
)
|
||||
) {
|
||||
this.#validateRevision(key, options[key]);
|
||||
}
|
||||
|
||||
if (["dtype", "device", "executionPriority", "logLevel"].includes(key)) {
|
||||
this.#validateEnum(key, options[key]);
|
||||
}
|
||||
|
||||
if (key === "numThreads") {
|
||||
this.#validateIntegerRange(key, options[key], 0, 100);
|
||||
}
|
||||
|
||||
if (key === "timeoutMS") {
|
||||
this.#validateIntegerRange(key, options[key], 0, 36000000);
|
||||
}
|
||||
|
||||
this[key] = options[key];
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an object containing all current options.
|
||||
|
||||
*
|
||||
* @returns {object} An object with the current options.
|
||||
*/
|
||||
getOptions() {
|
||||
|
@ -101,7 +101,7 @@ add_task(async function test_ml_engine_pick_feature_id() {
|
||||
id: "74a71cfd-1734-44e6-85c0-69cf3e874138",
|
||||
},
|
||||
{
|
||||
featureId: "myCoolFeature",
|
||||
featureId: "pdfjs-alt-text",
|
||||
taskName: "moz-echo",
|
||||
modelId: "mozilla/distilvit",
|
||||
processorId: "mozilla/distilvit",
|
||||
@ -118,7 +118,7 @@ add_task(async function test_ml_engine_pick_feature_id() {
|
||||
|
||||
info("Get the engine");
|
||||
const engineInstance = await createEngine({
|
||||
featureId: "myCoolFeature",
|
||||
featureId: "pdfjs-alt-text",
|
||||
taskName: "moz-echo",
|
||||
});
|
||||
|
||||
@ -284,35 +284,6 @@ add_task(async function test_pref_is_off() {
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* Tests that we verify the task name is valid
|
||||
*/
|
||||
add_task(async function test_invalid_task_name() {
|
||||
const { cleanup, remoteClients } = await setup();
|
||||
|
||||
const options = new PipelineOptions({ taskName: "inv#alid" });
|
||||
const mlEngineParent = await EngineProcess.getMLEngineParent();
|
||||
const engineInstance = await mlEngineParent.getEngine(options);
|
||||
|
||||
let error;
|
||||
|
||||
try {
|
||||
const res = engineInstance.run({ data: "This gets echoed." });
|
||||
await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
|
||||
await res;
|
||||
} catch (e) {
|
||||
error = e;
|
||||
}
|
||||
is(
|
||||
error?.message,
|
||||
"Invalid task name. Task name should contain only alphanumeric characters and underscores/dashes.",
|
||||
"The error is correctly surfaced."
|
||||
);
|
||||
|
||||
await EngineProcess.destroyMLEngine();
|
||||
await cleanup();
|
||||
});
|
||||
|
||||
/**
|
||||
* Tests the generic pipeline API
|
||||
*/
|
||||
@ -714,7 +685,7 @@ add_task(async function test_ml_engine_get_status() {
|
||||
device: null,
|
||||
dtype: "q8",
|
||||
numThreads: null,
|
||||
executionPriority: "NORMAL",
|
||||
executionPriority: null,
|
||||
},
|
||||
engineId: "default-engine",
|
||||
},
|
||||
@ -760,3 +731,233 @@ add_task(async function test_ml_engine_not_enough_memory() {
|
||||
await EngineProcess.destroyMLEngine();
|
||||
await cleanup();
|
||||
});
|
||||
|
||||
/**
|
||||
* Helper function to create a basic set of valid options
|
||||
*/
|
||||
function getValidOptions(overrides = {}) {
|
||||
return Object.assign(
|
||||
{
|
||||
engineId: "validEngine1",
|
||||
featureId: "pdfjs-alt-text",
|
||||
taskName: "valid_task",
|
||||
modelHubRootUrl: "https://example.com",
|
||||
modelHubUrlTemplate: "https://example.com/{modelId}",
|
||||
timeoutMS: 5000,
|
||||
modelId: "validModel",
|
||||
modelRevision: "v1",
|
||||
tokenizerId: "validTokenizer",
|
||||
tokenizerRevision: "v1",
|
||||
processorId: "validProcessor",
|
||||
processorRevision: "v1",
|
||||
logLevel: null,
|
||||
runtimeFilename: "runtime.wasm",
|
||||
device: InferenceDevice.GPU,
|
||||
numThreads: 4,
|
||||
executionPriority: ExecutionPriority.NORMAL,
|
||||
},
|
||||
overrides
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* A collection of test cases for invalid and valid values.
|
||||
*/
|
||||
const commonInvalidCases = [
|
||||
{ description: "Invalid value (special characters)", value: "org1/my!value" },
|
||||
{
|
||||
description: "Invalid value (special characters in organization)",
|
||||
value: "org@1/my-value",
|
||||
},
|
||||
{ description: "Invalid value (missing name part)", value: "org1/" },
|
||||
{
|
||||
description: "Invalid value (invalid characters in name)",
|
||||
value: "my$value",
|
||||
},
|
||||
];
|
||||
|
||||
const commonValidCases = [
|
||||
{ description: "Valid organization/name", value: "org1/my-value" },
|
||||
{ description: "Valid name only", value: "my-value" },
|
||||
{
|
||||
description: "Valid name with underscores and dashes",
|
||||
value: "my_value-123",
|
||||
},
|
||||
{
|
||||
description: "Valid organization with underscores and dashes",
|
||||
value: "org_123/my-value",
|
||||
},
|
||||
];
|
||||
|
||||
const pipelineOptionsCases = [
|
||||
// Invalid cases for various fields
|
||||
...commonInvalidCases.map(test => ({
|
||||
description: `Invalid processorId (${test.description})`,
|
||||
options: { processorId: test.value },
|
||||
expectedError: /Invalid value/,
|
||||
})),
|
||||
...commonInvalidCases.map(test => ({
|
||||
description: `Invalid tokenizerId (${test.description})`,
|
||||
options: { tokenizerId: test.value },
|
||||
expectedError: /Invalid value/,
|
||||
})),
|
||||
...commonInvalidCases.map(test => ({
|
||||
description: `Invalid modelId (${test.description})`,
|
||||
options: { modelId: test.value },
|
||||
expectedError: /Invalid value/,
|
||||
})),
|
||||
|
||||
// Valid cases for various fields
|
||||
...commonValidCases.map(test => ({
|
||||
description: `Valid processorId (${test.description})`,
|
||||
options: { processorId: test.value },
|
||||
expected: { processorId: test.value },
|
||||
})),
|
||||
...commonValidCases.map(test => ({
|
||||
description: `Valid tokenizerId (${test.description})`,
|
||||
options: { tokenizerId: test.value },
|
||||
expected: { tokenizerId: test.value },
|
||||
})),
|
||||
...commonValidCases.map(test => ({
|
||||
description: `Valid modelId (${test.description})`,
|
||||
options: { modelId: test.value },
|
||||
expected: { modelId: test.value },
|
||||
})),
|
||||
|
||||
// Invalid dtype, device, executionPriority, featureId and logLevel cases
|
||||
{
|
||||
description: "Invalid featureId",
|
||||
options: { featureId: "unknown" },
|
||||
expectedError: /Invalid value/,
|
||||
},
|
||||
{
|
||||
description: "Invalid dtype",
|
||||
options: { dtype: "invalid_dtype" },
|
||||
expectedError: /Invalid value/,
|
||||
},
|
||||
{
|
||||
description: "Invalid device",
|
||||
options: { device: "invalid_device" },
|
||||
expectedError: /Invalid value/,
|
||||
},
|
||||
{
|
||||
description: "Invalid executionPriority",
|
||||
options: { executionPriority: "invalid_priority" },
|
||||
expectedError: /Invalid value/,
|
||||
},
|
||||
{
|
||||
description: "Invalid logLevel",
|
||||
options: { logLevel: "invalid_log_level" },
|
||||
expectedError: /Invalid value/,
|
||||
},
|
||||
|
||||
// Valid cases for dtype, device, executionPriority, and logLevel
|
||||
{
|
||||
description: "Valid dtype",
|
||||
options: { dtype: QuantizationLevel.FP16 },
|
||||
expected: { dtype: QuantizationLevel.FP16 },
|
||||
},
|
||||
{
|
||||
description: "Valid device",
|
||||
options: { device: InferenceDevice.WASM },
|
||||
expected: { device: InferenceDevice.WASM },
|
||||
},
|
||||
{
|
||||
description: "Valid executionPriority",
|
||||
options: { executionPriority: ExecutionPriority.HIGH },
|
||||
expected: { executionPriority: ExecutionPriority.HIGH },
|
||||
},
|
||||
{
|
||||
description: "Valid logLevel (Info)",
|
||||
options: { logLevel: LogLevel.INFO },
|
||||
expected: { logLevel: LogLevel.INFO },
|
||||
},
|
||||
{
|
||||
description: "Valid logLevel (Critical)",
|
||||
options: { logLevel: LogLevel.CRITICAL },
|
||||
expected: { logLevel: LogLevel.CRITICAL },
|
||||
},
|
||||
{
|
||||
description: "Valid logLevel (All)",
|
||||
options: { logLevel: LogLevel.ALL },
|
||||
expected: { logLevel: LogLevel.ALL },
|
||||
},
|
||||
|
||||
// Invalid revision cases
|
||||
{
|
||||
description: "Invalid revision (random string)",
|
||||
options: { modelRevision: "invalid_revision" },
|
||||
expectedError: /Invalid value/,
|
||||
},
|
||||
{
|
||||
description: "Invalid revision (too many version numbers)",
|
||||
options: { tokenizerRevision: "v1.0.3.4.5" },
|
||||
expectedError: /Invalid value/,
|
||||
},
|
||||
{
|
||||
description: "Invalid revision (unknown suffix)",
|
||||
options: { processorRevision: "v1.0.0-unknown" },
|
||||
expectedError: /Invalid value/,
|
||||
},
|
||||
|
||||
// Valid revision cases with new format
|
||||
{
|
||||
description: "Valid revision (main)",
|
||||
options: { modelRevision: "main" },
|
||||
expected: { modelRevision: "main" },
|
||||
},
|
||||
{
|
||||
description: "Valid revision (v-prefixed version with alpha)",
|
||||
options: { tokenizerRevision: "v1.2.3-alpha1" },
|
||||
expected: { tokenizerRevision: "v1.2.3-alpha1" },
|
||||
},
|
||||
{
|
||||
description:
|
||||
"Valid revision (v-prefixed version with beta and dot separator)",
|
||||
options: { tokenizerRevision: "v1.2.3.beta2" },
|
||||
expected: { tokenizerRevision: "v1.2.3.beta2" },
|
||||
},
|
||||
{
|
||||
description:
|
||||
"Valid revision (non-prefixed version with rc and dash separator)",
|
||||
options: { processorRevision: "1.0.0-rc3" },
|
||||
expected: { processorRevision: "1.0.0-rc3" },
|
||||
},
|
||||
{
|
||||
description:
|
||||
"Valid revision (non-prefixed version with pre and dot separator)",
|
||||
options: { processorRevision: "1.0.0.pre4" },
|
||||
expected: { processorRevision: "1.0.0.pre4" },
|
||||
},
|
||||
{
|
||||
description: "Valid revision (version without suffix)",
|
||||
options: { modelRevision: "1.0.0" },
|
||||
expected: { modelRevision: "1.0.0" },
|
||||
},
|
||||
];
|
||||
|
||||
/**
|
||||
* Testing PipelineOption validation
|
||||
*/
|
||||
add_task(async function test_pipeline_options_validation() {
|
||||
pipelineOptionsCases.forEach(testCase => {
|
||||
if (testCase.expectedError) {
|
||||
Assert.throws(
|
||||
() => new PipelineOptions(getValidOptions(testCase.options)),
|
||||
testCase.expectedError,
|
||||
`${testCase.description} throws the expected error`
|
||||
);
|
||||
} else {
|
||||
const pipelineOptions = new PipelineOptions(
|
||||
getValidOptions(testCase.options)
|
||||
);
|
||||
Object.keys(testCase.expected).forEach(key => {
|
||||
is(
|
||||
pipelineOptions[key],
|
||||
testCase.expected[key],
|
||||
`${testCase.description} sets ${key} correctly`
|
||||
);
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
|
@ -16,7 +16,14 @@ const { ModelHub, IndexedDBCache } = ChromeUtils.importESModule(
|
||||
"chrome://global/content/ml/ModelHub.sys.mjs"
|
||||
);
|
||||
|
||||
const { createEngine, PipelineOptions } = ChromeUtils.importESModule(
|
||||
const {
|
||||
createEngine,
|
||||
PipelineOptions,
|
||||
QuantizationLevel,
|
||||
ExecutionPriority,
|
||||
InferenceDevice,
|
||||
LogLevel,
|
||||
} = ChromeUtils.importESModule(
|
||||
"chrome://global/content/ml/EngineProcess.sys.mjs"
|
||||
);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user