spirv-opt: add pass for interface variable scalar replacement (#4779)

Replace shader's stage variables whose types are array or matrix
with scalars/vectors.
For example,
```
Before:
  %foo = OpVariable %_ptr_Output__arr_v2float_uint_4 Output
After:
  %foo = OpVariable %_ptr_Output_v2float Output
  %foo_0 = OpVariable %_ptr_Output_v2float Output
  %foo_1 = OpVariable %_ptr_Output_v2float Output
  %foo_2 = OpVariable %_ptr_Output_v2float Output
```
This commit is contained in:
Jaebaek Seo 2022-05-09 14:04:52 -04:00 committed by GitHub
parent ffc8f2d455
commit ad3514b732
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1795 additions and 0 deletions

View File

@ -128,6 +128,7 @@ SPVTOOLS_OPT_SRC_FILES := \
source/opt/instruction.cpp \
source/opt/instruction_list.cpp \
source/opt/instrument_pass.cpp \
source/opt/interface_var_sroa.cpp \
source/opt/interp_fixup_pass.cpp \
source/opt/ir_context.cpp \
source/opt/ir_loader.cpp \

View File

@ -667,6 +667,8 @@ static_library("spvtools_opt") {
"source/opt/instruction_list.h",
"source/opt/instrument_pass.cpp",
"source/opt/instrument_pass.h",
"source/opt/interface_var_sroa.cpp",
"source/opt/interface_var_sroa.h",
"source/opt/interp_fixup_pass.cpp",
"source/opt/interp_fixup_pass.h",
"source/opt/ir_builder.h",

View File

@ -903,6 +903,11 @@ Optimizer::PassToken CreateConvertToSampledImagePass(
const std::vector<opt::DescriptorSetAndBinding>&
descriptor_set_binding_pairs);
// Create an interface-variable-scalar-replacement pass that replaces array or
// matrix interface variables with a series of scalar or vector interface
// variables. For example, it replaces `float3 foo[2]` with `float3 foo0, foo1`.
Optimizer::PassToken CreateInterfaceVariableScalarReplacementPass();
// Creates a remove-dont-inline pass to remove the |DontInline| function control
// from every function in the module. This is useful if you want the inliner to
// inline these functions some reason.

View File

@ -68,6 +68,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
instruction.h
instruction_list.h
instrument_pass.h
interface_var_sroa.h
interp_fixup_pass.h
ir_builder.h
ir_context.h
@ -182,6 +183,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
instruction.cpp
instruction_list.cpp
instrument_pass.cpp
interface_var_sroa.cpp
interp_fixup_pass.cpp
ir_context.cpp
ir_loader.cpp

View File

@ -0,0 +1,964 @@
// Copyright (c) 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/opt/interface_var_sroa.h"
#include <iostream>
#include "source/opt/decoration_manager.h"
#include "source/opt/def_use_manager.h"
#include "source/opt/function.h"
#include "source/opt/log.h"
#include "source/opt/type_manager.h"
#include "source/util/make_unique.h"
const static uint32_t kOpDecorateDecorationInOperandIndex = 1;
const static uint32_t kOpDecorateLiteralInOperandIndex = 2;
const static uint32_t kOpEntryPointInOperandInterface = 3;
const static uint32_t kOpVariableStorageClassInOperandIndex = 0;
const static uint32_t kOpTypeArrayElemTypeInOperandIndex = 0;
const static uint32_t kOpTypeArrayLengthInOperandIndex = 1;
const static uint32_t kOpTypeMatrixColCountInOperandIndex = 1;
const static uint32_t kOpTypeMatrixColTypeInOperandIndex = 0;
const static uint32_t kOpTypePtrTypeInOperandIndex = 1;
const static uint32_t kOpConstantValueInOperandIndex = 0;
namespace spvtools {
namespace opt {
namespace {
// Get the length of the OpTypeArray |array_type|.
uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr,
Instruction* array_type) {
assert(array_type->opcode() == SpvOpTypeArray);
uint32_t const_int_id =
array_type->GetSingleWordInOperand(kOpTypeArrayLengthInOperandIndex);
Instruction* array_length_inst = def_use_mgr->GetDef(const_int_id);
assert(array_length_inst->opcode() == SpvOpConstant);
return array_length_inst->GetSingleWordInOperand(
kOpConstantValueInOperandIndex);
}
// Get the element type instruction of the OpTypeArray |array_type|.
Instruction* GetArrayElementType(analysis::DefUseManager* def_use_mgr,
Instruction* array_type) {
assert(array_type->opcode() == SpvOpTypeArray);
uint32_t elem_type_id =
array_type->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
return def_use_mgr->GetDef(elem_type_id);
}
// Get the column type instruction of the OpTypeMatrix |matrix_type|.
Instruction* GetMatrixColumnType(analysis::DefUseManager* def_use_mgr,
Instruction* matrix_type) {
assert(matrix_type->opcode() == SpvOpTypeMatrix);
uint32_t column_type_id =
matrix_type->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
return def_use_mgr->GetDef(column_type_id);
}
// Traverses the component type of OpTypeArray or OpTypeMatrix. Repeats it
// |depth_to_component| times recursively and returns the component type.
// |type_id| is the result id of the OpTypeArray or OpTypeMatrix instruction.
uint32_t GetComponentTypeOfArrayMatrix(analysis::DefUseManager* def_use_mgr,
uint32_t type_id,
uint32_t depth_to_component) {
if (depth_to_component == 0) return type_id;
Instruction* type_inst = def_use_mgr->GetDef(type_id);
if (type_inst->opcode() == SpvOpTypeArray) {
uint32_t elem_type_id =
type_inst->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
return GetComponentTypeOfArrayMatrix(def_use_mgr, elem_type_id,
depth_to_component - 1);
}
assert(type_inst->opcode() == SpvOpTypeMatrix);
uint32_t column_type_id =
type_inst->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
return GetComponentTypeOfArrayMatrix(def_use_mgr, column_type_id,
depth_to_component - 1);
}
// Creates an OpDecorate instruction whose Target is |var_id| and Decoration is
// |decoration|. Adds |literal| as an extra operand of the instruction.
void CreateDecoration(analysis::DecorationManager* decoration_mgr,
uint32_t var_id, SpvDecoration decoration,
uint32_t literal) {
std::vector<Operand> operands({
{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
{spv_operand_type_t::SPV_OPERAND_TYPE_DECORATION,
{static_cast<uint32_t>(decoration)}},
{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}},
});
decoration_mgr->AddDecoration(SpvOpDecorate, std::move(operands));
}
// Replaces load instructions with composite construct instructions in all the
// users of the loads. |loads_to_composites| is the mapping from each load to
// its corresponding OpCompositeConstruct.
void ReplaceLoadWithCompositeConstruct(
IRContext* context,
const std::unordered_map<Instruction*, Instruction*>& loads_to_composites) {
for (const auto& load_and_composite : loads_to_composites) {
Instruction* load = load_and_composite.first;
Instruction* composite_construct = load_and_composite.second;
std::vector<Instruction*> users;
context->get_def_use_mgr()->ForEachUse(
load, [&users, composite_construct](Instruction* user, uint32_t index) {
user->GetOperand(index).words[0] = composite_construct->result_id();
users.push_back(user);
});
for (Instruction* user : users)
context->get_def_use_mgr()->AnalyzeInstUse(user);
}
}
// Returns the storage class of the instruction |var|.
SpvStorageClass GetStorageClass(Instruction* var) {
return static_cast<SpvStorageClass>(
var->GetSingleWordInOperand(kOpVariableStorageClassInOperandIndex));
}
} // namespace
bool InterfaceVariableScalarReplacement::HasExtraArrayness(
Instruction& entry_point, Instruction* var) {
SpvExecutionModel execution_model =
static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
if (execution_model != SpvExecutionModelTessellationEvaluation &&
execution_model != SpvExecutionModelTessellationControl) {
return false;
}
if (!context()->get_decoration_mgr()->HasDecoration(var->result_id(),
SpvDecorationPatch)) {
if (execution_model == SpvExecutionModelTessellationControl) return true;
return GetStorageClass(var) != SpvStorageClassOutput;
}
return false;
}
bool InterfaceVariableScalarReplacement::
CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var,
bool has_extra_arrayness) {
if (has_extra_arrayness) {
return !ReportErrorIfHasNoExtraArraynessForOtherEntry(interface_var);
}
return !ReportErrorIfHasExtraArraynessForOtherEntry(interface_var);
}
bool InterfaceVariableScalarReplacement::GetVariableLocation(
Instruction* var, uint32_t* location) {
return !context()->get_decoration_mgr()->WhileEachDecoration(
var->result_id(), SpvDecorationLocation,
[location](const Instruction& inst) {
*location =
inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
return false;
});
}
bool InterfaceVariableScalarReplacement::GetVariableComponent(
Instruction* var, uint32_t* component) {
return !context()->get_decoration_mgr()->WhileEachDecoration(
var->result_id(), SpvDecorationComponent,
[component](const Instruction& inst) {
*component =
inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
return false;
});
}
std::vector<Instruction*>
InterfaceVariableScalarReplacement::CollectInterfaceVariables(
Instruction& entry_point) {
std::vector<Instruction*> interface_vars;
for (uint32_t i = kOpEntryPointInOperandInterface;
i < entry_point.NumInOperands(); ++i) {
Instruction* interface_var = context()->get_def_use_mgr()->GetDef(
entry_point.GetSingleWordInOperand(i));
assert(interface_var->opcode() == SpvOpVariable);
SpvStorageClass storage_class = GetStorageClass(interface_var);
if (storage_class != SpvStorageClassInput &&
storage_class != SpvStorageClassOutput) {
continue;
}
interface_vars.push_back(interface_var);
}
return interface_vars;
}
void InterfaceVariableScalarReplacement::KillInstructionAndUsers(
Instruction* inst) {
if (inst->opcode() == SpvOpEntryPoint) {
return;
}
if (inst->opcode() != SpvOpAccessChain) {
context()->KillInst(inst);
return;
}
context()->get_def_use_mgr()->ForEachUser(
inst, [this](Instruction* user) { KillInstructionAndUsers(user); });
context()->KillInst(inst);
}
void InterfaceVariableScalarReplacement::KillInstructionsAndUsers(
const std::vector<Instruction*>& insts) {
for (Instruction* inst : insts) {
KillInstructionAndUsers(inst);
}
}
void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations(
uint32_t var_id) {
context()->get_decoration_mgr()->RemoveDecorationsFrom(
var_id, [](const Instruction& inst) {
uint32_t decoration =
inst.GetSingleWordInOperand(kOpDecorateDecorationInOperandIndex);
return decoration == SpvDecorationLocation ||
decoration == SpvDecorationComponent;
});
}
bool InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars(
Instruction* interface_var, Instruction* interface_var_type,
uint32_t location, uint32_t component, uint32_t extra_array_length) {
NestedCompositeComponents scalar_interface_vars =
CreateScalarInterfaceVarsForReplacement(interface_var_type,
GetStorageClass(interface_var),
extra_array_length);
AddLocationAndComponentDecorations(scalar_interface_vars, &location,
component);
KillLocationAndComponentDecorations(interface_var->result_id());
if (!ReplaceInterfaceVarWith(interface_var, extra_array_length,
scalar_interface_vars)) {
return false;
}
context()->KillInst(interface_var);
return true;
}
bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
Instruction* interface_var, uint32_t extra_array_length,
const NestedCompositeComponents& scalar_interface_vars) {
std::vector<Instruction*> users;
context()->get_def_use_mgr()->ForEachUser(
interface_var, [&users](Instruction* user) { users.push_back(user); });
std::vector<uint32_t> interface_var_component_indices;
std::unordered_map<Instruction*, Instruction*> loads_to_composites;
std::unordered_map<Instruction*, Instruction*>
loads_for_access_chain_to_composites;
if (extra_array_length != 0) {
// Note that the extra arrayness is the first dimension of the array
// interface variable.
for (uint32_t index = 0; index < extra_array_length; ++index) {
std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
if (!ReplaceComponentsOfInterfaceVarWith(
interface_var, users, scalar_interface_vars,
interface_var_component_indices, &index,
&loads_to_component_values,
&loads_for_access_chain_to_composites)) {
return false;
}
AddComponentsToCompositesForLoads(loads_to_component_values,
&loads_to_composites, 0);
}
} else if (!ReplaceComponentsOfInterfaceVarWith(
interface_var, users, scalar_interface_vars,
interface_var_component_indices, nullptr, &loads_to_composites,
&loads_for_access_chain_to_composites)) {
return false;
}
ReplaceLoadWithCompositeConstruct(context(), loads_to_composites);
ReplaceLoadWithCompositeConstruct(context(),
loads_for_access_chain_to_composites);
KillInstructionsAndUsers(users);
return true;
}
void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations(
const NestedCompositeComponents& vars, uint32_t* location,
uint32_t component) {
if (!vars.HasMultipleComponents()) {
uint32_t var_id = vars.GetComponentVariable()->result_id();
CreateDecoration(context()->get_decoration_mgr(), var_id,
SpvDecorationLocation, *location);
CreateDecoration(context()->get_decoration_mgr(), var_id,
SpvDecorationComponent, component);
++(*location);
return;
}
for (const auto& var : vars.GetComponents()) {
AddLocationAndComponentDecorations(var, location, component);
}
}
bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
Instruction* interface_var,
const std::vector<Instruction*>& interface_var_users,
const NestedCompositeComponents& scalar_interface_vars,
std::vector<uint32_t>& interface_var_component_indices,
const uint32_t* extra_array_index,
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
std::unordered_map<Instruction*, Instruction*>*
loads_for_access_chain_to_composites) {
if (!scalar_interface_vars.HasMultipleComponents()) {
for (Instruction* interface_var_user : interface_var_users) {
if (!ReplaceComponentOfInterfaceVarWith(
interface_var, interface_var_user,
scalar_interface_vars.GetComponentVariable(),
interface_var_component_indices, extra_array_index,
loads_to_composites, loads_for_access_chain_to_composites)) {
return false;
}
}
return true;
}
return ReplaceMultipleComponentsOfInterfaceVarWith(
interface_var, interface_var_users, scalar_interface_vars.GetComponents(),
interface_var_component_indices, extra_array_index, loads_to_composites,
loads_for_access_chain_to_composites);
}
bool InterfaceVariableScalarReplacement::
ReplaceMultipleComponentsOfInterfaceVarWith(
Instruction* interface_var,
const std::vector<Instruction*>& interface_var_users,
const std::vector<NestedCompositeComponents>& components,
std::vector<uint32_t>& interface_var_component_indices,
const uint32_t* extra_array_index,
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
std::unordered_map<Instruction*, Instruction*>*
loads_for_access_chain_to_composites) {
for (uint32_t i = 0; i < components.size(); ++i) {
interface_var_component_indices.push_back(i);
std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
std::unordered_map<Instruction*, Instruction*>
loads_for_access_chain_to_component_values;
if (!ReplaceComponentsOfInterfaceVarWith(
interface_var, interface_var_users, components[i],
interface_var_component_indices, extra_array_index,
&loads_to_component_values,
&loads_for_access_chain_to_component_values)) {
return false;
}
interface_var_component_indices.pop_back();
uint32_t depth_to_component =
static_cast<uint32_t>(interface_var_component_indices.size());
AddComponentsToCompositesForLoads(
loads_for_access_chain_to_component_values,
loads_for_access_chain_to_composites, depth_to_component);
if (extra_array_index) ++depth_to_component;
AddComponentsToCompositesForLoads(loads_to_component_values,
loads_to_composites, depth_to_component);
}
return true;
}
bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
Instruction* interface_var, Instruction* interface_var_user,
Instruction* scalar_var,
const std::vector<uint32_t>& interface_var_component_indices,
const uint32_t* extra_array_index,
std::unordered_map<Instruction*, Instruction*>* loads_to_component_values,
std::unordered_map<Instruction*, Instruction*>*
loads_for_access_chain_to_component_values) {
SpvOp opcode = interface_var_user->opcode();
if (opcode == SpvOpStore) {
uint32_t value_id = interface_var_user->GetSingleWordInOperand(1);
StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices,
scalar_var, extra_array_index,
interface_var_user);
return true;
}
if (opcode == SpvOpLoad) {
Instruction* scalar_load =
LoadScalarVar(scalar_var, extra_array_index, interface_var_user);
loads_to_component_values->insert({interface_var_user, scalar_load});
return true;
}
// Copy OpName and annotation instructions only once. Therefore, we create
// them only for the first element of the extra array.
if (extra_array_index && *extra_array_index != 0) return true;
if (opcode == SpvOpDecorateId || opcode == SpvOpDecorateString ||
opcode == SpvOpDecorate) {
CloneAnnotationForVariable(interface_var_user, scalar_var->result_id());
return true;
}
if (opcode == SpvOpName) {
std::unique_ptr<Instruction> new_inst(interface_var_user->Clone(context()));
new_inst->SetInOperand(0, {scalar_var->result_id()});
context()->AddDebug2Inst(std::move(new_inst));
return true;
}
if (opcode == SpvOpEntryPoint) {
return ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user,
scalar_var->result_id());
}
if (opcode == SpvOpAccessChain) {
ReplaceAccessChainWith(interface_var_user, interface_var_component_indices,
scalar_var,
loads_for_access_chain_to_component_values);
return true;
}
std::string message("Unhandled instruction");
message += "\n " + interface_var_user->PrettyPrint(
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
message +=
"\nfor interface variable scalar replacement\n " +
interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
return false;
}
void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain(
Instruction* access_chain, Instruction* base_access_chain) {
assert(base_access_chain->opcode() == SpvOpAccessChain &&
access_chain->opcode() == SpvOpAccessChain &&
access_chain->GetSingleWordInOperand(0) ==
base_access_chain->result_id());
Instruction::OperandList new_operands;
for (uint32_t i = 0; i < base_access_chain->NumInOperands(); ++i) {
new_operands.emplace_back(base_access_chain->GetInOperand(i));
}
for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
new_operands.emplace_back(access_chain->GetInOperand(i));
}
access_chain->SetInOperands(std::move(new_operands));
}
Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar(
uint32_t var_type_id, Instruction* var,
const std::vector<uint32_t>& index_ids, Instruction* insert_before,
uint32_t* component_type_id) {
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
*component_type_id = GetComponentTypeOfArrayMatrix(
def_use_mgr, var_type_id, static_cast<uint32_t>(index_ids.size()));
uint32_t ptr_type_id =
GetPointerType(*component_type_id, GetStorageClass(var));
std::unique_ptr<Instruction> new_access_chain(
new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(),
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
for (uint32_t index_id : index_ids) {
new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}});
}
Instruction* inst = new_access_chain.get();
def_use_mgr->AnalyzeInstDefUse(inst);
insert_before->InsertBefore(std::move(new_access_chain));
return inst;
}
Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex(
uint32_t component_type_id, Instruction* var, uint32_t index,
Instruction* insert_before) {
uint32_t ptr_type_id =
GetPointerType(component_type_id, GetStorageClass(var));
uint32_t index_id = context()->get_constant_mgr()->GetUIntConst(index);
std::unique_ptr<Instruction> new_access_chain(
new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(),
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_ID, {var->result_id()}},
{SPV_OPERAND_TYPE_ID, {index_id}},
}));
Instruction* inst = new_access_chain.get();
context()->get_def_use_mgr()->AnalyzeInstDefUse(inst);
insert_before->InsertBefore(std::move(new_access_chain));
return inst;
}
void InterfaceVariableScalarReplacement::ReplaceAccessChainWith(
Instruction* access_chain,
const std::vector<uint32_t>& interface_var_component_indices,
Instruction* scalar_var,
std::unordered_map<Instruction*, Instruction*>* loads_to_component_values) {
std::vector<uint32_t> indexes;
for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
indexes.push_back(access_chain->GetSingleWordInOperand(i));
}
// Note that we have a strong assumption that |access_chain| has only a single
// index that is for the extra arrayness.
context()->get_def_use_mgr()->ForEachUser(
access_chain,
[this, access_chain, &indexes, &interface_var_component_indices,
scalar_var, loads_to_component_values](Instruction* user) {
switch (user->opcode()) {
case SpvOpAccessChain: {
UseBaseAccessChainForAccessChain(user, access_chain);
ReplaceAccessChainWith(user, interface_var_component_indices,
scalar_var, loads_to_component_values);
return;
}
case SpvOpStore: {
uint32_t value_id = user->GetSingleWordInOperand(1);
StoreComponentOfValueToAccessChainToScalarVar(
value_id, interface_var_component_indices, scalar_var, indexes,
user);
return;
}
case SpvOpLoad: {
Instruction* value =
LoadAccessChainToVar(scalar_var, indexes, user);
loads_to_component_values->insert({user, value});
return;
}
default:
break;
}
});
}
void InterfaceVariableScalarReplacement::CloneAnnotationForVariable(
Instruction* annotation_inst, uint32_t var_id) {
assert(annotation_inst->opcode() == SpvOpDecorate ||
annotation_inst->opcode() == SpvOpDecorateId ||
annotation_inst->opcode() == SpvOpDecorateString);
std::unique_ptr<Instruction> new_inst(annotation_inst->Clone(context()));
new_inst->SetInOperand(0, {var_id});
context()->AddAnnotationInst(std::move(new_inst));
}
bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarInEntryPoint(
Instruction* interface_var, Instruction* entry_point,
uint32_t scalar_var_id) {
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
uint32_t interface_var_id = interface_var->result_id();
if (interface_vars_removed_from_entry_point_operands_.find(
interface_var_id) !=
interface_vars_removed_from_entry_point_operands_.end()) {
entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {scalar_var_id}});
def_use_mgr->AnalyzeInstUse(entry_point);
return true;
}
bool success = !entry_point->WhileEachInId(
[&interface_var_id, &scalar_var_id](uint32_t* id) {
if (*id == interface_var_id) {
*id = scalar_var_id;
return false;
}
return true;
});
if (!success) {
std::string message(
"interface variable is not an operand of the entry point");
message += "\n " + interface_var->PrettyPrint(
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
message += "\n " + entry_point->PrettyPrint(
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
return false;
}
def_use_mgr->AnalyzeInstUse(entry_point);
interface_vars_removed_from_entry_point_operands_.insert(interface_var_id);
return true;
}
uint32_t InterfaceVariableScalarReplacement::GetPointeeTypeIdOfVar(
Instruction* var) {
assert(var->opcode() == SpvOpVariable);
uint32_t ptr_type_id = var->type_id();
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_type_id);
assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
"Variable must have a pointer type.");
return ptr_type_inst->GetSingleWordInOperand(kOpTypePtrTypeInOperandIndex);
}
void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar(
uint32_t value_id, const std::vector<uint32_t>& component_indices,
Instruction* scalar_var, const uint32_t* extra_array_index,
Instruction* insert_before) {
uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
Instruction* ptr = scalar_var;
if (extra_array_index) {
auto* ty_mgr = context()->get_type_mgr();
analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
assert(array_type != nullptr);
component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
*extra_array_index, insert_before);
}
StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
extra_array_index, insert_before);
}
Instruction* InterfaceVariableScalarReplacement::LoadScalarVar(
Instruction* scalar_var, const uint32_t* extra_array_index,
Instruction* insert_before) {
uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
Instruction* ptr = scalar_var;
if (extra_array_index) {
auto* ty_mgr = context()->get_type_mgr();
analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
assert(array_type != nullptr);
component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
*extra_array_index, insert_before);
}
return CreateLoad(component_type_id, ptr, insert_before);
}
Instruction* InterfaceVariableScalarReplacement::CreateLoad(
uint32_t type_id, Instruction* ptr, Instruction* insert_before) {
std::unique_ptr<Instruction> load(
new Instruction(context(), SpvOpLoad, type_id, TakeNextId(),
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_ID, {ptr->result_id()}}}));
Instruction* load_inst = load.get();
context()->get_def_use_mgr()->AnalyzeInstDefUse(load_inst);
insert_before->InsertBefore(std::move(load));
return load_inst;
}
void InterfaceVariableScalarReplacement::StoreComponentOfValueTo(
uint32_t component_type_id, uint32_t value_id,
const std::vector<uint32_t>& component_indices, Instruction* ptr,
const uint32_t* extra_array_index, Instruction* insert_before) {
std::unique_ptr<Instruction> composite_extract(CreateCompositeExtract(
component_type_id, value_id, component_indices, extra_array_index));
std::unique_ptr<Instruction> new_store(
new Instruction(context(), SpvOpStore));
new_store->AddOperand({SPV_OPERAND_TYPE_ID, {ptr->result_id()}});
new_store->AddOperand(
{SPV_OPERAND_TYPE_ID, {composite_extract->result_id()}});
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
def_use_mgr->AnalyzeInstDefUse(composite_extract.get());
def_use_mgr->AnalyzeInstDefUse(new_store.get());
insert_before->InsertBefore(std::move(composite_extract));
insert_before->InsertBefore(std::move(new_store));
}
Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract(
uint32_t type_id, uint32_t composite_id,
const std::vector<uint32_t>& indexes, const uint32_t* extra_first_index) {
uint32_t component_id = TakeNextId();
Instruction* composite_extract = new Instruction(
context(), SpvOpCompositeExtract, type_id, component_id,
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {composite_id}}});
if (extra_first_index) {
composite_extract->AddOperand(
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {*extra_first_index}});
}
for (uint32_t index : indexes) {
composite_extract->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}});
}
return composite_extract;
}
void InterfaceVariableScalarReplacement::
StoreComponentOfValueToAccessChainToScalarVar(
uint32_t value_id, const std::vector<uint32_t>& component_indices,
Instruction* scalar_var,
const std::vector<uint32_t>& access_chain_indices,
Instruction* insert_before) {
uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
Instruction* ptr = scalar_var;
if (!access_chain_indices.empty()) {
ptr = CreateAccessChainToVar(component_type_id, scalar_var,
access_chain_indices, insert_before,
&component_type_id);
}
StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
nullptr, insert_before);
}
Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar(
Instruction* var, const std::vector<uint32_t>& indexes,
Instruction* insert_before) {
uint32_t component_type_id = GetPointeeTypeIdOfVar(var);
Instruction* ptr = var;
if (!indexes.empty()) {
ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before,
&component_type_id);
}
return CreateLoad(component_type_id, ptr, insert_before);
}
Instruction*
InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad(
Instruction* load, uint32_t depth_to_component) {
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
uint32_t type_id = load->type_id();
if (depth_to_component != 0) {
type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(),
depth_to_component);
}
uint32_t new_id = context()->TakeNextId();
std::unique_ptr<Instruction> new_composite_construct(
new Instruction(context(), SpvOpCompositeConstruct, type_id, new_id, {}));
Instruction* composite_construct = new_composite_construct.get();
def_use_mgr->AnalyzeInstDefUse(composite_construct);
// Insert |new_composite_construct| after |load|. When there are multiple
// recursive composite construct instructions for a load, we have to place the
// composite construct with a lower depth later because it constructs the
// composite that contains other composites with lower depths.
auto* insert_before = load->NextNode();
while (true) {
auto itr =
composite_ids_to_component_depths.find(insert_before->result_id());
if (itr == composite_ids_to_component_depths.end()) break;
if (itr->second <= depth_to_component) break;
insert_before = insert_before->NextNode();
}
insert_before->InsertBefore(std::move(new_composite_construct));
composite_ids_to_component_depths.insert({new_id, depth_to_component});
return composite_construct;
}
void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads(
const std::unordered_map<Instruction*, Instruction*>&
loads_to_component_values,
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
uint32_t depth_to_component) {
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
for (auto& load_and_component_vale : loads_to_component_values) {
Instruction* load = load_and_component_vale.first;
Instruction* component_value = load_and_component_vale.second;
Instruction* composite_construct = nullptr;
auto itr = loads_to_composites->find(load);
if (itr == loads_to_composites->end()) {
composite_construct =
CreateCompositeConstructForComponentOfLoad(load, depth_to_component);
loads_to_composites->insert({load, composite_construct});
} else {
composite_construct = itr->second;
}
composite_construct->AddOperand(
{SPV_OPERAND_TYPE_ID, {component_value->result_id()}});
def_use_mgr->AnalyzeInstDefUse(composite_construct);
}
}
uint32_t InterfaceVariableScalarReplacement::GetArrayType(
uint32_t elem_type_id, uint32_t array_length) {
analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id);
uint32_t array_length_id =
context()->get_constant_mgr()->GetUIntConst(array_length);
analysis::Array array_type(
elem_type,
analysis::Array::LengthInfo{array_length_id, {0, array_length}});
return context()->get_type_mgr()->GetTypeInstruction(&array_type);
}
uint32_t InterfaceVariableScalarReplacement::GetPointerType(
uint32_t type_id, SpvStorageClass storage_class) {
analysis::Type* type = context()->get_type_mgr()->GetType(type_id);
analysis::Pointer ptr_type(type, storage_class);
return context()->get_type_mgr()->GetTypeInstruction(&ptr_type);
}
InterfaceVariableScalarReplacement::NestedCompositeComponents
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray(
Instruction* interface_var_type, SpvStorageClass storage_class,
uint32_t extra_array_length) {
assert(interface_var_type->opcode() == SpvOpTypeArray);
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
uint32_t array_length = GetArrayLength(def_use_mgr, interface_var_type);
Instruction* elem_type = GetArrayElementType(def_use_mgr, interface_var_type);
NestedCompositeComponents scalar_vars;
while (array_length > 0) {
NestedCompositeComponents scalar_vars_for_element =
CreateScalarInterfaceVarsForReplacement(elem_type, storage_class,
extra_array_length);
scalar_vars.AddComponent(scalar_vars_for_element);
--array_length;
}
return scalar_vars;
}
InterfaceVariableScalarReplacement::NestedCompositeComponents
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix(
Instruction* interface_var_type, SpvStorageClass storage_class,
uint32_t extra_array_length) {
assert(interface_var_type->opcode() == SpvOpTypeMatrix);
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
uint32_t column_count = interface_var_type->GetSingleWordInOperand(
kOpTypeMatrixColCountInOperandIndex);
Instruction* column_type =
GetMatrixColumnType(def_use_mgr, interface_var_type);
NestedCompositeComponents scalar_vars;
while (column_count > 0) {
NestedCompositeComponents scalar_vars_for_column =
CreateScalarInterfaceVarsForReplacement(column_type, storage_class,
extra_array_length);
scalar_vars.AddComponent(scalar_vars_for_column);
--column_count;
}
return scalar_vars;
}
InterfaceVariableScalarReplacement::NestedCompositeComponents
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement(
Instruction* interface_var_type, SpvStorageClass storage_class,
uint32_t extra_array_length) {
// Handle array case.
if (interface_var_type->opcode() == SpvOpTypeArray) {
return CreateScalarInterfaceVarsForArray(interface_var_type, storage_class,
extra_array_length);
}
// Handle matrix case.
if (interface_var_type->opcode() == SpvOpTypeMatrix) {
return CreateScalarInterfaceVarsForMatrix(interface_var_type, storage_class,
extra_array_length);
}
// Handle scalar or vector case.
NestedCompositeComponents scalar_var;
uint32_t type_id = interface_var_type->result_id();
if (extra_array_length != 0) {
type_id = GetArrayType(type_id, extra_array_length);
}
uint32_t ptr_type_id =
context()->get_type_mgr()->FindPointerToType(type_id, storage_class);
uint32_t id = TakeNextId();
std::unique_ptr<Instruction> variable(
new Instruction(context(), SpvOpVariable, ptr_type_id, id,
std::initializer_list<Operand>{
{SPV_OPERAND_TYPE_STORAGE_CLASS,
{static_cast<uint32_t>(storage_class)}}}));
scalar_var.SetSingleComponentVariable(variable.get());
context()->AddGlobalValue(std::move(variable));
return scalar_var;
}
Instruction* InterfaceVariableScalarReplacement::GetTypeOfVariable(
Instruction* var) {
uint32_t pointee_type_id = GetPointeeTypeIdOfVar(var);
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
return def_use_mgr->GetDef(pointee_type_id);
}
Pass::Status InterfaceVariableScalarReplacement::Process() {
Pass::Status status = Status::SuccessWithoutChange;
for (Instruction& entry_point : get_module()->entry_points()) {
status =
CombineStatus(status, ReplaceInterfaceVarsWithScalars(entry_point));
}
return status;
}
bool InterfaceVariableScalarReplacement::
ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var) {
if (vars_with_extra_arrayness.find(var) == vars_with_extra_arrayness.end())
return false;
std::string message(
"A variable is arrayed for an entry point but it is not "
"arrayed for another entry point");
message +=
"\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
return true;
}
bool InterfaceVariableScalarReplacement::
ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var) {
if (vars_without_extra_arrayness.find(var) ==
vars_without_extra_arrayness.end())
return false;
std::string message(
"A variable is not arrayed for an entry point but it is "
"arrayed for another entry point");
message +=
"\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
return true;
}
Pass::Status
InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars(
Instruction& entry_point) {
std::vector<Instruction*> interface_vars =
CollectInterfaceVariables(entry_point);
Pass::Status status = Status::SuccessWithoutChange;
for (Instruction* interface_var : interface_vars) {
uint32_t location, component;
if (!GetVariableLocation(interface_var, &location)) continue;
if (!GetVariableComponent(interface_var, &component)) component = 0;
Instruction* interface_var_type = GetTypeOfVariable(interface_var);
uint32_t extra_array_length = 0;
if (HasExtraArrayness(entry_point, interface_var)) {
extra_array_length =
GetArrayLength(context()->get_def_use_mgr(), interface_var_type);
interface_var_type =
GetArrayElementType(context()->get_def_use_mgr(), interface_var_type);
vars_with_extra_arrayness.insert(interface_var);
} else {
vars_without_extra_arrayness.insert(interface_var);
}
if (!CheckExtraArraynessConflictBetweenEntries(interface_var,
extra_array_length != 0)) {
return Pass::Status::Failure;
}
if (interface_var_type->opcode() != SpvOpTypeArray &&
interface_var_type->opcode() != SpvOpTypeMatrix) {
continue;
}
if (!ReplaceInterfaceVariableWithScalars(interface_var, interface_var_type,
location, component,
extra_array_length)) {
return Pass::Status::Failure;
}
status = Pass::Status::SuccessWithChange;
}
return status;
}
} // namespace opt
} // namespace spvtools

