Bug 1925471 - Check for values in PipelineOptions r=vazish,urlbar-reviewers,adw, a=dmeehan

Differential Revision: https://phabricator.services.mozilla.com/D226224
This commit is contained in:
Tarek Ziadé 2024-10-30 13:05:24 +00:00
parent e5bebc56cd
commit f679596bf1
5 changed files with 499 additions and 77 deletions

View File

@ -19,16 +19,16 @@ ChromeUtils.defineESModuleGetters(lazy, {
*/
const INTENT_OPTIONS = {
taskName: "text-classification",
modelId: "mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier",
modelRevision: "v0.1.0",
dtype: "q8",
featureId: "suggest-intent-classification",
engineId: "ml-suggest-intent",
timeoutMS: 36000000,
};
const NER_OPTIONS = {
taskName: "token-classification",
modelId: "mozilla/distilbert-uncased-NER-LoRA",
modelRevision: "v0.1.1",
dtype: "q8",
featureId: "suggest-NER",
engineId: "ml-suggest-ner",
timeoutMS: 36000000,
};
// List of prepositions used in subject cleaning.
@ -115,20 +115,8 @@ class _MLSuggest {
}
}
/**
* Helper method to generate a unique key for model engines.
*
* @param {object} options
* The options object containing taskName and modelId.
* @returns {string}
* The key for the model engine.
*/
#getmodelEnginesKey(options) {
return `${options.taskName}-${options.modelId}`;
}
async #initializeModelEngine(options) {
const engineId = this.#getmodelEnginesKey(options);
const engineId = options.engineId;
// uses cache if engine was used
if (this.#modelEngines[engineId]) {
@ -153,9 +141,7 @@ class _MLSuggest {
* The predicted intent label or null if the model is not initialized.
*/
async _findIntent(query, options = {}) {
const engineIntentClassifier =
this.#modelEngines[this.#getmodelEnginesKey(INTENT_OPTIONS)];
const engineIntentClassifier = this.#modelEngines[INTENT_OPTIONS.engineId];
if (!engineIntentClassifier) {
return null;
}
@ -180,7 +166,7 @@ class _MLSuggest {
* The NER results or null if the model is not initialized.
*/
async _findNER(query, options = {}) {
const engineNER = this.#modelEngines[this.#getmodelEnginesKey(NER_OPTIONS)];
const engineNER = this.#modelEngines[NER_OPTIONS.engineId];
return engineNER?.run({ args: [query], options });
}

View File

@ -580,7 +580,7 @@ async function runInference() {
tokenizerId: modelId,
processorId: modelId,
taskName,
engineId: "about:inference",
engineId: "about-inference",
modelHubRootUrl,
modelHubUrlTemplate,
device,

View File

@ -16,6 +16,42 @@ ChromeUtils.defineESModuleGetters(
* @typedef {import("../content/Utils.sys.mjs").ProgressAndStatusCallbackParams} ProgressAndStatusCallbackParams
*/
/**
* Lists Firefox internal features
*/
const FEATURES = [
"pdfjs-alt-text", // see toolkit/components/pdfjs/content/PdfjsParent.sys.mjs
"suggest-intent-classification", // see browser/components/urlbar/private/MLSuggest.sys.mjs
"suggest-NER", // see browser/components/urlbar/private/MLSuggest.sys.mjs
];
/**
* Custom error class for validation errors.
*
* This error is thrown when a field fails validation, providing additional context such as
* the name of the field that caused the error.
*
* @augments Error
*/
class PipelineOptionsValidationError extends Error {
/**
* Create a PipelineOptionsValidationError.
*
* @param {string} field - The name of the field that caused the validation error.
* @param {any} value - The invalid value provided for the field.
* @param {string} [tips=null] - Optional tips or suggestions for valid values.
*/
constructor(field, value, tips = null) {
const baseMessage = `Invalid value "${value}" for field "${field}".`;
const message = tips ? `${baseMessage} ${tips}` : baseMessage;
super(message);
this.name = this.constructor.name;
this.field = field;
this.value = value;
}
}
/**
* Enum for execution priority.
*
@ -29,16 +65,66 @@ ChromeUtils.defineESModuleGetters(
* @enum {string}
*/
export const ExecutionPriority = {
/** High priority, needed for Firefox */
HIGH: "HIGH",
/** Normal priority, default */
NORMAL: "NORMAL",
/** Low priority, used for 3rd party calls */
LOW: "LOW",
};
/**
* Enum for quantization levels.
*
* Defines the quantization level of the task:
*
* - 'fp32': Full precision 32-bit floating point (`''`)
* - 'fp16': Half precision 16-bit floating point (`'_fp16'`)
* - 'q8': Quantized 8-bit (`'_quantized'`)
* - 'int8': Integer 8-bit quantization (`'_int8'`)
* - 'uint8': Unsigned integer 8-bit quantization (`'_uint8'`)
* - 'q4': Quantized 4-bit (`'_q4'`)
* - 'bnb4': Binary/Boolean 4-bit quantization (`'_bnb4'`)
* - 'q4f16': 16-bit floating point model with 4-bit block weight quantization (`'_q4f16'`)
*
* @readonly
* @enum {string}
*/
export const QuantizationLevel = {
FP32: "fp32",
FP16: "fp16",
Q8: "q8",
INT8: "int8",
UINT8: "uint8",
Q4: "q4",
BNB4: "bnb4",
Q4F16: "q4f16",
};
/**
* Enum for the device used for inference.
*
* @readonly
* @enum {string}
*/
export const InferenceDevice = {
GPU: "gpu",
WASM: "wasm",
};
/**
* Enum for log levels.
*
* @readonly
* @enum {string}
*/
export const LogLevel = {
TRACE: "Trace",
INFO: "Info",
DEBUG: "Debug",
WARN: "Warn",
ERROR: "Error",
CRITICAL: "Critical",
ALL: "All",
};
/**
* @typedef {import("../../translations/actors/TranslationsEngineParent.sys.mjs").TranslationsEngineParent} TranslationsEngineParent
*/
@ -138,7 +224,7 @@ export class PipelineOptions {
/**
* The log level used in the worker
*
* @type {?string}
* @type {LogLevel | null}
*/
logLevel = null;
@ -152,24 +238,14 @@ export class PipelineOptions {
/**
* Device used for inference
*
* @type {"gpu" | "wasm" | null}
* @type {InferenceDevice | null}
*/
device = null;
/**
* Quantization level
*
* - name : description (onnx file suffix)
* - 'fp32': Full precision 32-bit floating point (`''`)
* - 'fp16': Half precision 16-bit floating point (`'_fp16'`)
* - 'q8': Quantized 8-bit (`'_quantized'`)
* - 'int8': Integer 8-bit quantization (`'_int8'`)
* - 'uint8': Unsigned integer 8-bit quantization (`'_uint8'`)
* - 'q4': Quantized 4-bit (`'_q4'`)
* - 'bnb4': Binary/Boolean 4-bit quantization (`'_bnb4'`)
* - 'q4f16': 16-bit floating point model with 4-bit block weight quantization (`'_q4f16'`)
*
* @type {"fp32" | "fp16" | "q8" | "int8" | "uint8" | "q4" | "bnb4" | "q4f16" | null}
* @type {QuantizationLevel | null}
*/
dtype = null;
@ -187,7 +263,7 @@ export class PipelineOptions {
*
* @type {ExecutionPriority}
*/
executionPriority = ExecutionPriority.NORMAL;
executionPriority = null;
/**
* Create a PipelineOptions instance.
@ -198,6 +274,122 @@ export class PipelineOptions {
this.updateOptions(options);
}
/**
* Private method to validate enum fields.
*
* @param {string} field - The field being validated (e.g., 'dtype', 'device', 'executionPriority').
* @param {*} value - The value being checked against the enum.
* @throws {Error} Throws an error if the value is not valid.
* @private
*/
#validateEnum(field, value) {
const enums = {
dtype: QuantizationLevel,
device: InferenceDevice,
executionPriority: ExecutionPriority,
logLevel: LogLevel,
};
// Check if the value is part of the enum or null
if (!Object.values(enums[field]).includes(value)) {
throw new PipelineOptionsValidationError(field, value);
}
}
/**
* Validates the taskName field, ensuring it contains only alphanumeric characters, underscores, and dashes.
* Slashes are not allowed in the taskName.
*
* @param {string} field - The name of the field being validated (e.g., taskName).
* @param {string} value - The value of the field to validate.
* @throws {Error} Throws an error if the taskName contains invalid characters.
* @private
*/
#validateTaskName(field, value) {
// Define a regular expression to verify taskName pattern (alphanumeric, underscores, and dashes, no slashes)
const validTaskNamePattern = /^[a-zA-Z0-9_\-]+$/;
// Check if the value matches the pattern
if (!validTaskNamePattern.test(value)) {
throw new PipelineOptionsValidationError(
field,
value,
"Should contain only alphanumeric characters, underscores, or dashes."
);
}
}
/**
* Validates a taskName or ID.
*
* The ID can optionally be in the form `organization/name`, where both `organization` and `name`
* follow the `taskName` pattern (alphanumeric characters, underscores, and dashes).
*
* Throws an exception if the name or ID is invalid.
*
* @param {string} field - The name of the field being validated (e.g., taskName, engineId).
* @param {string} value - The value of the field to validate.
* @throws {PipelineOptionsValidationError} Throws a validation error if the ID is invalid.
* @private
*/
#validateId(field, value) {
// Define a regular expression to match the optional organization and required name
// `organization/` part is optional, and both parts should follow the taskName pattern.
const validPattern = /^(?:[a-zA-Z0-9_\-]+\/)?[a-zA-Z0-9_\-]+$/;
// Check if the value matches the pattern
if (!validPattern.test(value)) {
throw new PipelineOptionsValidationError(
field,
value,
"Should follow the format 'organization/name' or 'name', where both parts contain only alphanumeric characters, underscores, or dashes."
);
}
}
/**
* Generic method to validate an integer within a specified range.
*
* @param {string} field - The name of the field being validated.
* @param {number} value - The integer value to validate.
* @param {number} min - The minimum allowed value (inclusive).
* @param {number} max - The maximum allowed value (inclusive).
* @throws {Error} Throws an error if the value is not a valid integer within the range.
* @private
*/
#validateIntegerRange(field, value, min, max) {
if (!Number.isInteger(value) || value < min || value > max) {
throw new PipelineOptionsValidationError(
field,
value,
`Should be an integer between ${min} and ${max}.`
);
}
}
/**
* Validates the revision field.
* The revision can be `main` or a version following a pattern like `v1.0.0`, `1.0.0-beta1`, `1.0.0.alpha2`, `1.0.0.rc1`, etc.
*
* @param {string} field - The name of the field being validated (e.g., modelRevision, tokenizerRevision).
* @param {string} value - The value of the revision field to validate.
* @throws {Error} Throws an error if the revision does not follow the expected pattern.
* @private
*/
#validateRevision(field, value) {
// Regular expression to match `main` or a version like `v1`, `v1.0.0`, `1.0.0-alpha1`, `1.0.0.alpha2`, `1.0.0.rc1`, etc.
const revisionPattern =
/^v?(\d+(\.\d+){0,2})([-\.](alpha\d*|beta\d*|pre\d*|post\d*|rc\d*))?$|^main$/;
// Check if the value matches the pattern
if (!revisionPattern.test(value)) {
throw new PipelineOptionsValidationError(
field,
value,
`Should be 'main' or follow a versioning pattern like 'v1.0.0', '1.0.0-beta1', '1.0.0.alpha2', '1.0.0.rc1', etc.`
);
}
}
/**
* Updates multiple options at once.
*
@ -238,13 +430,49 @@ export class PipelineOptions {
if (!optionsKeys.includes(key) || options[key] == null) {
return;
}
if (key === "featureId" && !FEATURES.includes(options[key])) {
throw new PipelineOptionsValidationError(
key,
options[key],
`Should be one of ${FEATURES.join(", ")}`
);
}
// Validating values.
if (["taskName", "engineId"].includes(key)) {
this.#validateTaskName(key, options[key]);
}
if (["modelId", "tokenizerId", "processorId"].includes(key)) {
this.#validateId(key, options[key]);
}
if (
["modelRevision", "tokenizerRevision", "processorRevision"].includes(
key
)
) {
this.#validateRevision(key, options[key]);
}
if (["dtype", "device", "executionPriority", "logLevel"].includes(key)) {
this.#validateEnum(key, options[key]);
}
if (key === "numThreads") {
this.#validateIntegerRange(key, options[key], 0, 100);
}
if (key === "timeoutMS") {
this.#validateIntegerRange(key, options[key], 0, 36000000);
}
this[key] = options[key];
});
}
/**
* Returns an object containing all current options.
*
* @returns {object} An object with the current options.
*/
getOptions() {

View File

@ -101,7 +101,7 @@ add_task(async function test_ml_engine_pick_feature_id() {
id: "74a71cfd-1734-44e6-85c0-69cf3e874138",
},
{
featureId: "myCoolFeature",
featureId: "pdfjs-alt-text",
taskName: "moz-echo",
modelId: "mozilla/distilvit",
processorId: "mozilla/distilvit",
@ -118,7 +118,7 @@ add_task(async function test_ml_engine_pick_feature_id() {
info("Get the engine");
const engineInstance = await createEngine({
featureId: "myCoolFeature",
featureId: "pdfjs-alt-text",
taskName: "moz-echo",
});
@ -284,35 +284,6 @@ add_task(async function test_pref_is_off() {
});
});
/**
* Tests that we verify the task name is valid
*/
add_task(async function test_invalid_task_name() {
const { cleanup, remoteClients } = await setup();
const options = new PipelineOptions({ taskName: "inv#alid" });
const mlEngineParent = await EngineProcess.getMLEngineParent();
const engineInstance = await mlEngineParent.getEngine(options);
let error;
try {
const res = engineInstance.run({ data: "This gets echoed." });
await remoteClients["ml-onnx-runtime"].resolvePendingDownloads(1);
await res;
} catch (e) {
error = e;
}
is(
error?.message,
"Invalid task name. Task name should contain only alphanumeric characters and underscores/dashes.",
"The error is correctly surfaced."
);
await EngineProcess.destroyMLEngine();
await cleanup();
});
/**
* Tests the generic pipeline API
*/
@ -714,7 +685,7 @@ add_task(async function test_ml_engine_get_status() {
device: null,
dtype: "q8",
numThreads: null,
executionPriority: "NORMAL",
executionPriority: null,
},
engineId: "default-engine",
},
@ -760,3 +731,233 @@ add_task(async function test_ml_engine_not_enough_memory() {
await EngineProcess.destroyMLEngine();
await cleanup();
});
/**
* Helper function to create a basic set of valid options
*/
function getValidOptions(overrides = {}) {
return Object.assign(
{
engineId: "validEngine1",
featureId: "pdfjs-alt-text",
taskName: "valid_task",
modelHubRootUrl: "https://example.com",
modelHubUrlTemplate: "https://example.com/{modelId}",
timeoutMS: 5000,
modelId: "validModel",
modelRevision: "v1",
tokenizerId: "validTokenizer",
tokenizerRevision: "v1",
processorId: "validProcessor",
processorRevision: "v1",
logLevel: null,
runtimeFilename: "runtime.wasm",
device: InferenceDevice.GPU,
numThreads: 4,
executionPriority: ExecutionPriority.NORMAL,
},
overrides
);
}
/**
* A collection of test cases for invalid and valid values.
*/
const commonInvalidCases = [
{ description: "Invalid value (special characters)", value: "org1/my!value" },
{
description: "Invalid value (special characters in organization)",
value: "org@1/my-value",
},
{ description: "Invalid value (missing name part)", value: "org1/" },
{
description: "Invalid value (invalid characters in name)",
value: "my$value",
},
];
const commonValidCases = [
{ description: "Valid organization/name", value: "org1/my-value" },
{ description: "Valid name only", value: "my-value" },
{
description: "Valid name with underscores and dashes",
value: "my_value-123",
},
{
description: "Valid organization with underscores and dashes",
value: "org_123/my-value",
},
];
const pipelineOptionsCases = [
// Invalid cases for various fields
...commonInvalidCases.map(test => ({
description: `Invalid processorId (${test.description})`,
options: { processorId: test.value },
expectedError: /Invalid value/,
})),
...commonInvalidCases.map(test => ({
description: `Invalid tokenizerId (${test.description})`,
options: { tokenizerId: test.value },
expectedError: /Invalid value/,
})),
...commonInvalidCases.map(test => ({
description: `Invalid modelId (${test.description})`,
options: { modelId: test.value },
expectedError: /Invalid value/,
})),
// Valid cases for various fields
...commonValidCases.map(test => ({
description: `Valid processorId (${test.description})`,
options: { processorId: test.value },
expected: { processorId: test.value },
})),
...commonValidCases.map(test => ({
description: `Valid tokenizerId (${test.description})`,
options: { tokenizerId: test.value },
expected: { tokenizerId: test.value },
})),
...commonValidCases.map(test => ({
description: `Valid modelId (${test.description})`,
options: { modelId: test.value },
expected: { modelId: test.value },
})),
// Invalid dtype, device, executionPriority, featureId and logLevel cases
{
description: "Invalid featureId",
options: { featureId: "unknown" },
expectedError: /Invalid value/,
},
{
description: "Invalid dtype",
options: { dtype: "invalid_dtype" },
expectedError: /Invalid value/,
},
{
description: "Invalid device",
options: { device: "invalid_device" },
expectedError: /Invalid value/,
},
{
description: "Invalid executionPriority",
options: { executionPriority: "invalid_priority" },
expectedError: /Invalid value/,
},
{
description: "Invalid logLevel",
options: { logLevel: "invalid_log_level" },
expectedError: /Invalid value/,
},
// Valid cases for dtype, device, executionPriority, and logLevel
{
description: "Valid dtype",
options: { dtype: QuantizationLevel.FP16 },
expected: { dtype: QuantizationLevel.FP16 },
},
{
description: "Valid device",
options: { device: InferenceDevice.WASM },
expected: { device: InferenceDevice.WASM },
},
{
description: "Valid executionPriority",
options: { executionPriority: ExecutionPriority.HIGH },
expected: { executionPriority: ExecutionPriority.HIGH },
},
{
description: "Valid logLevel (Info)",
options: { logLevel: LogLevel.INFO },
expected: { logLevel: LogLevel.INFO },
},
{
description: "Valid logLevel (Critical)",
options: { logLevel: LogLevel.CRITICAL },
expected: { logLevel: LogLevel.CRITICAL },
},
{
description: "Valid logLevel (All)",
options: { logLevel: LogLevel.ALL },
expected: { logLevel: LogLevel.ALL },
},
// Invalid revision cases
{
description: "Invalid revision (random string)",
options: { modelRevision: "invalid_revision" },
expectedError: /Invalid value/,
},
{
description: "Invalid revision (too many version numbers)",
options: { tokenizerRevision: "v1.0.3.4.5" },
expectedError: /Invalid value/,
},
{
description: "Invalid revision (unknown suffix)",
options: { processorRevision: "v1.0.0-unknown" },
expectedError: /Invalid value/,
},
// Valid revision cases with new format
{
description: "Valid revision (main)",
options: { modelRevision: "main" },
expected: { modelRevision: "main" },
},
{
description: "Valid revision (v-prefixed version with alpha)",
options: { tokenizerRevision: "v1.2.3-alpha1" },
expected: { tokenizerRevision: "v1.2.3-alpha1" },
},
{
description:
"Valid revision (v-prefixed version with beta and dot separator)",
options: { tokenizerRevision: "v1.2.3.beta2" },
expected: { tokenizerRevision: "v1.2.3.beta2" },
},
{
description:
"Valid revision (non-prefixed version with rc and dash separator)",
options: { processorRevision: "1.0.0-rc3" },
expected: { processorRevision: "1.0.0-rc3" },
},
{
description:
"Valid revision (non-prefixed version with pre and dot separator)",
options: { processorRevision: "1.0.0.pre4" },
expected: { processorRevision: "1.0.0.pre4" },
},
{
description: "Valid revision (version without suffix)",
options: { modelRevision: "1.0.0" },
expected: { modelRevision: "1.0.0" },
},
];
/**
* Testing PipelineOption validation
*/
add_task(async function test_pipeline_options_validation() {
pipelineOptionsCases.forEach(testCase => {
if (testCase.expectedError) {
Assert.throws(
() => new PipelineOptions(getValidOptions(testCase.options)),
testCase.expectedError,
`${testCase.description} throws the expected error`
);
} else {
const pipelineOptions = new PipelineOptions(
getValidOptions(testCase.options)
);
Object.keys(testCase.expected).forEach(key => {
is(
pipelineOptions[key],
testCase.expected[key],
`${testCase.description} sets ${key} correctly`
);
});
}
});
});

View File

@ -16,7 +16,14 @@ const { ModelHub, IndexedDBCache } = ChromeUtils.importESModule(
"chrome://global/content/ml/ModelHub.sys.mjs"
);
const { createEngine, PipelineOptions } = ChromeUtils.importESModule(
const {
createEngine,
PipelineOptions,
QuantizationLevel,
ExecutionPriority,
InferenceDevice,
LogLevel,
} = ChromeUtils.importESModule(
"chrome://global/content/ml/EngineProcess.sys.mjs"
);