mirror of
https://gitee.com/openharmony/third_party_spirv-tools
synced 2024-11-27 17:40:28 +00:00
Check for recursion in Vulkan and WebGPU entry points (#2161)
Fixes #2061 Fixes #2160
This commit is contained in:
parent
2f5f5308b6
commit
378b7f3a29
@ -175,14 +175,19 @@ spv_result_t ValidateForwardDecls(ValidationState_t& _) {
|
||||
// capability is being used.
|
||||
// * No function can be targeted by both an OpEntryPoint instruction and an
|
||||
// OpFunctionCall instruction.
|
||||
//
|
||||
// Additionally enforces that entry points for Vulkan and WebGPU should not have
|
||||
// recursion.
|
||||
spv_result_t ValidateEntryPoints(ValidationState_t& _) {
|
||||
_.ComputeFunctionToEntryPointMapping();
|
||||
_.ComputeRecursiveEntryPoints();
|
||||
|
||||
if (_.entry_points().empty() && !_.HasCapability(SpvCapabilityLinkage)) {
|
||||
return _.diag(SPV_ERROR_INVALID_BINARY, nullptr)
|
||||
<< "No OpEntryPoint instruction was found. This is only allowed if "
|
||||
"the Linkage capability is being used.";
|
||||
}
|
||||
|
||||
for (const auto& entry_point : _.entry_points()) {
|
||||
if (_.IsFunctionCallTarget(entry_point)) {
|
||||
return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
|
||||
@ -190,6 +195,17 @@ spv_result_t ValidateEntryPoints(ValidationState_t& _) {
|
||||
<< ") may not be targeted by both an OpEntryPoint instruction and "
|
||||
"an OpFunctionCall instruction.";
|
||||
}
|
||||
|
||||
// For Vulkan and WebGPU, the static function-call graph for an entry point
|
||||
// must not contain cycles.
|
||||
if (spvIsWebGPUEnv(_.context()->target_env) ||
|
||||
spvIsVulkanEnv(_.context()->target_env)) {
|
||||
if (_.recursive_entry_points().find(entry_point) !=
|
||||
_.recursive_entry_points().end()) {
|
||||
return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
|
||||
<< "Entry points may not have a call graph with cycles.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SPV_SUCCESS;
|
||||
|
@ -919,6 +919,39 @@ void ValidationState_t::ComputeFunctionToEntryPointMapping() {
|
||||
}
|
||||
}
|
||||
|
||||
void ValidationState_t::ComputeRecursiveEntryPoints() {
|
||||
for (const Function func : functions()) {
|
||||
std::stack<uint32_t> call_stack;
|
||||
std::set<uint32_t> visited;
|
||||
|
||||
for (const uint32_t new_call : func.function_call_targets()) {
|
||||
call_stack.push(new_call);
|
||||
}
|
||||
|
||||
while (!call_stack.empty()) {
|
||||
const uint32_t called_func_id = call_stack.top();
|
||||
call_stack.pop();
|
||||
|
||||
if (!visited.insert(called_func_id).second) continue;
|
||||
|
||||
if (called_func_id == func.id()) {
|
||||
for (const uint32_t entry_point :
|
||||
function_to_entry_points_[called_func_id])
|
||||
recursive_entry_points_.insert(entry_point);
|
||||
break;
|
||||
}
|
||||
|
||||
const Function* called_func = function(called_func_id);
|
||||
if (called_func) {
|
||||
// Other checks should error out on this invalid SPIR-V.
|
||||
for (const uint32_t new_call : called_func->function_call_targets()) {
|
||||
call_stack.push(new_call);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<uint32_t>& ValidationState_t::FunctionEntryPoints(
|
||||
uint32_t func) const {
|
||||
auto iter = function_to_entry_points_.find(func);
|
||||
|
@ -222,6 +222,12 @@ class ValidationState_t {
|
||||
/// Returns a list of entry point function ids
|
||||
const std::vector<uint32_t>& entry_points() const { return entry_points_; }
|
||||
|
||||
/// Returns the set of entry points that root call graphs that contain
|
||||
/// recursion.
|
||||
const std::set<uint32_t>& recursive_entry_points() const {
|
||||
return recursive_entry_points_;
|
||||
}
|
||||
|
||||
/// Registers execution mode for the given entry point.
|
||||
void RegisterExecutionModeForEntryPoint(uint32_t entry_point,
|
||||
SpvExecutionMode execution_mode) {
|
||||
@ -261,6 +267,11 @@ class ValidationState_t {
|
||||
/// Note: called after fully parsing the binary.
|
||||
void ComputeFunctionToEntryPointMapping();
|
||||
|
||||
/// Traverse call tree and computes recursive_entry_points_.
|
||||
/// Note: called after fully parsing the binary and calling
|
||||
/// ComputeFunctionToEntryPointMapping.
|
||||
void ComputeRecursiveEntryPoints();
|
||||
|
||||
/// Returns all the entry points that can call |func|.
|
||||
const std::vector<uint32_t>& FunctionEntryPoints(uint32_t func) const;
|
||||
|
||||
@ -610,6 +621,10 @@ class ValidationState_t {
|
||||
std::unordered_map<uint32_t, std::vector<EntryPointDescription>>
|
||||
entry_point_descriptions_;
|
||||
|
||||
/// IDs that are entry points, ie, arguments to OpEntryPoint, and root a call
|
||||
/// graph that recurses.
|
||||
std::set<uint32_t> recursive_entry_points_;
|
||||
|
||||
/// Functions IDs that are target of OpFunctionCall.
|
||||
std::unordered_set<uint32_t> function_call_targets_;
|
||||
|
||||
|
@ -29,11 +29,17 @@ using ::testing::HasSubstr;
|
||||
|
||||
using ValidationStateTest = spvtest::ValidateBase<bool>;
|
||||
|
||||
const char header[] =
|
||||
const char kHeader[] =
|
||||
" OpCapability Shader"
|
||||
" OpCapability Linkage"
|
||||
" OpMemoryModel Logical GLSL450 ";
|
||||
|
||||
const char kVulkanMemoryHeader[] =
|
||||
" OpCapability Shader"
|
||||
" OpCapability VulkanMemoryModelKHR"
|
||||
" OpExtension \"SPV_KHR_vulkan_memory_model\""
|
||||
" OpMemoryModel Logical VulkanKHR ";
|
||||
|
||||
const char kVoidFVoid[] =
|
||||
" %void = OpTypeVoid"
|
||||
" %void_f = OpTypeFunction %void"
|
||||
@ -42,9 +48,79 @@ const char kVoidFVoid[] =
|
||||
" OpReturn"
|
||||
" OpFunctionEnd ";
|
||||
|
||||
// k*RecursiveBody examples originally from test/opt/function_test.cpp
|
||||
const char* kNonRecursiveBody = R"(
|
||||
OpEntryPoint Fragment %1 "main"
|
||||
OpExecutionMode %1 OriginUpperLeft
|
||||
%void = OpTypeVoid
|
||||
%4 = OpTypeFunction %void
|
||||
%float = OpTypeFloat 32
|
||||
%_struct_6 = OpTypeStruct %float %float
|
||||
%7 = OpTypeFunction %_struct_6
|
||||
%1 = OpFunction %void Pure|Const %4
|
||||
%8 = OpLabel
|
||||
%2 = OpFunctionCall %_struct_6 %9
|
||||
OpKill
|
||||
OpFunctionEnd
|
||||
%9 = OpFunction %_struct_6 None %7
|
||||
%10 = OpLabel
|
||||
%11 = OpFunctionCall %_struct_6 %12
|
||||
OpUnreachable
|
||||
OpFunctionEnd
|
||||
%12 = OpFunction %_struct_6 None %7
|
||||
%13 = OpLabel
|
||||
OpUnreachable
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
const char* kDirectlyRecursiveBody = R"(
|
||||
OpEntryPoint Fragment %1 "main"
|
||||
OpExecutionMode %1 OriginUpperLeft
|
||||
%void = OpTypeVoid
|
||||
%4 = OpTypeFunction %void
|
||||
%float = OpTypeFloat 32
|
||||
%_struct_6 = OpTypeStruct %float %float
|
||||
%7 = OpTypeFunction %_struct_6
|
||||
%1 = OpFunction %void Pure|Const %4
|
||||
%8 = OpLabel
|
||||
%2 = OpFunctionCall %_struct_6 %9
|
||||
OpKill
|
||||
OpFunctionEnd
|
||||
%9 = OpFunction %_struct_6 None %7
|
||||
%10 = OpLabel
|
||||
%11 = OpFunctionCall %_struct_6 %9
|
||||
OpUnreachable
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
const char* kIndirectlyRecursiveBody = R"(
|
||||
OpEntryPoint Fragment %1 "main"
|
||||
OpExecutionMode %1 OriginUpperLeft
|
||||
%void = OpTypeVoid
|
||||
%4 = OpTypeFunction %void
|
||||
%float = OpTypeFloat 32
|
||||
%_struct_6 = OpTypeStruct %float %float
|
||||
%7 = OpTypeFunction %_struct_6
|
||||
%1 = OpFunction %void Pure|Const %4
|
||||
%8 = OpLabel
|
||||
%2 = OpFunctionCall %_struct_6 %9
|
||||
OpKill
|
||||
OpFunctionEnd
|
||||
%9 = OpFunction %_struct_6 None %7
|
||||
%10 = OpLabel
|
||||
%11 = OpFunctionCall %_struct_6 %12
|
||||
OpUnreachable
|
||||
OpFunctionEnd
|
||||
%12 = OpFunction %_struct_6 None %7
|
||||
%13 = OpLabel
|
||||
%14 = OpFunctionCall %_struct_6 %9
|
||||
OpUnreachable
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
// Tests that the instruction count in ValidationState is correct.
|
||||
TEST_F(ValidationStateTest, CheckNumInstructions) {
|
||||
std::string spirv = std::string(header) + "%int = OpTypeInt 32 0";
|
||||
std::string spirv = std::string(kHeader) + "%int = OpTypeInt 32 0";
|
||||
CompileSuccessfully(spirv);
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
|
||||
EXPECT_EQ(size_t(4), vstate_->ordered_instructions().size());
|
||||
@ -52,7 +128,7 @@ TEST_F(ValidationStateTest, CheckNumInstructions) {
|
||||
|
||||
// Tests that the number of global variables in ValidationState is correct.
|
||||
TEST_F(ValidationStateTest, CheckNumGlobalVars) {
|
||||
std::string spirv = std::string(header) + R"(
|
||||
std::string spirv = std::string(kHeader) + R"(
|
||||
%int = OpTypeInt 32 0
|
||||
%_ptr_int = OpTypePointer Input %int
|
||||
%var_1 = OpVariable %_ptr_int Input
|
||||
@ -65,7 +141,7 @@ TEST_F(ValidationStateTest, CheckNumGlobalVars) {
|
||||
|
||||
// Tests that the number of local variables in ValidationState is correct.
|
||||
TEST_F(ValidationStateTest, CheckNumLocalVars) {
|
||||
std::string spirv = std::string(header) + R"(
|
||||
std::string spirv = std::string(kHeader) + R"(
|
||||
%int = OpTypeInt 32 0
|
||||
%_ptr_int = OpTypePointer Function %int
|
||||
%voidt = OpTypeVoid
|
||||
@ -85,7 +161,7 @@ TEST_F(ValidationStateTest, CheckNumLocalVars) {
|
||||
|
||||
// Tests that the "id bound" in ValidationState is correct.
|
||||
TEST_F(ValidationStateTest, CheckIdBound) {
|
||||
std::string spirv = std::string(header) + R"(
|
||||
std::string spirv = std::string(kHeader) + R"(
|
||||
%int = OpTypeInt 32 0
|
||||
%voidt = OpTypeVoid
|
||||
)";
|
||||
@ -96,7 +172,7 @@ TEST_F(ValidationStateTest, CheckIdBound) {
|
||||
|
||||
// Tests that the entry_points in ValidationState is correct.
|
||||
TEST_F(ValidationStateTest, CheckEntryPoints) {
|
||||
std::string spirv = std::string(header) +
|
||||
std::string spirv = std::string(kHeader) +
|
||||
" OpEntryPoint Vertex %func \"shader\"" +
|
||||
std::string(kVoidFVoid);
|
||||
CompileSuccessfully(spirv);
|
||||
@ -154,6 +230,79 @@ TEST_F(ValidationStateTest, CheckAccessChainIndexesLimitOption) {
|
||||
EXPECT_EQ(100u, options_->universal_limits_.max_access_chain_indexes);
|
||||
}
|
||||
|
||||
TEST_F(ValidationStateTest, CheckNonRecursiveBodyGood) {
|
||||
std::string spirv = std::string(kHeader) + kNonRecursiveBody;
|
||||
CompileSuccessfully(spirv);
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
|
||||
}
|
||||
|
||||
TEST_F(ValidationStateTest, CheckVulkanNonRecursiveBodyGood) {
|
||||
std::string spirv = std::string(kVulkanMemoryHeader) + kNonRecursiveBody;
|
||||
CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
|
||||
EXPECT_EQ(SPV_SUCCESS,
|
||||
ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
|
||||
}
|
||||
|
||||
TEST_F(ValidationStateTest, CheckWebGPUNonRecursiveBodyGood) {
|
||||
std::string spirv = std::string(kVulkanMemoryHeader) + kNonRecursiveBody;
|
||||
CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
|
||||
}
|
||||
|
||||
TEST_F(ValidationStateTest, CheckDirectlyRecursiveBodyGood) {
|
||||
std::string spirv = std::string(kHeader) + kDirectlyRecursiveBody;
|
||||
CompileSuccessfully(spirv);
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
|
||||
}
|
||||
|
||||
TEST_F(ValidationStateTest, CheckVulkanDirectlyRecursiveBodyBad) {
|
||||
std::string spirv = std::string(kVulkanMemoryHeader) + kDirectlyRecursiveBody;
|
||||
CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
|
||||
EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
|
||||
ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
HasSubstr("Entry points may not have a call graph with cycles.\n "
|
||||
" %1 = OpFunction %void Pure|Const %3\n"));
|
||||
}
|
||||
|
||||
TEST_F(ValidationStateTest, CheckWebGPUDirectlyRecursiveBodyBad) {
|
||||
std::string spirv = std::string(kVulkanMemoryHeader) + kDirectlyRecursiveBody;
|
||||
CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
|
||||
EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
|
||||
ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
HasSubstr("Entry points may not have a call graph with cycles.\n "
|
||||
" %1 = OpFunction %void Pure|Const %3\n"));
|
||||
}
|
||||
|
||||
TEST_F(ValidationStateTest, CheckIndirectlyRecursiveBodyGood) {
|
||||
std::string spirv = std::string(kHeader) + kIndirectlyRecursiveBody;
|
||||
CompileSuccessfully(spirv);
|
||||
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
|
||||
}
|
||||
|
||||
TEST_F(ValidationStateTest, CheckVulkanIndirectlyRecursiveBodyBad) {
|
||||
std::string spirv =
|
||||
std::string(kVulkanMemoryHeader) + kIndirectlyRecursiveBody;
|
||||
CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
|
||||
EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
|
||||
ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
HasSubstr("Entry points may not have a call graph with cycles.\n "
|
||||
" %1 = OpFunction %void Pure|Const %3\n"));
|
||||
}
|
||||
|
||||
TEST_F(ValidationStateTest, CheckWebGPUIndirectlyRecursiveBodyBad) {
|
||||
std::string spirv =
|
||||
std::string(kVulkanMemoryHeader) + kIndirectlyRecursiveBody;
|
||||
CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
|
||||
EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
|
||||
ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
|
||||
EXPECT_THAT(getDiagnosticString(),
|
||||
HasSubstr("Entry points may not have a call graph with cycles.\n "
|
||||
" %1 = OpFunction %void Pure|Const %3\n"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace val
|
||||
} // namespace spvtools
|
||||
|
Loading…
Reference in New Issue
Block a user