View File

@ -0,0 +1,401 @@
// Copyright (c) 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_OPT_INTERFACE_VAR_SROA_H_
#define SOURCE_OPT_INTERFACE_VAR_SROA_H_
#include <unordered_set>
#include "source/opt/pass.h"
namespace spvtools {
namespace opt {
// See optimizer.hpp for documentation.
//
// Note that the current implementation of this pass covers only store, load,
// access chain instructions for the interface variables. Supporting other types
// of instructions is a future work.
class InterfaceVariableScalarReplacement : public Pass {
public:
InterfaceVariableScalarReplacement() {}
const char* name() const override {
return "interface-variable-scalar-replacement";
}
Status Process() override;
IRContext::Analysis GetPreservedAnalyses() override {
return IRContext::kAnalysisDecorations | IRContext::kAnalysisDefUse |
IRContext::kAnalysisConstants | IRContext::kAnalysisTypes;
}
private:
// A struct containing components of a composite variable. If the composite
// consists of multiple or recursive components, |component_variable| is
// nullptr and |nested_composite_components| keeps the components. If it has a
// single component, |nested_composite_components| is empty and
// |component_variable| is the component. Note that each element of
// |nested_composite_components| has the NestedCompositeComponents struct as
// its type that can recursively keep the components.
struct NestedCompositeComponents {
NestedCompositeComponents() : component_variable(nullptr) {}
bool HasMultipleComponents() const {
return !nested_composite_components.empty();
}
const std::vector<NestedCompositeComponents>& GetComponents() const {
return nested_composite_components;
}
void AddComponent(const NestedCompositeComponents& component) {
nested_composite_components.push_back(component);
}
Instruction* GetComponentVariable() const { return component_variable; }
void SetSingleComponentVariable(Instruction* var) {
component_variable = var;
}
private:
std::vector<NestedCompositeComponents> nested_composite_components;
Instruction* component_variable;
};
// Collects all interface variables used by the |entry_point|.
std::vector<Instruction*> CollectInterfaceVariables(Instruction& entry_point);
// Returns whether |var| has the extra arrayness for the entry point
// |entry_point| or not.
bool HasExtraArrayness(Instruction& entry_point, Instruction* var);
// Finds a Location BuiltIn decoration of |var| and returns it via
// |location|. Returns true whether the location exists or not.
bool GetVariableLocation(Instruction* var, uint32_t* location);
// Finds a Component BuiltIn decoration of |var| and returns it via
// |component|. Returns true whether the component exists or not.
bool GetVariableComponent(Instruction* var, uint32_t* component);
// Returns the interface variable instruction whose result id is
// |interface_var_id|.
Instruction* GetInterfaceVariable(uint32_t interface_var_id);
// Returns the type of |var| as an instruction.
Instruction* GetTypeOfVariable(Instruction* var);
// Replaces an interface variable |interface_var| whose type is
// |interface_var_type| with scalars and returns whether it succeeds or not.
// |location| is the value of Location Decoration for |interface_var|.
// |component| is the value of Component Decoration for |interface_var|.
// If |extra_array_length| is 0, it means |interface_var| has a Patch
// decoration. Otherwise, |extra_array_length| denotes the length of the extra
// array of |interface_var|.
bool ReplaceInterfaceVariableWithScalars(Instruction* interface_var,
Instruction* interface_var_type,
uint32_t location,
uint32_t component,
uint32_t extra_array_length);
// Creates scalar variables with the storage classe |storage_class| to replace
// an interface variable whose type is |interface_var_type|. If
// |extra_array_length| is not zero, adds the extra arrayness to the created
// scalar variables.
NestedCompositeComponents CreateScalarInterfaceVarsForReplacement(
Instruction* interface_var_type, SpvStorageClass storage_class,
uint32_t extra_array_length);
// Creates scalar variables with the storage classe |storage_class| to replace
// the interface variable whose type is OpTypeArray |interface_var_type| with.
// If |extra_array_length| is not zero, adds the extra arrayness to all the
// scalar variables.
NestedCompositeComponents CreateScalarInterfaceVarsForArray(
Instruction* interface_var_type, SpvStorageClass storage_class,
uint32_t extra_array_length);
// Creates scalar variables with the storage classe |storage_class| to replace
// the interface variable whose type is OpTypeMatrix |interface_var_type|
// with. If |extra_array_length| is not zero, adds the extra arrayness to all
// the scalar variables.
NestedCompositeComponents CreateScalarInterfaceVarsForMatrix(
Instruction* interface_var_type, SpvStorageClass storage_class,
uint32_t extra_array_length);
// Recursively adds Location and Component decorations to variables in
// |vars| with |location| and |component|. Increases |location| by one after
// it actually adds Location and Component decorations for a variable.
void AddLocationAndComponentDecorations(const NestedCompositeComponents& vars,
uint32_t* location,
uint32_t component);
// Replaces the interface variable |interface_var| with
// |scalar_interface_vars| and returns whether it succeeds or not.
// |extra_arrayness| is the extra arrayness of the interface variable.
// |scalar_interface_vars| contains the nested variables to replace the
// interface variable with.
bool ReplaceInterfaceVarWith(
Instruction* interface_var, uint32_t extra_arrayness,
const NestedCompositeComponents& scalar_interface_vars);
// Replaces |interface_var| in the operands of instructions
// |interface_var_users| with |scalar_interface_vars|. This is a recursive
// method and |interface_var_component_indices| is used to specify which
// recursive component of |interface_var| is replaced. Returns composite
// construct instructions to be replaced with load instructions of
// |interface_var_users| via |loads_to_composites|. Returns composite
// construct instructions to be replaced with load instructions of access
// chain instructions in |interface_var_users| via
// |loads_for_access_chain_to_composites|.
bool ReplaceComponentsOfInterfaceVarWith(
Instruction* interface_var,
const std::vector<Instruction*>& interface_var_users,
const NestedCompositeComponents& scalar_interface_vars,
std::vector<uint32_t>& interface_var_component_indices,
const uint32_t* extra_array_index,
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
std::unordered_map<Instruction*, Instruction*>*
loads_for_access_chain_to_composites);
// Replaces |interface_var| in the operands of instructions
// |interface_var_users| with |components| that is a vector of components for
// the interface variable |interface_var|. This is a recursive method and
// |interface_var_component_indices| is used to specify which recursive
// component of |interface_var| is replaced. Returns composite construct
// instructions to be replaced with load instructions of |interface_var_users|
// via |loads_to_composites|. Returns composite construct instructions to be
// replaced with load instructions of access chain instructions in
// |interface_var_users| via |loads_for_access_chain_to_composites|.
bool ReplaceMultipleComponentsOfInterfaceVarWith(
Instruction* interface_var,
const std::vector<Instruction*>& interface_var_users,
const std::vector<NestedCompositeComponents>& components,
std::vector<uint32_t>& interface_var_component_indices,
const uint32_t* extra_array_index,
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
std::unordered_map<Instruction*, Instruction*>*
loads_for_access_chain_to_composites);
// Replaces a component of |interface_var| that is used as an operand of
// instruction |interface_var_user| with |scalar_var|.
// |interface_var_component_indices| is a vector of recursive indices for
// which recursive component of |interface_var| is replaced. If
// |interface_var_user| is a load, returns the component value via
// |loads_to_component_values|. If |interface_var_user| is an access chain,
// returns the component value for loads of |interface_var_user| via
// |loads_for_access_chain_to_component_values|.
bool ReplaceComponentOfInterfaceVarWith(
Instruction* interface_var, Instruction* interface_var_user,
Instruction* scalar_var,
const std::vector<uint32_t>& interface_var_component_indices,
const uint32_t* extra_array_index,
std::unordered_map<Instruction*, Instruction*>* loads_to_component_values,
std::unordered_map<Instruction*, Instruction*>*
loads_for_access_chain_to_component_values);
// Creates instructions to load |scalar_var| and inserts them before
// |insert_before|. If |extra_array_index| is not null, they load
// |extra_array_index| th component of |scalar_var| instead of |scalar_var|
// itself.
Instruction* LoadScalarVar(Instruction* scalar_var,
const uint32_t* extra_array_index,
Instruction* insert_before);
// Creates instructions to load an access chain to |var| and inserts them
// before |insert_before|. |Indexes| will be Indexes operand of the access
// chain.
Instruction* LoadAccessChainToVar(Instruction* var,
const std::vector<uint32_t>& indexes,
Instruction* insert_before);
// Creates instructions to store a component of an aggregate whose id is
// |value_id| to an access chain to |scalar_var| and inserts the created
// instructions before |insert_before|. To get the component, recursively
// traverses the aggregate with |component_indices| as indexes.
// Numbers in |access_chain_indices| are the Indexes operand of the access
// chain to |scalar_var|
void StoreComponentOfValueToAccessChainToScalarVar(
uint32_t value_id, const std::vector<uint32_t>& component_indices,
Instruction* scalar_var,
const std::vector<uint32_t>& access_chain_indices,
Instruction* insert_before);
// Creates instructions to store a component of an aggregate whose id is
// |value_id| to |scalar_var| and inserts the created instructions before
// |insert_before|. To get the component, recursively traverses the aggregate
// using |extra_array_index| and |component_indices| as indexes.
void StoreComponentOfValueToScalarVar(
uint32_t value_id, const std::vector<uint32_t>& component_indices,
Instruction* scalar_var, const uint32_t* extra_array_index,
Instruction* insert_before);
// Creates instructions to store a component of an aggregate whose id is
// |value_id| to |ptr| and inserts the created instructions before
// |insert_before|. To get the component, recursively traverses the aggregate
// using |extra_array_index| and |component_indices| as indexes.
// |component_type_id| is the id of the type instruction of the component.
void StoreComponentOfValueTo(uint32_t component_type_id, uint32_t value_id,
const std::vector<uint32_t>& component_indices,
Instruction* ptr,
const uint32_t* extra_array_index,
Instruction* insert_before);
// Creates new OpCompositeExtract with |type_id| for Result Type,
// |composite_id| for Composite operand, and |indexes| for Indexes operands.
// If |extra_first_index| is not nullptr, uses it as the first Indexes
// operand.
Instruction* CreateCompositeExtract(uint32_t type_id, uint32_t composite_id,
const std::vector<uint32_t>& indexes,
const uint32_t* extra_first_index);
// Creates a new OpLoad whose Result Type is |type_id| and Pointer operand is
// |ptr|. Inserts the new instruction before |insert_before|.
Instruction* CreateLoad(uint32_t type_id, Instruction* ptr,
Instruction* insert_before);
// Clones an annotation instruction |annotation_inst| and sets the target
// operand of the new annotation instruction as |var_id|.
void CloneAnnotationForVariable(Instruction* annotation_inst,
uint32_t var_id);
// Replaces the interface variable |interface_var| in the operands of the
// entry point |entry_point| with |scalar_var_id|. If it cannot find
// |interface_var| from the operands of the entry point |entry_point|, adds
// |scalar_var_id| as an operand of the entry point |entry_point|.
bool ReplaceInterfaceVarInEntryPoint(Instruction* interface_var,
Instruction* entry_point,
uint32_t scalar_var_id);
// Creates an access chain instruction whose Base operand is |var| and Indexes
// operand is |index|. |component_type_id| is the id of the type instruction
// that is the type of component. Inserts the new access chain before
// |insert_before|.
Instruction* CreateAccessChainWithIndex(uint32_t component_type_id,
Instruction* var, uint32_t index,
Instruction* insert_before);
// Returns the pointee type of the type of variable |var|.
uint32_t GetPointeeTypeIdOfVar(Instruction* var);
// Replaces the access chain |access_chain| and its users with a new access
// chain that points |scalar_var| as the Base operand having
// |interface_var_component_indices| as Indexes operands and users of the new
// access chain. When some of the users are load instructions, returns the
// original load instruction to the new instruction that loads a component of
// the original load value via |loads_to_component_values|.
void ReplaceAccessChainWith(
Instruction* access_chain,
const std::vector<uint32_t>& interface_var_component_indices,
Instruction* scalar_var,
std::unordered_map<Instruction*, Instruction*>*
loads_to_component_values);
// Assuming that |access_chain| is an access chain instruction whose Base
// operand is |base_access_chain|, replaces the operands of |access_chain|
// with operands of |base_access_chain| and Indexes operands of
// |access_chain|.
void UseBaseAccessChainForAccessChain(Instruction* access_chain,
Instruction* base_access_chain);
// Creates composite construct instructions for load instructions that are the
// keys of |loads_to_component_values| if no such composite construct
// instructions exist. Adds a component of the composite as an operand of the
// created composite construct instruction. Each value of
// |loads_to_component_values| is the component. Returns the created composite
// construct instructions using |loads_to_composites|. |depth_to_component| is
// the number of recursive access steps to get the component from the
// composite.
void AddComponentsToCompositesForLoads(
const std::unordered_map<Instruction*, Instruction*>&
loads_to_component_values,
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
uint32_t depth_to_component);
// Creates a composite construct instruction for a component of the value of
// instruction |load| in |depth_to_component| th recursive depth and inserts
// it after |load|.
Instruction* CreateCompositeConstructForComponentOfLoad(
Instruction* load, uint32_t depth_to_component);
// Creates a new access chain instruction that points to variable |var| whose
// type is the instruction with |var_type_id| and inserts it before
// |insert_before|. The new access chain will have |index_ids| for Indexes
// operands. Returns the type id of the component that is pointed by the new
// access chain via |component_type_id|.
Instruction* CreateAccessChainToVar(uint32_t var_type_id, Instruction* var,
const std::vector<uint32_t>& index_ids,
Instruction* insert_before,
uint32_t* component_type_id);
// Returns the result id of OpTypeArray instrunction whose Element Type
// operand is |elem_type_id| and Length operand is |array_length|.
uint32_t GetArrayType(uint32_t elem_type_id, uint32_t array_length);
// Returns the result id of OpTypePointer instrunction whose Type
// operand is |type_id| and Storage Class operand is |storage_class|.
uint32_t GetPointerType(uint32_t type_id, SpvStorageClass storage_class);
// Kills an instrunction |inst| and its users.
void KillInstructionAndUsers(Instruction* inst);
// Kills a vector of instrunctions |insts| and their users.
void KillInstructionsAndUsers(const std::vector<Instruction*>& insts);
// Kills all OpDecorate instructions for Location and Component of the
// variable whose id is |var_id|.
void KillLocationAndComponentDecorations(uint32_t var_id);
// If |var| has the extra arrayness for an entry point, reports an error and
// returns true. Otherwise, returns false.
bool ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var);
// If |var| does not have the extra arrayness for an entry point, reports an
// error and returns true. Otherwise, returns false.
bool ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var);
// If |interface_var| has the extra arrayness for an entry point but it does
// not have one for another entry point, reports an error and returns false.
// Otherwise, returns true. |has_extra_arrayness| denotes whether it has an
// extra arrayness for an entry point or not.
bool CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var,
bool has_extra_arrayness);
// Conducts the scalar replacement for the interface variables used by the
// |entry_point|.
Pass::Status ReplaceInterfaceVarsWithScalars(Instruction& entry_point);
// A set of interface variable ids that were already removed from operands of
// the entry point.
std::unordered_set<uint32_t>
interface_vars_removed_from_entry_point_operands_;
// A mapping from ids of new composite construct instructions that load
// instructions are replaced with to the recursive depth of the component of
// load that the new component construct instruction is used for.
std::unordered_map<uint32_t, uint32_t> composite_ids_to_component_depths;
// A set of interface variables with the extra arrayness for any of the entry
// points.
std::unordered_set<Instruction*> vars_with_extra_arrayness;
// A set of interface variables without the extra arrayness for any of the
// entry points.
std::unordered_set<Instruction*> vars_without_extra_arrayness;
};
} // namespace opt
} // namespace spvtools
#endif // SOURCE_OPT_INTERFACE_VAR_SROA_H_

