Handle 64-bit integers in local access chain convert (#4798)

* Handle 64-bit integers in local access chain convert

The local access chain convert pass does on run on module that have
64-bit integers, even if they have nothing to to with access chains.
This is very limiting because other passes rely on the access chains
being removed. So this commit will add this functionality to the pass.
This commit is contained in:
Steven Perron 2022-05-10 13:02:14 -04:00 committed by GitHub
parent f7a6e3b9d5
commit f74b85853c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 115 additions and 12 deletions

View File

@ -28,8 +28,6 @@ namespace {
const uint32_t kStoreValIdInIdx = 1;
const uint32_t kAccessChainPtrIdInIdx = 0;
const uint32_t kConstantValueInIdx = 0;
const uint32_t kTypeIntWidthInIdx = 0;
} // anonymous namespace
@ -67,7 +65,19 @@ void LocalAccessChainConvertPass::AppendConstantOperands(
ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
if (iidIdx > 0) {
const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx);
const auto* constant_value =
context()->get_constant_mgr()->GetConstantFromInst(cInst);
assert(constant_value != nullptr &&
"Expecting the index to be a constant.");
// We take the sign extended value because OpAccessChain interprets the
// index as signed.
int64_t long_value = constant_value->GetSignExtendedValue();
assert(long_value <= UINT32_MAX && long_value >= 0 &&
"The index value is too large for a composite insert or extract "
"instruction.");
uint32_t val = static_cast<uint32_t>(long_value);
in_opnds->push_back(
{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
}
@ -169,13 +179,16 @@ bool LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
return true;
}
bool LocalAccessChainConvertPass::IsConstantIndexAccessChain(
bool LocalAccessChainConvertPass::Is32BitConstantIndexAccessChain(
const Instruction* acp) const {
uint32_t inIdx = 0;
return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
if (inIdx > 0) {
Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
if (opInst->opcode() != SpvOpConstant) return false;
const auto* index =
context()->get_constant_mgr()->GetConstantFromInst(opInst);
if (index->GetSignExtendedValue() > UINT32_MAX) return false;
}
++inIdx;
return true;
@ -231,7 +244,7 @@ void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
break;
}
// Rule out variables accessed with non-constant indices
if (!IsConstantIndexAccessChain(ptrInst)) {
if (!Is32BitConstantIndexAccessChain(ptrInst)) {
seen_non_target_vars_.insert(varId);
seen_target_vars_.erase(varId);
break;
@ -349,12 +362,6 @@ bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
}
Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
// If non-32-bit integer type in module, terminate processing
// TODO(): Handle non-32-bit integer constants in access chains
for (const Instruction& inst : get_module()->types_values())
if (inst.opcode() == SpvOpTypeInt &&
inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
return Status::SuccessWithoutChange;
// Do not process if module contains OpGroupDecorate. Additional
// support required in KillNamesAndDecorates().
// TODO(greg-lunarg): Add support for OpGroupDecorate

View File

@ -95,7 +95,8 @@ class LocalAccessChainConvertPass : public MemPass {
Instruction* original_load);
// Return true if all indices of access chain |acp| are OpConstant integers
bool IsConstantIndexAccessChain(const Instruction* acp) const;
// whose values can fit into an unsigned 32-bit value.
bool Is32BitConstantIndexAccessChain(const Instruction* acp) const;
// Identify all function scope variables of target type which are
// accessed only with loads, stores and access chains with constant

View File

@ -1156,6 +1156,101 @@ TEST_F(LocalAccessChainConvertTest, AccessChainWithNoIndex) {
SinglePassRunAndMatch<LocalAccessChainConvertPass>(before, true);
}
TEST_F(LocalAccessChainConvertTest, AccessChainWithLongIndex) {
// The access chain take a value that is larger than 32-bit. The index cannot
// be encoded in an OpCompositeExtract, so nothing should be done.
const std::string before =
R"(OpCapability Shader
OpCapability Int64
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %2 "main_0004f4d4_85b2f584"
OpExecutionMode %2 OriginUpperLeft
%ulong = OpTypeInt 64 0
%ulong_8589934592 = OpConstant %ulong 8589934592
%ulong_8589934591 = OpConstant %ulong 8589934591
%_arr_ulong_ulong_8589934592 = OpTypeArray %ulong %ulong_8589934592
%_ptr_Function__arr_ulong_ulong_8589934592 = OpTypePointer Function %_arr_ulong_ulong_8589934592
%_ptr_Function_ulong = OpTypePointer Function %ulong
%void = OpTypeVoid
%10 = OpTypeFunction %void
%2 = OpFunction %void None %10
%11 = OpLabel
%12 = OpVariable %_ptr_Function__arr_ulong_ulong_8589934592 Function
%13 = OpAccessChain %_ptr_Function_ulong %12 %ulong_8589934591
%14 = OpLoad %ulong %13
OpReturn
OpFunctionEnd
)";
SinglePassRunAndCheck<LocalAccessChainConvertPass>(before, before, false,
true);
}
TEST_F(LocalAccessChainConvertTest, AccessChainWith32BitIndexInLong) {
// The access chain has a value that is 32-bits, but it is stored in a 64-bit
// variable. This access change can be converted to an extract.
const std::string before =
R"(
; CHECK: OpFunction
; CHECK: [[var:%\w+]] = OpVariable
; CHECK: [[ld:%\w+]] = OpLoad {{%\w+}} [[var]]
; CHECK: OpCompositeExtract %ulong [[ld]] 3
OpCapability Shader
OpCapability Int64
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %2 "main_0004f4d4_85b2f584"
OpExecutionMode %2 OriginUpperLeft
%ulong = OpTypeInt 64 0
%ulong_8589934592 = OpConstant %ulong 8589934592
%ulong_3 = OpConstant %ulong 3
%_arr_ulong_ulong_8589934592 = OpTypeArray %ulong %ulong_8589934592
%_ptr_Function__arr_ulong_ulong_8589934592 = OpTypePointer Function %_arr_ulong_ulong_8589934592
%_ptr_Function_ulong = OpTypePointer Function %ulong
%void = OpTypeVoid
%10 = OpTypeFunction %void
%2 = OpFunction %void None %10
%11 = OpLabel
%12 = OpVariable %_ptr_Function__arr_ulong_ulong_8589934592 Function
%13 = OpAccessChain %_ptr_Function_ulong %12 %ulong_3
%14 = OpLoad %ulong %13
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<LocalAccessChainConvertPass>(before, true);
}
TEST_F(LocalAccessChainConvertTest, AccessChainWithVarIndex) {
// The access chain has a value that is not constant, so there should not be
// any changes.
const std::string before =
R"(OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %2 "main_0004f4d4_85b2f584"
OpExecutionMode %2 OriginUpperLeft
%uint = OpTypeInt 32 0
%uint_5 = OpConstant %uint 5
%_arr_uint_uint_5 = OpTypeArray %uint %uint_5
%_ptr_Function__arr_uint_uint_5 = OpTypePointer Function %_arr_uint_uint_5
%_ptr_Function_uint = OpTypePointer Function %uint
%8 = OpUndef %uint
%void = OpTypeVoid
%10 = OpTypeFunction %void
%2 = OpFunction %void None %10
%11 = OpLabel
%12 = OpVariable %_ptr_Function__arr_uint_uint_5 Function
%13 = OpAccessChain %_ptr_Function_uint %12 %8
%14 = OpLoad %uint %13
OpReturn
OpFunctionEnd
)";
SinglePassRunAndCheck<LocalAccessChainConvertPass>(before, before, false,
true);
}
// TODO(greg-lunarg): Add tests to verify handling of these cases:
//