diff --git a/testing/web-platform/tests/webnn/resources/utils_validation.js b/testing/web-platform/tests/webnn/resources/utils_validation.js index e289d3815d17..b2c1c534c8c6 100644 --- a/testing/web-platform/tests/webnn/resources/utils_validation.js +++ b/testing/web-platform/tests/webnn/resources/utils_validation.js @@ -17,6 +17,11 @@ const allWebNNOperandDataTypes = [ // range [0, 4294967295]. // 4294967295 = 2 ** 32 - 1 const kMaxUnsignedLong = 2 ** 32 - 1; + +const floatingPointTypes = ['float32', 'float16']; + +const signedIntegerTypes = ['int32', 'int64', 'int8']; + const unsignedLongType = 'unsigned long'; const dimensions0D = []; @@ -361,6 +366,51 @@ function validateOptionsAxes(operationName, inputRank) { } } +/** + * Validate a unary operation + * @param {String} operationName - An operation name + * @param {Array} supportedDataTypes - Test building with these data types + * succeeds and test building with all other data types fails + * @param {Boolean} alsoBuildActivation - If test building this operation as an + * activation + */ +function validateUnaryOperation( + operationName, supportedDataTypes, alsoBuildActivation = false) { + for (let dataType of supportedDataTypes) { + for (let dimensions of allWebNNDimensionsArray) { + promise_test( + async t => { + const input = builder.input(`input`, {dataType, dimensions}); + const output = builder[operationName](input); + assert_equals(output.dataType(), dataType); + assert_array_equals(output.shape(), dimensions); + }, + `[${operationName}] Test building an operator, dataType = ${ + dataType}, dimensions = [${dimensions}]`); + } + } + + const unsupportedDataTypes = + new Set(allWebNNOperandDataTypes).difference(new Set(supportedDataTypes)); + for (let dataType of unsupportedDataTypes) { + for (let dimensions of allWebNNDimensionsArray) { + promise_test( + async t => { + const input = builder.input(`input`, {dataType, dimensions}); + assert_throws_js(TypeError, () => builder[operationName](input)); + }, + `[${operationName}] Throw if the dataType is not supported, dataType = ${ + dataType}, dimensions = [${dimensions}]`); + } + } + + if (alsoBuildActivation) { + promise_test(async t => { + builder[operationName](); + }, `[${operationName}] Test building an activation`); + } +} + /** * Basic test that the builder method specified by `operationName` throws if * given an input from another builder. Operands which do not accept a float32 diff --git a/testing/web-platform/tests/webnn/validation_tests/elementwise-unary.https.any.js b/testing/web-platform/tests/webnn/validation_tests/elementwise-unary.https.any.js index 8f3d544d83ab..f87c61b4e456 100644 --- a/testing/web-platform/tests/webnn/validation_tests/elementwise-unary.https.any.js +++ b/testing/web-platform/tests/webnn/validation_tests/elementwise-unary.https.any.js @@ -12,3 +12,28 @@ const kElementwiseUnaryOperators = [ kElementwiseUnaryOperators.forEach((operatorName) => { validateInputFromAnotherBuilder(operatorName); }); + +const kElementwiseUnaryOperations = [ + { + name: 'abs', + supportedDataTypes: [...floatingPointTypes, ...signedIntegerTypes] + }, + {name: 'ceil', supportedDataTypes: floatingPointTypes}, + {name: 'exp', supportedDataTypes: floatingPointTypes}, + {name: 'floor', supportedDataTypes: floatingPointTypes}, + {name: 'log', supportedDataTypes: floatingPointTypes}, { + name: 'neg', + supportedDataTypes: [...floatingPointTypes, ...signedIntegerTypes] + }, + {name: 'sin', supportedDataTypes: floatingPointTypes}, + {name: 'tan', supportedDataTypes: floatingPointTypes}, + {name: 'erf', supportedDataTypes: floatingPointTypes}, + {name: 'identity', supportedDataTypes: allWebNNOperandDataTypes}, + {name: 'logicalNot', supportedDataTypes: ['uint8']}, + {name: 'reciprocal', supportedDataTypes: floatingPointTypes}, + {name: 'sqrt', supportedDataTypes: floatingPointTypes} +]; + +kElementwiseUnaryOperations.forEach((operation) => { + validateUnaryOperation(operation.name, operation.supportedDataTypes); +}); diff --git a/testing/web-platform/tests/webnn/validation_tests/hardSwish.https.any.js b/testing/web-platform/tests/webnn/validation_tests/hardSwish.https.any.js index dd85165cf56e..97ecfb4142de 100644 --- a/testing/web-platform/tests/webnn/validation_tests/hardSwish.https.any.js +++ b/testing/web-platform/tests/webnn/validation_tests/hardSwish.https.any.js @@ -5,3 +5,6 @@ 'use strict'; validateInputFromAnotherBuilder('hardSwish'); + +validateUnaryOperation( + 'hardSwish', floatingPointTypes, /*alsoBuildActivation=*/ true); diff --git a/testing/web-platform/tests/webnn/validation_tests/relu.https.any.js b/testing/web-platform/tests/webnn/validation_tests/relu.https.any.js index 4c44d3c0dc17..237c1c3eda3a 100644 --- a/testing/web-platform/tests/webnn/validation_tests/relu.https.any.js +++ b/testing/web-platform/tests/webnn/validation_tests/relu.https.any.js @@ -5,3 +5,6 @@ 'use strict'; validateInputFromAnotherBuilder('relu'); + +validateUnaryOperation( + 'relu', allWebNNOperandDataTypes, /*alsoBuildActivation=*/ true); diff --git a/testing/web-platform/tests/webnn/validation_tests/sigmoid.https.any.js b/testing/web-platform/tests/webnn/validation_tests/sigmoid.https.any.js index 206b6eda26b9..b40ddc3fd4a7 100644 --- a/testing/web-platform/tests/webnn/validation_tests/sigmoid.https.any.js +++ b/testing/web-platform/tests/webnn/validation_tests/sigmoid.https.any.js @@ -5,3 +5,6 @@ 'use strict'; validateInputFromAnotherBuilder('sigmoid'); + +validateUnaryOperation( + 'sigmoid', floatingPointTypes, /*alsoBuildActivation=*/ true); diff --git a/testing/web-platform/tests/webnn/validation_tests/softsign.https.any.js b/testing/web-platform/tests/webnn/validation_tests/softsign.https.any.js index a4a3847c2a3d..58ec48715996 100644 --- a/testing/web-platform/tests/webnn/validation_tests/softsign.https.any.js +++ b/testing/web-platform/tests/webnn/validation_tests/softsign.https.any.js @@ -5,3 +5,6 @@ 'use strict'; validateInputFromAnotherBuilder('softsign'); + +validateUnaryOperation( + 'softsign', floatingPointTypes, /*alsoBuildActivation=*/ true); diff --git a/testing/web-platform/tests/webnn/validation_tests/tanh.https.any.js b/testing/web-platform/tests/webnn/validation_tests/tanh.https.any.js index 4be36b9dbf2f..4f9de919f61e 100644 --- a/testing/web-platform/tests/webnn/validation_tests/tanh.https.any.js +++ b/testing/web-platform/tests/webnn/validation_tests/tanh.https.any.js @@ -5,3 +5,6 @@ 'use strict'; validateInputFromAnotherBuilder('tanh'); + +validateUnaryOperation( + 'tanh', floatingPointTypes, /*alsoBuildActivation=*/ true);