From 36ff135341989364e8ea8fcc2b6d8fec9dac0f8f Mon Sep 17 00:00:00 2001 From: Alastair Donaldson Date: Tue, 14 Sep 2021 22:09:05 +0100 Subject: [PATCH] spirv-opt: Avoid integer overflow during constant folding (#4511) In SPIR-V, integers use 2s complement representation, so that signed integer overflow and underflow is well defined. However, the constant folder was causing overflow / underflow at the C++ level. This change avoids such overflows by performing constant folding for IAdd, ISub and IMul in the context of unsigned values, which works because signedness is irrelevant according to the SPIR-V semantics for these instructions. Fixes #4510. --- source/opt/folding_rules.cpp | 34 ++--- test/opt/fold_test.cpp | 259 ++++++++++++++++++++++++++++++++++- 2 files changed, 273 insertions(+), 20 deletions(-) diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 6ae078fb..20051a6b 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -523,7 +523,8 @@ uint32_t PerformFloatingPointOperation(analysis::ConstantManager* const_mgr, float fval = val.getAsFloat(); \ if (!IsValidResult(fval)) return 0; \ words = val.GetWords(); \ - } static_assert(true, "require extra semicolon") + } \ + static_assert(true, "require extra semicolon") switch (opcode) { case SpvOpFMul: FOLD_OP(*); @@ -558,24 +559,19 @@ uint32_t PerformIntegerOperation(analysis::ConstantManager* const_mgr, uint32_t width = type->AsInteger()->width(); assert(width == 32 || width == 64); std::vector words; -#define FOLD_OP(op) \ - if (width == 64) { \ - if (type->IsSigned()) { \ - int64_t val = input1->GetS64() op input2->GetS64(); \ - words = ExtractInts(static_cast(val)); \ - } else { \ - uint64_t val = input1->GetU64() op input2->GetU64(); \ - words = ExtractInts(val); \ - } \ - } else { \ - if (type->IsSigned()) { \ - int32_t val = input1->GetS32() op input2->GetS32(); \ - words.push_back(static_cast(val)); \ - } else { \ - uint32_t val = input1->GetU32() op input2->GetU32(); \ - words.push_back(val); \ - } \ - } static_assert(true, "require extra semicalon") + // Regardless of the sign of the constant, folding is performed on an unsigned + // interpretation of the constant data. This avoids signed integer overflow + // while folding, and works because sign is irrelevant for the IAdd, ISub and + // IMul instructions. +#define FOLD_OP(op) \ + if (width == 64) { \ + uint64_t val = input1->GetU64() op input2->GetU64(); \ + words = ExtractInts(val); \ + } else { \ + uint32_t val = input1->GetU32() op input2->GetU32(); \ + words.push_back(val); \ + } \ + static_assert(true, "require extra semicolon") switch (opcode) { case SpvOpIMul: FOLD_OP(*); diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index da5b017d..292a869e 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -137,6 +137,7 @@ OpName %main "main" %int = OpTypeInt 32 1 %long = OpTypeInt 64 1 %uint = OpTypeInt 32 0 +%ulong = OpTypeInt 64 0 %v2int = OpTypeVector %int 2 %v4int = OpTypeVector %int 4 %v4float = OpTypeVector %float 4 @@ -154,6 +155,7 @@ OpName %main "main" %_ptr_double = OpTypePointer Function %double %_ptr_half = OpTypePointer Function %half %_ptr_long = OpTypePointer Function %long +%_ptr_ulong = OpTypePointer Function %ulong %_ptr_v2int = OpTypePointer Function %v2int %_ptr_v4int = OpTypePointer Function %v4int %_ptr_v4float = OpTypePointer Function %v4float @@ -171,12 +173,23 @@ OpName %main "main" %int_2 = OpConstant %int 2 %int_3 = OpConstant %int 3 %int_4 = OpConstant %int 4 +%int_10 = OpConstant %int 10 +%int_1073741824 = OpConstant %int 1073741824 +%int_n1 = OpConstant %int -1 %int_n24 = OpConstant %int -24 +%int_n858993459 = OpConstant %int -858993459 %int_min = OpConstant %int -2147483648 %int_max = OpConstant %int 2147483647 %long_0 = OpConstant %long 0 +%long_1 = OpConstant %long 1 %long_2 = OpConstant %long 2 %long_3 = OpConstant %long 3 +%long_10 = OpConstant %long 10 +%long_4611686018427387904 = OpConstant %long 4611686018427387904 +%long_n1 = OpConstant %long -1 +%long_n3689348814741910323 = OpConstant %long -3689348814741910323 +%long_min = OpConstant %long -9223372036854775808 +%long_max = OpConstant %long 9223372036854775807 %uint_0 = OpConstant %uint 0 %uint_1 = OpConstant %uint 1 %uint_2 = OpConstant %uint 2 @@ -184,7 +197,13 @@ OpName %main "main" %uint_4 = OpConstant %uint 4 %uint_32 = OpConstant %uint 32 %uint_42 = OpConstant %uint 42 +%uint_2147483649 = OpConstant %uint 2147483649 %uint_max = OpConstant %uint 4294967295 +%ulong_0 = OpConstant %ulong 0 +%ulong_1 = OpConstant %ulong 1 +%ulong_2 = OpConstant %ulong 2 +%ulong_9223372036854775809 = OpConstant %ulong 9223372036854775809 +%ulong_max = OpConstant %ulong 18446744073709551615 %v2int_undef = OpUndef %v2int %v2int_0_0 = OpConstantComposite %v2int %int_0 %int_0 %v2int_1_0 = OpConstantComposite %v2int %int_1 %int_0 @@ -5572,7 +5591,109 @@ INSTANTIATE_TEST_SUITE_P(MergeMulTest, MatchingInstructionFoldingTest, "%5 = OpFMul %float %4 %2\n" + "OpReturn\n" + "OpFunctionEnd\n", - 5, true) + 5, true), + // Test case 25: fold overflowing signed 32 bit imuls + // (x * 1073741824) * 2 = x * int_min + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32\n" + + "; CHECK: [[int_min:%\\w+]] = OpConstant [[int]] -2147483648\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_min]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %2 %int_1073741824\n" + + "%4 = OpIMul %int %3 %int_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 26: fold overflowing signed 64 bit imuls + // (x * 4611686018427387904) * 2 = x * long_min + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" + + "; CHECK: [[long_min:%\\w+]] = OpConstant [[long]] -9223372036854775808\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_min]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpIMul %long %2 %long_4611686018427387904\n" + + "%4 = OpIMul %long %3 %long_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 27: fold overflowing 32 bit unsigned imuls + // (x * 2147483649) * 2 = x * 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_2:%\\w+]] = OpConstant [[uint]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[uint]]\n" + + "; CHECK: %4 = OpIMul [[uint]] [[ld]] [[uint_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_uint Function\n" + + "%2 = OpLoad %uint %var\n" + + "%3 = OpIMul %uint %2 %uint_2147483649\n" + + "%4 = OpIMul %uint %3 %uint_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 28: fold overflowing 64 bit unsigned imuls + // (x * 9223372036854775809) * 2 = x * 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[ulong:%\\w+]] = OpTypeInt 64 0\n" + + "; CHECK: [[ulong_2:%\\w+]] = OpConstant [[ulong]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[ulong]]\n" + + "; CHECK: %4 = OpIMul [[ulong]] [[ld]] [[ulong_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_ulong Function\n" + + "%2 = OpLoad %ulong %var\n" + + "%3 = OpIMul %ulong %2 %ulong_9223372036854775809\n" + + "%4 = OpIMul %ulong %3 %ulong_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 29: fold underflowing signed 32 bit imuls + // (x * (-858993459)) * 10 = x * 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32\n" + + "; CHECK: [[int_2:%\\w+]] = OpConstant [[int]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %2 %int_n858993459\n" + + "%4 = OpIMul %int %3 %int_10\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 30: fold underflowing signed 64 bit imuls + // (x * (-3689348814741910323)) * 10 = x * 2 + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" + + "; CHECK: [[long_2:%\\w+]] = OpConstant [[long]] 2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpIMul %long %2 %long_n3689348814741910323\n" + + "%4 = OpIMul %long %3 %long_10\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true) )); INSTANTIATE_TEST_SUITE_P(MergeDivTest, MatchingInstructionFoldingTest, @@ -6052,6 +6173,108 @@ INSTANTIATE_TEST_SUITE_P(MergeAddTest, MatchingInstructionFoldingTest, "%4 = OpFAdd %float %float_2 %3\n" + "OpReturn\n" + "OpFunctionEnd\n", + 4, true), + // Test case 12: fold overflowing signed 32 bit iadds + // (x + int_max) + 1 = x + int_min + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32\n" + + "; CHECK: [[int_min:%\\w+]] = OpConstant [[int]] -2147483648\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpIAdd [[int]] [[ld]] [[int_min]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIAdd %int %2 %int_max\n" + + "%4 = OpIAdd %int %3 %int_1\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 13: fold overflowing signed 64 bit iadds + // (x + long_max) + 1 = x + long_min + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" + + "; CHECK: [[long_min:%\\w+]] = OpConstant [[long]] -9223372036854775808\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpIAdd [[long]] [[ld]] [[long_min]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpIAdd %long %2 %long_max\n" + + "%4 = OpIAdd %long %3 %long_1\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 14: fold overflowing 32 bit unsigned iadds + // (x + uint_max) + 2 = x + 1 + InstructionFoldingCase( + Header() + + "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_1:%\\w+]] = OpConstant [[uint]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[uint]]\n" + + "; CHECK: %4 = OpIAdd [[uint]] [[ld]] [[uint_1]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_uint Function\n" + + "%2 = OpLoad %uint %var\n" + + "%3 = OpIAdd %uint %2 %uint_max\n" + + "%4 = OpIAdd %uint %3 %uint_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 15: fold overflowing 64 bit unsigned iadds + // (x + ulong_max) + 2 = x + 1 + InstructionFoldingCase( + Header() + + "; CHECK: [[ulong:%\\w+]] = OpTypeInt 64 0\n" + + "; CHECK: [[ulong_1:%\\w+]] = OpConstant [[ulong]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[ulong]]\n" + + "; CHECK: %4 = OpIAdd [[ulong]] [[ld]] [[ulong_1]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_ulong Function\n" + + "%2 = OpLoad %ulong %var\n" + + "%3 = OpIAdd %ulong %2 %ulong_max\n" + + "%4 = OpIAdd %ulong %3 %ulong_2\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 16: fold underflowing signed 32 bit iadds + // (x + int_min) + (-1) = x + int_max + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32\n" + + "; CHECK: [[int_max:%\\w+]] = OpConstant [[int]] 2147483647\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpIAdd [[int]] [[ld]] [[int_max]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIAdd %int %2 %int_min\n" + + "%4 = OpIAdd %int %3 %int_n1\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 17: fold underflowing signed 64 bit iadds + // (x + long_min) + (-1) = x + long_max + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" + + "; CHECK: [[long_max:%\\w+]] = OpConstant [[long]] 9223372036854775807\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpIAdd [[long]] [[ld]] [[long_max]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpIAdd %long %2 %long_min\n" + + "%4 = OpIAdd %long %3 %long_n1\n" + + "OpReturn\n" + + "OpFunctionEnd\n", 4, true) )); @@ -6420,6 +6643,40 @@ INSTANTIATE_TEST_SUITE_P(MergeSubTest, MatchingInstructionFoldingTest, "%4 = OpISub %int %int_2 %3\n" + "OpReturn\n" + "OpFunctionEnd\n", + 4, true), + // Test case 14: fold overflowing signed 32 bit isubs + // (x - int_max) - 1 = x - int_min + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32\n" + + "; CHECK: [[int_min:%\\w+]] = OpConstant [[int]] -2147483648\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpISub [[int]] [[ld]] [[int_min]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpISub %int %2 %int_max\n" + + "%4 = OpISub %int %3 %int_1\n" + + "OpReturn\n" + + "OpFunctionEnd\n", + 4, true), + // Test case 15: fold overflowing signed 64 bit isubs + // (x - long_max) - 1 = x - long_min + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64\n" + + "; CHECK: [[long_min:%\\w+]] = OpConstant [[long]] -9223372036854775808\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpISub [[long]] [[ld]] [[long_min]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpISub %long %2 %long_max\n" + + "%4 = OpISub %long %3 %long_1\n" + + "OpReturn\n" + + "OpFunctionEnd\n", 4, true) ));