Linker: Better type comparison for OpTypeArray and OpTypeForwardPointer (#2580)

* Types: Avoid comparing IDs for in Type::IsSameImpl

When linking, we end up with duplicate types for imported and exported
types, that needs to be removed. The current code would reject valid
import/export pairs of symbols due to IDs mismatch, even if the types or
constants behind those ID were the same.

Enabled remaining type_match_test

Fixes #2442
This commit is contained in:
Pierre Moreau 2019-05-29 22:12:02 +02:00 committed by Steven Perron
parent 0125b28ed4
commit e7866de4b1
10 changed files with 211 additions and 111 deletions

View File

@ -33,6 +33,7 @@
#include "source/opt/ir_loader.h"
#include "source/opt/pass_manager.h"
#include "source/opt/remove_duplicates_pass.h"
#include "source/opt/type_manager.h"
#include "source/spirv_target_env.h"
#include "source/util/make_unique.h"
#include "spirv-tools/libspirv.hpp"
@ -40,14 +41,15 @@
namespace spvtools {
namespace {
using opt::IRContext;
using opt::Instruction;
using opt::IRContext;
using opt::Module;
using opt::Operand;
using opt::PassManager;
using opt::RemoveDuplicatesPass;
using opt::analysis::DecorationManager;
using opt::analysis::DefUseManager;
using opt::analysis::Type;
using opt::analysis::TypeManager;
// Stores various information about an imported or exported symbol.
struct LinkageSymbolInfo {
@ -472,14 +474,15 @@ spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
opt::IRContext* context) {
spv_position_t position = {};
// Ensure th import and export types are the same.
const DefUseManager& def_use_manager = *context->get_def_use_mgr();
// Ensure the import and export types are the same.
const DecorationManager& decoration_manager = *context->get_decoration_mgr();
const TypeManager& type_manager = *context->get_type_mgr();
for (const auto& linking_entry : linkings_to_do) {
if (!RemoveDuplicatesPass::AreTypesEqual(
*def_use_manager.GetDef(linking_entry.imported_symbol.type_id),
*def_use_manager.GetDef(linking_entry.exported_symbol.type_id),
context))
Type* imported_symbol_type =
type_manager.GetType(linking_entry.imported_symbol.type_id);
Type* exported_symbol_type =
type_manager.GetType(linking_entry.exported_symbol.type_id);
if (!(*imported_symbol_type == *exported_symbol_type))
return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
<< "Type mismatch on symbol \""
<< linking_entry.imported_symbol.name

View File

@ -96,35 +96,67 @@ bool RemoveDuplicatesPass::RemoveDuplicateTypes() const {
return modified;
}
analysis::TypeManager type_manager(context()->consumer(), context());
std::vector<Instruction*> visited_types;
std::vector<analysis::ForwardPointer> visited_forward_pointers;
std::vector<Instruction*> to_delete;
for (auto* i = &*context()->types_values_begin(); i; i = i->NextNode()) {
const bool is_i_forward_pointer = i->opcode() == SpvOpTypeForwardPointer;
// We only care about types.
if (!spvOpcodeGeneratesType((i->opcode())) &&
i->opcode() != SpvOpTypeForwardPointer) {
if (!spvOpcodeGeneratesType(i->opcode()) && !is_i_forward_pointer) {
continue;
}
// Is the current type equal to one of the types we have aready visited?
SpvId id_to_keep = 0u;
// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
for (auto j : visited_types) {
if (AreTypesEqual(*i, *j, context())) {
id_to_keep = j->result_id();
break;
if (!is_i_forward_pointer) {
// Is the current type equal to one of the types we have already visited?
SpvId id_to_keep = 0u;
analysis::Type* i_type = type_manager.GetType(i->result_id());
assert(i_type);
// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
for (auto j : visited_types) {
analysis::Type* j_type = type_manager.GetType(j->result_id());
assert(j_type);
if (*i_type == *j_type) {
id_to_keep = j->result_id();
break;
}
}
}
if (id_to_keep == 0u) {
// This is a never seen before type, keep it around.
visited_types.emplace_back(i);
if (id_to_keep == 0u) {
// This is a never seen before type, keep it around.
visited_types.emplace_back(i);
} else {
// The same type has already been seen before, remove this one.
context()->KillNamesAndDecorates(i->result_id());
context()->ReplaceAllUsesWith(i->result_id(), id_to_keep);
modified = true;
to_delete.emplace_back(i);
}
} else {
// The same type has already been seen before, remove this one.
context()->KillNamesAndDecorates(i->result_id());
context()->ReplaceAllUsesWith(i->result_id(), id_to_keep);
modified = true;
to_delete.emplace_back(i);
analysis::ForwardPointer i_type(
i->GetSingleWordInOperand(0u),
(SpvStorageClass)i->GetSingleWordInOperand(1u));
i_type.SetTargetPointer(
type_manager.GetType(i_type.target_id())->AsPointer());
// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
const bool found_a_match =
std::find(std::begin(visited_forward_pointers),
std::end(visited_forward_pointers),
i_type) != std::end(visited_forward_pointers);
if (!found_a_match) {
// This is a never seen before type, keep it around.
visited_forward_pointers.emplace_back(i_type);
} else {
// The same type has already been seen before, remove this one.
modified = true;
to_delete.emplace_back(i);
}
}
}
@ -151,8 +183,8 @@ bool RemoveDuplicatesPass::RemoveDuplicateDecorations() const {
analysis::DecorationManager decoration_manager(context()->module());
for (auto* i = &*context()->annotation_begin(); i;) {
// Is the current decoration equal to one of the decorations we have aready
// visited?
// Is the current decoration equal to one of the decorations we have
// already visited?
bool already_visited = false;
// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
@ -177,20 +209,5 @@ bool RemoveDuplicatesPass::RemoveDuplicateDecorations() const {
return modified;
}
bool RemoveDuplicatesPass::AreTypesEqual(const Instruction& inst1,
const Instruction& inst2,
IRContext* context) {
if (inst1.opcode() != inst2.opcode()) return false;
if (!IsTypeInst(inst1.opcode())) return false;
const analysis::Type* type1 =
context->get_type_mgr()->GetType(inst1.result_id());
const analysis::Type* type2 =
context->get_type_mgr()->GetType(inst2.result_id());
if (type1 && type2 && *type1 == *type2) return true;
return false;
}
} // namespace opt
} // namespace spvtools

View File

@ -36,12 +36,6 @@ class RemoveDuplicatesPass : public Pass {
const char* name() const override { return "remove-duplicates"; }
Status Process() override;
// TODO(pierremoreau): Move this function somewhere else (e.g. pass.h or
// within the type manager)
// Returns whether two types are equal, and have the same decorations.
static bool AreTypesEqual(const Instruction& inst1, const Instruction& inst2,
IRContext* context);
private:
// Remove duplicate capabilities from the module
//

View File

@ -66,7 +66,13 @@ uint32_t TypeManager::GetId(const Type* type) const {
}
void TypeManager::AnalyzeTypes(const Module& module) {
// First pass through the types. Any types that reference a forward pointer
// First pass through the constants, as some will be needed when traversing
// the types in the next pass.
for (const auto* inst : module.GetConstants()) {
id_to_constant_inst_[inst->result_id()] = inst;
}
// Then pass through the types. Any types that reference a forward pointer
// (directly or indirectly) are incomplete, and are added to incomplete types.
for (const auto* inst : module.GetTypes()) {
RecordIfTypeDefinition(*inst);
@ -154,7 +160,7 @@ void TypeManager::AnalyzeTypes(const Module& module) {
#ifndef NDEBUG
// Check if the type pool contains two types that are the same. This
// is an indication that the hashing and comparision are wrong. It
// is an indication that the hashing and comparison are wrong. It
// will cause a problem if the type pool gets resized and everything
// is rehashed.
for (auto& i : type_pool_) {
@ -505,8 +511,15 @@ Type* TypeManager::RebuildType(const Type& type) {
case Type::kArray: {
const Array* array_ty = type.AsArray();
const Type* ele_ty = array_ty->element_type();
rebuilt_ty =
MakeUnique<Array>(RebuildType(*ele_ty), array_ty->LengthId());
if (array_ty->length_spec_id() != 0u)
rebuilt_ty =
MakeUnique<Array>(RebuildType(*ele_ty), array_ty->LengthId(),
array_ty->length_spec_id());
else
rebuilt_ty =
MakeUnique<Array>(RebuildType(*ele_ty), array_ty->LengthId(),
array_ty->length_constant_type(),
array_ty->length_constant_words());
break;
}
case Type::kRuntimeArray: {
@ -636,15 +649,39 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
case SpvOpTypeSampledImage:
type = new SampledImage(GetType(inst.GetSingleWordInOperand(0)));
break;
case SpvOpTypeArray:
type = new Array(GetType(inst.GetSingleWordInOperand(0)),
inst.GetSingleWordInOperand(1));
case SpvOpTypeArray: {
const uint32_t length_id = inst.GetSingleWordInOperand(1);
const Instruction* length_constant_inst = id_to_constant_inst_[length_id];
assert(length_constant_inst);
// If it is a specialised constants, retrieve its SpecId.
uint32_t spec_id = 0u;
Type* length_type = nullptr;
Operand::OperandData length_words;
if (spvOpcodeIsSpecConstant(length_constant_inst->opcode())) {
context()->get_decoration_mgr()->ForEachDecoration(
length_id, SpvDecorationSpecId,
[&spec_id](const Instruction& decoration) {
assert(decoration.opcode() == SpvOpDecorate);
spec_id = decoration.GetSingleWordOperand(2u);
});
} else {
length_type = GetType(length_constant_inst->type_id());
length_words = length_constant_inst->GetOperand(2u).words;
}
if (spec_id != 0u)
type = new Array(GetType(inst.GetSingleWordInOperand(0)), length_id,
spec_id);
else
type = new Array(GetType(inst.GetSingleWordInOperand(0)), length_id,
length_type, length_words);
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {
incomplete_types_.emplace_back(inst.result_id(), type);
id_to_incomplete_type_[inst.result_id()] = type;
return type;
}
break;
} break;
case SpvOpTypeRuntimeArray:
type = new RuntimeArray(GetType(inst.GetSingleWordInOperand(0)));
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {

View File

@ -209,6 +209,8 @@ class TypeManager {
IdToTypeMap id_to_incomplete_type_; // Maps ids to their type representations
// for incomplete types.
std::unordered_map<uint32_t, const Instruction*> id_to_constant_inst_;
};
} // namespace analysis

View File

@ -383,17 +383,46 @@ void SampledImage::GetExtraHashWords(
image_type_->GetHashWords(words, seen);
}
Array::Array(Type* type, uint32_t length_id)
: Type(kArray), element_type_(type), length_id_(length_id) {
Array::Array(Type* type, uint32_t length_id, uint32_t spec_id)
: Type(kArray),
element_type_(type),
length_id_(length_id),
length_spec_id_(spec_id),
length_constant_type_(nullptr),
length_constant_words_() {
assert(!type->AsVoid());
assert(spec_id != 0u);
}
Array::Array(Type* type, uint32_t length_id, const Type* constant_type,
Operand::OperandData constant_words)
: Type(kArray),
element_type_(type),
length_id_(length_id),
length_spec_id_(0u),
length_constant_type_(constant_type),
length_constant_words_(constant_words) {
assert(!type->AsVoid());
assert(constant_type && constant_type->AsInteger());
}
bool Array::IsSameImpl(const Type* that, IsSameCache* seen) const {
const Array* at = that->AsArray();
if (!at) return false;
return length_id_ == at->length_id_ &&
element_type_->IsSameImpl(at->element_type_, seen) &&
HasSameDecorations(that);
bool is_same = element_type_->IsSameImpl(at->element_type_, seen) &&
HasSameDecorations(that);
// If it is a specialized constant
if (length_spec_id_ != 0u) {
// ensure they have the same SpecId
is_same = is_same && length_spec_id_ == at->length_spec_id_;
} else {
// else, ensure they have the same length literal number.
is_same =
is_same &&
length_constant_type_->IsSameImpl(at->length_constant_type_, seen) &&
length_constant_words_ == at->length_constant_words_;
}
return is_same;
}
std::string Array::str() const {
@ -405,7 +434,13 @@ std::string Array::str() const {
void Array::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
element_type_->GetHashWords(words, seen);
words->push_back(length_id_);
if (length_spec_id_ != 0u) {
words->push_back(length_spec_id_);
} else {
length_constant_type_->GetHashWords(words, seen);
words->insert(words->end(), length_constant_words_.begin(),
length_constant_words_.end());
}
}
void Array::ReplaceElementType(const Type* type) { element_type_ = type; }
@ -609,7 +644,8 @@ void Pipe::GetExtraHashWords(std::vector<uint32_t>* words,
bool ForwardPointer::IsSameImpl(const Type* that, IsSameCache*) const {
const ForwardPointer* fpt = that->AsForwardPointer();
if (!fpt) return false;
return target_id_ == fpt->target_id_ &&
return (pointer_ && fpt->pointer_ ? *pointer_ == *fpt->pointer_
: target_id_ == fpt->target_id_) &&
storage_class_ == fpt->storage_class_ && HasSameDecorations(that);
}

View File

@ -27,6 +27,7 @@
#include <vector>
#include "source/latest_version_spirv_header.h"
#include "source/opt/instruction.h"
#include "spirv-tools/libspirv.h"
namespace spvtools {
@ -356,12 +357,19 @@ class SampledImage : public Type {
class Array : public Type {
public:
Array(Type* element_type, uint32_t length_id);
Array(Type* element_type, uint32_t length_id, uint32_t spec_id);
Array(Type* element_type, uint32_t length_id, const Type* constant_type,
Operand::OperandData constant_words);
Array(const Array&) = default;
std::string str() const override;
const Type* element_type() const { return element_type_; }
uint32_t LengthId() const { return length_id_; }
uint32_t length_spec_id() const { return length_spec_id_; }
const Type* length_constant_type() const { return length_constant_type_; }
Operand::OperandData length_constant_words() const {
return length_constant_words_;
}
Array* AsArray() override { return this; }
const Array* AsArray() const override { return this; }
@ -376,6 +384,9 @@ class Array : public Type {
const Type* element_type_;
uint32_t length_id_;
uint32_t length_spec_id_;
const Type* length_constant_type_;
Operand::OperandData length_constant_words_;
};
class RuntimeArray : public Type {

View File

@ -21,34 +21,39 @@ namespace {
using TypeMatch = spvtest::LinkerTest;
// Basic types
#define PartInt(N) N " = OpTypeInt 32 0"
#define PartFloat(N) N " = OpTypeFloat 32"
#define PartOpaque(N) N " = OpTypeOpaque \"bar\""
#define PartSampler(N) N " = OpTypeSampler"
#define PartEvent(N) N " = OpTypeEvent"
#define PartDeviceEvent(N) N " = OpTypeDeviceEvent"
#define PartReserveId(N) N " = OpTypeReserveId"
#define PartQueue(N) N " = OpTypeQueue"
#define PartPipe(N) N " = OpTypePipe ReadWrite"
#define PartPipeStorage(N) N " = OpTypePipeStorage"
#define PartNamedBarrier(N) N " = OpTypeNamedBarrier"
#define PartInt(D, N) D(N) " = OpTypeInt 32 0"
#define PartFloat(D, N) D(N) " = OpTypeFloat 32"
#define PartOpaque(D, N) D(N) " = OpTypeOpaque \"bar\""
#define PartSampler(D, N) D(N) " = OpTypeSampler"
#define PartEvent(D, N) D(N) " = OpTypeEvent"
#define PartDeviceEvent(D, N) D(N) " = OpTypeDeviceEvent"
#define PartReserveId(D, N) D(N) " = OpTypeReserveId"
#define PartQueue(D, N) D(N) " = OpTypeQueue"
#define PartPipe(D, N) D(N) " = OpTypePipe ReadWrite"
#define PartPipeStorage(D, N) D(N) " = OpTypePipeStorage"
#define PartNamedBarrier(D, N) D(N) " = OpTypeNamedBarrier"
// Compound types
#define PartVector(N, T) N " = OpTypeVector " T " 3"
#define PartMatrix(N, T) N " = OpTypeMatrix " T " 4"
#define PartImage(N, T) N " = OpTypeImage " T " 2D 0 0 0 0 Rgba32f"
#define PartSampledImage(N, T) N " = OpTypeSampledImage " T
#define PartArray(N, T) N " = OpTypeArray " T " %const"
#define PartRuntimeArray(N, T) N " = OpTypeRuntimeArray " T
#define PartStruct(N, T) N " = OpTypeStruct " T " " T
#define PartPointer(N, T) N " = OpTypePointer Workgroup " T
#define PartFunction(N, T) N " = OpTypeFunction " T " " T
#define PartVector(DR, DA, N, T) DR(N) " = OpTypeVector " DA(T) " 3"
#define PartMatrix(DR, DA, N, T) DR(N) " = OpTypeMatrix " DA(T) " 4"
#define PartImage(DR, DA, N, T) \
DR(N) " = OpTypeImage " DA(T) " 2D 0 0 0 0 Rgba32f"
#define PartSampledImage(DR, DA, N, T) DR(N) " = OpTypeSampledImage " DA(T)
#define PartArray(DR, DA, N, T) DR(N) " = OpTypeArray " DA(T) " " DA(const)
#define PartRuntimeArray(DR, DA, N, T) DR(N) " = OpTypeRuntimeArray " DA(T)
#define PartStruct(DR, DA, N, T) DR(N) " = OpTypeStruct " DA(T) " " DA(T)
#define PartPointer(DR, DA, N, T) DR(N) " = OpTypePointer Workgroup " DA(T)
#define PartFunction(DR, DA, N, T) DR(N) " = OpTypeFunction " DA(T) " " DA(T)
#define CheckDecoRes(S) "[[" #S ":%\\w+]]"
#define CheckDecoArg(S) "[[" #S "]]"
#define InstDeco(S) "%" #S
#define MatchPart1(F, N) \
"; CHECK: " Part##F("[[" #N ":%\\w+]]") "\n" Part##F("%" #N) "\n"
#define MatchPart2(F, N, T) \
"; CHECK: " Part##F("[[" #N ":%\\w+]]", "[[" #T ":%\\w+]]") "\n" Part##F( \
"%" #N, "%" #T) "\n"
"; CHECK: " Part##F(CheckDecoRes, N) "\n" Part##F(InstDeco, N) "\n"
#define MatchPart2(F, N, T) \
"; CHECK: " Part##F(CheckDecoRes, CheckDecoArg, N, T) "\n" Part##F( \
InstDeco, InstDeco, N, T) "\n"
#define MatchF(N, CODE) \
TEST_F(TypeMatch, N) { \
@ -98,47 +103,42 @@ Match3(Matrix, Vector, Float);
Match2(Image, Float);
// Unrestricted compound types
// The following skip Array as it causes issues
#define MatchCompounds1(A) \
Match2(RuntimeArray, A); \
Match2(Struct, A); \
Match2(Pointer, A); \
Match2(Function, A); \
// Match2(Array, A); // Disabled as it fails currently
Match2(Array, A);
#define MatchCompounds2(A, B) \
Match3(RuntimeArray, A, B); \
Match3(Struct, A, B); \
Match3(Pointer, A, B); \
Match3(Function, A, B); \
// Match3(Array, A, B); // Disabled as it fails currently
Match3(Array, A, B);
MatchCompounds1(Float);
// MatchCompounds2(Array, Float);
MatchCompounds2(Array, Float);
MatchCompounds2(RuntimeArray, Float);
MatchCompounds2(Struct, Float);
MatchCompounds2(Pointer, Float);
MatchCompounds2(Function, Float);
// ForwardPointer tests, which don't fit into the previous mold
#define MatchFpF(N, CODE) \
MatchF(N, \
"; CHECK: [[type:%\\w+]] = OpTypeForwardPointer [[pointer:%\\w+]] " \
"Workgroup\n" \
"%type = OpTypeForwardPointer %pointer Workgroup\n" CODE \
"; CHECK: [[pointer]] = OpTypePointer Workgroup [[realtype]]\n" \
"%pointer = OpTypePointer Workgroup %realtype\n")
#define MatchFpF(N, CODE) \
MatchF(N, \
"; CHECK: OpTypeForwardPointer [[type:%\\w+]] Workgroup\n" \
"OpTypeForwardPointer %type Workgroup\n" CODE \
"; CHECK: [[type]] = OpTypePointer Workgroup [[realtype]]\n" \
"%type = OpTypePointer Workgroup %realtype\n")
#define MatchFp1(T) MatchFpF(ForwardPointerOf##T, MatchPart1(T, realtype))
#define MatchFp2(T, A) \
MatchFpF(ForwardPointerOf##T, MatchPart1(A, a) MatchPart2(T, realtype, a))
// Disabled currently, causes assertion failures
/*
MatchFp1(Float);
MatchFp2(Array, Float);
MatchFp2(RuntimeArray, Float);
MatchFp2(Struct, Float);
MatchFp2(Function, Float);
// */
} // namespace
} // namespace spvtools

View File

@ -117,10 +117,10 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
types.emplace_back(new SampledImage(image2));
// Array
types.emplace_back(new Array(f32, 100));
types.emplace_back(new Array(f32, 42));
types.emplace_back(new Array(f32, 100, 1u));
types.emplace_back(new Array(f32, 42, 2u));
auto* a42f32 = types.back().get();
types.emplace_back(new Array(u64, 24));
types.emplace_back(new Array(u64, 24, s32, {42}));
// RuntimeArray
types.emplace_back(new RuntimeArray(v3f32));

View File

@ -72,7 +72,7 @@ TestMultipleInstancesOfTheSameType(Image, f64_t_.get(), SpvDimCube, 0, 0, 1, 1,
SpvAccessQualifierWriteOnly);
TestMultipleInstancesOfTheSameType(Sampler);
TestMultipleInstancesOfTheSameType(SampledImage, image_t_.get());
TestMultipleInstancesOfTheSameType(Array, u32_t_.get(), 10);
TestMultipleInstancesOfTheSameType(Array, u32_t_.get(), 10, 3);
TestMultipleInstancesOfTheSameType(RuntimeArray, u32_t_.get());
TestMultipleInstancesOfTheSameType(Struct, std::vector<const Type*>{
u32_t_.get(), f64_t_.get()});
@ -151,10 +151,10 @@ std::vector<std::unique_ptr<Type>> GenerateAllTypes() {
types.emplace_back(new SampledImage(image2));
// Array
types.emplace_back(new Array(f32, 100));
types.emplace_back(new Array(f32, 42));
types.emplace_back(new Array(f32, 100, 1u));
types.emplace_back(new Array(f32, 42, 2u));
auto* a42f32 = types.back().get();
types.emplace_back(new Array(u64, 24));
types.emplace_back(new Array(u64, 24, s32, {42}));
// RuntimeArray
types.emplace_back(new RuntimeArray(v3f32));