Bug 1692069 - wasm: Update typing of br_on_cast instruction. r=lth

The br_on_cast instruction is currently too strict and disallows the passing
along extra values. This commit adds validation for this, adds comments clarifying
the baseline implementation of the function, normalizes argument order for
WasmOpIter, and adds a test.

Differential Revision: https://phabricator.services.mozilla.com/D108505
This commit is contained in:
Ryan Hunt 2021-03-16 16:13:11 +00:00
parent 3567b6d14a
commit 59b1f4f018
5 changed files with 277 additions and 18 deletions

View File

@ -0,0 +1,174 @@
// |jit-test| skip-if: !wasmGcEnabled()
function typingModule(types, castToTypeIndex, brParams, blockResults) {
return `(module
${types}
(func
(param ${brParams.join(' ')})
(result ${blockResults.join(' ')})
(; push params onto the stack in the same order as they appear, leaving
the last param at the top of the stack. ;)
${brParams.map((_, i) => `local.get ${i}`).join('\n')}
rtt.canon ${castToTypeIndex}
br_on_cast 0
unreachable
)
)`;
}
function validTyping(types, castToTypeIndex, brParams, blockResults) {
wasmValidateText(typingModule(types, castToTypeIndex, brParams, blockResults));
}
function invalidTyping(types, castToTypeIndex, brParams, blockResults, error) {
wasmFailValidateText(typingModule(types, castToTypeIndex, brParams, blockResults), error);
}
// valid: input eqref, output non-nullable struct
validTyping('(type $a (struct))', '$a', ['eqref'], ['(ref $a)']);
// valid: input eqref, output nullable struct
validTyping('(type $a (struct))', '$a', ['eqref'], ['(ref null $a)']);
// valid: input non-nullable struct, output non-nullable struct
validTyping('(type $a (struct)) (type $b (struct))', '$b', ['(ref $a)'], ['(ref $b)']);
// valid: input nullable struct, output non-nullable struct
validTyping('(type $a (struct)) (type $b (struct))', '$b', ['(ref null $a)'], ['(ref $b)']);
// valid: input nullable struct, output nullable struct
validTyping('(type $a (struct)) (type $b (struct))', '$b', ['(ref null $a)'], ['(ref null $b)']);
// valid: input with an extra i32
validTyping('(type $a (struct))', '$a', ['i32', 'eqref'], ['i32', '(ref $a)']);
// valid: input with an extra i32 and f32
validTyping('(type $a (struct))', '$a', ['i32', 'f32', 'eqref'], ['i32', 'f32', '(ref $a)']);
// invalid: block result type must have slot for casted-to type
invalidTyping('(type $a (struct))', '$a', ['eqref'], [], /type mismatch/);
// invalid: block result type must be subtype of casted-to type
invalidTyping('(type $a (struct)) (type $b (struct (field i32)))', '$a', ['eqref'], ['(ref $b)'], /type mismatch/);
// invalid: input is missing extra i32 from the branch target type
invalidTyping('(type $a (struct))', '$a', ['f32', 'eqref'], ['i32', 'f32', '(ref $a)'], /popping value/);
// invalid: input is has extra [i32, f32] swapped from the branch target type
invalidTyping('(type $a (struct))', '$a', ['i32', 'f32', 'eqref'], ['f32', 'i32', '(ref $a)'], /type mismatch/);
// Simple runtime test of casting
{
let { makeA, makeB, isA, isB } = wasmEvalText(`(module
(type $a (struct))
(type $b (struct (field i32)))
(func (export "makeA") (result eqref)
rtt.canon $a
struct.new_default_with_rtt $a
)
(func (export "makeB") (result eqref)
rtt.canon $b
struct.new_default_with_rtt $b
)
(func (export "isA") (param eqref) (result i32)
(block (result (ref $a))
local.get 0
rtt.canon $a
br_on_cast 0
i32.const 0
br 1
)
drop
i32.const 1
)
(func (export "isB") (param eqref) (result i32)
(block (result (ref $a))
local.get 0
rtt.canon $b
br_on_cast 0
i32.const 0
br 1
)
drop
i32.const 1
)
)`).exports;
let a = makeA();
let b = makeB();
assertEq(isA(a), 1);
assertEq(isA(b), 0);
assertEq(isB(a), 0);
assertEq(isB(b), 1);
}
// Runtime test of casting with extra values
{
function assertEqResults(a, b) {
if (!(a instanceof Array)) {
a = [a];
}
if (!(b instanceof Array)) {
b = [b];
}
if (a.length !== b.length) {
assertEq(a.length, b.length);
}
for (let i = 0; i < a.length; i++) {
let x = a[i];
let y = b[i];
// intentionally use loose equality to allow bigint to compare equally
// to number, as can happen with how we use the JS-API here.
assertEq(x == y, true);
}
}
function testExtra(values) {
let { makeT, makeF, select } = wasmEvalText(`(module
(type $t (struct))
(type $f (struct (field i32)))
(func (export "makeT") (result eqref)
rtt.canon $t
struct.new_default_with_rtt $t
)
(func (export "makeF") (result eqref)
rtt.canon $f
struct.new_default_with_rtt $f
)
(func (export "select") (param eqref) (result ${values.map((type) => type).join(" ")})
(block (result (ref $t))
local.get 0
rtt.canon $t
br_on_cast 0
${values.map((type, i) => `${type}.const ${values.length + i}`).join("\n")}
br 1
)
drop
${values.map((type, i) => `${type}.const ${i}`).join("\n")}
)
)`).exports;
let t = makeT();
let f = makeF();
let trueValues = values.map((type, i) => i);
let falseValues = values.map((type, i) => values.length + i);
assertEqResults(select(t), trueValues);
assertEqResults(select(f), falseValues);
}
// multiples of primitive valtypes
for (let valtype of ['i32', 'i64', 'f32', 'f64']) {
testExtra([valtype]);
testExtra([valtype, valtype]);
testExtra([valtype, valtype, valtype]);
testExtra([valtype, valtype, valtype, valtype, valtype, valtype, valtype, valtype]);
}
// random sundry of valtypes
testExtra(['i32', 'f32', 'i64', 'f64']);
testExtra(['i32', 'f32', 'i64', 'f64', 'i32', 'f32', 'i64', 'f64']);
}

