mirror of
https://gitee.com/openharmony/third_party_spirv-tools
synced 2024-11-27 01:21:25 +00:00
spirv-opt: add pass for interface variable scalar replacement (#4779)
Replace shader's stage variables whose types are array or matrix with scalars/vectors. For example, ``` Before: %foo = OpVariable %_ptr_Output__arr_v2float_uint_4 Output After: %foo = OpVariable %_ptr_Output_v2float Output %foo_0 = OpVariable %_ptr_Output_v2float Output %foo_1 = OpVariable %_ptr_Output_v2float Output %foo_2 = OpVariable %_ptr_Output_v2float Output ```
This commit is contained in:
parent
ffc8f2d455
commit
ad3514b732
@ -128,6 +128,7 @@ SPVTOOLS_OPT_SRC_FILES := \
|
||||
source/opt/instruction.cpp \
|
||||
source/opt/instruction_list.cpp \
|
||||
source/opt/instrument_pass.cpp \
|
||||
source/opt/interface_var_sroa.cpp \
|
||||
source/opt/interp_fixup_pass.cpp \
|
||||
source/opt/ir_context.cpp \
|
||||
source/opt/ir_loader.cpp \
|
||||
|
2
BUILD.gn
2
BUILD.gn
@ -667,6 +667,8 @@ static_library("spvtools_opt") {
|
||||
"source/opt/instruction_list.h",
|
||||
"source/opt/instrument_pass.cpp",
|
||||
"source/opt/instrument_pass.h",
|
||||
"source/opt/interface_var_sroa.cpp",
|
||||
"source/opt/interface_var_sroa.h",
|
||||
"source/opt/interp_fixup_pass.cpp",
|
||||
"source/opt/interp_fixup_pass.h",
|
||||
"source/opt/ir_builder.h",
|
||||
|
@ -903,6 +903,11 @@ Optimizer::PassToken CreateConvertToSampledImagePass(
|
||||
const std::vector<opt::DescriptorSetAndBinding>&
|
||||
descriptor_set_binding_pairs);
|
||||
|
||||
// Create an interface-variable-scalar-replacement pass that replaces array or
|
||||
// matrix interface variables with a series of scalar or vector interface
|
||||
// variables. For example, it replaces `float3 foo[2]` with `float3 foo0, foo1`.
|
||||
Optimizer::PassToken CreateInterfaceVariableScalarReplacementPass();
|
||||
|
||||
// Creates a remove-dont-inline pass to remove the |DontInline| function control
|
||||
// from every function in the module. This is useful if you want the inliner to
|
||||
// inline these functions some reason.
|
||||
|
@ -68,6 +68,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
|
||||
instruction.h
|
||||
instruction_list.h
|
||||
instrument_pass.h
|
||||
interface_var_sroa.h
|
||||
interp_fixup_pass.h
|
||||
ir_builder.h
|
||||
ir_context.h
|
||||
@ -182,6 +183,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
|
||||
instruction.cpp
|
||||
instruction_list.cpp
|
||||
instrument_pass.cpp
|
||||
interface_var_sroa.cpp
|
||||
interp_fixup_pass.cpp
|
||||
ir_context.cpp
|
||||
ir_loader.cpp
|
||||
|
964
source/opt/interface_var_sroa.cpp
Normal file
964
source/opt/interface_var_sroa.cpp
Normal file
@ -0,0 +1,964 @@
|
||||
// Copyright (c) 2022 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "source/opt/interface_var_sroa.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "source/opt/decoration_manager.h"
|
||||
#include "source/opt/def_use_manager.h"
|
||||
#include "source/opt/function.h"
|
||||
#include "source/opt/log.h"
|
||||
#include "source/opt/type_manager.h"
|
||||
#include "source/util/make_unique.h"
|
||||
|
||||
const static uint32_t kOpDecorateDecorationInOperandIndex = 1;
|
||||
const static uint32_t kOpDecorateLiteralInOperandIndex = 2;
|
||||
const static uint32_t kOpEntryPointInOperandInterface = 3;
|
||||
const static uint32_t kOpVariableStorageClassInOperandIndex = 0;
|
||||
const static uint32_t kOpTypeArrayElemTypeInOperandIndex = 0;
|
||||
const static uint32_t kOpTypeArrayLengthInOperandIndex = 1;
|
||||
const static uint32_t kOpTypeMatrixColCountInOperandIndex = 1;
|
||||
const static uint32_t kOpTypeMatrixColTypeInOperandIndex = 0;
|
||||
const static uint32_t kOpTypePtrTypeInOperandIndex = 1;
|
||||
const static uint32_t kOpConstantValueInOperandIndex = 0;
|
||||
|
||||
namespace spvtools {
|
||||
namespace opt {
|
||||
namespace {
|
||||
|
||||
// Get the length of the OpTypeArray |array_type|.
|
||||
uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr,
|
||||
Instruction* array_type) {
|
||||
assert(array_type->opcode() == SpvOpTypeArray);
|
||||
uint32_t const_int_id =
|
||||
array_type->GetSingleWordInOperand(kOpTypeArrayLengthInOperandIndex);
|
||||
Instruction* array_length_inst = def_use_mgr->GetDef(const_int_id);
|
||||
assert(array_length_inst->opcode() == SpvOpConstant);
|
||||
return array_length_inst->GetSingleWordInOperand(
|
||||
kOpConstantValueInOperandIndex);
|
||||
}
|
||||
|
||||
// Get the element type instruction of the OpTypeArray |array_type|.
|
||||
Instruction* GetArrayElementType(analysis::DefUseManager* def_use_mgr,
|
||||
Instruction* array_type) {
|
||||
assert(array_type->opcode() == SpvOpTypeArray);
|
||||
uint32_t elem_type_id =
|
||||
array_type->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
|
||||
return def_use_mgr->GetDef(elem_type_id);
|
||||
}
|
||||
|
||||
// Get the column type instruction of the OpTypeMatrix |matrix_type|.
|
||||
Instruction* GetMatrixColumnType(analysis::DefUseManager* def_use_mgr,
|
||||
Instruction* matrix_type) {
|
||||
assert(matrix_type->opcode() == SpvOpTypeMatrix);
|
||||
uint32_t column_type_id =
|
||||
matrix_type->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
|
||||
return def_use_mgr->GetDef(column_type_id);
|
||||
}
|
||||
|
||||
// Traverses the component type of OpTypeArray or OpTypeMatrix. Repeats it
|
||||
// |depth_to_component| times recursively and returns the component type.
|
||||
// |type_id| is the result id of the OpTypeArray or OpTypeMatrix instruction.
|
||||
uint32_t GetComponentTypeOfArrayMatrix(analysis::DefUseManager* def_use_mgr,
|
||||
uint32_t type_id,
|
||||
uint32_t depth_to_component) {
|
||||
if (depth_to_component == 0) return type_id;
|
||||
|
||||
Instruction* type_inst = def_use_mgr->GetDef(type_id);
|
||||
if (type_inst->opcode() == SpvOpTypeArray) {
|
||||
uint32_t elem_type_id =
|
||||
type_inst->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex);
|
||||
return GetComponentTypeOfArrayMatrix(def_use_mgr, elem_type_id,
|
||||
depth_to_component - 1);
|
||||
}
|
||||
|
||||
assert(type_inst->opcode() == SpvOpTypeMatrix);
|
||||
uint32_t column_type_id =
|
||||
type_inst->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex);
|
||||
return GetComponentTypeOfArrayMatrix(def_use_mgr, column_type_id,
|
||||
depth_to_component - 1);
|
||||
}
|
||||
|
||||
// Creates an OpDecorate instruction whose Target is |var_id| and Decoration is
|
||||
// |decoration|. Adds |literal| as an extra operand of the instruction.
|
||||
void CreateDecoration(analysis::DecorationManager* decoration_mgr,
|
||||
uint32_t var_id, SpvDecoration decoration,
|
||||
uint32_t literal) {
|
||||
std::vector<Operand> operands({
|
||||
{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {var_id}},
|
||||
{spv_operand_type_t::SPV_OPERAND_TYPE_DECORATION,
|
||||
{static_cast<uint32_t>(decoration)}},
|
||||
{spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}},
|
||||
});
|
||||
decoration_mgr->AddDecoration(SpvOpDecorate, std::move(operands));
|
||||
}
|
||||
|
||||
// Replaces load instructions with composite construct instructions in all the
|
||||
// users of the loads. |loads_to_composites| is the mapping from each load to
|
||||
// its corresponding OpCompositeConstruct.
|
||||
void ReplaceLoadWithCompositeConstruct(
|
||||
IRContext* context,
|
||||
const std::unordered_map<Instruction*, Instruction*>& loads_to_composites) {
|
||||
for (const auto& load_and_composite : loads_to_composites) {
|
||||
Instruction* load = load_and_composite.first;
|
||||
Instruction* composite_construct = load_and_composite.second;
|
||||
|
||||
std::vector<Instruction*> users;
|
||||
context->get_def_use_mgr()->ForEachUse(
|
||||
load, [&users, composite_construct](Instruction* user, uint32_t index) {
|
||||
user->GetOperand(index).words[0] = composite_construct->result_id();
|
||||
users.push_back(user);
|
||||
});
|
||||
|
||||
for (Instruction* user : users)
|
||||
context->get_def_use_mgr()->AnalyzeInstUse(user);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the storage class of the instruction |var|.
|
||||
SpvStorageClass GetStorageClass(Instruction* var) {
|
||||
return static_cast<SpvStorageClass>(
|
||||
var->GetSingleWordInOperand(kOpVariableStorageClassInOperandIndex));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool InterfaceVariableScalarReplacement::HasExtraArrayness(
|
||||
Instruction& entry_point, Instruction* var) {
|
||||
SpvExecutionModel execution_model =
|
||||
static_cast<SpvExecutionModel>(entry_point.GetSingleWordInOperand(0));
|
||||
if (execution_model != SpvExecutionModelTessellationEvaluation &&
|
||||
execution_model != SpvExecutionModelTessellationControl) {
|
||||
return false;
|
||||
}
|
||||
if (!context()->get_decoration_mgr()->HasDecoration(var->result_id(),
|
||||
SpvDecorationPatch)) {
|
||||
if (execution_model == SpvExecutionModelTessellationControl) return true;
|
||||
return GetStorageClass(var) != SpvStorageClassOutput;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::
|
||||
CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var,
|
||||
bool has_extra_arrayness) {
|
||||
if (has_extra_arrayness) {
|
||||
return !ReportErrorIfHasNoExtraArraynessForOtherEntry(interface_var);
|
||||
}
|
||||
return !ReportErrorIfHasExtraArraynessForOtherEntry(interface_var);
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::GetVariableLocation(
|
||||
Instruction* var, uint32_t* location) {
|
||||
return !context()->get_decoration_mgr()->WhileEachDecoration(
|
||||
var->result_id(), SpvDecorationLocation,
|
||||
[location](const Instruction& inst) {
|
||||
*location =
|
||||
inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::GetVariableComponent(
|
||||
Instruction* var, uint32_t* component) {
|
||||
return !context()->get_decoration_mgr()->WhileEachDecoration(
|
||||
var->result_id(), SpvDecorationComponent,
|
||||
[component](const Instruction& inst) {
|
||||
*component =
|
||||
inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex);
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<Instruction*>
|
||||
InterfaceVariableScalarReplacement::CollectInterfaceVariables(
|
||||
Instruction& entry_point) {
|
||||
std::vector<Instruction*> interface_vars;
|
||||
for (uint32_t i = kOpEntryPointInOperandInterface;
|
||||
i < entry_point.NumInOperands(); ++i) {
|
||||
Instruction* interface_var = context()->get_def_use_mgr()->GetDef(
|
||||
entry_point.GetSingleWordInOperand(i));
|
||||
assert(interface_var->opcode() == SpvOpVariable);
|
||||
|
||||
SpvStorageClass storage_class = GetStorageClass(interface_var);
|
||||
if (storage_class != SpvStorageClassInput &&
|
||||
storage_class != SpvStorageClassOutput) {
|
||||
continue;
|
||||
}
|
||||
|
||||
interface_vars.push_back(interface_var);
|
||||
}
|
||||
return interface_vars;
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::KillInstructionAndUsers(
|
||||
Instruction* inst) {
|
||||
if (inst->opcode() == SpvOpEntryPoint) {
|
||||
return;
|
||||
}
|
||||
if (inst->opcode() != SpvOpAccessChain) {
|
||||
context()->KillInst(inst);
|
||||
return;
|
||||
}
|
||||
context()->get_def_use_mgr()->ForEachUser(
|
||||
inst, [this](Instruction* user) { KillInstructionAndUsers(user); });
|
||||
context()->KillInst(inst);
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::KillInstructionsAndUsers(
|
||||
const std::vector<Instruction*>& insts) {
|
||||
for (Instruction* inst : insts) {
|
||||
KillInstructionAndUsers(inst);
|
||||
}
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::KillLocationAndComponentDecorations(
|
||||
uint32_t var_id) {
|
||||
context()->get_decoration_mgr()->RemoveDecorationsFrom(
|
||||
var_id, [](const Instruction& inst) {
|
||||
uint32_t decoration =
|
||||
inst.GetSingleWordInOperand(kOpDecorateDecorationInOperandIndex);
|
||||
return decoration == SpvDecorationLocation ||
|
||||
decoration == SpvDecorationComponent;
|
||||
});
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::ReplaceInterfaceVariableWithScalars(
|
||||
Instruction* interface_var, Instruction* interface_var_type,
|
||||
uint32_t location, uint32_t component, uint32_t extra_array_length) {
|
||||
NestedCompositeComponents scalar_interface_vars =
|
||||
CreateScalarInterfaceVarsForReplacement(interface_var_type,
|
||||
GetStorageClass(interface_var),
|
||||
extra_array_length);
|
||||
|
||||
AddLocationAndComponentDecorations(scalar_interface_vars, &location,
|
||||
component);
|
||||
KillLocationAndComponentDecorations(interface_var->result_id());
|
||||
|
||||
if (!ReplaceInterfaceVarWith(interface_var, extra_array_length,
|
||||
scalar_interface_vars)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
context()->KillInst(interface_var);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarWith(
|
||||
Instruction* interface_var, uint32_t extra_array_length,
|
||||
const NestedCompositeComponents& scalar_interface_vars) {
|
||||
std::vector<Instruction*> users;
|
||||
context()->get_def_use_mgr()->ForEachUser(
|
||||
interface_var, [&users](Instruction* user) { users.push_back(user); });
|
||||
|
||||
std::vector<uint32_t> interface_var_component_indices;
|
||||
std::unordered_map<Instruction*, Instruction*> loads_to_composites;
|
||||
std::unordered_map<Instruction*, Instruction*>
|
||||
loads_for_access_chain_to_composites;
|
||||
if (extra_array_length != 0) {
|
||||
// Note that the extra arrayness is the first dimension of the array
|
||||
// interface variable.
|
||||
for (uint32_t index = 0; index < extra_array_length; ++index) {
|
||||
std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
|
||||
if (!ReplaceComponentsOfInterfaceVarWith(
|
||||
interface_var, users, scalar_interface_vars,
|
||||
interface_var_component_indices, &index,
|
||||
&loads_to_component_values,
|
||||
&loads_for_access_chain_to_composites)) {
|
||||
return false;
|
||||
}
|
||||
AddComponentsToCompositesForLoads(loads_to_component_values,
|
||||
&loads_to_composites, 0);
|
||||
}
|
||||
} else if (!ReplaceComponentsOfInterfaceVarWith(
|
||||
interface_var, users, scalar_interface_vars,
|
||||
interface_var_component_indices, nullptr, &loads_to_composites,
|
||||
&loads_for_access_chain_to_composites)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ReplaceLoadWithCompositeConstruct(context(), loads_to_composites);
|
||||
ReplaceLoadWithCompositeConstruct(context(),
|
||||
loads_for_access_chain_to_composites);
|
||||
|
||||
KillInstructionsAndUsers(users);
|
||||
return true;
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::AddLocationAndComponentDecorations(
|
||||
const NestedCompositeComponents& vars, uint32_t* location,
|
||||
uint32_t component) {
|
||||
if (!vars.HasMultipleComponents()) {
|
||||
uint32_t var_id = vars.GetComponentVariable()->result_id();
|
||||
CreateDecoration(context()->get_decoration_mgr(), var_id,
|
||||
SpvDecorationLocation, *location);
|
||||
CreateDecoration(context()->get_decoration_mgr(), var_id,
|
||||
SpvDecorationComponent, component);
|
||||
++(*location);
|
||||
return;
|
||||
}
|
||||
for (const auto& var : vars.GetComponents()) {
|
||||
AddLocationAndComponentDecorations(var, location, component);
|
||||
}
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::ReplaceComponentsOfInterfaceVarWith(
|
||||
Instruction* interface_var,
|
||||
const std::vector<Instruction*>& interface_var_users,
|
||||
const NestedCompositeComponents& scalar_interface_vars,
|
||||
std::vector<uint32_t>& interface_var_component_indices,
|
||||
const uint32_t* extra_array_index,
|
||||
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
|
||||
std::unordered_map<Instruction*, Instruction*>*
|
||||
loads_for_access_chain_to_composites) {
|
||||
if (!scalar_interface_vars.HasMultipleComponents()) {
|
||||
for (Instruction* interface_var_user : interface_var_users) {
|
||||
if (!ReplaceComponentOfInterfaceVarWith(
|
||||
interface_var, interface_var_user,
|
||||
scalar_interface_vars.GetComponentVariable(),
|
||||
interface_var_component_indices, extra_array_index,
|
||||
loads_to_composites, loads_for_access_chain_to_composites)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return ReplaceMultipleComponentsOfInterfaceVarWith(
|
||||
interface_var, interface_var_users, scalar_interface_vars.GetComponents(),
|
||||
interface_var_component_indices, extra_array_index, loads_to_composites,
|
||||
loads_for_access_chain_to_composites);
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::
|
||||
ReplaceMultipleComponentsOfInterfaceVarWith(
|
||||
Instruction* interface_var,
|
||||
const std::vector<Instruction*>& interface_var_users,
|
||||
const std::vector<NestedCompositeComponents>& components,
|
||||
std::vector<uint32_t>& interface_var_component_indices,
|
||||
const uint32_t* extra_array_index,
|
||||
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
|
||||
std::unordered_map<Instruction*, Instruction*>*
|
||||
loads_for_access_chain_to_composites) {
|
||||
for (uint32_t i = 0; i < components.size(); ++i) {
|
||||
interface_var_component_indices.push_back(i);
|
||||
std::unordered_map<Instruction*, Instruction*> loads_to_component_values;
|
||||
std::unordered_map<Instruction*, Instruction*>
|
||||
loads_for_access_chain_to_component_values;
|
||||
if (!ReplaceComponentsOfInterfaceVarWith(
|
||||
interface_var, interface_var_users, components[i],
|
||||
interface_var_component_indices, extra_array_index,
|
||||
&loads_to_component_values,
|
||||
&loads_for_access_chain_to_component_values)) {
|
||||
return false;
|
||||
}
|
||||
interface_var_component_indices.pop_back();
|
||||
|
||||
uint32_t depth_to_component =
|
||||
static_cast<uint32_t>(interface_var_component_indices.size());
|
||||
AddComponentsToCompositesForLoads(
|
||||
loads_for_access_chain_to_component_values,
|
||||
loads_for_access_chain_to_composites, depth_to_component);
|
||||
if (extra_array_index) ++depth_to_component;
|
||||
AddComponentsToCompositesForLoads(loads_to_component_values,
|
||||
loads_to_composites, depth_to_component);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::ReplaceComponentOfInterfaceVarWith(
|
||||
Instruction* interface_var, Instruction* interface_var_user,
|
||||
Instruction* scalar_var,
|
||||
const std::vector<uint32_t>& interface_var_component_indices,
|
||||
const uint32_t* extra_array_index,
|
||||
std::unordered_map<Instruction*, Instruction*>* loads_to_component_values,
|
||||
std::unordered_map<Instruction*, Instruction*>*
|
||||
loads_for_access_chain_to_component_values) {
|
||||
SpvOp opcode = interface_var_user->opcode();
|
||||
if (opcode == SpvOpStore) {
|
||||
uint32_t value_id = interface_var_user->GetSingleWordInOperand(1);
|
||||
StoreComponentOfValueToScalarVar(value_id, interface_var_component_indices,
|
||||
scalar_var, extra_array_index,
|
||||
interface_var_user);
|
||||
return true;
|
||||
}
|
||||
if (opcode == SpvOpLoad) {
|
||||
Instruction* scalar_load =
|
||||
LoadScalarVar(scalar_var, extra_array_index, interface_var_user);
|
||||
loads_to_component_values->insert({interface_var_user, scalar_load});
|
||||
return true;
|
||||
}
|
||||
|
||||
// Copy OpName and annotation instructions only once. Therefore, we create
|
||||
// them only for the first element of the extra array.
|
||||
if (extra_array_index && *extra_array_index != 0) return true;
|
||||
|
||||
if (opcode == SpvOpDecorateId || opcode == SpvOpDecorateString ||
|
||||
opcode == SpvOpDecorate) {
|
||||
CloneAnnotationForVariable(interface_var_user, scalar_var->result_id());
|
||||
return true;
|
||||
}
|
||||
|
||||
if (opcode == SpvOpName) {
|
||||
std::unique_ptr<Instruction> new_inst(interface_var_user->Clone(context()));
|
||||
new_inst->SetInOperand(0, {scalar_var->result_id()});
|
||||
context()->AddDebug2Inst(std::move(new_inst));
|
||||
return true;
|
||||
}
|
||||
|
||||
if (opcode == SpvOpEntryPoint) {
|
||||
return ReplaceInterfaceVarInEntryPoint(interface_var, interface_var_user,
|
||||
scalar_var->result_id());
|
||||
}
|
||||
|
||||
if (opcode == SpvOpAccessChain) {
|
||||
ReplaceAccessChainWith(interface_var_user, interface_var_component_indices,
|
||||
scalar_var,
|
||||
loads_for_access_chain_to_component_values);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string message("Unhandled instruction");
|
||||
message += "\n " + interface_var_user->PrettyPrint(
|
||||
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
|
||||
message +=
|
||||
"\nfor interface variable scalar replacement\n " +
|
||||
interface_var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
|
||||
context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::UseBaseAccessChainForAccessChain(
|
||||
Instruction* access_chain, Instruction* base_access_chain) {
|
||||
assert(base_access_chain->opcode() == SpvOpAccessChain &&
|
||||
access_chain->opcode() == SpvOpAccessChain &&
|
||||
access_chain->GetSingleWordInOperand(0) ==
|
||||
base_access_chain->result_id());
|
||||
Instruction::OperandList new_operands;
|
||||
for (uint32_t i = 0; i < base_access_chain->NumInOperands(); ++i) {
|
||||
new_operands.emplace_back(base_access_chain->GetInOperand(i));
|
||||
}
|
||||
for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
|
||||
new_operands.emplace_back(access_chain->GetInOperand(i));
|
||||
}
|
||||
access_chain->SetInOperands(std::move(new_operands));
|
||||
}
|
||||
|
||||
Instruction* InterfaceVariableScalarReplacement::CreateAccessChainToVar(
|
||||
uint32_t var_type_id, Instruction* var,
|
||||
const std::vector<uint32_t>& index_ids, Instruction* insert_before,
|
||||
uint32_t* component_type_id) {
|
||||
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
||||
*component_type_id = GetComponentTypeOfArrayMatrix(
|
||||
def_use_mgr, var_type_id, static_cast<uint32_t>(index_ids.size()));
|
||||
|
||||
uint32_t ptr_type_id =
|
||||
GetPointerType(*component_type_id, GetStorageClass(var));
|
||||
|
||||
std::unique_ptr<Instruction> new_access_chain(
|
||||
new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(),
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {var->result_id()}}}));
|
||||
for (uint32_t index_id : index_ids) {
|
||||
new_access_chain->AddOperand({SPV_OPERAND_TYPE_ID, {index_id}});
|
||||
}
|
||||
|
||||
Instruction* inst = new_access_chain.get();
|
||||
def_use_mgr->AnalyzeInstDefUse(inst);
|
||||
insert_before->InsertBefore(std::move(new_access_chain));
|
||||
return inst;
|
||||
}
|
||||
|
||||
Instruction* InterfaceVariableScalarReplacement::CreateAccessChainWithIndex(
|
||||
uint32_t component_type_id, Instruction* var, uint32_t index,
|
||||
Instruction* insert_before) {
|
||||
uint32_t ptr_type_id =
|
||||
GetPointerType(component_type_id, GetStorageClass(var));
|
||||
uint32_t index_id = context()->get_constant_mgr()->GetUIntConst(index);
|
||||
std::unique_ptr<Instruction> new_access_chain(
|
||||
new Instruction(context(), SpvOpAccessChain, ptr_type_id, TakeNextId(),
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {var->result_id()}},
|
||||
{SPV_OPERAND_TYPE_ID, {index_id}},
|
||||
}));
|
||||
Instruction* inst = new_access_chain.get();
|
||||
context()->get_def_use_mgr()->AnalyzeInstDefUse(inst);
|
||||
insert_before->InsertBefore(std::move(new_access_chain));
|
||||
return inst;
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::ReplaceAccessChainWith(
|
||||
Instruction* access_chain,
|
||||
const std::vector<uint32_t>& interface_var_component_indices,
|
||||
Instruction* scalar_var,
|
||||
std::unordered_map<Instruction*, Instruction*>* loads_to_component_values) {
|
||||
std::vector<uint32_t> indexes;
|
||||
for (uint32_t i = 1; i < access_chain->NumInOperands(); ++i) {
|
||||
indexes.push_back(access_chain->GetSingleWordInOperand(i));
|
||||
}
|
||||
|
||||
// Note that we have a strong assumption that |access_chain| has only a single
|
||||
// index that is for the extra arrayness.
|
||||
context()->get_def_use_mgr()->ForEachUser(
|
||||
access_chain,
|
||||
[this, access_chain, &indexes, &interface_var_component_indices,
|
||||
scalar_var, loads_to_component_values](Instruction* user) {
|
||||
switch (user->opcode()) {
|
||||
case SpvOpAccessChain: {
|
||||
UseBaseAccessChainForAccessChain(user, access_chain);
|
||||
ReplaceAccessChainWith(user, interface_var_component_indices,
|
||||
scalar_var, loads_to_component_values);
|
||||
return;
|
||||
}
|
||||
case SpvOpStore: {
|
||||
uint32_t value_id = user->GetSingleWordInOperand(1);
|
||||
StoreComponentOfValueToAccessChainToScalarVar(
|
||||
value_id, interface_var_component_indices, scalar_var, indexes,
|
||||
user);
|
||||
return;
|
||||
}
|
||||
case SpvOpLoad: {
|
||||
Instruction* value =
|
||||
LoadAccessChainToVar(scalar_var, indexes, user);
|
||||
loads_to_component_values->insert({user, value});
|
||||
return;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::CloneAnnotationForVariable(
|
||||
Instruction* annotation_inst, uint32_t var_id) {
|
||||
assert(annotation_inst->opcode() == SpvOpDecorate ||
|
||||
annotation_inst->opcode() == SpvOpDecorateId ||
|
||||
annotation_inst->opcode() == SpvOpDecorateString);
|
||||
std::unique_ptr<Instruction> new_inst(annotation_inst->Clone(context()));
|
||||
new_inst->SetInOperand(0, {var_id});
|
||||
context()->AddAnnotationInst(std::move(new_inst));
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::ReplaceInterfaceVarInEntryPoint(
|
||||
Instruction* interface_var, Instruction* entry_point,
|
||||
uint32_t scalar_var_id) {
|
||||
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
||||
uint32_t interface_var_id = interface_var->result_id();
|
||||
if (interface_vars_removed_from_entry_point_operands_.find(
|
||||
interface_var_id) !=
|
||||
interface_vars_removed_from_entry_point_operands_.end()) {
|
||||
entry_point->AddOperand({SPV_OPERAND_TYPE_ID, {scalar_var_id}});
|
||||
def_use_mgr->AnalyzeInstUse(entry_point);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool success = !entry_point->WhileEachInId(
|
||||
[&interface_var_id, &scalar_var_id](uint32_t* id) {
|
||||
if (*id == interface_var_id) {
|
||||
*id = scalar_var_id;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (!success) {
|
||||
std::string message(
|
||||
"interface variable is not an operand of the entry point");
|
||||
message += "\n " + interface_var->PrettyPrint(
|
||||
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
|
||||
message += "\n " + entry_point->PrettyPrint(
|
||||
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
|
||||
context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
def_use_mgr->AnalyzeInstUse(entry_point);
|
||||
interface_vars_removed_from_entry_point_operands_.insert(interface_var_id);
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t InterfaceVariableScalarReplacement::GetPointeeTypeIdOfVar(
|
||||
Instruction* var) {
|
||||
assert(var->opcode() == SpvOpVariable);
|
||||
|
||||
uint32_t ptr_type_id = var->type_id();
|
||||
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
||||
Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_type_id);
|
||||
|
||||
assert(ptr_type_inst->opcode() == SpvOpTypePointer &&
|
||||
"Variable must have a pointer type.");
|
||||
return ptr_type_inst->GetSingleWordInOperand(kOpTypePtrTypeInOperandIndex);
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::StoreComponentOfValueToScalarVar(
|
||||
uint32_t value_id, const std::vector<uint32_t>& component_indices,
|
||||
Instruction* scalar_var, const uint32_t* extra_array_index,
|
||||
Instruction* insert_before) {
|
||||
uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
|
||||
Instruction* ptr = scalar_var;
|
||||
if (extra_array_index) {
|
||||
auto* ty_mgr = context()->get_type_mgr();
|
||||
analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
|
||||
assert(array_type != nullptr);
|
||||
component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
|
||||
ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
|
||||
*extra_array_index, insert_before);
|
||||
}
|
||||
|
||||
StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
|
||||
extra_array_index, insert_before);
|
||||
}
|
||||
|
||||
Instruction* InterfaceVariableScalarReplacement::LoadScalarVar(
|
||||
Instruction* scalar_var, const uint32_t* extra_array_index,
|
||||
Instruction* insert_before) {
|
||||
uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
|
||||
Instruction* ptr = scalar_var;
|
||||
if (extra_array_index) {
|
||||
auto* ty_mgr = context()->get_type_mgr();
|
||||
analysis::Array* array_type = ty_mgr->GetType(component_type_id)->AsArray();
|
||||
assert(array_type != nullptr);
|
||||
component_type_id = ty_mgr->GetTypeInstruction(array_type->element_type());
|
||||
ptr = CreateAccessChainWithIndex(component_type_id, scalar_var,
|
||||
*extra_array_index, insert_before);
|
||||
}
|
||||
|
||||
return CreateLoad(component_type_id, ptr, insert_before);
|
||||
}
|
||||
|
||||
Instruction* InterfaceVariableScalarReplacement::CreateLoad(
|
||||
uint32_t type_id, Instruction* ptr, Instruction* insert_before) {
|
||||
std::unique_ptr<Instruction> load(
|
||||
new Instruction(context(), SpvOpLoad, type_id, TakeNextId(),
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_ID, {ptr->result_id()}}}));
|
||||
Instruction* load_inst = load.get();
|
||||
context()->get_def_use_mgr()->AnalyzeInstDefUse(load_inst);
|
||||
insert_before->InsertBefore(std::move(load));
|
||||
return load_inst;
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::StoreComponentOfValueTo(
|
||||
uint32_t component_type_id, uint32_t value_id,
|
||||
const std::vector<uint32_t>& component_indices, Instruction* ptr,
|
||||
const uint32_t* extra_array_index, Instruction* insert_before) {
|
||||
std::unique_ptr<Instruction> composite_extract(CreateCompositeExtract(
|
||||
component_type_id, value_id, component_indices, extra_array_index));
|
||||
|
||||
std::unique_ptr<Instruction> new_store(
|
||||
new Instruction(context(), SpvOpStore));
|
||||
new_store->AddOperand({SPV_OPERAND_TYPE_ID, {ptr->result_id()}});
|
||||
new_store->AddOperand(
|
||||
{SPV_OPERAND_TYPE_ID, {composite_extract->result_id()}});
|
||||
|
||||
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
||||
def_use_mgr->AnalyzeInstDefUse(composite_extract.get());
|
||||
def_use_mgr->AnalyzeInstDefUse(new_store.get());
|
||||
|
||||
insert_before->InsertBefore(std::move(composite_extract));
|
||||
insert_before->InsertBefore(std::move(new_store));
|
||||
}
|
||||
|
||||
Instruction* InterfaceVariableScalarReplacement::CreateCompositeExtract(
|
||||
uint32_t type_id, uint32_t composite_id,
|
||||
const std::vector<uint32_t>& indexes, const uint32_t* extra_first_index) {
|
||||
uint32_t component_id = TakeNextId();
|
||||
Instruction* composite_extract = new Instruction(
|
||||
context(), SpvOpCompositeExtract, type_id, component_id,
|
||||
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {composite_id}}});
|
||||
if (extra_first_index) {
|
||||
composite_extract->AddOperand(
|
||||
{SPV_OPERAND_TYPE_LITERAL_INTEGER, {*extra_first_index}});
|
||||
}
|
||||
for (uint32_t index : indexes) {
|
||||
composite_extract->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {index}});
|
||||
}
|
||||
return composite_extract;
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::
|
||||
StoreComponentOfValueToAccessChainToScalarVar(
|
||||
uint32_t value_id, const std::vector<uint32_t>& component_indices,
|
||||
Instruction* scalar_var,
|
||||
const std::vector<uint32_t>& access_chain_indices,
|
||||
Instruction* insert_before) {
|
||||
uint32_t component_type_id = GetPointeeTypeIdOfVar(scalar_var);
|
||||
Instruction* ptr = scalar_var;
|
||||
if (!access_chain_indices.empty()) {
|
||||
ptr = CreateAccessChainToVar(component_type_id, scalar_var,
|
||||
access_chain_indices, insert_before,
|
||||
&component_type_id);
|
||||
}
|
||||
|
||||
StoreComponentOfValueTo(component_type_id, value_id, component_indices, ptr,
|
||||
nullptr, insert_before);
|
||||
}
|
||||
|
||||
Instruction* InterfaceVariableScalarReplacement::LoadAccessChainToVar(
|
||||
Instruction* var, const std::vector<uint32_t>& indexes,
|
||||
Instruction* insert_before) {
|
||||
uint32_t component_type_id = GetPointeeTypeIdOfVar(var);
|
||||
Instruction* ptr = var;
|
||||
if (!indexes.empty()) {
|
||||
ptr = CreateAccessChainToVar(component_type_id, var, indexes, insert_before,
|
||||
&component_type_id);
|
||||
}
|
||||
|
||||
return CreateLoad(component_type_id, ptr, insert_before);
|
||||
}
|
||||
|
||||
Instruction*
|
||||
InterfaceVariableScalarReplacement::CreateCompositeConstructForComponentOfLoad(
|
||||
Instruction* load, uint32_t depth_to_component) {
|
||||
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
||||
uint32_t type_id = load->type_id();
|
||||
if (depth_to_component != 0) {
|
||||
type_id = GetComponentTypeOfArrayMatrix(def_use_mgr, load->type_id(),
|
||||
depth_to_component);
|
||||
}
|
||||
uint32_t new_id = context()->TakeNextId();
|
||||
std::unique_ptr<Instruction> new_composite_construct(
|
||||
new Instruction(context(), SpvOpCompositeConstruct, type_id, new_id, {}));
|
||||
Instruction* composite_construct = new_composite_construct.get();
|
||||
def_use_mgr->AnalyzeInstDefUse(composite_construct);
|
||||
|
||||
// Insert |new_composite_construct| after |load|. When there are multiple
|
||||
// recursive composite construct instructions for a load, we have to place the
|
||||
// composite construct with a lower depth later because it constructs the
|
||||
// composite that contains other composites with lower depths.
|
||||
auto* insert_before = load->NextNode();
|
||||
while (true) {
|
||||
auto itr =
|
||||
composite_ids_to_component_depths.find(insert_before->result_id());
|
||||
if (itr == composite_ids_to_component_depths.end()) break;
|
||||
if (itr->second <= depth_to_component) break;
|
||||
insert_before = insert_before->NextNode();
|
||||
}
|
||||
insert_before->InsertBefore(std::move(new_composite_construct));
|
||||
composite_ids_to_component_depths.insert({new_id, depth_to_component});
|
||||
return composite_construct;
|
||||
}
|
||||
|
||||
void InterfaceVariableScalarReplacement::AddComponentsToCompositesForLoads(
|
||||
const std::unordered_map<Instruction*, Instruction*>&
|
||||
loads_to_component_values,
|
||||
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
|
||||
uint32_t depth_to_component) {
|
||||
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
||||
for (auto& load_and_component_vale : loads_to_component_values) {
|
||||
Instruction* load = load_and_component_vale.first;
|
||||
Instruction* component_value = load_and_component_vale.second;
|
||||
Instruction* composite_construct = nullptr;
|
||||
auto itr = loads_to_composites->find(load);
|
||||
if (itr == loads_to_composites->end()) {
|
||||
composite_construct =
|
||||
CreateCompositeConstructForComponentOfLoad(load, depth_to_component);
|
||||
loads_to_composites->insert({load, composite_construct});
|
||||
} else {
|
||||
composite_construct = itr->second;
|
||||
}
|
||||
composite_construct->AddOperand(
|
||||
{SPV_OPERAND_TYPE_ID, {component_value->result_id()}});
|
||||
def_use_mgr->AnalyzeInstDefUse(composite_construct);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t InterfaceVariableScalarReplacement::GetArrayType(
|
||||
uint32_t elem_type_id, uint32_t array_length) {
|
||||
analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id);
|
||||
uint32_t array_length_id =
|
||||
context()->get_constant_mgr()->GetUIntConst(array_length);
|
||||
analysis::Array array_type(
|
||||
elem_type,
|
||||
analysis::Array::LengthInfo{array_length_id, {0, array_length}});
|
||||
return context()->get_type_mgr()->GetTypeInstruction(&array_type);
|
||||
}
|
||||
|
||||
uint32_t InterfaceVariableScalarReplacement::GetPointerType(
|
||||
uint32_t type_id, SpvStorageClass storage_class) {
|
||||
analysis::Type* type = context()->get_type_mgr()->GetType(type_id);
|
||||
analysis::Pointer ptr_type(type, storage_class);
|
||||
return context()->get_type_mgr()->GetTypeInstruction(&ptr_type);
|
||||
}
|
||||
|
||||
InterfaceVariableScalarReplacement::NestedCompositeComponents
|
||||
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForArray(
|
||||
Instruction* interface_var_type, SpvStorageClass storage_class,
|
||||
uint32_t extra_array_length) {
|
||||
assert(interface_var_type->opcode() == SpvOpTypeArray);
|
||||
|
||||
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
||||
uint32_t array_length = GetArrayLength(def_use_mgr, interface_var_type);
|
||||
Instruction* elem_type = GetArrayElementType(def_use_mgr, interface_var_type);
|
||||
|
||||
NestedCompositeComponents scalar_vars;
|
||||
while (array_length > 0) {
|
||||
NestedCompositeComponents scalar_vars_for_element =
|
||||
CreateScalarInterfaceVarsForReplacement(elem_type, storage_class,
|
||||
extra_array_length);
|
||||
scalar_vars.AddComponent(scalar_vars_for_element);
|
||||
--array_length;
|
||||
}
|
||||
return scalar_vars;
|
||||
}
|
||||
|
||||
InterfaceVariableScalarReplacement::NestedCompositeComponents
|
||||
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForMatrix(
|
||||
Instruction* interface_var_type, SpvStorageClass storage_class,
|
||||
uint32_t extra_array_length) {
|
||||
assert(interface_var_type->opcode() == SpvOpTypeMatrix);
|
||||
|
||||
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
||||
uint32_t column_count = interface_var_type->GetSingleWordInOperand(
|
||||
kOpTypeMatrixColCountInOperandIndex);
|
||||
Instruction* column_type =
|
||||
GetMatrixColumnType(def_use_mgr, interface_var_type);
|
||||
|
||||
NestedCompositeComponents scalar_vars;
|
||||
while (column_count > 0) {
|
||||
NestedCompositeComponents scalar_vars_for_column =
|
||||
CreateScalarInterfaceVarsForReplacement(column_type, storage_class,
|
||||
extra_array_length);
|
||||
scalar_vars.AddComponent(scalar_vars_for_column);
|
||||
--column_count;
|
||||
}
|
||||
return scalar_vars;
|
||||
}
|
||||
|
||||
InterfaceVariableScalarReplacement::NestedCompositeComponents
|
||||
InterfaceVariableScalarReplacement::CreateScalarInterfaceVarsForReplacement(
|
||||
Instruction* interface_var_type, SpvStorageClass storage_class,
|
||||
uint32_t extra_array_length) {
|
||||
// Handle array case.
|
||||
if (interface_var_type->opcode() == SpvOpTypeArray) {
|
||||
return CreateScalarInterfaceVarsForArray(interface_var_type, storage_class,
|
||||
extra_array_length);
|
||||
}
|
||||
|
||||
// Handle matrix case.
|
||||
if (interface_var_type->opcode() == SpvOpTypeMatrix) {
|
||||
return CreateScalarInterfaceVarsForMatrix(interface_var_type, storage_class,
|
||||
extra_array_length);
|
||||
}
|
||||
|
||||
// Handle scalar or vector case.
|
||||
NestedCompositeComponents scalar_var;
|
||||
uint32_t type_id = interface_var_type->result_id();
|
||||
if (extra_array_length != 0) {
|
||||
type_id = GetArrayType(type_id, extra_array_length);
|
||||
}
|
||||
uint32_t ptr_type_id =
|
||||
context()->get_type_mgr()->FindPointerToType(type_id, storage_class);
|
||||
uint32_t id = TakeNextId();
|
||||
std::unique_ptr<Instruction> variable(
|
||||
new Instruction(context(), SpvOpVariable, ptr_type_id, id,
|
||||
std::initializer_list<Operand>{
|
||||
{SPV_OPERAND_TYPE_STORAGE_CLASS,
|
||||
{static_cast<uint32_t>(storage_class)}}}));
|
||||
scalar_var.SetSingleComponentVariable(variable.get());
|
||||
context()->AddGlobalValue(std::move(variable));
|
||||
return scalar_var;
|
||||
}
|
||||
|
||||
Instruction* InterfaceVariableScalarReplacement::GetTypeOfVariable(
|
||||
Instruction* var) {
|
||||
uint32_t pointee_type_id = GetPointeeTypeIdOfVar(var);
|
||||
analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
|
||||
return def_use_mgr->GetDef(pointee_type_id);
|
||||
}
|
||||
|
||||
Pass::Status InterfaceVariableScalarReplacement::Process() {
|
||||
Pass::Status status = Status::SuccessWithoutChange;
|
||||
for (Instruction& entry_point : get_module()->entry_points()) {
|
||||
status =
|
||||
CombineStatus(status, ReplaceInterfaceVarsWithScalars(entry_point));
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::
|
||||
ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var) {
|
||||
if (vars_with_extra_arrayness.find(var) == vars_with_extra_arrayness.end())
|
||||
return false;
|
||||
|
||||
std::string message(
|
||||
"A variable is arrayed for an entry point but it is not "
|
||||
"arrayed for another entry point");
|
||||
message +=
|
||||
"\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
|
||||
context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool InterfaceVariableScalarReplacement::
|
||||
ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var) {
|
||||
if (vars_without_extra_arrayness.find(var) ==
|
||||
vars_without_extra_arrayness.end())
|
||||
return false;
|
||||
|
||||
std::string message(
|
||||
"A variable is not arrayed for an entry point but it is "
|
||||
"arrayed for another entry point");
|
||||
message +=
|
||||
"\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
|
||||
context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
Pass::Status
|
||||
InterfaceVariableScalarReplacement::ReplaceInterfaceVarsWithScalars(
|
||||
Instruction& entry_point) {
|
||||
std::vector<Instruction*> interface_vars =
|
||||
CollectInterfaceVariables(entry_point);
|
||||
|
||||
Pass::Status status = Status::SuccessWithoutChange;
|
||||
for (Instruction* interface_var : interface_vars) {
|
||||
uint32_t location, component;
|
||||
if (!GetVariableLocation(interface_var, &location)) continue;
|
||||
if (!GetVariableComponent(interface_var, &component)) component = 0;
|
||||
|
||||
Instruction* interface_var_type = GetTypeOfVariable(interface_var);
|
||||
uint32_t extra_array_length = 0;
|
||||
if (HasExtraArrayness(entry_point, interface_var)) {
|
||||
extra_array_length =
|
||||
GetArrayLength(context()->get_def_use_mgr(), interface_var_type);
|
||||
interface_var_type =
|
||||
GetArrayElementType(context()->get_def_use_mgr(), interface_var_type);
|
||||
vars_with_extra_arrayness.insert(interface_var);
|
||||
} else {
|
||||
vars_without_extra_arrayness.insert(interface_var);
|
||||
}
|
||||
|
||||
if (!CheckExtraArraynessConflictBetweenEntries(interface_var,
|
||||
extra_array_length != 0)) {
|
||||
return Pass::Status::Failure;
|
||||
}
|
||||
|
||||
if (interface_var_type->opcode() != SpvOpTypeArray &&
|
||||
interface_var_type->opcode() != SpvOpTypeMatrix) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!ReplaceInterfaceVariableWithScalars(interface_var, interface_var_type,
|
||||
location, component,
|
||||
extra_array_length)) {
|
||||
return Pass::Status::Failure;
|
||||
}
|
||||
status = Pass::Status::SuccessWithChange;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
401
source/opt/interface_var_sroa.h
Normal file
401
source/opt/interface_var_sroa.h
Normal file
@ -0,0 +1,401 @@
|
||||
// Copyright (c) 2022 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SOURCE_OPT_INTERFACE_VAR_SROA_H_
|
||||
#define SOURCE_OPT_INTERFACE_VAR_SROA_H_
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "source/opt/pass.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace opt {
|
||||
|
||||
// See optimizer.hpp for documentation.
|
||||
//
|
||||
// Note that the current implementation of this pass covers only store, load,
|
||||
// access chain instructions for the interface variables. Supporting other types
|
||||
// of instructions is a future work.
|
||||
class InterfaceVariableScalarReplacement : public Pass {
|
||||
public:
|
||||
InterfaceVariableScalarReplacement() {}
|
||||
|
||||
const char* name() const override {
|
||||
return "interface-variable-scalar-replacement";
|
||||
}
|
||||
Status Process() override;
|
||||
|
||||
IRContext::Analysis GetPreservedAnalyses() override {
|
||||
return IRContext::kAnalysisDecorations | IRContext::kAnalysisDefUse |
|
||||
IRContext::kAnalysisConstants | IRContext::kAnalysisTypes;
|
||||
}
|
||||
|
||||
private:
|
||||
// A struct containing components of a composite variable. If the composite
|
||||
// consists of multiple or recursive components, |component_variable| is
|
||||
// nullptr and |nested_composite_components| keeps the components. If it has a
|
||||
// single component, |nested_composite_components| is empty and
|
||||
// |component_variable| is the component. Note that each element of
|
||||
// |nested_composite_components| has the NestedCompositeComponents struct as
|
||||
// its type that can recursively keep the components.
|
||||
struct NestedCompositeComponents {
|
||||
NestedCompositeComponents() : component_variable(nullptr) {}
|
||||
|
||||
bool HasMultipleComponents() const {
|
||||
return !nested_composite_components.empty();
|
||||
}
|
||||
|
||||
const std::vector<NestedCompositeComponents>& GetComponents() const {
|
||||
return nested_composite_components;
|
||||
}
|
||||
|
||||
void AddComponent(const NestedCompositeComponents& component) {
|
||||
nested_composite_components.push_back(component);
|
||||
}
|
||||
|
||||
Instruction* GetComponentVariable() const { return component_variable; }
|
||||
|
||||
void SetSingleComponentVariable(Instruction* var) {
|
||||
component_variable = var;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<NestedCompositeComponents> nested_composite_components;
|
||||
Instruction* component_variable;
|
||||
};
|
||||
|
||||
// Collects all interface variables used by the |entry_point|.
|
||||
std::vector<Instruction*> CollectInterfaceVariables(Instruction& entry_point);
|
||||
|
||||
// Returns whether |var| has the extra arrayness for the entry point
|
||||
// |entry_point| or not.
|
||||
bool HasExtraArrayness(Instruction& entry_point, Instruction* var);
|
||||
|
||||
// Finds a Location BuiltIn decoration of |var| and returns it via
|
||||
// |location|. Returns true whether the location exists or not.
|
||||
bool GetVariableLocation(Instruction* var, uint32_t* location);
|
||||
|
||||
// Finds a Component BuiltIn decoration of |var| and returns it via
|
||||
// |component|. Returns true whether the component exists or not.
|
||||
bool GetVariableComponent(Instruction* var, uint32_t* component);
|
||||
|
||||
// Returns the interface variable instruction whose result id is
|
||||
// |interface_var_id|.
|
||||
Instruction* GetInterfaceVariable(uint32_t interface_var_id);
|
||||
|
||||
// Returns the type of |var| as an instruction.
|
||||
Instruction* GetTypeOfVariable(Instruction* var);
|
||||
|
||||
// Replaces an interface variable |interface_var| whose type is
|
||||
// |interface_var_type| with scalars and returns whether it succeeds or not.
|
||||
// |location| is the value of Location Decoration for |interface_var|.
|
||||
// |component| is the value of Component Decoration for |interface_var|.
|
||||
// If |extra_array_length| is 0, it means |interface_var| has a Patch
|
||||
// decoration. Otherwise, |extra_array_length| denotes the length of the extra
|
||||
// array of |interface_var|.
|
||||
bool ReplaceInterfaceVariableWithScalars(Instruction* interface_var,
|
||||
Instruction* interface_var_type,
|
||||
uint32_t location,
|
||||
uint32_t component,
|
||||
uint32_t extra_array_length);
|
||||
|
||||
// Creates scalar variables with the storage classe |storage_class| to replace
|
||||
// an interface variable whose type is |interface_var_type|. If
|
||||
// |extra_array_length| is not zero, adds the extra arrayness to the created
|
||||
// scalar variables.
|
||||
NestedCompositeComponents CreateScalarInterfaceVarsForReplacement(
|
||||
Instruction* interface_var_type, SpvStorageClass storage_class,
|
||||
uint32_t extra_array_length);
|
||||
|
||||
// Creates scalar variables with the storage classe |storage_class| to replace
|
||||
// the interface variable whose type is OpTypeArray |interface_var_type| with.
|
||||
// If |extra_array_length| is not zero, adds the extra arrayness to all the
|
||||
// scalar variables.
|
||||
NestedCompositeComponents CreateScalarInterfaceVarsForArray(
|
||||
Instruction* interface_var_type, SpvStorageClass storage_class,
|
||||
uint32_t extra_array_length);
|
||||
|
||||
// Creates scalar variables with the storage classe |storage_class| to replace
|
||||
// the interface variable whose type is OpTypeMatrix |interface_var_type|
|
||||
// with. If |extra_array_length| is not zero, adds the extra arrayness to all
|
||||
// the scalar variables.
|
||||
NestedCompositeComponents CreateScalarInterfaceVarsForMatrix(
|
||||
Instruction* interface_var_type, SpvStorageClass storage_class,
|
||||
uint32_t extra_array_length);
|
||||
|
||||
// Recursively adds Location and Component decorations to variables in
|
||||
// |vars| with |location| and |component|. Increases |location| by one after
|
||||
// it actually adds Location and Component decorations for a variable.
|
||||
void AddLocationAndComponentDecorations(const NestedCompositeComponents& vars,
|
||||
uint32_t* location,
|
||||
uint32_t component);
|
||||
|
||||
// Replaces the interface variable |interface_var| with
|
||||
// |scalar_interface_vars| and returns whether it succeeds or not.
|
||||
// |extra_arrayness| is the extra arrayness of the interface variable.
|
||||
// |scalar_interface_vars| contains the nested variables to replace the
|
||||
// interface variable with.
|
||||
bool ReplaceInterfaceVarWith(
|
||||
Instruction* interface_var, uint32_t extra_arrayness,
|
||||
const NestedCompositeComponents& scalar_interface_vars);
|
||||
|
||||
// Replaces |interface_var| in the operands of instructions
|
||||
// |interface_var_users| with |scalar_interface_vars|. This is a recursive
|
||||
// method and |interface_var_component_indices| is used to specify which
|
||||
// recursive component of |interface_var| is replaced. Returns composite
|
||||
// construct instructions to be replaced with load instructions of
|
||||
// |interface_var_users| via |loads_to_composites|. Returns composite
|
||||
// construct instructions to be replaced with load instructions of access
|
||||
// chain instructions in |interface_var_users| via
|
||||
// |loads_for_access_chain_to_composites|.
|
||||
bool ReplaceComponentsOfInterfaceVarWith(
|
||||
Instruction* interface_var,
|
||||
const std::vector<Instruction*>& interface_var_users,
|
||||
const NestedCompositeComponents& scalar_interface_vars,
|
||||
std::vector<uint32_t>& interface_var_component_indices,
|
||||
const uint32_t* extra_array_index,
|
||||
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
|
||||
std::unordered_map<Instruction*, Instruction*>*
|
||||
loads_for_access_chain_to_composites);
|
||||
|
||||
// Replaces |interface_var| in the operands of instructions
|
||||
// |interface_var_users| with |components| that is a vector of components for
|
||||
// the interface variable |interface_var|. This is a recursive method and
|
||||
// |interface_var_component_indices| is used to specify which recursive
|
||||
// component of |interface_var| is replaced. Returns composite construct
|
||||
// instructions to be replaced with load instructions of |interface_var_users|
|
||||
// via |loads_to_composites|. Returns composite construct instructions to be
|
||||
// replaced with load instructions of access chain instructions in
|
||||
// |interface_var_users| via |loads_for_access_chain_to_composites|.
|
||||
bool ReplaceMultipleComponentsOfInterfaceVarWith(
|
||||
Instruction* interface_var,
|
||||
const std::vector<Instruction*>& interface_var_users,
|
||||
const std::vector<NestedCompositeComponents>& components,
|
||||
std::vector<uint32_t>& interface_var_component_indices,
|
||||
const uint32_t* extra_array_index,
|
||||
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
|
||||
std::unordered_map<Instruction*, Instruction*>*
|
||||
loads_for_access_chain_to_composites);
|
||||
|
||||
// Replaces a component of |interface_var| that is used as an operand of
|
||||
// instruction |interface_var_user| with |scalar_var|.
|
||||
// |interface_var_component_indices| is a vector of recursive indices for
|
||||
// which recursive component of |interface_var| is replaced. If
|
||||
// |interface_var_user| is a load, returns the component value via
|
||||
// |loads_to_component_values|. If |interface_var_user| is an access chain,
|
||||
// returns the component value for loads of |interface_var_user| via
|
||||
// |loads_for_access_chain_to_component_values|.
|
||||
bool ReplaceComponentOfInterfaceVarWith(
|
||||
Instruction* interface_var, Instruction* interface_var_user,
|
||||
Instruction* scalar_var,
|
||||
const std::vector<uint32_t>& interface_var_component_indices,
|
||||
const uint32_t* extra_array_index,
|
||||
std::unordered_map<Instruction*, Instruction*>* loads_to_component_values,
|
||||
std::unordered_map<Instruction*, Instruction*>*
|
||||
loads_for_access_chain_to_component_values);
|
||||
|
||||
// Creates instructions to load |scalar_var| and inserts them before
|
||||
// |insert_before|. If |extra_array_index| is not null, they load
|
||||
// |extra_array_index| th component of |scalar_var| instead of |scalar_var|
|
||||
// itself.
|
||||
Instruction* LoadScalarVar(Instruction* scalar_var,
|
||||
const uint32_t* extra_array_index,
|
||||
Instruction* insert_before);
|
||||
|
||||
// Creates instructions to load an access chain to |var| and inserts them
|
||||
// before |insert_before|. |Indexes| will be Indexes operand of the access
|
||||
// chain.
|
||||
Instruction* LoadAccessChainToVar(Instruction* var,
|
||||
const std::vector<uint32_t>& indexes,
|
||||
Instruction* insert_before);
|
||||
|
||||
// Creates instructions to store a component of an aggregate whose id is
|
||||
// |value_id| to an access chain to |scalar_var| and inserts the created
|
||||
// instructions before |insert_before|. To get the component, recursively
|
||||
// traverses the aggregate with |component_indices| as indexes.
|
||||
// Numbers in |access_chain_indices| are the Indexes operand of the access
|
||||
// chain to |scalar_var|
|
||||
void StoreComponentOfValueToAccessChainToScalarVar(
|
||||
uint32_t value_id, const std::vector<uint32_t>& component_indices,
|
||||
Instruction* scalar_var,
|
||||
const std::vector<uint32_t>& access_chain_indices,
|
||||
Instruction* insert_before);
|
||||
|
||||
// Creates instructions to store a component of an aggregate whose id is
|
||||
// |value_id| to |scalar_var| and inserts the created instructions before
|
||||
// |insert_before|. To get the component, recursively traverses the aggregate
|
||||
// using |extra_array_index| and |component_indices| as indexes.
|
||||
void StoreComponentOfValueToScalarVar(
|
||||
uint32_t value_id, const std::vector<uint32_t>& component_indices,
|
||||
Instruction* scalar_var, const uint32_t* extra_array_index,
|
||||
Instruction* insert_before);
|
||||
|
||||
// Creates instructions to store a component of an aggregate whose id is
|
||||
// |value_id| to |ptr| and inserts the created instructions before
|
||||
// |insert_before|. To get the component, recursively traverses the aggregate
|
||||
// using |extra_array_index| and |component_indices| as indexes.
|
||||
// |component_type_id| is the id of the type instruction of the component.
|
||||
void StoreComponentOfValueTo(uint32_t component_type_id, uint32_t value_id,
|
||||
const std::vector<uint32_t>& component_indices,
|
||||
Instruction* ptr,
|
||||
const uint32_t* extra_array_index,
|
||||
Instruction* insert_before);
|
||||
|
||||
// Creates new OpCompositeExtract with |type_id| for Result Type,
|
||||
// |composite_id| for Composite operand, and |indexes| for Indexes operands.
|
||||
// If |extra_first_index| is not nullptr, uses it as the first Indexes
|
||||
// operand.
|
||||
Instruction* CreateCompositeExtract(uint32_t type_id, uint32_t composite_id,
|
||||
const std::vector<uint32_t>& indexes,
|
||||
const uint32_t* extra_first_index);
|
||||
|
||||
// Creates a new OpLoad whose Result Type is |type_id| and Pointer operand is
|
||||
// |ptr|. Inserts the new instruction before |insert_before|.
|
||||
Instruction* CreateLoad(uint32_t type_id, Instruction* ptr,
|
||||
Instruction* insert_before);
|
||||
|
||||
// Clones an annotation instruction |annotation_inst| and sets the target
|
||||
// operand of the new annotation instruction as |var_id|.
|
||||
void CloneAnnotationForVariable(Instruction* annotation_inst,
|
||||
uint32_t var_id);
|
||||
|
||||
// Replaces the interface variable |interface_var| in the operands of the
|
||||
// entry point |entry_point| with |scalar_var_id|. If it cannot find
|
||||
// |interface_var| from the operands of the entry point |entry_point|, adds
|
||||
// |scalar_var_id| as an operand of the entry point |entry_point|.
|
||||
bool ReplaceInterfaceVarInEntryPoint(Instruction* interface_var,
|
||||
Instruction* entry_point,
|
||||
uint32_t scalar_var_id);
|
||||
|
||||
// Creates an access chain instruction whose Base operand is |var| and Indexes
|
||||
// operand is |index|. |component_type_id| is the id of the type instruction
|
||||
// that is the type of component. Inserts the new access chain before
|
||||
// |insert_before|.
|
||||
Instruction* CreateAccessChainWithIndex(uint32_t component_type_id,
|
||||
Instruction* var, uint32_t index,
|
||||
Instruction* insert_before);
|
||||
|
||||
// Returns the pointee type of the type of variable |var|.
|
||||
uint32_t GetPointeeTypeIdOfVar(Instruction* var);
|
||||
|
||||
// Replaces the access chain |access_chain| and its users with a new access
|
||||
// chain that points |scalar_var| as the Base operand having
|
||||
// |interface_var_component_indices| as Indexes operands and users of the new
|
||||
// access chain. When some of the users are load instructions, returns the
|
||||
// original load instruction to the new instruction that loads a component of
|
||||
// the original load value via |loads_to_component_values|.
|
||||
void ReplaceAccessChainWith(
|
||||
Instruction* access_chain,
|
||||
const std::vector<uint32_t>& interface_var_component_indices,
|
||||
Instruction* scalar_var,
|
||||
std::unordered_map<Instruction*, Instruction*>*
|
||||
loads_to_component_values);
|
||||
|
||||
// Assuming that |access_chain| is an access chain instruction whose Base
|
||||
// operand is |base_access_chain|, replaces the operands of |access_chain|
|
||||
// with operands of |base_access_chain| and Indexes operands of
|
||||
// |access_chain|.
|
||||
void UseBaseAccessChainForAccessChain(Instruction* access_chain,
|
||||
Instruction* base_access_chain);
|
||||
|
||||
// Creates composite construct instructions for load instructions that are the
|
||||
// keys of |loads_to_component_values| if no such composite construct
|
||||
// instructions exist. Adds a component of the composite as an operand of the
|
||||
// created composite construct instruction. Each value of
|
||||
// |loads_to_component_values| is the component. Returns the created composite
|
||||
// construct instructions using |loads_to_composites|. |depth_to_component| is
|
||||
// the number of recursive access steps to get the component from the
|
||||
// composite.
|
||||
void AddComponentsToCompositesForLoads(
|
||||
const std::unordered_map<Instruction*, Instruction*>&
|
||||
loads_to_component_values,
|
||||
std::unordered_map<Instruction*, Instruction*>* loads_to_composites,
|
||||
uint32_t depth_to_component);
|
||||
|
||||
// Creates a composite construct instruction for a component of the value of
|
||||
// instruction |load| in |depth_to_component| th recursive depth and inserts
|
||||
// it after |load|.
|
||||
Instruction* CreateCompositeConstructForComponentOfLoad(
|
||||
Instruction* load, uint32_t depth_to_component);
|
||||
|
||||
// Creates a new access chain instruction that points to variable |var| whose
|
||||
// type is the instruction with |var_type_id| and inserts it before
|
||||
// |insert_before|. The new access chain will have |index_ids| for Indexes
|
||||
// operands. Returns the type id of the component that is pointed by the new
|
||||
// access chain via |component_type_id|.
|
||||
Instruction* CreateAccessChainToVar(uint32_t var_type_id, Instruction* var,
|
||||
const std::vector<uint32_t>& index_ids,
|
||||
Instruction* insert_before,
|
||||
uint32_t* component_type_id);
|
||||
|
||||
// Returns the result id of OpTypeArray instrunction whose Element Type
|
||||
// operand is |elem_type_id| and Length operand is |array_length|.
|
||||
uint32_t GetArrayType(uint32_t elem_type_id, uint32_t array_length);
|
||||
|
||||
// Returns the result id of OpTypePointer instrunction whose Type
|
||||
// operand is |type_id| and Storage Class operand is |storage_class|.
|
||||
uint32_t GetPointerType(uint32_t type_id, SpvStorageClass storage_class);
|
||||
|
||||
// Kills an instrunction |inst| and its users.
|
||||
void KillInstructionAndUsers(Instruction* inst);
|
||||
|
||||
// Kills a vector of instrunctions |insts| and their users.
|
||||
void KillInstructionsAndUsers(const std::vector<Instruction*>& insts);
|
||||
|
||||
// Kills all OpDecorate instructions for Location and Component of the
|
||||
// variable whose id is |var_id|.
|
||||
void KillLocationAndComponentDecorations(uint32_t var_id);
|
||||
|
||||
// If |var| has the extra arrayness for an entry point, reports an error and
|
||||
// returns true. Otherwise, returns false.
|
||||
bool ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var);
|
||||
|
||||
// If |var| does not have the extra arrayness for an entry point, reports an
|
||||
// error and returns true. Otherwise, returns false.
|
||||
bool ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var);
|
||||
|
||||
// If |interface_var| has the extra arrayness for an entry point but it does
|
||||
// not have one for another entry point, reports an error and returns false.
|
||||
// Otherwise, returns true. |has_extra_arrayness| denotes whether it has an
|
||||
// extra arrayness for an entry point or not.
|
||||
bool CheckExtraArraynessConflictBetweenEntries(Instruction* interface_var,
|
||||
bool has_extra_arrayness);
|
||||
|
||||
// Conducts the scalar replacement for the interface variables used by the
|
||||
// |entry_point|.
|
||||
Pass::Status ReplaceInterfaceVarsWithScalars(Instruction& entry_point);
|
||||
|
||||
// A set of interface variable ids that were already removed from operands of
|
||||
// the entry point.
|
||||
std::unordered_set<uint32_t>
|
||||
interface_vars_removed_from_entry_point_operands_;
|
||||
|
||||
// A mapping from ids of new composite construct instructions that load
|
||||
// instructions are replaced with to the recursive depth of the component of
|
||||
// load that the new component construct instruction is used for.
|
||||
std::unordered_map<uint32_t, uint32_t> composite_ids_to_component_depths;
|
||||
|
||||
// A set of interface variables with the extra arrayness for any of the entry
|
||||
// points.
|
||||
std::unordered_set<Instruction*> vars_with_extra_arrayness;
|
||||
|
||||
// A set of interface variables without the extra arrayness for any of the
|
||||
// entry points.
|
||||
std::unordered_set<Instruction*> vars_without_extra_arrayness;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
||||
|
||||
#endif // SOURCE_OPT_INTERFACE_VAR_SROA_H_
|
@ -1094,6 +1094,9 @@ void IRContext::AddDebug2Inst(std::unique_ptr<Instruction>&& d) {
|
||||
id_to_name_->insert({d->GetSingleWordInOperand(0), d.get()});
|
||||
}
|
||||
}
|
||||
if (AreAnalysesValid(kAnalysisDefUse)) {
|
||||
get_def_use_mgr()->AnalyzeInstDefUse(d.get());
|
||||
}
|
||||
module()->AddDebug2Inst(std::move(d));
|
||||
}
|
||||
|
||||
|
@ -1020,6 +1020,11 @@ Optimizer::PassToken CreateConvertToSampledImagePass(
|
||||
MakeUnique<opt::ConvertToSampledImagePass>(descriptor_set_binding_pairs));
|
||||
}
|
||||
|
||||
Optimizer::PassToken CreateInterfaceVariableScalarReplacementPass() {
|
||||
return MakeUnique<Optimizer::PassToken::Impl>(
|
||||
MakeUnique<opt::InterfaceVariableScalarReplacement>());
|
||||
}
|
||||
|
||||
Optimizer::PassToken CreateRemoveDontInlinePass() {
|
||||
return MakeUnique<Optimizer::PassToken::Impl>(
|
||||
MakeUnique<opt::RemoveDontInline>());
|
||||
|
@ -49,6 +49,7 @@
|
||||
#include "source/opt/inst_bindless_check_pass.h"
|
||||
#include "source/opt/inst_buff_addr_check_pass.h"
|
||||
#include "source/opt/inst_debug_printf_pass.h"
|
||||
#include "source/opt/interface_var_sroa.h"
|
||||
#include "source/opt/interp_fixup_pass.h"
|
||||
#include "source/opt/licm_pass.h"
|
||||
#include "source/opt/local_access_chain_convert_pass.h"
|
||||
|
@ -62,6 +62,7 @@ add_spvtools_unittest(TARGET opt
|
||||
inst_debug_printf_test.cpp
|
||||
instruction_list_test.cpp
|
||||
instruction_test.cpp
|
||||
interface_var_sroa_test.cpp
|
||||
interp_fixup_test.cpp
|
||||
ir_builder.cpp
|
||||
ir_context_test.cpp
|
||||
|
410
test/opt/interface_var_sroa_test.cpp
Normal file
410
test/opt/interface_var_sroa_test.cpp
Normal file
@ -0,0 +1,410 @@
|
||||
// Copyright (c) 2022 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "test/opt/assembly_builder.h"
|
||||
#include "test/opt/pass_fixture.h"
|
||||
#include "test/opt/pass_utils.h"
|
||||
|
||||
namespace spvtools {
|
||||
namespace opt {
|
||||
namespace {
|
||||
|
||||
using InterfaceVariableScalarReplacementTest = PassTest<::testing::Test>;
|
||||
|
||||
TEST_F(InterfaceVariableScalarReplacementTest,
|
||||
ReplaceInterfaceVarsWithScalars) {
|
||||
const std::string spirv = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Tessellation
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint TessellationControl %func "shader" %x %y %z %w %u %v
|
||||
|
||||
; CHECK: OpName [[x:%\w+]] "x"
|
||||
; CHECK-NOT: OpName {{%\w+}} "x"
|
||||
; CHECK: OpName [[y:%\w+]] "y"
|
||||
; CHECK-NOT: OpName {{%\w+}} "y"
|
||||
; CHECK: OpName [[z0:%\w+]] "z"
|
||||
; CHECK: OpName [[z1:%\w+]] "z"
|
||||
; CHECK: OpName [[w0:%\w+]] "w"
|
||||
; CHECK: OpName [[w1:%\w+]] "w"
|
||||
; CHECK: OpName [[u0:%\w+]] "u"
|
||||
; CHECK: OpName [[u1:%\w+]] "u"
|
||||
; CHECK: OpName [[v0:%\w+]] "v"
|
||||
; CHECK: OpName [[v1:%\w+]] "v"
|
||||
; CHECK: OpName [[v2:%\w+]] "v"
|
||||
; CHECK: OpName [[v3:%\w+]] "v"
|
||||
; CHECK: OpName [[v4:%\w+]] "v"
|
||||
; CHECK: OpName [[v5:%\w+]] "v"
|
||||
OpName %x "x"
|
||||
OpName %y "y"
|
||||
OpName %z "z"
|
||||
OpName %w "w"
|
||||
OpName %u "u"
|
||||
OpName %v "v"
|
||||
|
||||
; CHECK-DAG: OpDecorate [[x]] Location 2
|
||||
; CHECK-DAG: OpDecorate [[y]] Location 0
|
||||
; CHECK-DAG: OpDecorate [[z0]] Location 0
|
||||
; CHECK-DAG: OpDecorate [[z0]] Component 0
|
||||
; CHECK-DAG: OpDecorate [[z1]] Location 1
|
||||
; CHECK-DAG: OpDecorate [[z1]] Component 0
|
||||
; CHECK-DAG: OpDecorate [[z0]] Patch
|
||||
; CHECK-DAG: OpDecorate [[z1]] Patch
|
||||
; CHECK-DAG: OpDecorate [[w0]] Location 2
|
||||
; CHECK-DAG: OpDecorate [[w0]] Component 0
|
||||
; CHECK-DAG: OpDecorate [[w1]] Location 3
|
||||
; CHECK-DAG: OpDecorate [[w1]] Component 0
|
||||
; CHECK-DAG: OpDecorate [[w0]] Patch
|
||||
; CHECK-DAG: OpDecorate [[w1]] Patch
|
||||
; CHECK-DAG: OpDecorate [[u0]] Location 3
|
||||
; CHECK-DAG: OpDecorate [[u0]] Component 2
|
||||
; CHECK-DAG: OpDecorate [[u1]] Location 4
|
||||
; CHECK-DAG: OpDecorate [[u1]] Component 2
|
||||
; CHECK-DAG: OpDecorate [[v0]] Location 3
|
||||
; CHECK-DAG: OpDecorate [[v0]] Component 3
|
||||
; CHECK-DAG: OpDecorate [[v1]] Location 4
|
||||
; CHECK-DAG: OpDecorate [[v1]] Component 3
|
||||
; CHECK-DAG: OpDecorate [[v2]] Location 5
|
||||
; CHECK-DAG: OpDecorate [[v2]] Component 3
|
||||
; CHECK-DAG: OpDecorate [[v3]] Location 6
|
||||
; CHECK-DAG: OpDecorate [[v3]] Component 3
|
||||
; CHECK-DAG: OpDecorate [[v4]] Location 7
|
||||
; CHECK-DAG: OpDecorate [[v4]] Component 3
|
||||
; CHECK-DAG: OpDecorate [[v5]] Location 8
|
||||
; CHECK-DAG: OpDecorate [[v5]] Component 3
|
||||
OpDecorate %z Patch
|
||||
OpDecorate %w Patch
|
||||
OpDecorate %z Location 0
|
||||
OpDecorate %x Location 2
|
||||
OpDecorate %v Location 3
|
||||
OpDecorate %v Component 3
|
||||
OpDecorate %y Location 0
|
||||
OpDecorate %w Location 2
|
||||
OpDecorate %u Location 3
|
||||
OpDecorate %u Component 2
|
||||
|
||||
%uint = OpTypeInt 32 0
|
||||
%uint_1 = OpConstant %uint 1
|
||||
%uint_2 = OpConstant %uint 2
|
||||
%uint_3 = OpConstant %uint 3
|
||||
%uint_4 = OpConstant %uint 4
|
||||
%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
|
||||
%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
|
||||
%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
|
||||
%_ptr_Input_uint = OpTypePointer Input %uint
|
||||
%_ptr_Output_uint = OpTypePointer Output %uint
|
||||
%_arr_arr_uint_uint_2_3 = OpTypeArray %_arr_uint_uint_2 %uint_3
|
||||
%_ptr_Input__arr_arr_uint_uint_2_3 = OpTypePointer Input %_arr_arr_uint_uint_2_3
|
||||
%_arr_arr_arr_uint_uint_2_3_4 = OpTypeArray %_arr_arr_uint_uint_2_3 %uint_4
|
||||
%_ptr_Output__arr_arr_arr_uint_uint_2_3_4 = OpTypePointer Output %_arr_arr_arr_uint_uint_2_3_4
|
||||
%_ptr_Output__arr_arr_uint_uint_2_3 = OpTypePointer Output %_arr_arr_uint_uint_2_3
|
||||
%z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
|
||||
%x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
|
||||
%y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
|
||||
%w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
|
||||
%u = OpVariable %_ptr_Input__arr_arr_uint_uint_2_3 Input
|
||||
%v = OpVariable %_ptr_Output__arr_arr_arr_uint_uint_2_3_4 Output
|
||||
|
||||
; CHECK-DAG: [[x]] = OpVariable %_ptr_Output__arr_uint_uint_2 Output
|
||||
; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input
|
||||
; CHECK-DAG: [[z0]] = OpVariable %_ptr_Output_uint Output
|
||||
; CHECK-DAG: [[z1]] = OpVariable %_ptr_Output_uint Output
|
||||
; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input
|
||||
; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input
|
||||
; CHECK-DAG: [[u0]] = OpVariable %_ptr_Input__arr_uint_uint_3 Input
|
||||
; CHECK-DAG: [[u1]] = OpVariable %_ptr_Input__arr_uint_uint_3 Input
|
||||
; CHECK-DAG: [[v0]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
|
||||
; CHECK-DAG: [[v1]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
|
||||
; CHECK-DAG: [[v2]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
|
||||
; CHECK-DAG: [[v3]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
|
||||
; CHECK-DAG: [[v4]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
|
||||
; CHECK-DAG: [[v5]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output
|
||||
|
||||
%void = OpTypeVoid
|
||||
%void_f = OpTypeFunction %void
|
||||
%func = OpFunction %void None %void_f
|
||||
%label = OpLabel
|
||||
|
||||
; CHECK: [[w0_value:%\w+]] = OpLoad %uint [[w0]]
|
||||
; CHECK: [[w1_value:%\w+]] = OpLoad %uint [[w1]]
|
||||
; CHECK: [[w_value:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[w0_value]] [[w1_value]]
|
||||
; CHECK: [[w0:%\w+]] = OpCompositeExtract %uint [[w_value]] 0
|
||||
; CHECK: OpStore [[z0]] [[w0]]
|
||||
; CHECK: [[w1:%\w+]] = OpCompositeExtract %uint [[w_value]] 1
|
||||
; CHECK: OpStore [[z1]] [[w1]]
|
||||
%w_value = OpLoad %_arr_uint_uint_2 %w
|
||||
OpStore %z %w_value
|
||||
|
||||
; CHECK: [[u00_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_0
|
||||
; CHECK: [[u00:%\w+]] = OpLoad %uint [[u00_ptr]]
|
||||
; CHECK: [[u10_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_0
|
||||
; CHECK: [[u10:%\w+]] = OpLoad %uint [[u10_ptr]]
|
||||
; CHECK: [[u01_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_1
|
||||
; CHECK: [[u01:%\w+]] = OpLoad %uint [[u01_ptr]]
|
||||
; CHECK: [[u11_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_1
|
||||
; CHECK: [[u11:%\w+]] = OpLoad %uint [[u11_ptr]]
|
||||
; CHECK: [[u02_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_2
|
||||
; CHECK: [[u02:%\w+]] = OpLoad %uint [[u02_ptr]]
|
||||
; CHECK: [[u12_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_2
|
||||
; CHECK: [[u12:%\w+]] = OpLoad %uint [[u12_ptr]]
|
||||
|
||||
; CHECK-DAG: [[u0_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u00]] [[u10]]
|
||||
; CHECK-DAG: [[u1_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u01]] [[u11]]
|
||||
; CHECK-DAG: [[u2_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u02]] [[u12]]
|
||||
|
||||
; CHECK: [[u_val:%\w+]] = OpCompositeConstruct %_arr__arr_uint_uint_2_uint_3 [[u0_val]] [[u1_val]] [[u2_val]]
|
||||
|
||||
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v0]] %uint_1
|
||||
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 0 0
|
||||
; CHECK: OpStore [[ptr]] [[val]]
|
||||
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v1]] %uint_1
|
||||
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 0 1
|
||||
; CHECK: OpStore [[ptr]] [[val]]
|
||||
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v2]] %uint_1
|
||||
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 1 0
|
||||
; CHECK: OpStore [[ptr]] [[val]]
|
||||
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v3]] %uint_1
|
||||
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 1 1
|
||||
; CHECK: OpStore [[ptr]] [[val]]
|
||||
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v4]] %uint_1
|
||||
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 2 0
|
||||
; CHECK: OpStore [[ptr]] [[val]]
|
||||
; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] %uint_1
|
||||
; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 2 1
|
||||
; CHECK: OpStore [[ptr]] [[val]]
|
||||
%v_ptr = OpAccessChain %_ptr_Output__arr_arr_uint_uint_2_3 %v %uint_1
|
||||
%u_val = OpLoad %_arr_arr_uint_uint_2_3 %u
|
||||
OpStore %v_ptr %u_val
|
||||
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
SinglePassRunAndMatch<InterfaceVariableScalarReplacement>(spirv, true);
|
||||
}
|
||||
|
||||
TEST_F(InterfaceVariableScalarReplacementTest,
|
||||
CheckPatchDecorationPreservation) {
|
||||
// Make sure scalars for the variables with the extra arrayness have the extra
|
||||
// arrayness after running the pass while others do not have it.
|
||||
// Only "y" does not have the extra arrayness in the following SPIR-V.
|
||||
const std::string spirv = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Tessellation
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint TessellationEvaluation %func "shader" %x %y %z %w
|
||||
OpDecorate %z Patch
|
||||
OpDecorate %w Patch
|
||||
OpDecorate %z Location 0
|
||||
OpDecorate %x Location 2
|
||||
OpDecorate %y Location 0
|
||||
OpDecorate %w Location 1
|
||||
OpName %x "x"
|
||||
OpName %y "y"
|
||||
OpName %z "z"
|
||||
OpName %w "w"
|
||||
|
||||
; CHECK: OpName [[y:%\w+]] "y"
|
||||
; CHECK-NOT: OpName {{%\w+}} "y"
|
||||
; CHECK-DAG: OpName [[z0:%\w+]] "z"
|
||||
; CHECK-DAG: OpName [[z1:%\w+]] "z"
|
||||
; CHECK-DAG: OpName [[w0:%\w+]] "w"
|
||||
; CHECK-DAG: OpName [[w1:%\w+]] "w"
|
||||
; CHECK-DAG: OpName [[x0:%\w+]] "x"
|
||||
; CHECK-DAG: OpName [[x1:%\w+]] "x"
|
||||
|
||||
%uint = OpTypeInt 32 0
|
||||
%uint_2 = OpConstant %uint 2
|
||||
%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
|
||||
%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
|
||||
%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
|
||||
%z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
|
||||
%x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
|
||||
%y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
|
||||
%w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
|
||||
|
||||
; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input
|
||||
; CHECK-DAG: [[z0]] = OpVariable %_ptr_Output_uint Output
|
||||
; CHECK-DAG: [[z1]] = OpVariable %_ptr_Output_uint Output
|
||||
; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input
|
||||
; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input
|
||||
; CHECK-DAG: [[x0]] = OpVariable %_ptr_Output_uint Output
|
||||
; CHECK-DAG: [[x1]] = OpVariable %_ptr_Output_uint Output
|
||||
|
||||
%void = OpTypeVoid
|
||||
%void_f = OpTypeFunction %void
|
||||
%func = OpFunction %void None %void_f
|
||||
%label = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
SinglePassRunAndMatch<InterfaceVariableScalarReplacement>(spirv, true);
|
||||
}
|
||||
|
||||
TEST_F(InterfaceVariableScalarReplacementTest,
|
||||
CheckEntryPointInterfaceOperands) {
|
||||
const std::string spirv = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Tessellation
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint TessellationEvaluation %tess "tess" %x %y
|
||||
OpEntryPoint Vertex %vert "vert" %w
|
||||
OpDecorate %z Location 0
|
||||
OpDecorate %x Location 2
|
||||
OpDecorate %y Location 0
|
||||
OpDecorate %w Location 1
|
||||
OpName %x "x"
|
||||
OpName %y "y"
|
||||
OpName %z "z"
|
||||
OpName %w "w"
|
||||
|
||||
; CHECK: OpName [[y:%\w+]] "y"
|
||||
; CHECK-NOT: OpName {{%\w+}} "y"
|
||||
; CHECK-DAG: OpName [[x0:%\w+]] "x"
|
||||
; CHECK-DAG: OpName [[x1:%\w+]] "x"
|
||||
; CHECK-DAG: OpName [[w0:%\w+]] "w"
|
||||
; CHECK-DAG: OpName [[w1:%\w+]] "w"
|
||||
; CHECK-DAG: OpName [[z:%\w+]] "z"
|
||||
; CHECK-NOT: OpName {{%\w+}} "z"
|
||||
|
||||
%uint = OpTypeInt 32 0
|
||||
%uint_2 = OpConstant %uint 2
|
||||
%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
|
||||
%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
|
||||
%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
|
||||
%z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
|
||||
%x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
|
||||
%y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
|
||||
%w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
|
||||
|
||||
; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input
|
||||
; CHECK-DAG: [[z]] = OpVariable %_ptr_Output__arr_uint_uint_2 Output
|
||||
; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input
|
||||
; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input
|
||||
; CHECK-DAG: [[x0]] = OpVariable %_ptr_Output_uint Output
|
||||
; CHECK-DAG: [[x1]] = OpVariable %_ptr_Output_uint Output
|
||||
|
||||
%void = OpTypeVoid
|
||||
%void_f = OpTypeFunction %void
|
||||
%tess = OpFunction %void None %void_f
|
||||
%bb0 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%vert = OpFunction %void None %void_f
|
||||
%bb1 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
SinglePassRunAndMatch<InterfaceVariableScalarReplacement>(spirv, true);
|
||||
}
|
||||
|
||||
class InterfaceVarSROAErrorTest : public PassTest<::testing::Test> {
|
||||
public:
|
||||
InterfaceVarSROAErrorTest()
|
||||
: consumer_([this](spv_message_level_t level, const char*,
|
||||
const spv_position_t& position, const char* message) {
|
||||
if (!error_message_.empty()) error_message_ += "\n";
|
||||
switch (level) {
|
||||
case SPV_MSG_FATAL:
|
||||
case SPV_MSG_INTERNAL_ERROR:
|
||||
case SPV_MSG_ERROR:
|
||||
error_message_ += "ERROR";
|
||||
break;
|
||||
case SPV_MSG_WARNING:
|
||||
error_message_ += "WARNING";
|
||||
break;
|
||||
case SPV_MSG_INFO:
|
||||
error_message_ += "INFO";
|
||||
break;
|
||||
case SPV_MSG_DEBUG:
|
||||
error_message_ += "DEBUG";
|
||||
break;
|
||||
}
|
||||
error_message_ +=
|
||||
": " + std::to_string(position.index) + ": " + message;
|
||||
}) {}
|
||||
|
||||
Pass::Status RunPass(const std::string& text) {
|
||||
std::unique_ptr<IRContext> context_ =
|
||||
spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_2, consumer_, text);
|
||||
if (!context_.get()) return Pass::Status::Failure;
|
||||
|
||||
PassManager manager;
|
||||
manager.SetMessageConsumer(consumer_);
|
||||
manager.AddPass<InterfaceVariableScalarReplacement>();
|
||||
|
||||
return manager.Run(context_.get());
|
||||
}
|
||||
|
||||
std::string GetErrorMessage() const { return error_message_; }
|
||||
|
||||
void TearDown() override { error_message_.clear(); }
|
||||
|
||||
private:
|
||||
spvtools::MessageConsumer consumer_;
|
||||
std::string error_message_;
|
||||
};
|
||||
|
||||
TEST_F(InterfaceVarSROAErrorTest, CheckConflictOfExtraArraynessBetweenEntries) {
|
||||
const std::string spirv = R"(
|
||||
OpCapability Shader
|
||||
OpCapability Tessellation
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint TessellationControl %tess "tess" %x %y %z
|
||||
OpEntryPoint Vertex %vert "vert" %z %w
|
||||
OpDecorate %z Location 0
|
||||
OpDecorate %x Location 2
|
||||
OpDecorate %y Location 0
|
||||
OpDecorate %w Location 1
|
||||
OpName %x "x"
|
||||
OpName %y "y"
|
||||
OpName %z "z"
|
||||
OpName %w "w"
|
||||
%uint = OpTypeInt 32 0
|
||||
%uint_2 = OpConstant %uint 2
|
||||
%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
|
||||
%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2
|
||||
%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2
|
||||
%z = OpVariable %_ptr_Output__arr_uint_uint_2 Output
|
||||
%x = OpVariable %_ptr_Output__arr_uint_uint_2 Output
|
||||
%y = OpVariable %_ptr_Input__arr_uint_uint_2 Input
|
||||
%w = OpVariable %_ptr_Input__arr_uint_uint_2 Input
|
||||
%void = OpTypeVoid
|
||||
%void_f = OpTypeFunction %void
|
||||
%tess = OpFunction %void None %void_f
|
||||
%bb0 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%vert = OpFunction %void None %void_f
|
||||
%bb1 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
|
||||
EXPECT_EQ(RunPass(spirv), Pass::Status::Failure);
|
||||
const char expected_error[] =
|
||||
"ERROR: 0: A variable is arrayed for an entry point but it is not "
|
||||
"arrayed for another entry point\n"
|
||||
" %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output";
|
||||
EXPECT_STREQ(GetErrorMessage().c_str(), expected_error);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace opt
|
||||
} // namespace spvtools
|
Loading…
Reference in New Issue
Block a user