2017-09-08 16:08:03 +00:00
|
|
|
// Copyright (c) 2017 Google Inc.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
#include "strength_reduction_pass.h"
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
#include <cstdio>
|
|
|
|
#include <cstring>
|
|
|
|
#include <unordered_map>
|
|
|
|
#include <unordered_set>
|
|
|
|
|
|
|
|
#include "def_use_manager.h"
|
2017-11-08 17:40:02 +00:00
|
|
|
#include "ir_context.h"
|
2017-09-08 16:08:03 +00:00
|
|
|
#include "log.h"
|
|
|
|
#include "reflect.h"
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
// Count the number of trailing zeros in the binary representation of
|
|
|
|
// |constVal|.
|
|
|
|
uint32_t CountTrailingZeros(uint32_t constVal) {
|
|
|
|
// Faster if we use the hardware count trailing zeros instruction.
|
|
|
|
// If not available, we could create a table.
|
|
|
|
uint32_t shiftAmount = 0;
|
|
|
|
while ((constVal & 1) == 0) {
|
|
|
|
++shiftAmount;
|
|
|
|
constVal = (constVal >> 1);
|
|
|
|
}
|
|
|
|
return shiftAmount;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Return true if |val| is a power of 2.
|
|
|
|
bool IsPowerOf2(uint32_t val) {
|
|
|
|
// The idea is that the & will clear out the least
|
|
|
|
// significant 1 bit. If it is a power of 2, then
|
|
|
|
// there is exactly 1 bit set, and the value becomes 0.
|
|
|
|
if (val == 0) return false;
|
|
|
|
return ((val - 1) & val) == 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
namespace spvtools {
|
|
|
|
namespace opt {
|
|
|
|
|
2017-10-30 15:13:24 +00:00
|
|
|
Pass::Status StrengthReductionPass::Process(ir::IRContext* c) {
|
|
|
|
InitializeProcessing(c);
|
2017-10-25 17:26:25 +00:00
|
|
|
|
2017-09-08 16:08:03 +00:00
|
|
|
// Initialize the member variables on a per module basis.
|
|
|
|
bool modified = false;
|
|
|
|
int32_type_id_ = 0;
|
|
|
|
uint32_type_id_ = 0;
|
|
|
|
std::memset(constant_ids_, 0, sizeof(constant_ids_));
|
|
|
|
|
|
|
|
FindIntTypesAndConstants();
|
|
|
|
modified = ScanFunctions();
|
|
|
|
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool StrengthReductionPass::ReplaceMultiplyByPowerOf2(
|
2017-11-21 19:47:46 +00:00
|
|
|
ir::BasicBlock::iterator* inst) {
|
|
|
|
assert((*inst)->opcode() == SpvOp::SpvOpIMul &&
|
2017-09-08 16:08:03 +00:00
|
|
|
"Only works for multiplication of integers.");
|
|
|
|
bool modified = false;
|
|
|
|
|
|
|
|
// Currently only works on 32-bit integers.
|
2017-11-21 19:47:46 +00:00
|
|
|
if ((*inst)->type_id() != int32_type_id_ &&
|
|
|
|
(*inst)->type_id() != uint32_type_id_) {
|
2017-09-08 16:08:03 +00:00
|
|
|
return modified;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check the operands for a constant that is a power of 2.
|
|
|
|
for (int i = 0; i < 2; i++) {
|
2017-11-21 19:47:46 +00:00
|
|
|
uint32_t opId = (*inst)->GetSingleWordInOperand(i);
|
2017-10-25 17:26:25 +00:00
|
|
|
ir::Instruction* opInst = get_def_use_mgr()->GetDef(opId);
|
2017-09-08 16:08:03 +00:00
|
|
|
if (opInst->opcode() == SpvOp::SpvOpConstant) {
|
|
|
|
// We found a constant operand.
|
|
|
|
uint32_t constVal = opInst->GetSingleWordOperand(2);
|
|
|
|
|
|
|
|
if (IsPowerOf2(constVal)) {
|
|
|
|
modified = true;
|
|
|
|
uint32_t shiftAmount = CountTrailingZeros(constVal);
|
|
|
|
uint32_t shiftConstResultId = GetConstantId(shiftAmount);
|
|
|
|
|
|
|
|
// Create the new instruction.
|
2017-10-25 17:26:25 +00:00
|
|
|
uint32_t newResultId = TakeNextId();
|
2017-09-08 16:08:03 +00:00
|
|
|
std::vector<ir::Operand> newOperands;
|
2017-11-21 19:47:46 +00:00
|
|
|
newOperands.push_back((*inst)->GetInOperand(1 - i));
|
2017-09-08 16:08:03 +00:00
|
|
|
ir::Operand shiftOperand(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
|
|
|
|
{shiftConstResultId});
|
|
|
|
newOperands.push_back(shiftOperand);
|
|
|
|
std::unique_ptr<ir::Instruction> newInstruction(
|
2017-11-27 15:16:41 +00:00
|
|
|
new ir::Instruction(context(), SpvOp::SpvOpShiftLeftLogical,
|
2017-11-21 19:47:46 +00:00
|
|
|
(*inst)->type_id(), newResultId, newOperands));
|
2017-09-08 16:08:03 +00:00
|
|
|
|
|
|
|
// Insert the new instruction and update the data structures.
|
2017-11-21 19:47:46 +00:00
|
|
|
(*inst) = (*inst).InsertBefore(std::move(newInstruction));
|
|
|
|
get_def_use_mgr()->AnalyzeInstDefUse(&*(*inst));
|
|
|
|
++(*inst);
|
|
|
|
context()->ReplaceAllUsesWith((*inst)->result_id(), newResultId);
|
2017-09-08 16:08:03 +00:00
|
|
|
|
|
|
|
// Remove the old instruction.
|
2017-11-21 19:47:46 +00:00
|
|
|
ir::Instruction* inst_to_delete = &*(*inst);
|
|
|
|
--(*inst);
|
|
|
|
context()->KillInst(inst_to_delete);
|
2017-09-08 16:08:03 +00:00
|
|
|
|
|
|
|
// We do not want to replace the instruction twice if both operands
|
|
|
|
// are constants that are a power of 2. So we break here.
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return modified;
|
|
|
|
}
|
|
|
|
|
|
|
|
void StrengthReductionPass::FindIntTypesAndConstants() {
|
2017-10-25 17:26:25 +00:00
|
|
|
for (auto iter = get_module()->types_values_begin();
|
|
|
|
iter != get_module()->types_values_end(); ++iter) {
|
2017-09-08 16:08:03 +00:00
|
|
|
switch (iter->opcode()) {
|
|
|
|
case SpvOp::SpvOpTypeInt:
|
|
|
|
if (iter->GetSingleWordOperand(1) == 32) {
|
|
|
|
if (iter->GetSingleWordOperand(2) == 1) {
|
|
|
|
int32_type_id_ = iter->result_id();
|
|
|
|
} else {
|
|
|
|
uint32_type_id_ = iter->result_id();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
case SpvOp::SpvOpConstant:
|
|
|
|
if (iter->type_id() == uint32_type_id_) {
|
|
|
|
uint32_t value = iter->GetSingleWordOperand(2);
|
|
|
|
if (value <= 32) constant_ids_[value] = iter->result_id();
|
|
|
|
}
|
|
|
|
break;
|
|
|
|
default:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
uint32_t StrengthReductionPass::GetConstantId(uint32_t val) {
|
|
|
|
assert(val <= 32 &&
|
|
|
|
"This function does not handle constants larger than 32.");
|
|
|
|
|
|
|
|
if (constant_ids_[val] == 0) {
|
|
|
|
if (uint32_type_id_ == 0) {
|
|
|
|
uint32_type_id_ = CreateUint32Type();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Construct the constant.
|
2017-10-25 17:26:25 +00:00
|
|
|
uint32_t resultId = TakeNextId();
|
2017-09-08 16:08:03 +00:00
|
|
|
ir::Operand constant(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
|
|
|
|
{val});
|
2017-11-27 15:16:41 +00:00
|
|
|
std::unique_ptr<ir::Instruction> newConstant(
|
|
|
|
new ir::Instruction(context(), SpvOp::SpvOpConstant, uint32_type_id_,
|
|
|
|
resultId, {constant}));
|
2017-10-25 17:26:25 +00:00
|
|
|
get_module()->AddGlobalValue(std::move(newConstant));
|
2017-09-08 16:08:03 +00:00
|
|
|
|
2017-11-14 19:11:50 +00:00
|
|
|
// Notify the DefUseManager about this constant.
|
|
|
|
auto constantIter = --get_module()->types_values_end();
|
|
|
|
get_def_use_mgr()->AnalyzeInstDef(&*constantIter);
|
|
|
|
|
2017-09-08 16:08:03 +00:00
|
|
|
// Store the result id for next time.
|
|
|
|
constant_ids_[val] = resultId;
|
|
|
|
}
|
|
|
|
|
|
|
|
return constant_ids_[val];
|
|
|
|
}
|
|
|
|
|
|
|
|
bool StrengthReductionPass::ScanFunctions() {
|
|
|
|
// I did not use |ForEachInst| in the module because the function that acts on
|
|
|
|
// the instruction gets a pointer to the instruction. We cannot use that to
|
|
|
|
// insert a new instruction. I want an iterator.
|
|
|
|
bool modified = false;
|
2017-10-25 17:26:25 +00:00
|
|
|
for (auto& func : *get_module()) {
|
2017-09-08 16:08:03 +00:00
|
|
|
for (auto& bb : func) {
|
|
|
|
for (auto inst = bb.begin(); inst != bb.end(); ++inst) {
|
|
|
|
switch (inst->opcode()) {
|
|
|
|
case SpvOp::SpvOpIMul:
|
|
|
|
if (ReplaceMultiplyByPowerOf2(&inst)) modified = true;
|
|
|
|
break;
|
|
|
|
default:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return modified;
|
|
|
|
}
|
|
|
|
|
|
|
|
uint32_t StrengthReductionPass::CreateUint32Type() {
|
2017-10-25 17:26:25 +00:00
|
|
|
uint32_t type_id = TakeNextId();
|
2017-09-08 16:08:03 +00:00
|
|
|
ir::Operand widthOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
|
|
|
|
{32});
|
|
|
|
ir::Operand signOperand(spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER,
|
|
|
|
{0});
|
|
|
|
std::unique_ptr<ir::Instruction> newType(new ir::Instruction(
|
2017-11-14 19:11:50 +00:00
|
|
|
context(), SpvOp::SpvOpTypeInt, type_id, 0, {widthOperand, signOperand}));
|
2017-10-30 15:13:24 +00:00
|
|
|
context()->AddType(std::move(newType));
|
2017-09-08 16:08:03 +00:00
|
|
|
return type_id;
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace opt
|
|
|
|
} // namespace spvtools
|