Handle dontinline function in spread-volatile-semantics (#4776)

Handle function calls in spread-volatile-semantics
This commit is contained in:
JiaoluAMD 2022-05-04 22:52:58 +08:00 committed by GitHub
parent 58dc37ea6a
commit 2c7fb9707b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 183 additions and 58 deletions

View File

@ -926,6 +926,19 @@ bool IRContext::ProcessCallTreeFromRoots(ProcessFunction& pfn,
return modified;
}
void IRContext::CollectCallTreeFromRoots(unsigned entryId,
std::unordered_set<uint32_t>* funcs) {
std::queue<uint32_t> roots;
roots.push(entryId);
while (!roots.empty()) {
const uint32_t fi = roots.front();
roots.pop();
funcs->insert(fi);
Function* fn = GetFunction(fi);
AddCalls(fn, &roots);
}
}
void IRContext::EmitErrorMessage(std::string message, Instruction* inst) {
if (!consumer()) {
return;

View File

@ -411,6 +411,10 @@ class IRContext {
void CollectNonSemanticTree(Instruction* inst,
std::unordered_set<Instruction*>* to_kill);
// Collect function reachable from |entryId|, returns |funcs|
void CollectCallTreeFromRoots(unsigned entryId,
std::unordered_set<uint32_t>* funcs);
// Returns true if all of the given analyses are valid.
bool AreAnalysesValid(Analysis set) { return (set & valid_analyses_) == set; }

View File

@ -68,38 +68,12 @@ bool HasVolatileDecoration(analysis::DecorationManager* decoration_manager,
return decoration_manager->HasDecoration(var_id, SpvDecorationVolatile);
}
bool HasOnlyEntryPointsAsFunctions(IRContext* context, Module* module) {
std::unordered_set<uint32_t> entry_function_ids;
for (Instruction& entry_point : module->entry_points()) {
entry_function_ids.insert(
entry_point.GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint));
}
for (auto& function : *module) {
if (entry_function_ids.find(function.result_id()) ==
entry_function_ids.end()) {
std::string message(
"Functions of SPIR-V for spread-volatile-semantics pass input must "
"be inlined except entry points");
message += "\n " + function.DefInst().PrettyPrint(
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
context->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
return false;
}
}
return true;
}
} // namespace
Pass::Status SpreadVolatileSemantics::Process() {
if (HasNoExecutionModel()) {
return Status::SuccessWithoutChange;
}
if (!HasOnlyEntryPointsAsFunctions(context(), get_module())) {
return Status::Failure;
}
const bool is_vk_memory_model_enabled =
context()->get_feature_mgr()->HasCapability(
SpvCapabilityVulkanMemoryModel);
@ -142,6 +116,8 @@ bool SpreadVolatileSemantics::IsTargetUsedByNonVolatileLoadInEntryPoint(
uint32_t var_id, Instruction* entry_point) {
uint32_t entry_function_id =
entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
std::unordered_set<uint32_t> funcs;
context()->CollectCallTreeFromRoots(entry_function_id, &funcs);
return !VisitLoadsOfPointersToVariableInEntries(
var_id,
[](Instruction* load) {
@ -154,7 +130,7 @@ bool SpreadVolatileSemantics::IsTargetUsedByNonVolatileLoadInEntryPoint(
load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
return (memory_operands & SpvMemoryAccessVolatileMask) != 0;
},
{entry_function_id});
funcs);
}
bool SpreadVolatileSemantics::HasInterfaceInConflictOfVolatileSemantics() {
@ -225,7 +201,7 @@ void SpreadVolatileSemantics::DecorateVarWithVolatile(Instruction* var) {
bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries(
uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
const std::unordered_set<uint32_t>& entry_function_ids) {
const std::unordered_set<uint32_t>& function_ids) {
std::vector<uint32_t> worklist({var_id});
auto* def_use_mgr = context()->get_def_use_mgr();
while (!worklist.empty()) {
@ -233,11 +209,11 @@ bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries(
worklist.pop_back();
bool finish_traversal = !def_use_mgr->WhileEachUser(
ptr_id, [this, &worklist, &ptr_id, handle_load,
&entry_function_ids](Instruction* user) {
&function_ids](Instruction* user) {
BasicBlock* block = context()->get_instr_block(user);
if (block == nullptr ||
entry_function_ids.find(block->GetParent()->result_id()) ==
entry_function_ids.end()) {
function_ids.find(block->GetParent()->result_id()) ==
function_ids.end()) {
return true;
}
@ -266,21 +242,25 @@ void SpreadVolatileSemantics::SetVolatileForLoadsInEntries(
Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids) {
// Set Volatile memory operand for all load instructions if they do not have
// it.
VisitLoadsOfPointersToVariableInEntries(
var->result_id(),
[](Instruction* load) {
if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
load->AddOperand(
{SPV_OPERAND_TYPE_MEMORY_ACCESS, {SpvMemoryAccessVolatileMask}});
for (auto entry_id : entry_function_ids) {
std::unordered_set<uint32_t> funcs;
context()->CollectCallTreeFromRoots(entry_id, &funcs);
VisitLoadsOfPointersToVariableInEntries(
var->result_id(),
[](Instruction* load) {
if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
load->AddOperand({SPV_OPERAND_TYPE_MEMORY_ACCESS,
{SpvMemoryAccessVolatileMask}});
return true;
}
uint32_t memory_operands =
load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
memory_operands |= SpvMemoryAccessVolatileMask;
load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
return true;
}
uint32_t memory_operands =
load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
memory_operands |= SpvMemoryAccessVolatileMask;
load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
return true;
},
entry_function_ids);
},
funcs);
}
}
bool SpreadVolatileSemantics::IsTargetForVolatileSemantics(

View File

@ -72,15 +72,14 @@ class SpreadVolatileSemantics : public Pass {
Instruction* entry_point);
// Visits load instructions of pointers to variable whose result id is
// |var_id| if the load instructions are in entry points whose
// function id is one of |entry_function_ids|. |handle_load| is a function to
// do some actions for the load instructions. Finishes the traversal and
// returns false if |handle_load| returns false for a load instruction.
// Otherwise, returns true after running |handle_load| for all the load
// instructions.
// |var_id| if the load instructions are in reachable functions from entry
// points. |handle_load| is a function to do some actions for the load
// instructions. Finishes the traversal and returns false if |handle_load|
// returns false for a load instruction. Otherwise, returns true after running
// |handle_load| for all the load instructions.
bool VisitLoadsOfPointersToVariableInEntries(
uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
const std::unordered_set<uint32_t>& entry_function_ids);
const std::unordered_set<uint32_t>& function_ids);
// Sets Memory Operands of OpLoad instructions that load |var| or pointers
// of |var| as Volatile if the function id of the OpLoad instruction is

View File

@ -54,6 +54,7 @@ OpSource GLSL 460
OpSourceExtension "GL_EXT_nonuniform_qualifier"
OpSourceExtension "GL_KHR_ray_tracing"
OpName %main "main"
OpName %fn "fn"
OpName %StorageBuffer "StorageBuffer"
OpMemberName %StorageBuffer 0 "index"
OpMemberName %StorageBuffer 1 "red"
@ -109,6 +110,11 @@ OpDecorate %var BuiltIn )") + built_in + std::string(R"(
%29 = OpCompositeExtract %float %27 0
%31 = OpAccessChain %_ptr_Uniform_float %sbo %int_1
OpStore %31 %29
%32 = OpFunctionCall %void %fn
OpReturn
OpFunctionEnd
%fn = OpFunction %void None %3
%33 = OpLabel
OpReturn
OpFunctionEnd
)");
@ -782,12 +788,7 @@ OpReturn
OpFunctionEnd
)";
EXPECT_EQ(RunPass(text), Pass::Status::Failure);
const char expected_error[] =
"ERROR: 0: Functions of SPIR-V for spread-volatile-semantics pass "
"input must be inlined except entry points";
EXPECT_STREQ(GetErrorMessage().substr(0, sizeof(expected_error) - 1).c_str(),
expected_error);
EXPECT_EQ(RunPass(text), Pass::Status::SuccessWithoutChange);
}
TEST_F(VolatileSpreadErrorTest, VarNotUsedInEntryPointForVolatile) {
@ -1133,6 +1134,134 @@ OpFunctionEnd
EXPECT_EQ(status, Pass::Status::SuccessWithoutChange);
}
TEST_F(VolatileSpreadTest, NoInlinedfuncCalls) {
const std::string text = R"(
OpCapability RayTracingNV
OpCapability VulkanMemoryModel
OpCapability GroupNonUniform
OpExtension "SPV_NV_ray_tracing"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical Vulkan
OpEntryPoint RayGenerationNV %main "main" %SubgroupSize
OpSource HLSL 630
OpName %main "main"
OpName %src_main "src.main"
OpName %bb_entry "bb.entry"
OpName %func0 "func0"
OpName %bb_entry_0 "bb.entry"
OpName %func2 "func2"
OpName %bb_entry_1 "bb.entry"
OpName %param_var_count "param.var.count"
OpName %func1 "func1"
OpName %bb_entry_2 "bb.entry"
OpName %func3 "func3"
OpName %count "count"
OpName %bb_entry_3 "bb.entry"
OpDecorate %SubgroupSize BuiltIn SubgroupSize
%uint = OpTypeInt 32 0
%_ptr_Input_uint = OpTypePointer Input %uint
%void = OpTypeVoid
%6 = OpTypeFunction %void
%_ptr_Function_uint = OpTypePointer Function %uint
%25 = OpTypeFunction %void %_ptr_Function_uint
%SubgroupSize = OpVariable %_ptr_Input_uint Input
%main = OpFunction %void None %6
%7 = OpLabel
%8 = OpFunctionCall %void %src_main
OpReturn
OpFunctionEnd
%src_main = OpFunction %void None %6
%bb_entry = OpLabel
%11 = OpFunctionCall %void %func0
OpReturn
OpFunctionEnd
%func0 = OpFunction %void DontInline %6
%bb_entry_0 = OpLabel
%14 = OpFunctionCall %void %func2
%16 = OpFunctionCall %void %func1
OpReturn
OpFunctionEnd
%func2 = OpFunction %void DontInline %6
%bb_entry_1 = OpLabel
%param_var_count = OpVariable %_ptr_Function_uint Function
; CHECK: {{%\w+}} = OpLoad %uint %SubgroupSize Volatile
%21 = OpLoad %uint %SubgroupSize
OpStore %param_var_count %21
%22 = OpFunctionCall %void %func3 %param_var_count
OpReturn
OpFunctionEnd
%func1 = OpFunction %void DontInline %6
%bb_entry_2 = OpLabel
OpReturn
OpFunctionEnd
%func3 = OpFunction %void DontInline %25
%count = OpFunctionParameter %_ptr_Function_uint
%bb_entry_3 = OpLabel
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<SpreadVolatileSemantics>(text, true);
}
TEST_F(VolatileSpreadErrorTest, NoInlinedMultiEntryfuncCalls) {
const std::string text = R"(
OpCapability RayTracingNV
OpCapability SubgroupBallotKHR
OpExtension "SPV_NV_ray_tracing"
OpExtension "SPV_KHR_shader_ballot"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint RayGenerationNV %main "main" %SubgroupSize
OpEntryPoint GLCompute %main2 "main2" %gl_LocalInvocationIndex %SubgroupSize
OpSource HLSL 630
OpName %main "main"
OpName %bb_entry "bb.entry"
OpName %main2 "main2"
OpName %bb_entry_0 "bb.entry"
OpName %func "func"
OpName %count "count"
OpName %bb_entry_1 "bb.entry"
OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex
OpDecorate %SubgroupSize BuiltIn SubgroupSize
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%_ptr_Input_uint = OpTypePointer Input %uint
%float = OpTypeFloat 32
%v4float = OpTypeVector %float 4
%void = OpTypeVoid
%12 = OpTypeFunction %void
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Function_v4float = OpTypePointer Function %v4float
%29 = OpTypeFunction %void %_ptr_Function_v4float
%34 = OpTypeFunction %void %_ptr_Function_uint
%SubgroupSize = OpVariable %_ptr_Input_uint Input
%gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
%main = OpFunction %void None %12
%bb_entry = OpLabel
%20 = OpFunctionCall %void %func
OpReturn
OpFunctionEnd
%main2 = OpFunction %void None %12
%bb_entry_0 = OpLabel
%33 = OpFunctionCall %void %func
OpReturn
OpFunctionEnd
%func = OpFunction %void DontInline %12
%bb_entry_1 = OpLabel
%count = OpVariable %_ptr_Function_uint Function
%35 = OpLoad %uint %SubgroupSize
OpStore %count %35
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(RunPass(text), Pass::Status::Failure);
const char expected_error[] =
"ERROR: 0: Variable is a target for Volatile semantics for an entry "
"point, but it is not for another entry point";
EXPECT_STREQ(GetErrorMessage().substr(0, sizeof(expected_error) - 1).c_str(),
expected_error);
}
} // namespace
} // namespace opt
} // namespace spvtools