View File

@ -1094,6 +1094,9 @@ void IRContext::AddDebug2Inst(std::unique_ptr<Instruction>&& d) {
id_to_name_->insert({d->GetSingleWordInOperand(0), d.get()});
}
}
if (AreAnalysesValid(kAnalysisDefUse)) {
get_def_use_mgr()->AnalyzeInstDefUse(d.get());
}
module()->AddDebug2Inst(std::move(d));
}

View File

@ -1020,6 +1020,11 @@ Optimizer::PassToken CreateConvertToSampledImagePass(
MakeUnique<opt::ConvertToSampledImagePass>(descriptor_set_binding_pairs));
}
Optimizer::PassToken CreateInterfaceVariableScalarReplacementPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::InterfaceVariableScalarReplacement>());
}
Optimizer::PassToken CreateRemoveDontInlinePass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::RemoveDontInline>());

View File

@ -49,6 +49,7 @@
#include "source/opt/inst_bindless_check_pass.h"
#include "source/opt/inst_buff_addr_check_pass.h"
#include "source/opt/inst_debug_printf_pass.h"
#include "source/opt/interface_var_sroa.h"
#include "source/opt/interp_fixup_pass.h"
#include "source/opt/licm_pass.h"
#include "source/opt/local_access_chain_convert_pass.h"

