diff --git a/include/spirv-tools/libspirv.h b/include/spirv-tools/libspirv.h index 42ccb2a9..c9812ecc 100644 --- a/include/spirv-tools/libspirv.h +++ b/include/spirv-tools/libspirv.h @@ -340,14 +340,14 @@ typedef enum { SPV_ENV_UNIVERSAL_1_0, // SPIR-V 1.0 latest revision, no other restrictions. SPV_ENV_VULKAN_1_0, // Vulkan 1.0 latest revision. SPV_ENV_UNIVERSAL_1_1, // SPIR-V 1.1 latest revision, no other restrictions. - SPV_ENV_OPENCL_2_1, // OpenCL 2.1 latest revision. - SPV_ENV_OPENCL_2_2, // OpenCL 2.2 latest revision. - SPV_ENV_OPENGL_4_0, // OpenGL 4.0 plus GL_ARB_gl_spirv, latest revisions. - SPV_ENV_OPENGL_4_1, // OpenGL 4.1 plus GL_ARB_gl_spirv, latest revisions. - SPV_ENV_OPENGL_4_2, // OpenGL 4.2 plus GL_ARB_gl_spirv, latest revisions. - SPV_ENV_OPENGL_4_3, // OpenGL 4.3 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENCL_2_1, // OpenCL 2.1 latest revision. + SPV_ENV_OPENCL_2_2, // OpenCL 2.2 latest revision. + SPV_ENV_OPENGL_4_0, // OpenGL 4.0 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENGL_4_1, // OpenGL 4.1 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENGL_4_2, // OpenGL 4.2 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENGL_4_3, // OpenGL 4.3 plus GL_ARB_gl_spirv, latest revisions. // There is no variant for OpenGL 4.4. - SPV_ENV_OPENGL_4_5, // OpenGL 4.5 plus GL_ARB_gl_spirv, latest revisions. + SPV_ENV_OPENGL_4_5, // OpenGL 4.5 plus GL_ARB_gl_spirv, latest revisions. } spv_target_env; // Returns a string describing the given SPIR-V target environment. @@ -361,8 +361,9 @@ void spvContextDestroy(spv_context context); // Encodes the given SPIR-V assembly text to its binary representation. The // length parameter specifies the number of bytes for text. Encoded binary will -// be stored into *binary. Any error will be written into *diagnostic. The -// generated binary is independent of the context and may outlive it. +// be stored into *binary. Any error will be written into *diagnostic if +// diagnostic is non-null. The generated binary is independent of the context +// and may outlive it. spv_result_t spvTextToBinary(const spv_const_context context, const char* text, const size_t length, spv_binary* binary, spv_diagnostic* diagnostic); @@ -374,7 +375,8 @@ void spvTextDestroy(spv_text text); // Decodes the given SPIR-V binary representation to its assembly text. The // word_count parameter specifies the number of words for binary. The options // parameter is a bit field of spv_binary_to_text_options_t. Decoded text will -// be stored into *text. Any error will be written into *diagnostic. +// be stored into *text. Any error will be written into *diagnostic if +// diagnostic is non-null. spv_result_t spvBinaryToText(const spv_const_context context, const uint32_t* binary, const size_t word_count, const uint32_t options, spv_text* text, @@ -385,7 +387,7 @@ spv_result_t spvBinaryToText(const spv_const_context context, void spvBinaryDestroy(spv_binary binary); // Validates a SPIR-V binary for correctness. Any errors will be written into -// *diagnostic. +// *diagnostic if diagnostic is non-null. spv_result_t spvValidate(const spv_const_context context, const spv_const_binary binary, spv_diagnostic* diagnostic); diff --git a/source/binary.cpp b/source/binary.cpp index a709c023..c8b7c38c 100644 --- a/source/binary.cpp +++ b/source/binary.cpp @@ -61,6 +61,7 @@ class Parser { spv_parsed_header_fn_t parsed_header_fn, spv_parsed_instruction_fn_t parsed_instruction_fn) : grammar_(context), + consumer_(context->consumer), user_data_(user_data), parsed_header_fn_(parsed_header_fn), parsed_instruction_fn_(parsed_instruction_fn) {} @@ -120,8 +121,7 @@ class Parser { // returned object will be propagated to the current parse's diagnostic // object. libspirv::DiagnosticStream diagnostic(spv_result_t error) { - return libspirv::DiagnosticStream({0, 0, _.word_index}, _.diagnostic, - error); + return libspirv::DiagnosticStream({0, 0, _.word_index}, consumer_, error); } // Returns a diagnostic stream object with the default parse error code. @@ -156,6 +156,7 @@ class Parser { // Data members const libspirv::AssemblyGrammar grammar_; // SPIR-V syntax utility. + const spvtools::MessageConsumer& consumer_; // Message consumer callback. void* const user_data_; // Context for the callbacks const spv_parsed_header_fn_t parsed_header_fn_; // Parsed header callback const spv_parsed_instruction_fn_t @@ -752,7 +753,12 @@ spv_result_t spvBinaryParse(const spv_const_context context, void* user_data, spv_parsed_header_fn_t parsed_header, spv_parsed_instruction_fn_t parsed_instruction, spv_diagnostic* diagnostic) { - Parser parser(context, user_data, parsed_header, parsed_instruction); + spv_context_t hijack_context = *context; + if (diagnostic) { + *diagnostic = nullptr; + libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); + } + Parser parser(&hijack_context, user_data, parsed_header, parsed_instruction); return parser.parse(code, num_words, diagnostic); } diff --git a/source/diagnostic.cpp b/source/diagnostic.cpp index b1a9cacf..8f0e3487 100644 --- a/source/diagnostic.cpp +++ b/source/diagnostic.cpp @@ -20,6 +20,7 @@ #include #include "spirv-tools/libspirv.h" +#include "table.h" // Diagnostic API @@ -68,12 +69,47 @@ spv_result_t spvDiagnosticPrint(const spv_diagnostic diagnostic) { namespace libspirv { DiagnosticStream::~DiagnosticStream() { - if (pDiagnostic_ && error_ != SPV_FAILED_MATCH) { - *pDiagnostic_ = spvDiagnosticCreate(&position_, stream_.str().c_str()); + using spvtools::MessageLevel; + if (error_ != SPV_FAILED_MATCH && consumer_ != nullptr) { + auto level = MessageLevel::Error; + switch (error_) { + case SPV_SUCCESS: + case SPV_REQUESTED_TERMINATION: // Essentially success. + level = MessageLevel::Info; + break; + case SPV_WARNING: + level = MessageLevel::Warning; + break; + case SPV_UNSUPPORTED: + case SPV_ERROR_INTERNAL: + case SPV_ERROR_INVALID_TABLE: + level = MessageLevel::InternalError; + break; + case SPV_ERROR_OUT_OF_MEMORY: + level = MessageLevel::Fatal; + break; + default: + break; + } + consumer_(level, "", position_, stream_.str().c_str()); } } -std::string -spvResultToString(spv_result_t res) { + +void UseDiagnosticAsMessageConsumer(spv_context context, + spv_diagnostic* diagnostic) { + assert(diagnostic && *diagnostic == nullptr); + + auto create_diagnostic = [diagnostic](spvtools::MessageLevel, const char*, + const spv_position_t& position, + const char* message) { + auto p = position; + spvDiagnosticDestroy(*diagnostic); // Avoid memory leak. + *diagnostic = spvDiagnosticCreate(&p, message); + }; + SetContextMessageConsumer(context, std::move(create_diagnostic)); +} + +std::string spvResultToString(spv_result_t res) { std::string out; switch (res) { case SPV_SUCCESS: diff --git a/source/diagnostic.h b/source/diagnostic.h index 840ec9ff..9bf9ae20 100644 --- a/source/diagnostic.h +++ b/source/diagnostic.h @@ -19,6 +19,7 @@ #include #include +#include "message.h" #include "spirv-tools/libspirv.h" namespace libspirv { @@ -29,20 +30,16 @@ namespace libspirv { // emitted during the destructor. class DiagnosticStream { public: - DiagnosticStream(spv_position_t position, spv_diagnostic* pDiagnostic, + DiagnosticStream(spv_position_t position, + const spvtools::MessageConsumer& consumer, spv_result_t error) - : position_(position), pDiagnostic_(pDiagnostic), error_(error) {} + : position_(position), consumer_(consumer), error_(error) {} DiagnosticStream(DiagnosticStream&& other) : stream_(other.stream_.str()), position_(other.position_), - pDiagnostic_(other.pDiagnostic_), - error_(other.error_) { - // The other object's destructor will emit the text in its stream_ - // member if its pDiagnostic_ member is non-null. Prevent that, - // since emitting that text is now the responsibility of *this. - other.pDiagnostic_ = nullptr; - } + consumer_(other.consumer_), + error_(other.error_) {} ~DiagnosticStream(); @@ -59,10 +56,18 @@ class DiagnosticStream { private: std::stringstream stream_; spv_position_t position_; - spv_diagnostic* pDiagnostic_; + const spvtools::MessageConsumer& consumer_; // Message consumer callback. spv_result_t error_; }; +// Changes the MessageConsumer in |context| to one that updates |diagnostic| +// with the last message received. +// +// This function expects that |diagnostic| is not nullptr and its content is a +// nullptr. +void UseDiagnosticAsMessageConsumer(spv_context context, + spv_diagnostic* diagnostic); + std::string spvResultToString(spv_result_t res); } // namespace libspirv diff --git a/source/disassemble.cpp b/source/disassemble.cpp index e71581da..267ed172 100644 --- a/source/disassemble.cpp +++ b/source/disassemble.cpp @@ -393,10 +393,13 @@ spv_result_t spvBinaryToText(const spv_const_context context, const uint32_t* code, const size_t wordCount, const uint32_t options, spv_text* pText, spv_diagnostic* pDiagnostic) { - // Invalid arguments return error codes, but don't necessarily generate - // diagnostics. These are programmer errors, not user errors. - if (!pDiagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC; - const libspirv::AssemblyGrammar grammar(context); + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + + const libspirv::AssemblyGrammar grammar(&hijack_context); if (!grammar.isValid()) return SPV_ERROR_INVALID_TABLE; // Generate friendly names for Ids if requested. @@ -404,15 +407,15 @@ spv_result_t spvBinaryToText(const spv_const_context context, libspirv::NameMapper name_mapper = libspirv::GetTrivialNameMapper(); if (options & SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES) { friendly_mapper.reset( - new libspirv::FriendlyNameMapper(context, code, wordCount)); + new libspirv::FriendlyNameMapper(&hijack_context, code, wordCount)); name_mapper = friendly_mapper->GetNameMapper(); } // Now disassemble! Disassembler disassembler(grammar, options, name_mapper); - if (auto error = spvBinaryParse(context, &disassembler, code, wordCount, - DisassembleHeader, DisassembleInstruction, - pDiagnostic)) { + if (auto error = spvBinaryParse(&hijack_context, &disassembler, code, + wordCount, DisassembleHeader, + DisassembleInstruction, pDiagnostic)) { return error; } diff --git a/source/message.h b/source/message.h new file mode 100644 index 00000000..68e4cd73 --- /dev/null +++ b/source/message.h @@ -0,0 +1,47 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SPIRV_TOOLS_MESSAGE_H_ +#define SPIRV_TOOLS_MESSAGE_H_ + +#include + +#include "spirv-tools/libspirv.h" + +namespace spvtools { + +// TODO(antiagainst): This eventually should be in the C++ interface. + +// Severity levels of messages communicated to the consumer. +enum class MessageLevel { + Fatal, // Unrecoverable error due to environment. Will abort the program + // immediately. E.g., out of memory. + InternalError, // Unrecoverable error due to SPIRV-Tools internals. Will + // abort the program immediately. E.g., unimplemented feature. + Error, // Normal error due to user input. + Warning, // Warning information. + Info, // General information. + Debug, // Debug information. +}; + +// Message consumer. The C strings for source and message are only alive for the +// specific invocation. +using MessageConsumer = std::function; + +} // namespace spvtools + +#endif // SPIRV_TOOLS_MESSAGE_H_ diff --git a/source/opt/libspirv.cpp b/source/opt/libspirv.cpp index 0636c645..eabc260f 100644 --- a/source/opt/libspirv.cpp +++ b/source/opt/libspirv.cpp @@ -16,6 +16,7 @@ #include "ir_loader.h" #include "make_unique.h" +#include "table.h" namespace spvtools { @@ -39,6 +40,10 @@ spv_result_t SetSpvInst(void* builder, const spv_parsed_instruction_t* inst) { } // annoymous namespace +void SpvTools::SetMessageConsumer(MessageConsumer consumer) { + SetContextMessageConsumer(context_, std::move(consumer)); +} + spv_result_t SpvTools::Assemble(const std::string& text, std::vector* binary) { spv_binary spvbinary = nullptr; diff --git a/source/opt/libspirv.hpp b/source/opt/libspirv.hpp index 46d73187..e645fcc6 100644 --- a/source/opt/libspirv.hpp +++ b/source/opt/libspirv.hpp @@ -19,6 +19,7 @@ #include #include +#include "message.h" #include "module.h" #include "spirv-tools/libspirv.h" @@ -36,7 +37,8 @@ class SpvTools { ~SpvTools() { spvContextDestroy(context_); } - // TODO(antiagainst): handle error message in the following APIs. + // Sets the message consumer to the given |consumer|. + void SetMessageConsumer(MessageConsumer consumer); // Assembles the given assembly |text| and writes the result to |binary|. // Returns SPV_SUCCESS on successful assembling. diff --git a/source/table.cpp b/source/table.cpp index 6bdbd9be..24ab5205 100644 --- a/source/table.cpp +++ b/source/table.cpp @@ -41,7 +41,13 @@ spv_context spvContextCreate(spv_target_env env) { spvOperandTableGet(&operand_table, env); spvExtInstTableGet(&ext_inst_table, env); - return new spv_context_t{env, opcode_table, operand_table, ext_inst_table}; + return new spv_context_t{env, opcode_table, operand_table, ext_inst_table, + nullptr /* a null default consumer */}; } void spvContextDestroy(spv_context context) { delete context; } + +void SetContextMessageConsumer(spv_context context, + spvtools::MessageConsumer consumer) { + context->consumer = std::move(consumer); +} diff --git a/source/table.h b/source/table.h index 8eedd0ae..abce443b 100644 --- a/source/table.h +++ b/source/table.h @@ -18,6 +18,7 @@ #include "spirv/1.1/spirv.h" #include "enum_set.h" +#include "message.h" #include "spirv-tools/libspirv.h" typedef struct spv_opcode_desc_t { @@ -87,8 +88,14 @@ struct spv_context_t { const spv_opcode_table opcode_table; const spv_operand_table operand_table; const spv_ext_inst_table ext_inst_table; + spvtools::MessageConsumer consumer; }; +// Sets the message consumer to |consumer| in the given |context|. The original +// message consumer will be overwritten. +void SetContextMessageConsumer(spv_context context, + spvtools::MessageConsumer consumer); + // Populates *table with entries for env. spv_result_t spvOpcodeTableGet(spv_opcode_table* table, spv_target_env env); diff --git a/source/text.cpp b/source/text.cpp index d317264b..6e68ac20 100644 --- a/source/text.cpp +++ b/source/text.cpp @@ -31,6 +31,7 @@ #include "diagnostic.h" #include "ext_inst.h" #include "instruction.h" +#include "message.h" #include "opcode.h" #include "operand.h" #include "spirv-tools/libspirv.h" @@ -662,10 +663,9 @@ spv_result_t SetHeader(spv_target_env env, const uint32_t bound, // If a diagnostic is generated, it is not yet marked as being // for a text-based input. spv_result_t spvTextToBinaryInternal(const libspirv::AssemblyGrammar& grammar, - const spv_text text, spv_binary* pBinary, - spv_diagnostic* pDiagnostic) { - if (!pDiagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC; - libspirv::AssemblyContext context(text, pDiagnostic); + const spvtools::MessageConsumer& consumer, + const spv_text text, spv_binary* pBinary) { + libspirv::AssemblyContext context(text, consumer); if (!text->str) return context.diagnostic() << "Missing assembly text."; if (!grammar.isValid()) { @@ -673,9 +673,6 @@ spv_result_t spvTextToBinaryInternal(const libspirv::AssemblyGrammar& grammar, } if (!pBinary) return SPV_ERROR_INVALID_POINTER; - // NOTE: Ensure diagnostic is zero initialised - *pDiagnostic = {}; - std::vector instructions; // Skip past whitespace and comments. @@ -728,11 +725,17 @@ spv_result_t spvTextToBinary(const spv_const_context context, const char* input_text, const size_t input_text_size, spv_binary* pBinary, spv_diagnostic* pDiagnostic) { + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } + spv_text_t text = {input_text, input_text_size}; - libspirv::AssemblyGrammar grammar(context); + libspirv::AssemblyGrammar grammar(&hijack_context); spv_result_t result = - spvTextToBinaryInternal(grammar, &text, pBinary, pDiagnostic); + spvTextToBinaryInternal(grammar, hijack_context.consumer, &text, pBinary); if (pDiagnostic && *pDiagnostic) (*pDiagnostic)->isTextSource = true; return result; diff --git a/source/text_handler.h b/source/text_handler.h index 9951643b..1bd004c1 100644 --- a/source/text_handler.h +++ b/source/text_handler.h @@ -22,6 +22,7 @@ #include "diagnostic.h" #include "instruction.h" +#include "message.h" #include "spirv-tools/libspirv.h" #include "text.h" @@ -116,11 +117,8 @@ class ClampToZeroIfUnsignedType< // Encapsulates the data used during the assembly of a SPIR-V module. class AssemblyContext { public: - AssemblyContext(spv_text text, spv_diagnostic* diagnostic_arg) - : current_position_({}), - pDiagnostic_(diagnostic_arg), - text_(text), - bound_(1) {} + AssemblyContext(spv_text text, const spvtools::MessageConsumer& consumer) + : current_position_({}), consumer_(consumer), text_(text), bound_(1) {} // Assigns a new integer value to the given text ID, or returns the previously // assigned integer value if the ID has been seen before. @@ -148,7 +146,7 @@ class AssemblyContext { // stream, and for the given error code. Any data written to this object will // show up in pDiagnsotic on destruction. DiagnosticStream diagnostic(spv_result_t error) { - return DiagnosticStream(current_position_, pDiagnostic_, error); + return DiagnosticStream(current_position_, consumer_, error); } // Returns a diagnostic object with the default assembly error code. @@ -227,7 +225,6 @@ class AssemblyContext { spv_ext_inst_type_t getExtInstTypeForId(uint32_t id) const; private: - // Maps ID names to their corresponding numerical ids. using spv_named_id_table = std::unordered_map; // Maps type-defining IDs to their IdType. @@ -241,7 +238,7 @@ class AssemblyContext { // Maps an extended instruction import Id to the extended instruction type. std::unordered_map import_id_to_ext_inst_type_; spv_position_t current_position_; - spv_diagnostic* pDiagnostic_; + spvtools::MessageConsumer consumer_; spv_text text_; uint32_t bound_; }; diff --git a/source/val/ValidationState.cpp b/source/val/ValidationState.cpp index c8735dcd..13684160 100644 --- a/source/val/ValidationState.cpp +++ b/source/val/ValidationState.cpp @@ -182,9 +182,8 @@ bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) { } // anonymous namespace -ValidationState_t::ValidationState_t(spv_diagnostic* diagnostic, - const spv_const_context context) - : diagnostic_(diagnostic), +ValidationState_t::ValidationState_t(const spv_const_context ctx) + : context_(ctx), instruction_counter_(0), unresolved_forward_ids_{}, operand_names_{}, @@ -193,7 +192,7 @@ ValidationState_t::ValidationState_t(spv_diagnostic* diagnostic, module_capabilities_(), ordered_instructions_(), all_definitions_(), - grammar_(context), + grammar_(ctx), addressing_model_(SpvAddressingModelLogical), memory_model_(SpvMemoryModelSimple), in_function_(false) {} @@ -290,7 +289,7 @@ bool ValidationState_t::IsOpcodeInCurrentLayoutSection(SpvOp op) { DiagnosticStream ValidationState_t::diag(spv_result_t error_code) const { return libspirv::DiagnosticStream( - {0, 0, static_cast(instruction_counter_)}, diagnostic_, + {0, 0, static_cast(instruction_counter_)}, context_->consumer, error_code); } @@ -377,8 +376,8 @@ spv_result_t ValidationState_t::RegisterFunctionEnd() { void ValidationState_t::RegisterInstruction( const spv_parsed_instruction_t& inst) { if (in_function_body()) { - ordered_instructions_.emplace_back( - &inst, ¤t_function(), current_function().current_block()); + ordered_instructions_.emplace_back(&inst, ¤t_function(), + current_function().current_block()); } else { ordered_instructions_.emplace_back(&inst, nullptr, nullptr); } diff --git a/source/val/ValidationState.h b/source/val/ValidationState.h index 1f5c0010..ff23d05e 100644 --- a/source/val/ValidationState.h +++ b/source/val/ValidationState.h @@ -53,8 +53,10 @@ enum ModuleLayoutSection { /// This class manages the state of the SPIR-V validation as it is being parsed. class ValidationState_t { public: - ValidationState_t(spv_diagnostic* diagnostic, - const spv_const_context context); + ValidationState_t(const spv_const_context context); + + /// Returns the context + spv_const_context context() const { return context_; } /// Forward declares the id in the module spv_result_t ForwardDeclareId(uint32_t id); @@ -174,7 +176,8 @@ class ValidationState_t { private: ValidationState_t(const ValidationState_t&); - spv_diagnostic* diagnostic_; + const spv_const_context context_; + /// Tracks the number of instructions evaluated by the validator int instruction_counter_; @@ -191,7 +194,8 @@ class ValidationState_t { std::deque module_functions_; /// The capabilities available in the module - libspirv::CapabilitySet module_capabilities_; /// Module's declared capabilities. + libspirv::CapabilitySet + module_capabilities_; /// Module's declared capabilities. /// List of all instructions in the order they appear in the binary std::deque ordered_instructions_; diff --git a/source/validate.cpp b/source/validate.cpp index 1b50d9ac..b0699cca 100644 --- a/source/validate.cpp +++ b/source/validate.cpp @@ -50,15 +50,17 @@ using libspirv::ModuleLayoutPass; using libspirv::IdPass; using libspirv::ValidationState_t; -spv_result_t spvValidateIDs( - const spv_instruction_t* pInsts, const uint64_t count, - const spv_opcode_table opcodeTable, const spv_operand_table operandTable, - const spv_ext_inst_table extInstTable, const ValidationState_t& state, - spv_position position, spv_diagnostic* pDiagnostic) { +spv_result_t spvValidateIDs(const spv_instruction_t* pInsts, + const uint64_t count, + const spv_opcode_table opcodeTable, + const spv_operand_table operandTable, + const spv_ext_inst_table extInstTable, + const ValidationState_t& state, + spv_position position) { position->index = SPV_INDEX_INSTRUCTION; if (auto error = spvValidateInstructionIDs(pInsts, count, opcodeTable, operandTable, - extInstTable, state, position, pDiagnostic)) + extInstTable, state, position)) return error; return SPV_SUCCESS; } @@ -175,29 +177,33 @@ UNUSED(void PrintDotGraph(ValidationState_t& _, libspirv::Function func)) { spv_result_t spvValidate(const spv_const_context context, const spv_const_binary binary, spv_diagnostic* pDiagnostic) { - if (!pDiagnostic) return SPV_ERROR_INVALID_DIAGNOSTIC; + spv_context_t hijack_context = *context; + if (pDiagnostic) { + *pDiagnostic = nullptr; + libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); + } spv_endianness_t endian; spv_position_t position = {}; if (spvBinaryEndianness(binary, &endian)) { - return libspirv::DiagnosticStream(position, pDiagnostic, + return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V magic number."; } spv_header_t header; if (spvBinaryHeaderGet(binary, endian, &header)) { - return libspirv::DiagnosticStream(position, pDiagnostic, + return libspirv::DiagnosticStream(position, hijack_context.consumer, SPV_ERROR_INVALID_BINARY) << "Invalid SPIR-V header."; } // NOTE: Parse the module and perform inline validation checks. These // checks do not require the the knowledge of the whole module. - ValidationState_t vstate(pDiagnostic, context); - if (auto error = - spvBinaryParse(context, &vstate, binary->code, binary->wordCount, - setHeader, ProcessInstruction, pDiagnostic)) + ValidationState_t vstate(&hijack_context); + if (auto error = spvBinaryParse(&hijack_context, &vstate, binary->code, + binary->wordCount, setHeader, + ProcessInstruction, pDiagnostic)) return error; if (vstate.in_function_body()) @@ -243,7 +249,7 @@ spv_result_t spvValidate(const spv_const_context context, position.index = SPV_INDEX_INSTRUCTION; return spvValidateIDs(instructions.data(), instructions.size(), - context->opcode_table, context->operand_table, - context->ext_inst_table, vstate, &position, - pDiagnostic); + hijack_context.opcode_table, + hijack_context.operand_table, + hijack_context.ext_inst_table, vstate, &position); } diff --git a/source/validate.h b/source/validate.h index 2ef341a4..a19e911a 100644 --- a/source/validate.h +++ b/source/validate.h @@ -20,6 +20,7 @@ #include #include "instruction.h" +#include "message.h" #include "spirv-tools/libspirv.h" #include "table.h" @@ -154,7 +155,6 @@ spv_result_t InstructionPass(ValidationState_t& _, /// @param[in] operandTable table of specified operands /// @param[in] usedefs use-def info from module parsing /// @param[in,out] position current position in the stream -/// @param[out] pDiag contains diagnostic on failure /// /// @return result code spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, @@ -163,8 +163,7 @@ spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, const spv_operand_table operandTable, const spv_ext_inst_table extInstTable, const libspirv::ValidationState_t& state, - spv_position position, - spv_diagnostic* pDiag); + spv_position position); /// @brief Validate the ID's within a SPIR-V binary /// @@ -174,7 +173,7 @@ spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, /// @param[in] opcodeTable table of specified Opcodes /// @param[in] operandTable table of specified operands /// @param[in,out] position current word in the binary -/// @param[out] pDiagnostic contains diagnostic on failure +/// @param[in] consumer message consumer callback /// /// @return result code spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions, @@ -182,6 +181,7 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInstructions, const spv_opcode_table opcodeTable, const spv_operand_table operandTable, const spv_ext_inst_table extInstTable, - spv_position position, spv_diagnostic* pDiagnostic); + spv_position position, + const spvtools::MessageConsumer& consumer); #endif // LIBSPIRV_VALIDATE_H_ diff --git a/source/validate_id.cpp b/source/validate_id.cpp index 1e105459..611d39fd 100644 --- a/source/validate_id.cpp +++ b/source/validate_id.cpp @@ -24,6 +24,7 @@ #include "diagnostic.h" #include "instruction.h" +#include "message.h" #include "opcode.h" #include "spirv-tools/libspirv.h" #include "val/Function.h" @@ -48,7 +49,7 @@ class idUsage { const SpvMemoryModel memoryModelArg, const SpvAddressingModel addressingModelArg, const ValidationState_t& module, const vector& entry_points, - spv_position positionArg, spv_diagnostic* pDiagnosticArg) + spv_position positionArg, const spvtools::MessageConsumer& consumer) : opcodeTable(opcodeTableArg), operandTable(operandTableArg), extInstTable(extInstTableArg), @@ -57,7 +58,7 @@ class idUsage { memoryModel(memoryModelArg), addressingModel(addressingModelArg), position(positionArg), - pDiagnostic(pDiagnosticArg), + consumer_(consumer), module_(module), entry_points_(entry_points) {} @@ -75,14 +76,14 @@ class idUsage { const SpvMemoryModel memoryModel; const SpvAddressingModel addressingModel; spv_position position; - spv_diagnostic* pDiagnostic; + const spvtools::MessageConsumer& consumer_; const ValidationState_t& module_; vector entry_points_; }; #define DIAG(INDEX) \ position->index += INDEX; \ - libspirv::DiagnosticStream helper(*position, pDiagnostic, \ + libspirv::DiagnosticStream helper(*position, consumer_, \ SPV_ERROR_INVALID_DIAGNOSTIC); \ helper @@ -2553,11 +2554,10 @@ spv_result_t spvValidateInstructionIDs(const spv_instruction_t* pInsts, const spv_operand_table operandTable, const spv_ext_inst_table extInstTable, const libspirv::ValidationState_t& state, - spv_position position, - spv_diagnostic* pDiag) { + spv_position position) { idUsage idUsage(opcodeTable, operandTable, extInstTable, pInsts, instCount, state.memory_model(), state.addressing_model(), state, - state.entry_points(), position, pDiag); + state.entry_points(), position, state.context()->consumer); for (uint64_t instIndex = 0; instIndex < instCount; ++instIndex) { if (!idUsage.isValid(&pInsts[instIndex])) return SPV_ERROR_INVALID_ID; position->index += pInsts[instIndex].words.size(); diff --git a/test/BinaryParse.cpp b/test/BinaryParse.cpp index 97c2104a..48e7edf7 100644 --- a/test/BinaryParse.cpp +++ b/test/BinaryParse.cpp @@ -19,6 +19,8 @@ #include "TestFixture.h" #include "UnitSPIRV.h" #include "gmock/gmock.h" +#include "source/message.h" +#include "source/table.h" #include "spirv/1.0/OpenCL.std.h" // Returns true if two spv_parsed_operand_t values are equal. @@ -258,6 +260,112 @@ TEST_F(BinaryParseTest, NullDiagnosticsIsOkForBadParse) { words.size(), invoke_header, invoke_instruction, nullptr)); } +// Make sure that we don't blow up when both the consumer and the diagnostic are +// null. +TEST_F(BinaryParseTest, NullConsumerNullDiagnosticsForBadParse) { + auto words = CompileSuccessfully(""); + + auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + SetContextMessageConsumer(ctx, nullptr); + + words.push_back(0xffffffff); // Certainly invalid instruction header. + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryParse(ctx, &client_, words.data(), words.size(), + invoke_header, invoke_instruction, nullptr)); + + spvContextDestroy(ctx); +} + +TEST_F(BinaryParseTest, SpecifyConsumerNullDiagnosticsForGoodParse) { + const auto words = CompileSuccessfully(""); + + auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + ctx, [&invocation](spvtools::MessageLevel, const char*, + const spv_position_t&, const char*) { ++invocation; }); + + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_SUCCESS, + spvBinaryParse(ctx, &client_, words.data(), words.size(), + invoke_header, invoke_instruction, nullptr)); + EXPECT_EQ(0, invocation); + + spvContextDestroy(ctx); +} + +TEST_F(BinaryParseTest, SpecifyConsumerNullDiagnosticsForBadParse) { + auto words = CompileSuccessfully(""); + + auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + ctx, [&invocation](spvtools::MessageLevel level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(spvtools::MessageLevel::Error, level); + EXPECT_STREQ("", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(5u, position.index); + EXPECT_STREQ("Invalid opcode: 65535", message); + }); + + words.push_back(0xffffffff); // Certainly invalid instruction header. + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryParse(ctx, &client_, words.data(), words.size(), + invoke_header, invoke_instruction, nullptr)); + EXPECT_EQ(1, invocation); + + spvContextDestroy(ctx); +} + +TEST_F(BinaryParseTest, SpecifyConsumerSpecifyDiagnosticsForGoodParse) { + const auto words = CompileSuccessfully(""); + + auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + ctx, [&invocation](spvtools::MessageLevel, const char*, + const spv_position_t&, const char*) { ++invocation; }); + + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_SUCCESS, + spvBinaryParse(ctx, &client_, words.data(), words.size(), + invoke_header, invoke_instruction, &diagnostic_)); + EXPECT_EQ(0, invocation); + EXPECT_EQ(nullptr, diagnostic_); + + spvContextDestroy(ctx); +} + +TEST_F(BinaryParseTest, SpecifyConsumerSpecifyDiagnosticsForBadParse) { + auto words = CompileSuccessfully(""); + + auto ctx = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + ctx, [&invocation](spvtools::MessageLevel, const char*, + const spv_position_t&, const char*) { ++invocation; }); + + words.push_back(0xffffffff); // Certainly invalid instruction header. + EXPECT_HEADER(1).WillOnce(Return(SPV_SUCCESS)); + EXPECT_CALL(client_, Instruction(_)).Times(0); // No instruction callback. + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryParse(ctx, &client_, words.data(), words.size(), + invoke_header, invoke_instruction, &diagnostic_)); + EXPECT_EQ(0, invocation); + EXPECT_STREQ("Invalid opcode: 65535", diagnostic_->error); + + spvContextDestroy(ctx); +} + TEST_F(BinaryParseTest, ModuleWithSingleInstructionHasValidHeaderAndInstructionCallback) { for (bool endian_swap : kSwapEndians) { diff --git a/test/BinaryToText.cpp b/test/BinaryToText.cpp index ec4e4112..1d28e9de 100644 --- a/test/BinaryToText.cpp +++ b/test/BinaryToText.cpp @@ -149,13 +149,6 @@ TEST_F(BinaryToText, InvalidMagicNumber) { spvDiagnosticDestroy(diagnostic); } -TEST_F(BinaryToText, InvalidDiagnostic) { - spv_text text; - ASSERT_EQ(SPV_ERROR_INVALID_DIAGNOSTIC, - spvBinaryToText(context, binary->code, binary->wordCount, - SPV_BINARY_TO_TEXT_OPTION_NONE, &text, nullptr)); -} - struct FailedDecodeCase { std::string source_text; std::vector appended_instruction; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b585a266..65fcfd3d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -129,6 +129,11 @@ add_spvtools_unittest( SRCS diagnostic.cpp LIBS ${SPIRV_TOOLS}) +add_spvtools_unittest( + TARGET c_interface + SRCS c_interface.cpp + LIBS ${SPIRV_TOOLS}) + add_spvtools_unittest( TARGET cpp_interface SRCS cpp_interface.cpp diff --git a/test/TextLiteral.cpp b/test/TextLiteral.cpp index 4870abfa..c2f7704b 100644 --- a/test/TextLiteral.cpp +++ b/test/TextLiteral.cpp @@ -14,15 +14,15 @@ #include "UnitSPIRV.h" -#include "gmock/gmock.h" #include "TestFixture.h" +#include "gmock/gmock.h" +#include "message.h" #include using ::testing::Eq; namespace { - TEST(TextLiteral, GoodI32) { spv_literal_t l; @@ -119,7 +119,7 @@ INSTANTIATE_TEST_CASE_P( {"\"\xE4\xBA\xB2\"", "\xE4\xBA\xB2"}, {"\"\\\xE4\xBA\xB2\"", "\xE4\xBA\xB2"}, {"\"this \\\" and this \\\\ and \\\xE4\xBA\xB2\"", - "this \" and this \\ and \xE4\xBA\xB2"}}),); + "this \" and this \\ and \xE4\xBA\xB2"}}), ); TEST(TextLiteral, StringTooLong) { spv_literal_t l; @@ -168,31 +168,32 @@ using IntegerTest = std::vector successfulEncode(const TextLiteralCase& test, libspirv::IdTypeClass type) { spv_instruction_t inst; - spv_diagnostic diagnostic; + std::string message; + auto capture_message = [&message](spvtools::MessageLevel, const char*, + const spv_position_t&, + const char* m) { message = m; }; libspirv::IdType expected_type{test.bitwidth, test.is_signed, type}; EXPECT_EQ(SPV_SUCCESS, - libspirv::AssemblyContext(nullptr, &diagnostic) + libspirv::AssemblyContext(nullptr, capture_message) .binaryEncodeNumericLiteral(test.text, SPV_ERROR_INVALID_TEXT, expected_type, &inst)) - << diagnostic->error; + << message; return inst.words; } std::string failedEncode(const TextLiteralCase& test, libspirv::IdTypeClass type) { spv_instruction_t inst; - spv_diagnostic diagnostic; + std::string message; + auto capture_message = [&message](spvtools::MessageLevel, const char*, + const spv_position_t&, + const char* m) { message = m; }; libspirv::IdType expected_type{test.bitwidth, test.is_signed, type}; EXPECT_EQ(SPV_ERROR_INVALID_TEXT, - libspirv::AssemblyContext(nullptr, &diagnostic) + libspirv::AssemblyContext(nullptr, capture_message) .binaryEncodeNumericLiteral(test.text, SPV_ERROR_INVALID_TEXT, expected_type, &inst)); - std::string ret_val; - if (diagnostic) { - ret_val = diagnostic->error; - spvDiagnosticDestroy(diagnostic); - } - return ret_val; + return message; } TEST_P(IntegerTest, IntegerBounds) { diff --git a/test/TextToBinary.cpp b/test/TextToBinary.cpp index 0a1439ac..4db3ed23 100644 --- a/test/TextToBinary.cpp +++ b/test/TextToBinary.cpp @@ -126,14 +126,6 @@ TEST_F(TextToBinaryTest, InvalidPointer) { nullptr, &diagnostic)); } -TEST_F(TextToBinaryTest, InvalidDiagnostic) { - SetText( - "OpEntryPoint Kernel 0 \"\"\nOpExecutionMode 0 LocalSizeHint 1 1 1\n"); - ASSERT_EQ(SPV_ERROR_INVALID_DIAGNOSTIC, - spvTextToBinary(ScopedContext().context, text.str, text.length, - &binary, nullptr)); -} - TEST_F(TextToBinaryTest, InvalidPrefix) { EXPECT_EQ( "Expected or at the beginning of an instruction, " diff --git a/test/c_interface.cpp b/test/c_interface.cpp new file mode 100644 index 00000000..96fcb48e --- /dev/null +++ b/test/c_interface.cpp @@ -0,0 +1,278 @@ +// Copyright (c) 2016 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "message.h" +#include "spirv-tools/libspirv.h" +#include "table.h" + +namespace { + +using namespace spvtools; + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForValidInput) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = + "OpCapability Shader\nOpMemoryModel Logical GLSL450"; + + spv_binary binary = nullptr; + EXPECT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + { + // Sadly the compiler don't allow me to feed binary directly to + // spvValidate(). + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_SUCCESS, spvValidate(context, &b, nullptr)); + } + + spv_text text = nullptr; + EXPECT_EQ(SPV_SUCCESS, spvBinaryToText(context, binary->code, + binary->wordCount, 0, &text, nullptr)); + + spvTextDestroy(text); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidAssembling) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = "%1 = OpName"; + + spv_binary binary = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, + spvTextToBinary(context, input_text, sizeof(input_text), &binary, + nullptr)); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidDiassembling) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = "OpNop"; + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + // Change OpNop to an invalid (wordcount|opcode) word. + binary->code[binary->wordCount - 1] = 0xffffffff; + + spv_text text = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, binary->code, binary->wordCount, 0, &text, + nullptr)); + + spvTextDestroy(text); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// The default consumer is a null std::function. +TEST(CInterface, DefaultConsumerNullDiagnosticForInvalidValidating) { + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + const char input_text[] = "OpNop"; + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, nullptr)); + + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerNullDiagnosticForAssembling) { + const char input_text[] = "%1 = OpName\n"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + // TODO(antiagainst): Use public C API for setting the consumer once exists. + SetContextMessageConsumer( + context, + [&invocation](MessageLevel level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(MessageLevel::Error, level); + // The error happens at scanning the begining of second line. + EXPECT_STREQ("", source); + EXPECT_EQ(1u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(12u, position.index); + EXPECT_STREQ("Expected operand, found end of stream.", message); + }); + + spv_binary binary = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, + spvTextToBinary(context, input_text, sizeof(input_text), &binary, + nullptr)); + EXPECT_EQ(1, invocation); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerNullDiagnosticForDisassembling) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, + [&invocation](MessageLevel level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(MessageLevel::Error, level); + EXPECT_STREQ("", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + EXPECT_EQ(5u, position.index); + EXPECT_STREQ("Invalid opcode: 65535", message); + }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + // Change OpNop to an invalid (wordcount|opcode) word. + binary->code[binary->wordCount - 1] = 0xffffffff; + + spv_text text = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, binary->code, binary->wordCount, 0, &text, + nullptr)); + EXPECT_EQ(1, invocation); + + spvTextDestroy(text); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerNullDiagnosticForValidating) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, + [&invocation](MessageLevel level, const char* source, + const spv_position_t& position, const char* message) { + ++invocation; + EXPECT_EQ(MessageLevel::Error, level); + EXPECT_STREQ("", source); + EXPECT_EQ(0u, position.line); + EXPECT_EQ(0u, position.column); + // TODO(antiagainst): what validation reports is not a word offset here. + // It is inconsistent with diassembler. Should be fixed. + EXPECT_EQ(1u, position.index); + EXPECT_STREQ("Nop cannot appear before the memory model instruction", + message); + }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, nullptr)); + EXPECT_EQ(1, invocation); + + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +// When having both a consumer and an diagnostic object, the diagnostic object +// should take priority. +TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForAssembling) { + const char input_text[] = "%1 = OpName"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, [&invocation](MessageLevel, const char*, const spv_position_t&, + const char*) { ++invocation; }); + + spv_binary binary = nullptr; + spv_diagnostic diagnostic = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_TEXT, + spvTextToBinary(context, input_text, sizeof(input_text), &binary, + &diagnostic)); + EXPECT_EQ(0, invocation); // Consumer should not be invoked at all. + EXPECT_STREQ("Expected operand, found end of stream.", diagnostic->error); + + spvDiagnosticDestroy(diagnostic); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForDisassembling) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, [&invocation](MessageLevel, const char*, const spv_position_t&, + const char*) { ++invocation; }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + // Change OpNop to an invalid (wordcount|opcode) word. + binary->code[binary->wordCount - 1] = 0xffffffff; + + spv_diagnostic diagnostic = nullptr; + spv_text text = nullptr; + EXPECT_EQ(SPV_ERROR_INVALID_BINARY, + spvBinaryToText(context, binary->code, binary->wordCount, 0, &text, + &diagnostic)); + + EXPECT_EQ(0, invocation); // Consumer should not be invoked at all. + EXPECT_STREQ("Invalid opcode: 65535", diagnostic->error); + + spvTextDestroy(text); + spvDiagnosticDestroy(diagnostic); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +TEST(CInterface, SpecifyConsumerSpecifyDiagnosticForValidating) { + const char input_text[] = "OpNop"; + + auto context = spvContextCreate(SPV_ENV_UNIVERSAL_1_1); + int invocation = 0; + SetContextMessageConsumer( + context, [&invocation](MessageLevel, const char*, const spv_position_t&, + const char*) { ++invocation; }); + + spv_binary binary = nullptr; + ASSERT_EQ(SPV_SUCCESS, spvTextToBinary(context, input_text, + sizeof(input_text), &binary, nullptr)); + + spv_diagnostic diagnostic = nullptr; + spv_const_binary_t b{binary->code, binary->wordCount}; + EXPECT_EQ(SPV_ERROR_INVALID_LAYOUT, spvValidate(context, &b, &diagnostic)); + + EXPECT_EQ(0, invocation); // Consumer should not be invoked at all. + EXPECT_STREQ("Nop cannot appear before the memory model instruction", + diagnostic->error); + + spvDiagnosticDestroy(diagnostic); + spvBinaryDestroy(binary); + spvContextDestroy(context); +} + +} // anonymous namespace diff --git a/test/diagnostic.cpp b/test/diagnostic.cpp index fa35502e..feb5403b 100644 --- a/test/diagnostic.cpp +++ b/test/diagnostic.cpp @@ -63,16 +63,16 @@ TEST(Diagnostic, PrintInvalidDiagnostic) { TEST(DiagnosticStream, ConversionToResultType) { // Check after the DiagnosticStream object is destroyed. spv_result_t value; - { value = DiagnosticStream({}, 0, SPV_ERROR_INVALID_TEXT); } + { value = DiagnosticStream({}, nullptr, SPV_ERROR_INVALID_TEXT); } EXPECT_EQ(SPV_ERROR_INVALID_TEXT, value); // Check implicit conversion via plain assignment. - value = DiagnosticStream({}, 0, SPV_SUCCESS); + value = DiagnosticStream({}, nullptr, SPV_SUCCESS); EXPECT_EQ(SPV_SUCCESS, value); // Check conversion via constructor. EXPECT_EQ(SPV_FAILED_MATCH, - spv_result_t(DiagnosticStream({}, 0, SPV_FAILED_MATCH))); + spv_result_t(DiagnosticStream({}, nullptr, SPV_FAILED_MATCH))); } } // anonymous namespace diff --git a/test/val/ValidationState.cpp b/test/val/ValidationState.cpp index e75f8cdf..23df7ccc 100644 --- a/test/val/ValidationState.cpp +++ b/test/val/ValidationState.cpp @@ -36,11 +36,9 @@ using std::vector; class ValidationStateTest : public testing::Test { public: ValidationStateTest() - : context_(spvContextCreate(SPV_ENV_UNIVERSAL_1_0)), - state_(&diag_, context_) {} + : context_(spvContextCreate(SPV_ENV_UNIVERSAL_1_0)), state_(context_) {} protected: - spv_diagnostic diag_; spv_context context_; ValidationState_t state_; };