View File

@ -14119,6 +14119,8 @@ bool BaseCompiler::emitRttSub() {
return true; return true;
} }
// rttSub builtin has same signature as rtt.sub instruction, stack is
// guaranteed to be in the right condition due to validation.
if (!emitInstanceCall(lineOrBytecode, SASigRttSub)) { if (!emitInstanceCall(lineOrBytecode, SASigRttSub)) {
return false; return false;
} }
@ -14138,6 +14140,9 @@ bool BaseCompiler::emitRefTest() {
if (deadCode_) { if (deadCode_) {
return true; return true;
} }
// refTest builtin has same signature as ref.test instruction, stack is
// guaranteed to be in the right condition due to validation.
return emitInstanceCall(lineOrBytecode, SASigRefTest); return emitInstanceCall(lineOrBytecode, SASigRefTest);
} }
@ -14157,16 +14162,20 @@ bool BaseCompiler::emitRefCast() {
RegRef rttPtr = popRef(); RegRef rttPtr = popRef();
RegRef refPtr = popRef(); RegRef refPtr = popRef();
// 1. duplicate and shuffle from [ref, rtt] to [ref, ref, rtt]
RegRef castedPtr = needRef(); RegRef castedPtr = needRef();
moveRef(refPtr, castedPtr); moveRef(refPtr, castedPtr);
pushRef(castedPtr); pushRef(castedPtr);
pushRef(refPtr); pushRef(refPtr);
pushRef(rttPtr); pushRef(rttPtr);
// 2. ref.test : [ref, rtt] -> [i32]
if (!emitInstanceCall(lineOrBytecode, SASigRefTest)) { if (!emitInstanceCall(lineOrBytecode, SASigRefTest)) {
return false; return false;
} }
// 3. trap if result is zero, leaving [ref] as result
RegI32 result = popI32(); RegI32 result = popI32();
Label nonZero; Label nonZero;
masm.branchTest32(Assembler::NonZero, result, result, &nonZero); masm.branchTest32(Assembler::NonZero, result, result, &nonZero);
@ -14186,9 +14195,9 @@ bool BaseCompiler::emitBrOnCast() {
NothingVector unused_values; NothingVector unused_values;
uint32_t rttTypeIndex; uint32_t rttTypeIndex;
uint32_t rttDepth; uint32_t rttDepth;
ResultType type; ResultType branchTargetType;
if (!iter_.readBrOnCast(&relativeDepth, &unused, &rttTypeIndex, &rttDepth, if (!iter_.readBrOnCast(&relativeDepth, &unused, &rttTypeIndex, &rttDepth,
&unused_values, &type)) { &branchTargetType, &unused_values)) {
return false; return false;
} }
@ -14201,17 +14210,22 @@ bool BaseCompiler::emitBrOnCast() {
RegRef rttPtr = popRef(); RegRef rttPtr = popRef();
RegRef refPtr = popRef(); RegRef refPtr = popRef();
// 1. duplicate and shuffle from [T*, ref, rtt] to [T*, ref, ref, rtt]
RegRef castedPtr = needRef(); RegRef castedPtr = needRef();
moveRef(refPtr, castedPtr); moveRef(refPtr, castedPtr);
pushRef(castedPtr); pushRef(castedPtr);
pushRef(refPtr); pushRef(refPtr);
pushRef(rttPtr); pushRef(rttPtr);
// 2. ref.test : [ref, rtt] -> [i32]
if (!emitInstanceCall(lineOrBytecode, SASigRefTest)) { if (!emitInstanceCall(lineOrBytecode, SASigRefTest)) {
return false; return false;
} }
BranchState b(&target.label, target.stackHeight, InvertBranch(false), type); // 3. br_if $l : [T*, ref, i32] -> [T*, ref]
BranchState b(&target.label, target.stackHeight, InvertBranch(false),
branchTargetType);
if (b.hasBlockResults()) { if (b.hasBlockResults()) {
needResultRegisters(b.resultType); needResultRegisters(b.resultType);
} }

View File

@ -362,8 +362,11 @@ class MOZ_STACK_CLASS OpIter : private Policy {
[[nodiscard]] bool getControl(uint32_t relativeDepth, Control** controlEntry); [[nodiscard]] bool getControl(uint32_t relativeDepth, Control** controlEntry);
[[nodiscard]] bool checkBranchValue(uint32_t relativeDepth, ResultType* type, [[nodiscard]] bool checkBranchValue(uint32_t relativeDepth, ResultType* type,
ValueVector* values); ValueVector* values);
[[nodiscard]] bool checkBranchType(uint32_t relativeDepth, [[nodiscard]] bool checkCastedBranchValue(uint32_t relativeDepth,
ResultType expectedType); ValType castedFromType,
ValType castedToType,
ResultType* branchTargetType,
ValueVector* values);
[[nodiscard]] bool checkBrTableEntry(uint32_t* relativeDepth, [[nodiscard]] bool checkBrTableEntry(uint32_t* relativeDepth,
ResultType prevBranchType, ResultType prevBranchType,
ResultType* branchType, ResultType* branchType,
@ -577,7 +580,8 @@ class MOZ_STACK_CLASS OpIter : private Policy {
uint32_t* rttDepth, Value* ref); uint32_t* rttDepth, Value* ref);
[[nodiscard]] bool readBrOnCast(uint32_t* relativeDepth, Value* rtt, [[nodiscard]] bool readBrOnCast(uint32_t* relativeDepth, Value* rtt,
uint32_t* rttTypeIndex, uint32_t* rttDepth, uint32_t* rttTypeIndex, uint32_t* rttDepth,
ValueVector* values, ResultType* types); ResultType* branchTargetType,
ValueVector* values);
[[nodiscard]] bool readValType(ValType* type); [[nodiscard]] bool readValType(ValType* type);
[[nodiscard]] bool readHeapType(bool nullable, RefType* type); [[nodiscard]] bool readHeapType(bool nullable, RefType* type);
[[nodiscard]] bool readReferenceType(ValType* type, [[nodiscard]] bool readReferenceType(ValType* type,
@ -1261,14 +1265,63 @@ inline bool OpIter<Policy>::checkBranchValue(uint32_t relativeDepth,
return topWithType(*type, values); return topWithType(*type, values);
} }
// Check the typing of a branch instruction which casts an input type to
// an output type, branching on success to a target which takes the output
// type along with extra values from the stack. On casting failure, the
// original input type and extra values are left on the stack.
template <typename Policy> template <typename Policy>
inline bool OpIter<Policy>::checkBranchType(uint32_t relativeDepth, inline bool OpIter<Policy>::checkCastedBranchValue(uint32_t relativeDepth,
ResultType expectedType) { ValType castedFromType,
ValType castedToType,
ResultType* branchTargetType,
ValueVector* values) {
// Get the branch target type, which will determine the type of extra values
// that are passed along with the casted type.
Control* block = nullptr; Control* block = nullptr;
if (!getControl(relativeDepth, &block)) { if (!getControl(relativeDepth, &block)) {
return false; return false;
} }
return checkIsSubtypeOf(expectedType, block->branchTargetType()); *branchTargetType = block->branchTargetType();
// Check we at least have one type in the branch target type, which will take
// the casted type.
if (branchTargetType->length() < 1) {
UniqueChars expectedText = ToString(castedToType);
if (!expectedText) {
return false;
}
UniqueChars error(JS_smprintf("type mismatch: expected [_, %s], got []",
expectedText.get()));
if (!error) {
return false;
}
return fail(error.get());
}
// The top of the stack is the type that is being cast. This is the last type
// in the branch target type. This is guaranteed to exist by the above check.
const size_t castTypeIndex = branchTargetType->length() - 1;
// Check that the branch target type can accept the castedToType. The branch
// target may specify a super type of the castedToType, and this is okay.
if (!checkIsSubtypeOf(castedToType, (*branchTargetType)[castTypeIndex])) {
return false;
}
// Create a copy of the branch target type, with the castTypeIndex replaced
// with the castedFromType. Use this to check that the stack has the proper
// types to branch to the target type.
//
// TODO: We could avoid a potential allocation here by handwriting a custom
// topWithType that handles this case.
ValTypeVector stackTargetType;
if (!branchTargetType->cloneToVector(&stackTargetType)) {
return false;
}
stackTargetType[castTypeIndex] = castedFromType;
return topWithType(ResultType::Vector(stackTargetType), values);
} }
template <typename Policy> template <typename Policy>
@ -2967,8 +3020,8 @@ template <typename Policy>
inline bool OpIter<Policy>::readBrOnCast(uint32_t* relativeDepth, Value* rtt, inline bool OpIter<Policy>::readBrOnCast(uint32_t* relativeDepth, Value* rtt,
uint32_t* rttTypeIndex, uint32_t* rttTypeIndex,
uint32_t* rttDepth, uint32_t* rttDepth,
ValueVector* values, ResultType* branchTargetType,
ResultType* types) { ValueVector* values) {
MOZ_ASSERT(Classify(op_) == OpKind::BrOnCast); MOZ_ASSERT(Classify(op_) == OpKind::BrOnCast);
if (!readVarU32(relativeDepth)) { if (!readVarU32(relativeDepth)) {
@ -2979,13 +3032,15 @@ inline bool OpIter<Policy>::readBrOnCast(uint32_t* relativeDepth, Value* rtt,
return false; return false;
} }
*types = // The casted from type is any subtype of eqref
ResultType::Single(ValType(RefType::fromTypeIndex(*rttTypeIndex, false))); ValType castedFromType(RefType::eq());
if (!checkBranchType(*relativeDepth, *types)) {
return false;
}
return topWithType(ResultType::Single(ValType(RefType::eq())), values); // The casted to type is a non-nullable reference to the type index specified
// by the input rtt on the stack
ValType castedToType(RefType::fromTypeIndex(*rttTypeIndex, false));
return checkCastedBranchValue(*relativeDepth, castedFromType, castedToType,
branchTargetType, values);
} }
#ifdef ENABLE_WASM_SIMD #ifdef ENABLE_WASM_SIMD

View File

@ -1899,6 +1899,22 @@ class ResultType {
} }
} }
[[nodiscard]] bool cloneToVector(ValTypeVector* out) {
MOZ_ASSERT(out->empty());
switch (kind()) {
case EmptyKind:
return true;
case SingleKind:
return out->append(singleValType());
#ifdef ENABLE_WASM_MULTI_VALUE
case VectorKind:
return out->appendAll(values());
#endif
default:
MOZ_CRASH("bad resulttype");
}
}
bool empty() const { return kind() == EmptyKind; } bool empty() const { return kind() == EmptyKind; }
size_t length() const { size_t length() const {

View File

@ -974,7 +974,7 @@ static bool DecodeFunctionBodyExprs(const ModuleEnvironment& env,
uint32_t unusedRttDepth; uint32_t unusedRttDepth;
CHECK(iter.readBrOnCast(&unusedRelativeDepth, &nothing, CHECK(iter.readBrOnCast(&unusedRelativeDepth, &nothing,
&unusedRttTypeIndex, &unusedRttDepth, &unusedRttTypeIndex, &unusedRttDepth,
&nothings, &unusedType)); &unusedType, &nothings));
} }
default: default:
return iter.unrecognizedOpcode(&op); return iter.unrecognizedOpcode(&op);