View File

@ -62,6 +62,7 @@ add_spvtools_unittest(TARGET opt
inst_debug_printf_test.cpp
instruction_list_test.cpp
instruction_test.cpp
interface_var_sroa_test.cpp
interp_fixup_test.cpp
ir_builder.cpp
ir_context_test.cpp

View File

@ -0,0 +1,410 @@
// Copyright (c) 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "gmock/gmock.h"
#include "test/opt/assembly_builder.h"
#include "test/opt/pass_fixture.h"
#include "test/opt/pass_utils.h"
namespace spvtools {
namespace opt {
namespace {
using InterfaceVariableScalarReplacementTest = PassTest<::testing::Test>;
TEST_F(InterfaceVariableScalarReplacementTest,
ReplaceInterfaceVarsWithScalars) {
const std::string spirv = R"(
OpCapability Shader
OpCapability Tessellation
OpMemoryModel Logical GLSL450
OpEntryPoint TessellationControl %func "shader" %x %y %z %w %u %v
; CHECK: OpName [[x:%\w+]] "x"
; CHECK-NOT: OpName {{%\w+}} "x"
; CHECK: OpName [[y:%\w+]] "y"
; CHECK-NOT: OpName {{%\w+}} "y"
; CHECK: OpName [[z0:%\w+]] "z"
; CHECK: OpName [[z1:%\w+]] "z"
; CHECK: OpName [[w0:%\w+]] "w"
; CHECK: OpName [[w1:%\w+]] "w"
; CHECK: OpName [[u0:%\w+]] "u"
; CHECK: OpName [[u1:%\w+]] "u"
; CHECK: OpName [[v0:%\w+]] "v"
; CHECK: OpName [[v1:%\w+]] "v"
; CHECK: OpName [[v2:%\w+]] "v"
; CHECK: OpName [[v3:%\w+]] "v"
; CHECK: OpName [[v4:%\w+]] "v"
; CHECK: OpName [[v5:%\w+]] "v"
OpName %x "x"
OpName %y "y"
OpName %z "z"
OpName %w "w"
OpName %u "u"
OpName %v "v"
; CHECK-DAG: OpDecorate [[x]] Location 2
; CHECK-DAG: OpDecorate [[y]] Location 0
; CHECK-DAG: OpDecorate [[z0]] Location 0
; CHECK-DAG: OpDecorate [[z0]] Component 0
; CHECK-DAG: OpDecorate [[z1]] Location 1
; CHECK-DAG: OpDecorate [[z1]] Component 0
; CHECK-DAG: OpDecorate [[z0]] Patch
; CHECK-DAG: OpDecorate [[z1]] Patch
; CHECK-DAG: OpDecorate [[w0]] Location 2
; CHECK-DAG: OpDecorate [[w0]] Component 0
; CHECK-DAG: OpDecorate [[w1]] Location 3
; CHECK-DAG: OpDecorate [[w1]] Component 0
; CHECK-DAG: OpDecorate [[w0]] Patch
; CHECK-DAG: OpDecorate [[w1]] Patch
; CHECK-DAG: OpDecorate [[u0]] Location 3
; CHECK-DAG: OpDecorate [[u0]] Component 2
; CHECK-DAG: OpDecorate [[u1]] Location 4
; CHECK-DAG: OpDecorate [[u1]] Component 2
; CHECK-DAG: OpDecorate [[v0]] Location 3
; CHECK-DAG: OpDecorate [[v0]] Component 3
; CHECK-DAG: OpDecorate [[v1]] Location 4
; CHECK-DAG: OpDecorate [[v1]] Component 3
; CHECK-DAG: OpDecorate [[v2]] Location 5
; CHECK-DAG: OpDecorate [[v2]] Component 3
; CHECK-DAG: OpDecorate [[v3]] Location 6
; CHECK-DAG: OpDecorate [[v3]] Component 3
; CHECK-DAG: OpDecorate [[v4]] Location 7
; CHECK-DAG: OpDecorate [[v4]] Component 3
; CHECK-DAG: OpDecorate [[v5]] Location 8
; CHECK-DAG: OpDecorate [[v5]] Component 3
OpDecorate %z Patch
OpDecorate %w Patch
OpDecorate %z Location 0
OpDecorate %x Location 2
OpDecorate %v Location 3
OpDecorate %v Component 3
OpDecorate %y Location 0
OpDecorate %w Location 2
OpDecorate %u Location 3
OpDecorate %u Component 2
%uint = OpTypeInt 32 0
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%uint_3 = OpConstant %uint 3
%uint_4 = OpConstant %uint 4
%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
%_ptr_Input_uint = OpTypePointer Input %uint
%_ptr_Output_uint = OpTypePointer Output %uint
%_arr_arr_uint_uint_2_3 = OpTypeArray %_arr_uint_uint_2 %uint_3
%_ptr_Input__arr_arr_uint_uint_2_3 = OpTypePointer Input %_arr_arr_uint_uint_2_3
%_arr_arr_arr_uint_uint_2_3_4 = OpTypeArray %_arr_arr_uint_uint_2_3 %uint_4
%_ptr_Output__arr_arr_arr_uint_uint_2_3_4 = OpTypePointer Output %_arr_arr_arr_uint_uint_2_3_4
%_ptr_Output__arr_arr_uint_uint_2_3 = OpTypePointer Output %_arr_arr_uint_uint_2_3
%z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
%x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
%y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
%w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
%u = OpVariable %_ptr_Input__arr_arr_uint_uint_2_3 Input
%v = OpVariable %_ptr_Output__arr_arr_arr_uint_uint_2_3_4 Output
; CHECK-DAG: [[x]] = OpVariable %_ptr_Output__arr_uint_uint_2 Output
; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input
; CHECK-DAG: [[z0]] = OpVariable %_ptr_Output_uint Output
; CHECK-DAG: [[z1]] = OpVariable %_ptr_Output_uint Output
; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input
; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input
; CHECK-DAG: [[u0]] = OpVariable %_ptr_Input__arr_uint_uint_3 Input
; CHECK-DAG: [[u1]] = OpVariable %_ptr_Input__arr_uint_uint_3 Input
; CHECK-DAG: [[v0]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
; CHECK-DAG: [[v1]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
; CHECK-DAG: [[v2]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
; CHECK-DAG: [[v3]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
; CHECK-DAG: [[v4]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
; CHECK-DAG: [[v5]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
%void = OpTypeVoid
%void_f = OpTypeFunction %void
%func = OpFunction %void None %void_f
%label = OpLabel
; CHECK: [[w0_value:%\w+]] = OpLoad %uint [[w0]]
; CHECK: [[w1_value:%\w+]] = OpLoad %uint [[w1]]
; CHECK: [[w_value:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[w0_value]] [[w1_value]]
; CHECK: [[w0:%\w+]] = OpCompositeExtract %uint [[w_value]] 0
; CHECK: OpStore [[z0]] [[w0]]
; CHECK: [[w1:%\w+]] = OpCompositeExtract %uint [[w_value]] 1
; CHECK: OpStore [[z1]] [[w1]]
%w_value = OpLoad %_arr_uint_uint_2 %w
OpStore %z %w_value
; CHECK: [[u00_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_0
; CHECK: [[u00:%\w+]] = OpLoad %uint [[u00_ptr]]
; CHECK: [[u10_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_0
; CHECK: [[u10:%\w+]] = OpLoad %uint [[u10_ptr]]
; CHECK: [[u01_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_1
; CHECK: [[u01:%\w+]] = OpLoad %uint [[u01_ptr]]
; CHECK: [[u11_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_1
; CHECK: [[u11:%\w+]] = OpLoad %uint [[u11_ptr]]
; CHECK: [[u02_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_2
; CHECK: [[u02:%\w+]] = OpLoad %uint [[u02_ptr]]
; CHECK: [[u12_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_2
; CHECK: [[u12:%\w+]] = OpLoad %uint [[u12_ptr]]
; CHECK-DAG: [[u0_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u00]] [[u10]]
; CHECK-DAG: [[u1_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u01]] [[u11]]
; CHECK-DAG: [[u2_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u02]] [[u12]]
; CHECK: [[u_val:%\w+]] = OpCompositeConstruct %_arr__arr_uint_uint_2_uint_3 [[u0_val]] [[u1_val]] [[u2_val]]
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v0]] %uint_1
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 0 0
; CHECK: OpStore [[ptr]] [[val]]
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v1]] %uint_1
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 0 1
; CHECK: OpStore [[ptr]] [[val]]
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v2]] %uint_1
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 1 0
; CHECK: OpStore [[ptr]] [[val]]
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v3]] %uint_1
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 1 1
; CHECK: OpStore [[ptr]] [[val]]
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v4]] %uint_1
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 2 0
; CHECK: OpStore [[ptr]] [[val]]
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] %uint_1
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 2 1
; CHECK: OpStore [[ptr]] [[val]]
%v_ptr = OpAccessChain %_ptr_Output__arr_arr_uint_uint_2_3 %v %uint_1
%u_val = OpLoad %_arr_arr_uint_uint_2_3 %u
OpStore %v_ptr %u_val
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<InterfaceVariableScalarReplacement>(spirv, true);
}
TEST_F(InterfaceVariableScalarReplacementTest,
CheckPatchDecorationPreservation) {
// Make sure scalars for the variables with the extra arrayness have the extra
// arrayness after running the pass while others do not have it.
// Only "y" does not have the extra arrayness in the following SPIR-V.
const std::string spirv = R"(
OpCapability Shader
OpCapability Tessellation
OpMemoryModel Logical GLSL450
OpEntryPoint TessellationEvaluation %func "shader" %x %y %z %w
OpDecorate %z Patch
OpDecorate %w Patch
OpDecorate %z Location 0
OpDecorate %x Location 2
OpDecorate %y Location 0
OpDecorate %w Location 1
OpName %x "x"
OpName %y "y"
OpName %z "z"
OpName %w "w"
; CHECK: OpName [[y:%\w+]] "y"
; CHECK-NOT: OpName {{%\w+}} "y"
; CHECK-DAG: OpName [[z0:%\w+]] "z"
; CHECK-DAG: OpName [[z1:%\w+]] "z"
; CHECK-DAG: OpName [[w0:%\w+]] "w"
; CHECK-DAG: OpName [[w1:%\w+]] "w"
; CHECK-DAG: OpName [[x0:%\w+]] "x"
; CHECK-DAG: OpName [[x1:%\w+]] "x"
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
%z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
%x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
%y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
%w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input
; CHECK-DAG: [[z0]] = OpVariable %_ptr_Output_uint Output
; CHECK-DAG: [[z1]] = OpVariable %_ptr_Output_uint Output
; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input
; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input
; CHECK-DAG: [[x0]] = OpVariable %_ptr_Output_uint Output
; CHECK-DAG: [[x1]] = OpVariable %_ptr_Output_uint Output
%void = OpTypeVoid
%void_f = OpTypeFunction %void
%func = OpFunction %void None %void_f
%label = OpLabel
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<InterfaceVariableScalarReplacement>(spirv, true);
}
TEST_F(InterfaceVariableScalarReplacementTest,
CheckEntryPointInterfaceOperands) {
const std::string spirv = R"(
OpCapability Shader
OpCapability Tessellation
OpMemoryModel Logical GLSL450
OpEntryPoint TessellationEvaluation %tess "tess" %x %y
OpEntryPoint Vertex %vert "vert" %w
OpDecorate %z Location 0
OpDecorate %x Location 2
OpDecorate %y Location 0
OpDecorate %w Location 1
OpName %x "x"
OpName %y "y"
OpName %z "z"
OpName %w "w"
; CHECK: OpName [[y:%\w+]] "y"
; CHECK-NOT: OpName {{%\w+}} "y"
; CHECK-DAG: OpName [[x0:%\w+]] "x"
; CHECK-DAG: OpName [[x1:%\w+]] "x"
; CHECK-DAG: OpName [[w0:%\w+]] "w"
; CHECK-DAG: OpName [[w1:%\w+]] "w"
; CHECK-DAG: OpName [[z:%\w+]] "z"
; CHECK-NOT: OpName {{%\w+}} "z"
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
%z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
%x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
%y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
%w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input
; CHECK-DAG: [[z]] = OpVariable %_ptr_Output__arr_uint_uint_2 Output
; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input
; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input
; CHECK-DAG: [[x0]] = OpVariable %_ptr_Output_uint Output
; CHECK-DAG: [[x1]] = OpVariable %_ptr_Output_uint Output
%void = OpTypeVoid
%void_f = OpTypeFunction %void
%tess = OpFunction %void None %void_f
%bb0 = OpLabel
OpReturn
OpFunctionEnd
%vert = OpFunction %void None %void_f
%bb1 = OpLabel
OpReturn
OpFunctionEnd
)";
SinglePassRunAndMatch<InterfaceVariableScalarReplacement>(spirv, true);
}
class InterfaceVarSROAErrorTest : public PassTest<::testing::Test> {
public:
InterfaceVarSROAErrorTest()
: consumer_([this](spv_message_level_t level, const char*,
const spv_position_t& position, const char* message) {
if (!error_message_.empty()) error_message_ += "\n";
switch (level) {
case SPV_MSG_FATAL:
case SPV_MSG_INTERNAL_ERROR:
case SPV_MSG_ERROR:
error_message_ += "ERROR";
break;
case SPV_MSG_WARNING:
error_message_ += "WARNING";
break;
case SPV_MSG_INFO:
error_message_ += "INFO";
break;
case SPV_MSG_DEBUG:
error_message_ += "DEBUG";
break;
}
error_message_ +=
": " + std::to_string(position.index) + ": " + message;
}) {}
Pass::Status RunPass(const std::string& text) {
std::unique_ptr<IRContext> context_ =
spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_2, consumer_, text);
if (!context_.get()) return Pass::Status::Failure;
PassManager manager;
manager.SetMessageConsumer(consumer_);
manager.AddPass<InterfaceVariableScalarReplacement>();
return manager.Run(context_.get());
}
std::string GetErrorMessage() const { return error_message_; }
void TearDown() override { error_message_.clear(); }
private:
spvtools::MessageConsumer consumer_;
std::string error_message_;
};
TEST_F(InterfaceVarSROAErrorTest, CheckConflictOfExtraArraynessBetweenEntries) {
const std::string spirv = R"(
OpCapability Shader
OpCapability Tessellation
OpMemoryModel Logical GLSL450
OpEntryPoint TessellationControl %tess "tess" %x %y %z
OpEntryPoint Vertex %vert "vert" %z %w
OpDecorate %z Location 0
OpDecorate %x Location 2
OpDecorate %y Location 0
OpDecorate %w Location 1
OpName %x "x"
OpName %y "y"
OpName %z "z"
OpName %w "w"
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
%z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
%x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
%y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
%w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
%void = OpTypeVoid
%void_f = OpTypeFunction %void
%tess = OpFunction %void None %void_f
%bb0 = OpLabel
OpReturn
OpFunctionEnd
%vert = OpFunction %void None %void_f
%bb1 = OpLabel
OpReturn
OpFunctionEnd
)";
EXPECT_EQ(RunPass(spirv), Pass::Status::Failure);
const char expected_error[] =
"ERROR: 0: A variable is arrayed for an entry point but it is not "
"arrayed for another entry point\n"
" %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output";
EXPECT_STREQ(GetErrorMessage().c_str(), expected_error);
}
} // namespace
} // namespace opt
} // namespace spvtools