Split function opcode validation into new files.

* Moved function opcode validation out of idUsage and into new files
 * minor style changes
 * General opcode checking is in validate_function.cpp
 * Execution limitation checking is in
 validate_execution_limitations.cpp
* Execution limitations was split into a new pass as it requires other
validation to register those limitations first.
This commit is contained in:
Alan Baker 2018-08-10 09:53:17 -04:00
parent 967bfa2d17
commit e7fdcdba75
9 changed files with 264 additions and 151 deletions

View File

@ -51,6 +51,8 @@ SPVTOOLS_SRC_FILES := \
source/val/validate_decorations.cpp \
source/val/validate_derivatives.cpp \
source/val/validate_ext_inst.cpp \
source/val/validate_execution_limitations.cpp \
source/val/validate_function.cpp \
source/val/validate_id.cpp \
source/val/validate_image.cpp \
source/val/validate_interfaces.cpp \

View File

@ -373,7 +373,9 @@ static_library("spvtools_val") {
"source/val/validate_debug.cpp",
"source/val/validate_decorations.cpp",
"source/val/validate_derivatives.cpp",
"source/val/validate_execution_limitations.cpp",
"source/val/validate_ext_inst.cpp",
"source/val/validate_function.cpp",
"source/val/validate_id.cpp",
"source/val/validate_image.cpp",
"source/val/validate_instruction.cpp",

View File

@ -297,6 +297,8 @@ set(SPIRV_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_decorations.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_derivatives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_ext_inst.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_execution_limitations.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_function.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_id.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_image.cpp
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_interfaces.cpp

View File

@ -319,7 +319,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
// Constants
if (auto error = ValidateMemoryInstructions(*vstate, &instruction))
return error;
// Functions
if (auto error = FunctionPass(*vstate, &instruction)) return error;
if (auto error = ImagePass(*vstate, &instruction)) return error;
if (auto error = ConversionPass(*vstate, &instruction)) return error;
if (auto error = CompositesPass(*vstate, &instruction)) return error;
@ -352,6 +352,11 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
// TODO(dsinclair): Restructure ValidateBuiltins so we can move into the
// for() above as it loops over all ordered_instructions internally.
if (auto error = ValidateBuiltIns(*vstate)) return error;
// These checks must be performed after individual opcode checks because
// those checks register the limitation checked here.
for (const auto inst : vstate->ordered_instructions()) {
if (auto error = ValidateExecutionLimitations(*vstate, &inst)) return error;
}
// NOTE: Copy each instruction for easier processing
std::vector<spv_instruction_t> instructions;

View File

@ -194,6 +194,15 @@ spv_result_t PrimitivesPass(ValidationState_t& _, const Instruction* inst);
/// Validates correctness of mode setting instructions.
spv_result_t ModeSettingPass(ValidationState_t& _, const Instruction* inst);
/// Validates correctness of function instructions.
spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst);
/// Validates execution limitations.
///
/// Verifies execution models are allowed for all functionality they contain.
spv_result_t ValidateExecutionLimitations(ValidationState_t& _,
const Instruction* inst);
/// @brief Validate the ID usage of the instruction stream
///
/// @param[in] pInsts stream of instructions

View File

@ -0,0 +1,61 @@
// Copyright (c) 2018 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/val/validate.h"
#include "source/val/function.h"
#include "source/val/validation_state.h"
namespace spvtools {
namespace val {
spv_result_t ValidateExecutionLimitations(ValidationState_t& _,
const Instruction* inst) {
if (inst->opcode() != SpvOpFunction) {
return SPV_SUCCESS;
}
const auto func = _.function(inst->id());
if (!func) {
return _.diag(SPV_ERROR_INTERNAL, inst)
<< "Internal error: missing function id " << inst->id() << ".";
}
for (uint32_t entry_id : _.FunctionEntryPoints(inst->id())) {
const auto* models = _.GetExecutionModels(entry_id);
if (models) {
if (models->empty()) {
return _.diag(SPV_ERROR_INTERNAL, inst)
<< "Internal error: empty execution models for function id "
<< entry_id << ".";
}
for (const auto model : *models) {
std::string reason;
if (!func->IsCompatibleWithExecutionModel(model, &reason)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpEntryPoint Entry Point <id> '" << _.getIdName(entry_id)
<< "'s callgraph contains function <id> "
<< _.getIdName(inst->id())
<< ", which cannot be used with the current execution "
"model:\n"
<< reason;
}
}
}
}
return SPV_SUCCESS;
}
} // namespace val
} // namespace spvtools

View File

@ -0,0 +1,179 @@
// Copyright (c) 2018 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/val/validate.h"
#include "source/opcode.h"
#include "source/val/instruction.h"
#include "source/val/validation_state.h"
namespace spvtools {
namespace val {
namespace {
spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
const auto function_type = _.FindDef(function_type_id);
if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
<< "' is not a function type.";
}
const auto return_id = function_type->GetOperandAs<uint32_t>(1);
if (return_id != inst->type_id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
<< "' does not match the Function Type's return type <id> '"
<< _.getIdName(return_id) << "'.";
}
return SPV_SUCCESS;
}
spv_result_t ValidateFunctionParameter(ValidationState_t& _,
const Instruction* inst) {
// NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
size_t param_index = 0;
size_t inst_num = inst->LineNum() - 1;
if (inst_num == 0) {
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
<< "Function parameter cannot be the first instruction.";
}
auto func_inst = &_.ordered_instructions()[inst_num];
while (--inst_num) {
func_inst = &_.ordered_instructions()[inst_num];
if (func_inst->opcode() == SpvOpFunction) {
break;
} else if (func_inst->opcode() == SpvOpFunctionParameter) {
++param_index;
}
}
if (func_inst->opcode() != SpvOpFunction) {
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
<< "Function parameter must be preceded by a function.";
}
const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
const auto function_type = _.FindDef(function_type_id);
if (!function_type) {
return _.diag(SPV_ERROR_INVALID_ID, func_inst)
<< "Missing function type definition.";
}
if (param_index >= function_type->words().size() - 3) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Too many OpFunctionParameters for " << func_inst->id()
<< ": expected " << function_type->words().size() - 3
<< " based on the function's type";
}
const auto param_type =
_.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
if (!param_type || inst->type_id() != param_type->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpFunctionParameter Result Type <id> '"
<< _.getIdName(inst->type_id())
<< "' does not match the OpTypeFunction parameter "
"type of the same index.";
}
return SPV_SUCCESS;
}
spv_result_t ValidateFunctionCall(ValidationState_t& _,
const Instruction* inst) {
const auto function_id = inst->GetOperandAs<uint32_t>(2);
const auto function = _.FindDef(function_id);
if (!function || SpvOpFunction != function->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpFunctionCall Function <id> '" << _.getIdName(function_id)
<< "' is not a function.";
}
auto return_type = _.FindDef(function->type_id());
if (!return_type || return_type->id() != inst->type_id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpFunctionCall Result Type <id> '"
<< _.getIdName(inst->type_id())
<< "'s type does not match Function <id> '"
<< _.getIdName(return_type->id()) << "'s return type.";
}
const auto function_type_id = function->GetOperandAs<uint32_t>(3);
const auto function_type = _.FindDef(function_type_id);
if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Missing function type definition.";
}
const auto function_call_arg_count = inst->words().size() - 4;
const auto function_param_count = function_type->words().size() - 3;
if (function_param_count != function_call_arg_count) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpFunctionCall Function <id>'s parameter count does not match "
"the argument count.";
}
for (size_t argument_index = 3, param_index = 2;
argument_index < inst->operands().size();
argument_index++, param_index++) {
const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
const auto argument = _.FindDef(argument_id);
if (!argument) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Missing argument " << argument_index - 3 << " definition.";
}
const auto argument_type = _.FindDef(argument->type_id());
if (!argument_type) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Missing argument " << argument_index - 3
<< " type definition.";
}
const auto parameter_type_id =
function_type->GetOperandAs<uint32_t>(param_index);
const auto parameter_type = _.FindDef(parameter_type_id);
if (!parameter_type || argument_type->id() != parameter_type->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
<< "'s type does not match Function <id> '"
<< _.getIdName(parameter_type_id) << "'s parameter type.";
}
}
return SPV_SUCCESS;
}
} // namespace
spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
switch (inst->opcode()) {
case SpvOpFunction:
if (auto error = ValidateFunction(_, inst)) return error;
break;
case SpvOpFunctionParameter:
if (auto error = ValidateFunctionParameter(_, inst)) return error;
break;
case SpvOpFunctionCall:
if (auto error = ValidateFunctionCall(_, inst)) return error;
break;
default:
break;
}
return SPV_SUCCESS;
}
} // namespace val
} // namespace spvtools

