mirror of
https://gitee.com/openharmony/third_party_spirv-tools
synced 2024-11-23 07:20:28 +00:00
Handle dontinline function in spread-volatile-semantics (#4776)
Handle function calls in spread-volatile-semantics
This commit is contained in:
parent
58dc37ea6a
commit
2c7fb9707b
@ -926,6 +926,19 @@ bool IRContext::ProcessCallTreeFromRoots(ProcessFunction& pfn,
|
||||
return modified;
|
||||
}
|
||||
|
||||
void IRContext::CollectCallTreeFromRoots(unsigned entryId,
|
||||
std::unordered_set<uint32_t>* funcs) {
|
||||
std::queue<uint32_t> roots;
|
||||
roots.push(entryId);
|
||||
while (!roots.empty()) {
|
||||
const uint32_t fi = roots.front();
|
||||
roots.pop();
|
||||
funcs->insert(fi);
|
||||
Function* fn = GetFunction(fi);
|
||||
AddCalls(fn, &roots);
|
||||
}
|
||||
}
|
||||
|
||||
void IRContext::EmitErrorMessage(std::string message, Instruction* inst) {
|
||||
if (!consumer()) {
|
||||
return;
|
||||
|
@ -411,6 +411,10 @@ class IRContext {
|
||||
void CollectNonSemanticTree(Instruction* inst,
|
||||
std::unordered_set<Instruction*>* to_kill);
|
||||
|
||||
// Collect function reachable from |entryId|, returns |funcs|
|
||||
void CollectCallTreeFromRoots(unsigned entryId,
|
||||
std::unordered_set<uint32_t>* funcs);
|
||||
|
||||
// Returns true if all of the given analyses are valid.
|
||||
bool AreAnalysesValid(Analysis set) { return (set & valid_analyses_) == set; }
|
||||
|
||||
|
@ -68,38 +68,12 @@ bool HasVolatileDecoration(analysis::DecorationManager* decoration_manager,
|
||||
return decoration_manager->HasDecoration(var_id, SpvDecorationVolatile);
|
||||
}
|
||||
|
||||
bool HasOnlyEntryPointsAsFunctions(IRContext* context, Module* module) {
|
||||
std::unordered_set<uint32_t> entry_function_ids;
|
||||
for (Instruction& entry_point : module->entry_points()) {
|
||||
entry_function_ids.insert(
|
||||
entry_point.GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint));
|
||||
}
|
||||
for (auto& function : *module) {
|
||||
if (entry_function_ids.find(function.result_id()) ==
|
||||
entry_function_ids.end()) {
|
||||
std::string message(
|
||||
"Functions of SPIR-V for spread-volatile-semantics pass input must "
|
||||
"be inlined except entry points");
|
||||
message += "\n " + function.DefInst().PrettyPrint(
|
||||
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
|
||||
context->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Pass::Status SpreadVolatileSemantics::Process() {
|
||||
if (HasNoExecutionModel()) {
|
||||
return Status::SuccessWithoutChange;
|
||||
}
|
||||
|
||||
if (!HasOnlyEntryPointsAsFunctions(context(), get_module())) {
|
||||
return Status::Failure;
|
||||
}
|
||||
|
||||
const bool is_vk_memory_model_enabled =
|
||||
context()->get_feature_mgr()->HasCapability(
|
||||
SpvCapabilityVulkanMemoryModel);
|
||||
@ -142,6 +116,8 @@ bool SpreadVolatileSemantics::IsTargetUsedByNonVolatileLoadInEntryPoint(
|
||||
uint32_t var_id, Instruction* entry_point) {
|
||||
uint32_t entry_function_id =
|
||||
entry_point->GetSingleWordInOperand(kOpEntryPointInOperandEntryPoint);
|
||||
std::unordered_set<uint32_t> funcs;
|
||||
context()->CollectCallTreeFromRoots(entry_function_id, &funcs);
|
||||
return !VisitLoadsOfPointersToVariableInEntries(
|
||||
var_id,
|
||||
[](Instruction* load) {
|
||||
@ -154,7 +130,7 @@ bool SpreadVolatileSemantics::IsTargetUsedByNonVolatileLoadInEntryPoint(
|
||||
load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
|
||||
return (memory_operands & SpvMemoryAccessVolatileMask) != 0;
|
||||
},
|
||||
{entry_function_id});
|
||||
funcs);
|
||||
}
|
||||
|
||||
bool SpreadVolatileSemantics::HasInterfaceInConflictOfVolatileSemantics() {
|
||||
@ -225,7 +201,7 @@ void SpreadVolatileSemantics::DecorateVarWithVolatile(Instruction* var) {
|
||||
|
||||
bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries(
|
||||
uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
|
||||
const std::unordered_set<uint32_t>& entry_function_ids) {
|
||||
const std::unordered_set<uint32_t>& function_ids) {
|
||||
std::vector<uint32_t> worklist({var_id});
|
||||
auto* def_use_mgr = context()->get_def_use_mgr();
|
||||
while (!worklist.empty()) {
|
||||
@ -233,11 +209,11 @@ bool SpreadVolatileSemantics::VisitLoadsOfPointersToVariableInEntries(
|
||||
worklist.pop_back();
|
||||
bool finish_traversal = !def_use_mgr->WhileEachUser(
|
||||
ptr_id, [this, &worklist, &ptr_id, handle_load,
|
||||
&entry_function_ids](Instruction* user) {
|
||||
&function_ids](Instruction* user) {
|
||||
BasicBlock* block = context()->get_instr_block(user);
|
||||
if (block == nullptr ||
|
||||
entry_function_ids.find(block->GetParent()->result_id()) ==
|
||||
entry_function_ids.end()) {
|
||||
function_ids.find(block->GetParent()->result_id()) ==
|
||||
function_ids.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -266,21 +242,25 @@ void SpreadVolatileSemantics::SetVolatileForLoadsInEntries(
|
||||
Instruction* var, const std::unordered_set<uint32_t>& entry_function_ids) {
|
||||
// Set Volatile memory operand for all load instructions if they do not have
|
||||
// it.
|
||||
VisitLoadsOfPointersToVariableInEntries(
|
||||
var->result_id(),
|
||||
[](Instruction* load) {
|
||||
if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
|
||||
load->AddOperand(
|
||||
{SPV_OPERAND_TYPE_MEMORY_ACCESS, {SpvMemoryAccessVolatileMask}});
|
||||
for (auto entry_id : entry_function_ids) {
|
||||
std::unordered_set<uint32_t> funcs;
|
||||
context()->CollectCallTreeFromRoots(entry_id, &funcs);
|
||||
VisitLoadsOfPointersToVariableInEntries(
|
||||
var->result_id(),
|
||||
[](Instruction* load) {
|
||||
if (load->NumInOperands() <= kOpLoadInOperandMemoryOperands) {
|
||||
load->AddOperand({SPV_OPERAND_TYPE_MEMORY_ACCESS,
|
||||
{SpvMemoryAccessVolatileMask}});
|
||||
return true;
|
||||
}
|
||||
uint32_t memory_operands =
|
||||
load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
|
||||
memory_operands |= SpvMemoryAccessVolatileMask;
|
||||
load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
|
||||
return true;
|
||||
}
|
||||
uint32_t memory_operands =
|
||||
load->GetSingleWordInOperand(kOpLoadInOperandMemoryOperands);
|
||||
memory_operands |= SpvMemoryAccessVolatileMask;
|
||||
load->SetInOperand(kOpLoadInOperandMemoryOperands, {memory_operands});
|
||||
return true;
|
||||
},
|
||||
entry_function_ids);
|
||||
},
|
||||
funcs);
|
||||
}
|
||||
}
|
||||
|
||||
bool SpreadVolatileSemantics::IsTargetForVolatileSemantics(
|
||||
|
@ -72,15 +72,14 @@ class SpreadVolatileSemantics : public Pass {
|
||||
Instruction* entry_point);
|
||||
|
||||
// Visits load instructions of pointers to variable whose result id is
|
||||
// |var_id| if the load instructions are in entry points whose
|
||||
// function id is one of |entry_function_ids|. |handle_load| is a function to
|
||||
// do some actions for the load instructions. Finishes the traversal and
|
||||
// returns false if |handle_load| returns false for a load instruction.
|
||||
// Otherwise, returns true after running |handle_load| for all the load
|
||||
// instructions.
|
||||
// |var_id| if the load instructions are in reachable functions from entry
|
||||
// points. |handle_load| is a function to do some actions for the load
|
||||
// instructions. Finishes the traversal and returns false if |handle_load|
|
||||
// returns false for a load instruction. Otherwise, returns true after running
|
||||
// |handle_load| for all the load instructions.
|
||||
bool VisitLoadsOfPointersToVariableInEntries(
|
||||
uint32_t var_id, const std::function<bool(Instruction*)>& handle_load,
|
||||
const std::unordered_set<uint32_t>& entry_function_ids);
|
||||
const std::unordered_set<uint32_t>& function_ids);
|
||||
|
||||
// Sets Memory Operands of OpLoad instructions that load |var| or pointers
|
||||
// of |var| as Volatile if the function id of the OpLoad instruction is
|
||||
|
@ -54,6 +54,7 @@ OpSource GLSL 460
|
||||
OpSourceExtension "GL_EXT_nonuniform_qualifier"
|
||||
OpSourceExtension "GL_KHR_ray_tracing"
|
||||
OpName %main "main"
|
||||
OpName %fn "fn"
|
||||
OpName %StorageBuffer "StorageBuffer"
|
||||
OpMemberName %StorageBuffer 0 "index"
|
||||
OpMemberName %StorageBuffer 1 "red"
|
||||
@ -109,6 +110,11 @@ OpDecorate %var BuiltIn )") + built_in + std::string(R"(
|
||||
%29 = OpCompositeExtract %float %27 0
|
||||
%31 = OpAccessChain %_ptr_Uniform_float %sbo %int_1
|
||||
OpStore %31 %29
|
||||
%32 = OpFunctionCall %void %fn
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%fn = OpFunction %void None %3
|
||||
%33 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
@ -782,12 +788,7 @@ OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
EXPECT_EQ(RunPass(text), Pass::Status::Failure);
|
||||
const char expected_error[] =
|
||||
"ERROR: 0: Functions of SPIR-V for spread-volatile-semantics pass "
|
||||
"input must be inlined except entry points";
|
||||
EXPECT_STREQ(GetErrorMessage().substr(0, sizeof(expected_error) - 1).c_str(),
|
||||
expected_error);
|
||||
EXPECT_EQ(RunPass(text), Pass::Status::SuccessWithoutChange);
|
||||
}
|
||||
|
||||
TEST_F(VolatileSpreadErrorTest, VarNotUsedInEntryPointForVolatile) {
|
||||
@ -1133,6 +1134,134 @@ OpFunctionEnd
|
||||
EXPECT_EQ(status, Pass::Status::SuccessWithoutChange);
|
||||
}
|
||||
|
||||
TEST_F(VolatileSpreadTest, NoInlinedfuncCalls) {
|
||||
const std::string text = R"(
|
||||
OpCapability RayTracingNV
|
||||
OpCapability VulkanMemoryModel
|
||||
OpCapability GroupNonUniform
|
||||
OpExtension "SPV_NV_ray_tracing"
|
||||
OpExtension "SPV_KHR_vulkan_memory_model"
|
||||
OpMemoryModel Logical Vulkan
|
||||
OpEntryPoint RayGenerationNV %main "main" %SubgroupSize
|
||||
OpSource HLSL 630
|
||||
OpName %main "main"
|
||||
OpName %src_main "src.main"
|
||||
OpName %bb_entry "bb.entry"
|
||||
OpName %func0 "func0"
|
||||
OpName %bb_entry_0 "bb.entry"
|
||||
OpName %func2 "func2"
|
||||
OpName %bb_entry_1 "bb.entry"
|
||||
OpName %param_var_count "param.var.count"
|
||||
OpName %func1 "func1"
|
||||
OpName %bb_entry_2 "bb.entry"
|
||||
OpName %func3 "func3"
|
||||
OpName %count "count"
|
||||
OpName %bb_entry_3 "bb.entry"
|
||||
OpDecorate %SubgroupSize BuiltIn SubgroupSize
|
||||
%uint = OpTypeInt 32 0
|
||||
%_ptr_Input_uint = OpTypePointer Input %uint
|
||||
%void = OpTypeVoid
|
||||
%6 = OpTypeFunction %void
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%25 = OpTypeFunction %void %_ptr_Function_uint
|
||||
%SubgroupSize = OpVariable %_ptr_Input_uint Input
|
||||
%main = OpFunction %void None %6
|
||||
%7 = OpLabel
|
||||
%8 = OpFunctionCall %void %src_main
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%src_main = OpFunction %void None %6
|
||||
%bb_entry = OpLabel
|
||||
%11 = OpFunctionCall %void %func0
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%func0 = OpFunction %void DontInline %6
|
||||
%bb_entry_0 = OpLabel
|
||||
%14 = OpFunctionCall %void %func2
|
||||
%16 = OpFunctionCall %void %func1
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%func2 = OpFunction %void DontInline %6
|
||||
%bb_entry_1 = OpLabel
|
||||
%param_var_count = OpVariable %_ptr_Function_uint Function
|
||||
; CHECK: {{%\w+}} = OpLoad %uint %SubgroupSize Volatile
|
||||
%21 = OpLoad %uint %SubgroupSize
|
||||
OpStore %param_var_count %21
|
||||
%22 = OpFunctionCall %void %func3 %param_var_count
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%func1 = OpFunction %void DontInline %6
|
||||
%bb_entry_2 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%func3 = OpFunction %void DontInline %25
|
||||
%count = OpFunctionParameter %_ptr_Function_uint
|
||||
%bb_entry_3 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
SinglePassRunAndMatch<SpreadVolatileSemantics>(text, true);
|
||||
}
|
||||
|
||||
TEST_F(VolatileSpreadErrorTest, NoInlinedMultiEntryfuncCalls) {
|
||||
const std::string text = R"(
|
||||
OpCapability RayTracingNV
|
||||
OpCapability SubgroupBallotKHR
|
||||
OpExtension "SPV_NV_ray_tracing"
|
||||
OpExtension "SPV_KHR_shader_ballot"
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint RayGenerationNV %main "main" %SubgroupSize
|
||||
OpEntryPoint GLCompute %main2 "main2" %gl_LocalInvocationIndex %SubgroupSize
|
||||
OpSource HLSL 630
|
||||
OpName %main "main"
|
||||
OpName %bb_entry "bb.entry"
|
||||
OpName %main2 "main2"
|
||||
OpName %bb_entry_0 "bb.entry"
|
||||
OpName %func "func"
|
||||
OpName %count "count"
|
||||
OpName %bb_entry_1 "bb.entry"
|
||||
OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex
|
||||
OpDecorate %SubgroupSize BuiltIn SubgroupSize
|
||||
%uint = OpTypeInt 32 0
|
||||
%uint_0 = OpConstant %uint 0
|
||||
%_ptr_Input_uint = OpTypePointer Input %uint
|
||||
%float = OpTypeFloat 32
|
||||
%v4float = OpTypeVector %float 4
|
||||
%void = OpTypeVoid
|
||||
%12 = OpTypeFunction %void
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%_ptr_Function_v4float = OpTypePointer Function %v4float
|
||||
%29 = OpTypeFunction %void %_ptr_Function_v4float
|
||||
%34 = OpTypeFunction %void %_ptr_Function_uint
|
||||
%SubgroupSize = OpVariable %_ptr_Input_uint Input
|
||||
%gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
|
||||
%main = OpFunction %void None %12
|
||||
%bb_entry = OpLabel
|
||||
%20 = OpFunctionCall %void %func
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%main2 = OpFunction %void None %12
|
||||
%bb_entry_0 = OpLabel
|
||||
%33 = OpFunctionCall %void %func
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%func = OpFunction %void DontInline %12
|
||||
%bb_entry_1 = OpLabel
|
||||
%count = OpVariable %_ptr_Function_uint Function
|
||||
%35 = OpLoad %uint %SubgroupSize
|
||||
OpStore %count %35
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
EXPECT_EQ(RunPass(text), Pass::Status::Failure);
|
||||
const char expected_error[] =
|
||||
"ERROR: 0: Variable is a target for Volatile semantics for an entry "
|
||||
"point, but it is not for another entry point";
|
||||
EXPECT_STREQ(GetErrorMessage().substr(0, sizeof(expected_error) - 1).c_str(),
|
||||
expected_error);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
|
Loading…
Reference in New Issue
Block a user