Add more folding for composite instructions (#4802)

* Add move folding for composite instructions

Fold chains of insert into construct

If a chain of OpCompositeInsert instruction write to every element of a
composite object, then we can replace it with an OpCompositeConstruct.

Fold a construct fed by extracts to a single extract

We already fold an OpCompositeConstruct when it is simlpy reconstructing
an object that was decomposed by a series of OpCompositeExtract
instructions.  However, we do not do that if that object is an element
of a larger object.

I have updated the rule, so that if the original object is a an element
of a larger object, then the OpCompositeConstruct is replaced with a
single OpCompositeExtract from the larger object.

Fixes #4371.
This commit is contained in:
Steven Perron 2022-05-26 10:29:02 -04:00 committed by GitHub
parent c267127846
commit 088cb1a5c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 358 additions and 8 deletions

View File

@ -1631,6 +1631,57 @@ bool CompositeConstructFeedingExtract(
return true;
}
// Walks the indexes chain from |start| to |end| of an OpCompositeInsert or
// OpCompositeExtract instruction, and returns the type of the final element
// being accessed.
const analysis::Type* GetElementType(uint32_t type_id,
Instruction::iterator start,
Instruction::iterator end,
const analysis::TypeManager* type_mgr) {
const analysis::Type* type = type_mgr->GetType(type_id);
for (auto index : make_range(std::move(start), std::move(end))) {
assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
index.words.size() == 1);
if (auto* array_type = type->AsArray()) {
type = array_type->element_type();
} else if (auto* matrix_type = type->AsMatrix()) {
type = matrix_type->element_type();
} else if (auto* struct_type = type->AsStruct()) {
type = struct_type->element_types()[index.words[0]];
} else {
type = nullptr;
}
}
return type;
}
// Returns true of |inst_1| and |inst_2| have the same indexes that will be used
// to index into a composite object, excluding the last index. The two
// instructions must have the same opcode, and be either OpCompositeExtract or
// OpCompositeInsert instructions.
bool HaveSameIndexesExceptForLast(Instruction* inst_1, Instruction* inst_2) {
assert(inst_1->opcode() == inst_2->opcode() &&
"Expecting the opcodes to be the same.");
assert((inst_1->opcode() == SpvOpCompositeInsert ||
inst_1->opcode() == SpvOpCompositeExtract) &&
"Instructions must be OpCompositeInsert or OpCompositeExtract.");
if (inst_1->NumInOperands() != inst_2->NumInOperands()) {
return false;
}
uint32_t first_index_position =
(inst_1->opcode() == SpvOpCompositeInsert ? 2 : 1);
for (uint32_t i = first_index_position; i < inst_1->NumInOperands() - 1;
i++) {
if (inst_1->GetSingleWordInOperand(i) !=
inst_2->GetSingleWordInOperand(i)) {
return false;
}
}
return true;
}
// If the OpCompositeConstruct is simply putting back together elements that
// where extracted from the same source, we can simply reuse the source.
//
@ -1653,19 +1704,24 @@ bool CompositeExtractFeedingConstruct(
// - extractions
// - extracting the same position they are inserting
// - all extract from the same id.
Instruction* first_element_inst = nullptr;
for (uint32_t i = 0; i < inst->NumInOperands(); ++i) {
const uint32_t element_id = inst->GetSingleWordInOperand(i);
Instruction* element_inst = def_use_mgr->GetDef(element_id);
if (first_element_inst == nullptr) {
first_element_inst = element_inst;
}
if (element_inst->opcode() != SpvOpCompositeExtract) {
return false;
}
if (element_inst->NumInOperands() != 2) {
if (!HaveSameIndexesExceptForLast(element_inst, first_element_inst)) {
return false;
}
if (element_inst->GetSingleWordInOperand(1) != i) {
if (element_inst->GetSingleWordInOperand(element_inst->NumInOperands() -
1) != i) {
return false;
}
@ -1681,13 +1737,31 @@ bool CompositeExtractFeedingConstruct(
// The last check it to see that the object being extracted from is the
// correct type.
Instruction* original_inst = def_use_mgr->GetDef(original_id);
if (original_inst->type_id() != inst->type_id()) {
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* original_type =
GetElementType(original_inst->type_id(), first_element_inst->begin() + 3,
first_element_inst->end() - 1, type_mgr);
if (original_type == nullptr) {
return false;
}
// Simplify by using the original object.
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
if (inst->type_id() != type_mgr->GetId(original_type)) {
return false;
}
if (first_element_inst->NumInOperands() == 2) {
// Simplify by using the original object.
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {original_id}}});
return true;
}
// Copies the original id and all indexes except for the last to the new
// extract instruction.
inst->SetOpcode(SpvOpCompositeExtract);
inst->SetInOperands(std::vector<Operand>(first_element_inst->begin() + 2,
first_element_inst->end() - 1));
return true;
}
@ -1891,6 +1965,139 @@ FoldingRule FMixFeedingExtract() {
};
}
// Returns the number of elements in the composite type |type|. Returns 0 if
// |type| is a scalar value.
uint32_t GetNumberOfElements(const analysis::Type* type) {
if (auto* vector_type = type->AsVector()) {
return vector_type->element_count();
}
if (auto* matrix_type = type->AsMatrix()) {
return matrix_type->element_count();
}
if (auto* struct_type = type->AsStruct()) {
return static_cast<uint32_t>(struct_type->element_types().size());
}
if (auto* array_type = type->AsArray()) {
return array_type->length_info().words[0];
}
return 0;
}
// Returns a map with the set of values that were inserted into an object by
// the chain of OpCompositeInsertInstruction starting with |inst|.
// The map will map the index to the value inserted at that index.
std::map<uint32_t, uint32_t> GetInsertedValues(Instruction* inst) {
analysis::DefUseManager* def_use_mgr = inst->context()->get_def_use_mgr();
std::map<uint32_t, uint32_t> values_inserted;
Instruction* current_inst = inst;
while (current_inst->opcode() == SpvOpCompositeInsert) {
if (current_inst->NumInOperands() > inst->NumInOperands()) {
// This is the catch the case
// %2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0
// %3 = OpCompositeInsert %m2x2int %int_4 %2 0 0
// %4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1
// In this case we cannot do a single construct to get the matrix.
uint32_t partially_inserted_element_index =
current_inst->GetSingleWordInOperand(inst->NumInOperands() - 1);
if (values_inserted.count(partially_inserted_element_index) == 0)
return {};
}
if (HaveSameIndexesExceptForLast(inst, current_inst)) {
values_inserted.insert(
{current_inst->GetSingleWordInOperand(current_inst->NumInOperands() -
1),
current_inst->GetSingleWordInOperand(kInsertObjectIdInIdx)});
}
current_inst = def_use_mgr->GetDef(
current_inst->GetSingleWordInOperand(kInsertCompositeIdInIdx));
}
return values_inserted;
}
// Returns true of there is an entry in |values_inserted| for every element of
// |Type|.
bool DoInsertedValuesCoverEntireObject(
const analysis::Type* type, std::map<uint32_t, uint32_t>& values_inserted) {
uint32_t container_size = GetNumberOfElements(type);
if (container_size != values_inserted.size()) {
return false;
}
if (values_inserted.rbegin()->first >= container_size) {
return false;
}
return true;
}
// Returns the type of the element that immediately contains the element being
// inserted by the OpCompositeInsert instruction |inst|.
const analysis::Type* GetContainerType(Instruction* inst) {
assert(inst->opcode() == SpvOpCompositeInsert);
analysis::TypeManager* type_mgr = inst->context()->get_type_mgr();
return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1,
type_mgr);
}
// Returns an OpCompositeConstruct instruction that build an object with
// |type_id| out of the values in |values_inserted|. Each value will be
// placed at the index corresponding to the value. The new instruction will
// be placed before |insert_before|.
Instruction* BuildCompositeConstruct(
uint32_t type_id, const std::map<uint32_t, uint32_t>& values_inserted,
Instruction* insert_before) {
InstructionBuilder ir_builder(
insert_before->context(), insert_before,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
std::vector<uint32_t> ids_in_order;
for (auto it : values_inserted) {
ids_in_order.push_back(it.second);
}
Instruction* construct =
ir_builder.AddCompositeConstruct(type_id, ids_in_order);
return construct;
}
// Replaces the OpCompositeInsert |inst| that inserts |construct| into the same
// object as |inst| with final index removed. If the resulting
// OpCompositeInsert instruction would have no remaining indexes, the
// instruction is replaced with an OpCopyObject instead.
void InsertConstructedObject(Instruction* inst, const Instruction* construct) {
if (inst->NumInOperands() == 3) {
inst->SetOpcode(SpvOpCopyObject);
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {construct->result_id()}}});
} else {
inst->SetInOperand(kInsertObjectIdInIdx, {construct->result_id()});
inst->RemoveOperand(inst->NumOperands() - 1);
}
}
// Replaces a series of |OpCompositeInsert| instruction that cover the entire
// object with an |OpCompositeConstruct|.
bool CompositeInsertToCompositeConstruct(
IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == SpvOpCompositeInsert &&
"Wrong opcode. Should be OpCompositeInsert.");
if (inst->NumInOperands() < 3) return false;
std::map<uint32_t, uint32_t> values_inserted = GetInsertedValues(inst);
const analysis::Type* container_type = GetContainerType(inst);
if (container_type == nullptr) {
return false;
}
if (!DoInsertedValuesCoverEntireObject(container_type, values_inserted)) {
return false;
}
analysis::TypeManager* type_mgr = context->get_type_mgr();
Instruction* construct = BuildCompositeConstruct(
type_mgr->GetId(container_type), values_inserted, inst);
InsertConstructedObject(inst, construct);
return true;
}
FoldingRule RedundantPhi() {
// An OpPhi instruction where all values are the same or the result of the phi
// itself, can be replaced by the value itself.
@ -2591,6 +2798,8 @@ void FoldingRules::AddFoldingRules() {
rules_[SpvOpCompositeExtract].push_back(VectorShuffleFeedingExtract());
rules_[SpvOpCompositeExtract].push_back(FMixFeedingExtract());
rules_[SpvOpCompositeInsert].push_back(CompositeInsertToCompositeConstruct);
rules_[SpvOpDot].push_back(DotProductDoingExtract());
rules_[SpvOpEntryPoint].push_back(RemoveRedundantOperands());

View File

@ -147,6 +147,7 @@ OpName %main "main"
%v2double = OpTypeVector %double 2
%v2half = OpTypeVector %half 2
%v2bool = OpTypeVector %bool 2
%m2x2int = OpTypeMatrix %v2int 2
%struct_v2int_int_int = OpTypeStruct %v2int %int %int
%_ptr_int = OpTypePointer Function %int
%_ptr_uint = OpTypePointer Function %uint
@ -218,7 +219,9 @@ OpName %main "main"
%struct_v2int_int_int_null = OpConstantNull %struct_v2int_int_int
%v2int_null = OpConstantNull %v2int
%102 = OpConstantComposite %v2int %103 %103
%v4int_undef = OpUndef %v4int
%v4int_0_0_0_0 = OpConstantComposite %v4int %int_0 %int_0 %int_0 %int_0
%m2x2int_undef = OpUndef %m2x2int
%struct_undef_0_0 = OpConstantComposite %struct_v2int_int_int %v2int_undef %int_0 %int_0
%float_n1 = OpConstant %float -1
%104 = OpConstant %float 0 ; Need a def with an numerical id to define id maps.
@ -6862,7 +6865,7 @@ INSTANTIATE_TEST_SUITE_P(SelectFoldingTest, MatchingInstructionFoldingTest,
4, true)
));
INSTANTIATE_TEST_SUITE_P(CompositeExtractMatchingTest, MatchingInstructionFoldingTest,
INSTANTIATE_TEST_SUITE_P(CompositeExtractOrInsertMatchingTest, MatchingInstructionFoldingTest,
::testing::Values(
// Test case 0: Extracting from result of consecutive shuffles of differing
// size.
@ -7002,7 +7005,145 @@ INSTANTIATE_TEST_SUITE_P(CompositeExtractMatchingTest, MatchingInstructionFoldin
"%4 = OpCompositeExtract %int %3 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true)
4, true),
// Test case 8: Inserting every element of a vector turns into a composite construct.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
"; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
"; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
"; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
"; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
"; CHECK: %5 = OpCopyObject [[v4]] [[construct]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
"%3 = OpCompositeInsert %v4int %int_1 %2 1\n" +
"%4 = OpCompositeInsert %v4int %int_2 %3 2\n" +
"%5 = OpCompositeInsert %v4int %int_3 %4 3\n" +
"OpReturn\n" +
"OpFunctionEnd",
5, true),
// Test case 9: Inserting every element of a vector turns into a composite construct in a different order.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
"; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
"; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
"; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
"; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
"; CHECK: %5 = OpCopyObject [[v4]] [[construct]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
"%4 = OpCompositeInsert %v4int %int_2 %2 2\n" +
"%3 = OpCompositeInsert %v4int %int_1 %4 1\n" +
"%5 = OpCompositeInsert %v4int %int_3 %3 3\n" +
"OpReturn\n" +
"OpFunctionEnd",
5, true),
// Test case 10: Check multiple inserts to the same position are handled correctly.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK-DAG: [[v4:%\\w+]] = OpTypeVector [[int]] 4\n" +
"; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
"; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
"; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
"; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v4]] %100 [[int1]] [[int2]] [[int3]]\n" +
"; CHECK: %6 = OpCopyObject [[v4]] [[construct]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpCompositeInsert %v4int %100 %v4int_undef 0\n" +
"%3 = OpCompositeInsert %v4int %int_2 %2 2\n" +
"%4 = OpCompositeInsert %v4int %int_4 %3 1\n" +
"%5 = OpCompositeInsert %v4int %int_1 %4 1\n" +
"%6 = OpCompositeInsert %v4int %int_3 %5 3\n" +
"OpReturn\n" +
"OpFunctionEnd",
6, true),
// Test case 11: The last indexes are 0 and 1, but they have different first indexes. This should not be folded.
InstructionFoldingCase<bool>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpCompositeInsert %m2x2int %100 %m2x2int_undef 0 0\n" +
"%3 = OpCompositeInsert %m2x2int %int_1 %2 1 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, false),
// Test case 12: Don't fold when there is a partial insertion.
InstructionFoldingCase<bool>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpCompositeInsert %m2x2int %v2int_1_0 %m2x2int_undef 0\n" +
"%3 = OpCompositeInsert %m2x2int %int_4 %2 0 0\n" +
"%4 = OpCompositeInsert %m2x2int %v2int_2_3 %3 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, false),
// Test case 13: Insert into a column of a matrix
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK-DAG: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
"; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
"; CHECK-DAG: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
"; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
// We keep this insert in the chain. DeadInsertElimPass should remove it.
"; CHECK: [[insert:%\\w+]] = OpCompositeInsert [[m2x2]] %100 [[m2x2_undef]] 0 0\n" +
"; CHECK: [[construct:%\\w+]] = OpCompositeConstruct [[v2]] %100 [[int1]]\n" +
"; CHECK: %3 = OpCompositeInsert [[m2x2]] [[construct]] [[insert]] 0\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpCompositeInsert %m2x2int %100 %m2x2int_undef 0 0\n" +
"%3 = OpCompositeInsert %m2x2int %int_1 %2 0 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 14: Insert all elements of the matrix.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK-DAG: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
"; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
"; CHECK-DAG: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
"; CHECK-DAG: [[int1:%\\w+]] = OpConstant [[int]] 1\n" +
"; CHECK-DAG: [[int2:%\\w+]] = OpConstant [[int]] 2\n" +
"; CHECK-DAG: [[int3:%\\w+]] = OpConstant [[int]] 3\n" +
"; CHECK: [[c0:%\\w+]] = OpCompositeConstruct [[v2]] %100 [[int1]]\n" +
"; CHECK: [[c1:%\\w+]] = OpCompositeConstruct [[v2]] [[int2]] [[int3]]\n" +
"; CHECK: [[matrix:%\\w+]] = OpCompositeConstruct [[m2x2]] [[c0]] [[c1]]\n" +
"; CHECK: %5 = OpCopyObject [[m2x2]] [[matrix]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpCompositeConstruct %v2int %100 %int_1\n" +
"%3 = OpCompositeInsert %m2x2int %2 %m2x2int_undef 0\n" +
"%4 = OpCompositeInsert %m2x2int %int_2 %3 1 0\n" +
"%5 = OpCompositeInsert %m2x2int %int_3 %4 1 1\n" +
"OpReturn\n" +
"OpFunctionEnd",
5, true),
// Test case 15: Replace construct with extract when reconstructing a member
// of another object.
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[v2:%\\w+]] = OpTypeVector [[int]] 2\n" +
"; CHECK: [[m2x2:%\\w+]] = OpTypeMatrix [[v2]] 2\n" +
"; CHECK: [[m2x2_undef:%\\w+]] = OpUndef [[m2x2]]\n" +
"; CHECK: %5 = OpCompositeExtract [[v2]] [[m2x2_undef]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%3 = OpCompositeExtract %int %m2x2int_undef 1 0\n" +
"%4 = OpCompositeExtract %int %m2x2int_undef 1 1\n" +
"%5 = OpCompositeConstruct %v2int %3 %4\n" +
"OpReturn\n" +
"OpFunctionEnd",
5, true)
));
INSTANTIATE_TEST_SUITE_P(DotProductMatchingTest, MatchingInstructionFoldingTest,