Check for recursion in Vulkan and WebGPU entry points (#2161)

Fixes #2061
Fixes #2160
This commit is contained in:
Ryan Harrison 2018-12-05 13:58:43 -05:00 committed by GitHub
parent 2f5f5308b6
commit 378b7f3a29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 219 additions and 6 deletions

View File

@ -175,14 +175,19 @@ spv_result_t ValidateForwardDecls(ValidationState_t& _) {
// capability is being used.
// * No function can be targeted by both an OpEntryPoint instruction and an
// OpFunctionCall instruction.
//
// Additionally enforces that entry points for Vulkan and WebGPU should not have
// recursion.
spv_result_t ValidateEntryPoints(ValidationState_t& _) {
_.ComputeFunctionToEntryPointMapping();
_.ComputeRecursiveEntryPoints();
if (_.entry_points().empty() && !_.HasCapability(SpvCapabilityLinkage)) {
return _.diag(SPV_ERROR_INVALID_BINARY, nullptr)
<< "No OpEntryPoint instruction was found. This is only allowed if "
"the Linkage capability is being used.";
}
for (const auto& entry_point : _.entry_points()) {
if (_.IsFunctionCallTarget(entry_point)) {
return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
@ -190,6 +195,17 @@ spv_result_t ValidateEntryPoints(ValidationState_t& _) {
<< ") may not be targeted by both an OpEntryPoint instruction and "
"an OpFunctionCall instruction.";
}
// For Vulkan and WebGPU, the static function-call graph for an entry point
// must not contain cycles.
if (spvIsWebGPUEnv(_.context()->target_env) ||
spvIsVulkanEnv(_.context()->target_env)) {
if (_.recursive_entry_points().find(entry_point) !=
_.recursive_entry_points().end()) {
return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
<< "Entry points may not have a call graph with cycles.";
}
}
}
return SPV_SUCCESS;

View File

@ -919,6 +919,39 @@ void ValidationState_t::ComputeFunctionToEntryPointMapping() {
}
}
void ValidationState_t::ComputeRecursiveEntryPoints() {
for (const Function func : functions()) {
std::stack<uint32_t> call_stack;
std::set<uint32_t> visited;
for (const uint32_t new_call : func.function_call_targets()) {
call_stack.push(new_call);
}
while (!call_stack.empty()) {
const uint32_t called_func_id = call_stack.top();
call_stack.pop();
if (!visited.insert(called_func_id).second) continue;
if (called_func_id == func.id()) {
for (const uint32_t entry_point :
function_to_entry_points_[called_func_id])
recursive_entry_points_.insert(entry_point);
break;
}
const Function* called_func = function(called_func_id);
if (called_func) {
// Other checks should error out on this invalid SPIR-V.
for (const uint32_t new_call : called_func->function_call_targets()) {
call_stack.push(new_call);
}
}
}
}
}
const std::vector<uint32_t>& ValidationState_t::FunctionEntryPoints(
uint32_t func) const {
auto iter = function_to_entry_points_.find(func);

View File

@ -222,6 +222,12 @@ class ValidationState_t {
/// Returns a list of entry point function ids
const std::vector<uint32_t>& entry_points() const { return entry_points_; }
/// Returns the set of entry points that root call graphs that contain
/// recursion.
const std::set<uint32_t>& recursive_entry_points() const {
return recursive_entry_points_;
}
/// Registers execution mode for the given entry point.
void RegisterExecutionModeForEntryPoint(uint32_t entry_point,
SpvExecutionMode execution_mode) {
@ -261,6 +267,11 @@ class ValidationState_t {
/// Note: called after fully parsing the binary.
void ComputeFunctionToEntryPointMapping();
/// Traverse call tree and computes recursive_entry_points_.
/// Note: called after fully parsing the binary and calling
/// ComputeFunctionToEntryPointMapping.
void ComputeRecursiveEntryPoints();
/// Returns all the entry points that can call |func|.
const std::vector<uint32_t>& FunctionEntryPoints(uint32_t func) const;
@ -610,6 +621,10 @@ class ValidationState_t {
std::unordered_map<uint32_t, std::vector<EntryPointDescription>>
entry_point_descriptions_;
/// IDs that are entry points, ie, arguments to OpEntryPoint, and root a call
/// graph that recurses.
std::set<uint32_t> recursive_entry_points_;
/// Functions IDs that are target of OpFunctionCall.
std::unordered_set<uint32_t> function_call_targets_;

View File

@ -29,11 +29,17 @@ using ::testing::HasSubstr;
using ValidationStateTest = spvtest::ValidateBase<bool>;
const char header[] =
const char kHeader[] =
" OpCapability Shader"
" OpCapability Linkage"
" OpMemoryModel Logical GLSL450 ";
const char kVulkanMemoryHeader[] =
" OpCapability Shader"
" OpCapability VulkanMemoryModelKHR"
" OpExtension \"SPV_KHR_vulkan_memory_model\""
" OpMemoryModel Logical VulkanKHR ";
const char kVoidFVoid[] =
" %void = OpTypeVoid"
" %void_f = OpTypeFunction %void"
@ -42,9 +48,79 @@ const char kVoidFVoid[] =
" OpReturn"
" OpFunctionEnd ";
// k*RecursiveBody examples originally from test/opt/function_test.cpp
const char* kNonRecursiveBody = R"(
OpEntryPoint Fragment %1 "main"
OpExecutionMode %1 OriginUpperLeft
%void = OpTypeVoid
%4 = OpTypeFunction %void
%float = OpTypeFloat 32
%_struct_6 = OpTypeStruct %float %float
%7 = OpTypeFunction %_struct_6
%1 = OpFunction %void Pure|Const %4
%8 = OpLabel
%2 = OpFunctionCall %_struct_6 %9
OpKill
OpFunctionEnd
%9 = OpFunction %_struct_6 None %7
%10 = OpLabel
%11 = OpFunctionCall %_struct_6 %12
OpUnreachable
OpFunctionEnd
%12 = OpFunction %_struct_6 None %7
%13 = OpLabel
OpUnreachable
OpFunctionEnd
)";
const char* kDirectlyRecursiveBody = R"(
OpEntryPoint Fragment %1 "main"
OpExecutionMode %1 OriginUpperLeft
%void = OpTypeVoid
%4 = OpTypeFunction %void
%float = OpTypeFloat 32
%_struct_6 = OpTypeStruct %float %float
%7 = OpTypeFunction %_struct_6
%1 = OpFunction %void Pure|Const %4
%8 = OpLabel
%2 = OpFunctionCall %_struct_6 %9
OpKill
OpFunctionEnd
%9 = OpFunction %_struct_6 None %7
%10 = OpLabel
%11 = OpFunctionCall %_struct_6 %9
OpUnreachable
OpFunctionEnd
)";
const char* kIndirectlyRecursiveBody = R"(
OpEntryPoint Fragment %1 "main"
OpExecutionMode %1 OriginUpperLeft
%void = OpTypeVoid
%4 = OpTypeFunction %void
%float = OpTypeFloat 32
%_struct_6 = OpTypeStruct %float %float
%7 = OpTypeFunction %_struct_6
%1 = OpFunction %void Pure|Const %4
%8 = OpLabel
%2 = OpFunctionCall %_struct_6 %9
OpKill
OpFunctionEnd
%9 = OpFunction %_struct_6 None %7
%10 = OpLabel
%11 = OpFunctionCall %_struct_6 %12
OpUnreachable
OpFunctionEnd
%12 = OpFunction %_struct_6 None %7
%13 = OpLabel
%14 = OpFunctionCall %_struct_6 %9
OpUnreachable
OpFunctionEnd
)";
// Tests that the instruction count in ValidationState is correct.
TEST_F(ValidationStateTest, CheckNumInstructions) {
std::string spirv = std::string(header) + "%int = OpTypeInt 32 0";
std::string spirv = std::string(kHeader) + "%int = OpTypeInt 32 0";
CompileSuccessfully(spirv);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
EXPECT_EQ(size_t(4), vstate_->ordered_instructions().size());
@ -52,7 +128,7 @@ TEST_F(ValidationStateTest, CheckNumInstructions) {
// Tests that the number of global variables in ValidationState is correct.
TEST_F(ValidationStateTest, CheckNumGlobalVars) {
std::string spirv = std::string(header) + R"(
std::string spirv = std::string(kHeader) + R"(
%int = OpTypeInt 32 0
%_ptr_int = OpTypePointer Input %int
%var_1 = OpVariable %_ptr_int Input
@ -65,7 +141,7 @@ TEST_F(ValidationStateTest, CheckNumGlobalVars) {
// Tests that the number of local variables in ValidationState is correct.
TEST_F(ValidationStateTest, CheckNumLocalVars) {
std::string spirv = std::string(header) + R"(
std::string spirv = std::string(kHeader) + R"(
%int = OpTypeInt 32 0
%_ptr_int = OpTypePointer Function %int
%voidt = OpTypeVoid
@ -85,7 +161,7 @@ TEST_F(ValidationStateTest, CheckNumLocalVars) {
// Tests that the "id bound" in ValidationState is correct.
TEST_F(ValidationStateTest, CheckIdBound) {
std::string spirv = std::string(header) + R"(
std::string spirv = std::string(kHeader) + R"(
%int = OpTypeInt 32 0
%voidt = OpTypeVoid
)";
@ -96,7 +172,7 @@ TEST_F(ValidationStateTest, CheckIdBound) {
// Tests that the entry_points in ValidationState is correct.
TEST_F(ValidationStateTest, CheckEntryPoints) {
std::string spirv = std::string(header) +
std::string spirv = std::string(kHeader) +
" OpEntryPoint Vertex %func \"shader\"" +
std::string(kVoidFVoid);
CompileSuccessfully(spirv);
@ -154,6 +230,79 @@ TEST_F(ValidationStateTest, CheckAccessChainIndexesLimitOption) {
EXPECT_EQ(100u, options_->universal_limits_.max_access_chain_indexes);
}
TEST_F(ValidationStateTest, CheckNonRecursiveBodyGood) {
std::string spirv = std::string(kHeader) + kNonRecursiveBody;
CompileSuccessfully(spirv);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
}
TEST_F(ValidationStateTest, CheckVulkanNonRecursiveBodyGood) {
std::string spirv = std::string(kVulkanMemoryHeader) + kNonRecursiveBody;
CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_SUCCESS,
ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
}
TEST_F(ValidationStateTest, CheckWebGPUNonRecursiveBodyGood) {
std::string spirv = std::string(kVulkanMemoryHeader) + kNonRecursiveBody;
CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
}
TEST_F(ValidationStateTest, CheckDirectlyRecursiveBodyGood) {
std::string spirv = std::string(kHeader) + kDirectlyRecursiveBody;
CompileSuccessfully(spirv);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
}
TEST_F(ValidationStateTest, CheckVulkanDirectlyRecursiveBodyBad) {
std::string spirv = std::string(kVulkanMemoryHeader) + kDirectlyRecursiveBody;
CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Entry points may not have a call graph with cycles.\n "
" %1 = OpFunction %void Pure|Const %3\n"));
}
TEST_F(ValidationStateTest, CheckWebGPUDirectlyRecursiveBodyBad) {
std::string spirv = std::string(kVulkanMemoryHeader) + kDirectlyRecursiveBody;
CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Entry points may not have a call graph with cycles.\n "
" %1 = OpFunction %void Pure|Const %3\n"));
}
TEST_F(ValidationStateTest, CheckIndirectlyRecursiveBodyGood) {
std::string spirv = std::string(kHeader) + kIndirectlyRecursiveBody;
CompileSuccessfully(spirv);
EXPECT_EQ(SPV_SUCCESS, ValidateAndRetrieveValidationState());
}
TEST_F(ValidationStateTest, CheckVulkanIndirectlyRecursiveBodyBad) {
std::string spirv =
std::string(kVulkanMemoryHeader) + kIndirectlyRecursiveBody;
CompileSuccessfully(spirv, SPV_ENV_VULKAN_1_1);
EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
ValidateAndRetrieveValidationState(SPV_ENV_VULKAN_1_1));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Entry points may not have a call graph with cycles.\n "
" %1 = OpFunction %void Pure|Const %3\n"));
}
TEST_F(ValidationStateTest, CheckWebGPUIndirectlyRecursiveBodyBad) {
std::string spirv =
std::string(kVulkanMemoryHeader) + kIndirectlyRecursiveBody;
CompileSuccessfully(spirv, SPV_ENV_WEBGPU_0);
EXPECT_EQ(SPV_ERROR_INVALID_BINARY,
ValidateAndRetrieveValidationState(SPV_ENV_WEBGPU_0));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("Entry points may not have a call graph with cycles.\n "
" %1 = OpFunction %void Pure|Const %3\n"));
}
} // namespace
} // namespace val
} // namespace spvtools