diff --git a/source/opt/set_spec_constant_default_value_pass.cpp b/source/opt/set_spec_constant_default_value_pass.cpp index 83712b36..4def2b09 100644 --- a/source/opt/set_spec_constant_default_value_pass.cpp +++ b/source/opt/set_spec_constant_default_value_pass.cpp @@ -85,6 +85,10 @@ std::vector ParseDefaultValueStr(const char* text, // with 0x1, which represents a 'true'. // If all words in the bit pattern are zero, returns a bit pattern with 0x0, // which represents a 'false'. +// For integer and floating point types narrower than 32 bits, the upper bits +// in the input bit pattern are ignored. Instead the upper bits are set +// according to SPIR-V literal requirements: sign extend a signed integer, and +// otherwise set the upper bits to zero. std::vector ParseDefaultValueBitPattern( const std::vector& input_bit_pattern, const analysis::Type* type) { @@ -98,16 +102,33 @@ std::vector ParseDefaultValueBitPattern( } return result; } else if (const auto* IT = type->AsInteger()) { - auto width = IT->width(); - if (width == 8 || width == 16) width = 32; - if (width == input_bit_pattern.size() * sizeof(uint32_t) * 8) { - return std::vector(input_bit_pattern); + const auto width = IT->width(); + assert(width > 0); + const auto adjusted_width = std::max(32u, width); + if (adjusted_width == input_bit_pattern.size() * sizeof(uint32_t) * 8) { + result = std::vector(input_bit_pattern); + if (width < 32) { + const uint32_t high_active_bit = (1u << width) >> 1; + if (IT->IsSigned() && (high_active_bit & result[0])) { + // Sign extend. This overwrites the sign bit again, but that's ok. + result[0] = result[0] | ~(high_active_bit - 1); + } else { + // Upper bits must be zero. + result[0] = result[0] & ((1u << width) - 1); + } + } + return result; } } else if (const auto* FT = type->AsFloat()) { - auto width = FT->width(); - if (width == 8 || width == 16) width = 32; - if (width == input_bit_pattern.size() * sizeof(uint32_t) * 8) { - return std::vector(input_bit_pattern); + const auto width = FT->width(); + const auto adjusted_width = std::max(32u, width); + if (adjusted_width == input_bit_pattern.size() * sizeof(uint32_t) * 8) { + result = std::vector(input_bit_pattern); + if (width < 32) { + // Upper bits must be zero. + result[0] = result[0] & ((1u << width) - 1); + } + return result; } } result.clear(); diff --git a/test/opt/set_spec_const_default_value_test.cpp b/test/opt/set_spec_const_default_value_test.cpp index 58acd431..f1dd50ee 100644 --- a/test/opt/set_spec_const_default_value_test.cpp +++ b/test/opt/set_spec_const_default_value_test.cpp @@ -935,7 +935,7 @@ INSTANTIATE_TEST_SUITE_P( "%2 = OpSpecConstantTrue %bool\n" "%3 = OpSpecConstantTrue %bool\n", }, - // 19. 16-bit int type. + // 19. 16-bit signed int type. { // code "OpDecorate %1 SpecId 100\n" @@ -947,17 +947,39 @@ INSTANTIATE_TEST_SUITE_P( "%3 = OpSpecConstant %short 11\n", // default values SpecIdToValueBitPatternMap{ - {100, {32768}}, {101, {0xffff}}, {102, {0xffffffd6}}}, - // expected + {100, {32767}}, {101, {0xffff}}, {102, {0xffffffd6}}}, + // expected. These are sign-extended "OpDecorate %1 SpecId 100\n" "OpDecorate %2 SpecId 101\n" "OpDecorate %3 SpecId 102\n" "%short = OpTypeInt 16 1\n" - "%1 = OpSpecConstant %short 32768\n" - "%2 = OpSpecConstant %short 65535\n" + "%1 = OpSpecConstant %short 32767\n" + "%2 = OpSpecConstant %short -1\n" "%3 = OpSpecConstant %short -42\n", }, - // 20. 8-bit int type. + // 20. 16-bit unsigned int type. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "%ushort = OpTypeInt 16 0\n" + "%1 = OpSpecConstant %ushort 10\n" + "%2 = OpSpecConstant %ushort 11\n" + "%3 = OpSpecConstant %ushort 11\n", + // default values + SpecIdToValueBitPatternMap{ + {100, {32767}}, {101, {0xffff}}, {102, {0xffffffd6}}}, + // expected. Upper bits are always zero. + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "%ushort = OpTypeInt 16 0\n" + "%1 = OpSpecConstant %ushort 32767\n" + "%2 = OpSpecConstant %ushort 65535\n" + "%3 = OpSpecConstant %ushort 65494\n", + }, + // 21. 8-bit signed int type. { // code "OpDecorate %1 SpecId 100\n" @@ -969,16 +991,42 @@ INSTANTIATE_TEST_SUITE_P( "%3 = OpSpecConstant %char 11\n", // default values SpecIdToValueBitPatternMap{ - {100, {128}}, {101, {129}}, {102, {0xffffffd6}}}, - // expected + {100, {127}}, {101, {128}}, {102, {0xd6}}}, + // expected. These are sign extended "OpDecorate %1 SpecId 100\n" "OpDecorate %2 SpecId 101\n" "OpDecorate %3 SpecId 102\n" "%char = OpTypeInt 8 1\n" - "%1 = OpSpecConstant %char 128\n" - "%2 = OpSpecConstant %char 129\n" + "%1 = OpSpecConstant %char 127\n" + "%2 = OpSpecConstant %char -128\n" "%3 = OpSpecConstant %char -42\n", }, + // 22. 8-bit unsigned int type. + { + // code + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "OpDecorate %4 SpecId 103\n" + "%uchar = OpTypeInt 8 0\n" + "%1 = OpSpecConstant %uchar 10\n" + "%2 = OpSpecConstant %uchar 11\n" + "%3 = OpSpecConstant %uchar 11\n" + "%4 = OpSpecConstant %uchar 11\n", + // default values + SpecIdToValueBitPatternMap{ + {100, {127}}, {101, {128}}, {102, {256}}, {103, {0xffffffd6}}}, + // expected. Upper bits are always zero. + "OpDecorate %1 SpecId 100\n" + "OpDecorate %2 SpecId 101\n" + "OpDecorate %3 SpecId 102\n" + "OpDecorate %4 SpecId 103\n" + "%uchar = OpTypeInt 8 0\n" + "%1 = OpSpecConstant %uchar 127\n" + "%2 = OpSpecConstant %uchar 128\n" + "%3 = OpSpecConstant %uchar 0\n" + "%4 = OpSpecConstant %uchar 214\n", + }, })); INSTANTIATE_TEST_SUITE_P(