View File

@ -642,150 +642,6 @@ bool idUsage::isValid<SpvOpSpecConstantComposite>(const spv_instruction_t* inst,
return true;
}
template <>
bool idUsage::isValid<SpvOpFunction>(const spv_instruction_t* inst,
const spv_opcode_desc) {
const auto* thisInst = module_.FindDef(inst->words[2u]);
if (!thisInst) return false;
for (uint32_t entryId : module_.FunctionEntryPoints(thisInst->id())) {
const Function* thisFunc = module_.function(thisInst->id());
assert(thisFunc);
const auto* models = module_.GetExecutionModels(entryId);
if (models) {
assert(models->size());
for (auto model : *models) {
std::string reason;
if (!thisFunc->IsCompatibleWithExecutionModel(model, &reason)) {
DIAG(module_.FindDef(inst->words[2]))
<< "OpEntryPoint Entry Point <id> '" << module_.getIdName(entryId)
<< "'s callgraph contains function <id> "
<< module_.getIdName(thisInst->id())
<< ", which cannot be used with the current execution model:\n"
<< reason;
return false;
}
}
}
}
auto resultTypeIndex = 1;
auto resultType = module_.FindDef(inst->words[resultTypeIndex]);
if (!resultType) return false;
auto functionTypeIndex = 4;
auto functionType = module_.FindDef(inst->words[functionTypeIndex]);
if (!functionType || SpvOpTypeFunction != functionType->opcode()) {
DIAG(functionType) << "OpFunction Function Type <id> '"
<< module_.getIdName(inst->words[functionTypeIndex])
<< "' is not a function type.";
return false;
}
auto returnType = module_.FindDef(functionType->words()[2]);
assert(returnType);
if (returnType->id() != resultType->id()) {
DIAG(resultType) << "OpFunction Result Type <id> '"
<< module_.getIdName(inst->words[resultTypeIndex])
<< "' does not match the Function Type <id> '"
<< module_.getIdName(resultType->id())
<< "'s return type.";
return false;
}
return true;
}
template <>
bool idUsage::isValid<SpvOpFunctionParameter>(const spv_instruction_t* inst,
const spv_opcode_desc) {
auto resultTypeIndex = 1;
auto resultType = module_.FindDef(inst->words[resultTypeIndex]);
if (!resultType) return false;
// NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
size_t paramIndex = 0;
assert(firstInst < inst && "Invalid instruction pointer");
while (firstInst != --inst) {
if (SpvOpFunction == inst->opcode) {
break;
} else if (SpvOpFunctionParameter == inst->opcode) {
paramIndex++;
}
}
auto functionType = module_.FindDef(inst->words[4]);
assert(functionType);
if (paramIndex >= functionType->words().size() - 3) {
DIAG(module_.FindDef(inst->words[0]))
<< "Too many OpFunctionParameters for " << inst->words[2]
<< ": expected " << functionType->words().size() - 3
<< " based on the function's type";
return false;
}
auto paramType = module_.FindDef(functionType->words()[paramIndex + 3]);
assert(paramType);
if (resultType->id() != paramType->id()) {
DIAG(resultType) << "OpFunctionParameter Result Type <id> '"
<< module_.getIdName(inst->words[resultTypeIndex])
<< "' does not match the OpTypeFunction parameter "
"type of the same index.";
return false;
}
return true;
}
template <>
bool idUsage::isValid<SpvOpFunctionCall>(const spv_instruction_t* inst,
const spv_opcode_desc) {
auto resultTypeIndex = 1;
auto resultType = module_.FindDef(inst->words[resultTypeIndex]);
if (!resultType) return false;
auto functionIndex = 3;
auto function = module_.FindDef(inst->words[functionIndex]);
if (!function || SpvOpFunction != function->opcode()) {
DIAG(function) << "OpFunctionCall Function <id> '"
<< module_.getIdName(inst->words[functionIndex])
<< "' is not a function.";
return false;
}
auto returnType = module_.FindDef(function->type_id());
assert(returnType);
if (returnType->id() != resultType->id()) {
DIAG(resultType) << "OpFunctionCall Result Type <id> '"
<< module_.getIdName(inst->words[resultTypeIndex])
<< "'s type does not match Function <id> '"
<< module_.getIdName(returnType->id())
<< "'s return type.";
return false;
}
auto functionType = module_.FindDef(function->words()[4]);
assert(functionType);
auto functionCallArgCount = inst->words.size() - 4;
auto functionParamCount = functionType->words().size() - 3;
if (functionParamCount != functionCallArgCount) {
DIAG(module_.FindDef(inst->words.back()))
<< "OpFunctionCall Function <id>'s parameter count does not match "
"the argument count.";
return false;
}
for (size_t argumentIndex = 4, paramIndex = 3;
argumentIndex < inst->words.size(); argumentIndex++, paramIndex++) {
auto argument = module_.FindDef(inst->words[argumentIndex]);
if (!argument) return false;
auto argumentType = module_.FindDef(argument->type_id());
assert(argumentType);
auto parameterType = module_.FindDef(functionType->words()[paramIndex]);
assert(parameterType);
if (argumentType->id() != parameterType->id()) {
DIAG(argument) << "OpFunctionCall Argument <id> '"
<< module_.getIdName(inst->words[argumentIndex])
<< "'s type does not match Function <id> '"
<< module_.getIdName(parameterType->id())
<< "'s parameter type.";
return false;
}
}
return true;
}
#undef DIAG
bool idUsage::isValid(const spv_instruction_t* inst) {
spv_opcode_desc opcodeEntry = nullptr;
if (spvOpcodeTableValueLookup(targetEnv, opcodeTable, inst->opcode,
@ -804,9 +660,6 @@ bool idUsage::isValid(const spv_instruction_t* inst) {
CASE(OpSpecConstantFalse)
CASE(OpSpecConstantComposite)
CASE(OpSampledImage)
CASE(OpFunction)
CASE(OpFunctionParameter)
CASE(OpFunctionCall)
// Other composite opcodes are validated in validate_composites.cpp.
// Arithmetic opcodes are validated in validate_arithmetics.cpp.
// Bitwise opcodes are validated in validate_bitwise.cpp.

View File

@ -2652,7 +2652,7 @@ TEST_F(ValidateIdWithMessage, OpCopyMemoryVoidTarget) {
%2 = OpTypeInt 32 0
%3 = OpTypePointer Uniform %1
%4 = OpTypePointer Uniform %2
%5 = OpTypeFunction %1 %2 %3
%5 = OpTypeFunction %1 %3 %4
%6 = OpFunction %1 None %5
%7 = OpFunctionParameter %3
%8 = OpFunctionParameter %4
@ -2674,7 +2674,7 @@ TEST_F(ValidateIdWithMessage, OpCopyMemoryVoidSource) {
%2 = OpTypeInt 32 0
%3 = OpTypePointer Uniform %1
%4 = OpTypePointer Uniform %2
%5 = OpTypeFunction %1 %2 %3
%5 = OpTypeFunction %1 %3 %4
%6 = OpFunction %1 None %5
%7 = OpFunctionParameter %3
%8 = OpFunctionParameter %4
@ -3544,7 +3544,7 @@ TEST_F(ValidateIdWithMessage, OpFunctionResultTypeBad) {
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("OpFunction Result Type <id> '2' does not match the "
"Function Type <id> '2's return type."));
"Function Type's return type <id> '1'."));
}
TEST_F(ValidateIdWithMessage, OpReturnValueTypeBad) {
std::string spirv = kGLSL450MemoryModel + R"(