spirv-opt: Avoid integer overflow during constant folding (#4511)

In SPIR-V, integers use 2s complement representation, so that signed
integer overflow and underflow is well defined. However, the constant
folder was causing overflow / underflow at the C++ level. This change
avoids such overflows by performing constant folding for IAdd, ISub and
IMul in the context of unsigned values, which works because signedness
is irrelevant according to the SPIR-V semantics for these instructions.

Fixes #4510.
This commit is contained in:
Alastair Donaldson 2021-09-14 22:09:05 +01:00 committed by GitHub
parent cb6c66917a
commit 36ff135341
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 273 additions and 20 deletions

View File

@ -523,7 +523,8 @@ uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr,
float fval = val.getAsFloat(); \ float fval = val.getAsFloat(); \
if (!IsValidResult(fval)) return 0; \ if (!IsValidResult(fval)) return 0; \
words = val.GetWords(); \ words = val.GetWords(); \
} static_assert(true, "require extra semicolon") } \
static_assert(true, "require extra semicolon")
switch (opcode) { switch (opcode) {
case SpvOpFMul: case SpvOpFMul:
FOLD_OP(*); FOLD_OP(*);
@ -558,24 +559,19 @@ uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr,
uint32_t width = type->AsInteger()->width(); uint32_t width = type->AsInteger()->width();
assert(width == 32 || width == 64); assert(width == 32 || width == 64);
std::vector<uint32_t> words; std::vector<uint32_t> words;
#define FOLD_OP(op) \ // Regardless of the sign of the constant, folding is performed on an unsigned
if (width == 64) { \ // interpretation of the constant data. This avoids signed integer overflow
if (type->IsSigned()) { \ // while folding, and works because sign is irrelevant for the IAdd, ISub and
int64_t val = input1->GetS64() op input2->GetS64(); \ // IMul instructions.
words = ExtractInts(static_cast<uint64_t>(val)); \ #define FOLD_OP(op) \
} else { \ if (width == 64) { \
uint64_t val = input1->GetU64() op input2->GetU64(); \ uint64_t val = input1->GetU64() op input2->GetU64(); \
words = ExtractInts(val); \ words = ExtractInts(val); \
} \ } else { \
} else { \ uint32_t val = input1->GetU32() op input2->GetU32(); \
if (type->IsSigned()) { \ words.push_back(val); \
int32_t val = input1->GetS32() op input2->GetS32(); \ } \
words.push_back(static_cast<uint32_t>(val)); \ static_assert(true, "require extra semicolon")
} else { \
uint32_t val = input1->GetU32() op input2->GetU32(); \
words.push_back(val); \
} \
} static_assert(true, "require extra semicalon")
switch (opcode) { switch (opcode) {
case SpvOpIMul: case SpvOpIMul:
FOLD_OP(*); FOLD_OP(*);

View File

@ -137,6 +137,7 @@ OpName %main "main"
%int = OpTypeInt 32 1 %int = OpTypeInt 32 1
%long = OpTypeInt 64 1 %long = OpTypeInt 64 1
%uint = OpTypeInt 32 0 %uint = OpTypeInt 32 0
%ulong = OpTypeInt 64 0
%v2int = OpTypeVector %int 2 %v2int = OpTypeVector %int 2
%v4int = OpTypeVector %int 4 %v4int = OpTypeVector %int 4
%v4float = OpTypeVector %float 4 %v4float = OpTypeVector %float 4
@ -154,6 +155,7 @@ OpName %main "main"
%_ptr_double = OpTypePointer Function %double %_ptr_double = OpTypePointer Function %double
%_ptr_half = OpTypePointer Function %half %_ptr_half = OpTypePointer Function %half
%_ptr_long = OpTypePointer Function %long %_ptr_long = OpTypePointer Function %long
%_ptr_ulong = OpTypePointer Function %ulong
%_ptr_v2int = OpTypePointer Function %v2int %_ptr_v2int = OpTypePointer Function %v2int
%_ptr_v4int = OpTypePointer Function %v4int %_ptr_v4int = OpTypePointer Function %v4int
%_ptr_v4float = OpTypePointer Function %v4float %_ptr_v4float = OpTypePointer Function %v4float
@ -171,12 +173,23 @@ OpName %main "main"
%int_2 = OpConstant %int 2 %int_2 = OpConstant %int 2
%int_3 = OpConstant %int 3 %int_3 = OpConstant %int 3
%int_4 = OpConstant %int 4 %int_4 = OpConstant %int 4
%int_10 = OpConstant %int 10
%int_1073741824 = OpConstant %int 1073741824
%int_n1 = OpConstant %int -1
%int_n24 = OpConstant %int -24 %int_n24 = OpConstant %int -24
%int_n858993459 = OpConstant %int -858993459
%int_min = OpConstant %int -2147483648 %int_min = OpConstant %int -2147483648
%int_max = OpConstant %int 2147483647 %int_max = OpConstant %int 2147483647
%long_0 = OpConstant %long 0 %long_0 = OpConstant %long 0
%long_1 = OpConstant %long 1
%long_2 = OpConstant %long 2 %long_2 = OpConstant %long 2
%long_3 = OpConstant %long 3 %long_3 = OpConstant %long 3
%long_10 = OpConstant %long 10
%long_4611686018427387904 = OpConstant %long 4611686018427387904
%long_n1 = OpConstant %long -1
%long_n3689348814741910323 = OpConstant %long -3689348814741910323
%long_min = OpConstant %long -9223372036854775808
%long_max = OpConstant %long 9223372036854775807
%uint_0 = OpConstant %uint 0 %uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1 %uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2 %uint_2 = OpConstant %uint 2
@ -184,7 +197,13 @@ OpName %main "main"
%uint_4 = OpConstant %uint 4 %uint_4 = OpConstant %uint 4
%uint_32 = OpConstant %uint 32 %uint_32 = OpConstant %uint 32
%uint_42 = OpConstant %uint 42 %uint_42 = OpConstant %uint 42
%uint_2147483649 = OpConstant %uint 2147483649
%uint_max = OpConstant %uint 4294967295 %uint_max = OpConstant %uint 4294967295
%ulong_0 = OpConstant %ulong 0
%ulong_1 = OpConstant %ulong 1
%ulong_2 = OpConstant %ulong 2
%ulong_9223372036854775809 = OpConstant %ulong 9223372036854775809
%ulong_max = OpConstant %ulong 18446744073709551615
%v2int_undef = OpUndef %v2int %v2int_undef = OpUndef %v2int
%v2int_0_0 = OpConstantComposite %v2int %int_0 %int_0 %v2int_0_0 = OpConstantComposite %v2int %int_0 %int_0
%v2int_1_0 = OpConstantComposite %v2int %int_1 %int_0 %v2int_1_0 = OpConstantComposite %v2int %int_1 %int_0
@ -5572,7 +5591,109 @@ INSTANTIATE_TEST_SUITE_P(MergeMulTest, MatchingInstructionFoldingTest,
"%5 = OpFMul %float %4 %2\n" + "%5 = OpFMul %float %4 %2\n" +
"OpReturn\n" + "OpReturn\n" +
"OpFunctionEnd\n", "OpFunctionEnd\n",
5, true) 5, true),
// Test case 25: fold overflowing signed 32 bit imuls
// (x * 1073741824) * 2 = x * int_min
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32\n" +
"; CHECK: [[int_min:%\\w+]] = OpConstant [[int]] -2147483648\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_min]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIMul %int %2 %int_1073741824\n" +
"%4 = OpIMul %int %3 %int_2\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true),
// Test case 26: fold overflowing signed 64 bit imuls
// (x * 4611686018427387904) * 2 = x * long_min
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" +
"; CHECK: [[long_min:%\\w+]] = OpConstant [[long]] -9223372036854775808\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
"; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_min]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_long Function\n" +
"%2 = OpLoad %long %var\n" +
"%3 = OpIMul %long %2 %long_4611686018427387904\n" +
"%4 = OpIMul %long %3 %long_2\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true),
// Test case 27: fold overflowing 32 bit unsigned imuls
// (x * 2147483649) * 2 = x * 2
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_2:%\\w+]] = OpConstant [[uint]] 2\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[uint]]\n" +
"; CHECK: %4 = OpIMul [[uint]] [[ld]] [[uint_2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_uint Function\n" +
"%2 = OpLoad %uint %var\n" +
"%3 = OpIMul %uint %2 %uint_2147483649\n" +
"%4 = OpIMul %uint %3 %uint_2\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true),
// Test case 28: fold overflowing 64 bit unsigned imuls
// (x * 9223372036854775809) * 2 = x * 2
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ulong:%\\w+]] = OpTypeInt 64 0\n" +
"; CHECK: [[ulong_2:%\\w+]] = OpConstant [[ulong]] 2\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[ulong]]\n" +
"; CHECK: %4 = OpIMul [[ulong]] [[ld]] [[ulong_2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_ulong Function\n" +
"%2 = OpLoad %ulong %var\n" +
"%3 = OpIMul %ulong %2 %ulong_9223372036854775809\n" +
"%4 = OpIMul %ulong %3 %ulong_2\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true),
// Test case 29: fold underflowing signed 32 bit imuls
// (x * (-858993459)) * 10 = x * 2
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32\n" +
"; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIMul %int %2 %int_n858993459\n" +
"%4 = OpIMul %int %3 %int_10\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true),
// Test case 30: fold underflowing signed 64 bit imuls
// (x * (-3689348814741910323)) * 10 = x * 2
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" +
"; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
"; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_long Function\n" +
"%2 = OpLoad %long %var\n" +
"%3 = OpIMul %long %2 %long_n3689348814741910323\n" +
"%4 = OpIMul %long %3 %long_10\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true)
)); ));
INSTANTIATE_TEST_SUITE_P(MergeDivTest, MatchingInstructionFoldingTest, INSTANTIATE_TEST_SUITE_P(MergeDivTest, MatchingInstructionFoldingTest,
@ -6052,6 +6173,108 @@ INSTANTIATE_TEST_SUITE_P(MergeAddTest, MatchingInstructionFoldingTest,
"%4 = OpFAdd %float %float_2 %3\n" + "%4 = OpFAdd %float %float_2 %3\n" +
"OpReturn\n" + "OpReturn\n" +
"OpFunctionEnd\n", "OpFunctionEnd\n",
4, true),
// Test case 12: fold overflowing signed 32 bit iadds
// (x + int_max) + 1 = x + int_min
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32\n" +
"; CHECK: [[int_min:%\\w+]] = OpConstant [[int]] -2147483648\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpIAdd [[int]] [[ld]] [[int_min]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIAdd %int %2 %int_max\n" +
"%4 = OpIAdd %int %3 %int_1\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true),
// Test case 13: fold overflowing signed 64 bit iadds
// (x + long_max) + 1 = x + long_min
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" +
"; CHECK: [[long_min:%\\w+]] = OpConstant [[long]] -9223372036854775808\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
"; CHECK: %4 = OpIAdd [[long]] [[ld]] [[long_min]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_long Function\n" +
"%2 = OpLoad %long %var\n" +
"%3 = OpIAdd %long %2 %long_max\n" +
"%4 = OpIAdd %long %3 %long_1\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true),
// Test case 14: fold overflowing 32 bit unsigned iadds
// (x + uint_max) + 2 = x + 1
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_1:%\\w+]] = OpConstant [[uint]] 1\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[uint]]\n" +
"; CHECK: %4 = OpIAdd [[uint]] [[ld]] [[uint_1]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_uint Function\n" +
"%2 = OpLoad %uint %var\n" +
"%3 = OpIAdd %uint %2 %uint_max\n" +
"%4 = OpIAdd %uint %3 %uint_2\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true),
// Test case 15: fold overflowing 64 bit unsigned iadds
// (x + ulong_max) + 2 = x + 1
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ulong:%\\w+]] = OpTypeInt 64 0\n" +
"; CHECK: [[ulong_1:%\\w+]] = OpConstant [[ulong]] 1\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[ulong]]\n" +
"; CHECK: %4 = OpIAdd [[ulong]] [[ld]] [[ulong_1]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_ulong Function\n" +
"%2 = OpLoad %ulong %var\n" +
"%3 = OpIAdd %ulong %2 %ulong_max\n" +
"%4 = OpIAdd %ulong %3 %ulong_2\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true),
// Test case 16: fold underflowing signed 32 bit iadds
// (x + int_min) + (-1) = x + int_max
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32\n" +
"; CHECK: [[int_max:%\\w+]] = OpConstant [[int]] 2147483647\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpIAdd [[int]] [[ld]] [[int_max]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIAdd %int %2 %int_min\n" +
"%4 = OpIAdd %int %3 %int_n1\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true),
// Test case 17: fold underflowing signed 64 bit iadds
// (x + long_min) + (-1) = x + long_max
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" +
"; CHECK: [[long_max:%\\w+]] = OpConstant [[long]] 9223372036854775807\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
"; CHECK: %4 = OpIAdd [[long]] [[ld]] [[long_max]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_long Function\n" +
"%2 = OpLoad %long %var\n" +
"%3 = OpIAdd %long %2 %long_min\n" +
"%4 = OpIAdd %long %3 %long_n1\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true) 4, true)
)); ));
@ -6420,6 +6643,40 @@ INSTANTIATE_TEST_SUITE_P(MergeSubTest, MatchingInstructionFoldingTest,
"%4 = OpISub %int %int_2 %3\n" + "%4 = OpISub %int %int_2 %3\n" +
"OpReturn\n" + "OpReturn\n" +
"OpFunctionEnd\n", "OpFunctionEnd\n",
4, true),
// Test case 14: fold overflowing signed 32 bit isubs
// (x - int_max) - 1 = x - int_min
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32\n" +
"; CHECK: [[int_min:%\\w+]] = OpConstant [[int]] -2147483648\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpISub [[int]] [[ld]] [[int_min]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpISub %int %2 %int_max\n" +
"%4 = OpISub %int %3 %int_1\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true),
// Test case 15: fold overflowing signed 64 bit isubs
// (x - long_max) - 1 = x - long_min
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" +
"; CHECK: [[long_min:%\\w+]] = OpConstant [[long]] -9223372036854775808\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
"; CHECK: %4 = OpISub [[long]] [[ld]] [[long_min]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_long Function\n" +
"%2 = OpLoad %long %var\n" +
"%3 = OpISub %long %2 %long_max\n" +
"%4 = OpISub %long %3 %long_1\n" +
"OpReturn\n" +
"OpFunctionEnd\n",
4, true) 4, true)
)); ));