Optimize function code structure in runtime_core

Signed-off-by: ah <liangahui@h-partners.com>
This commit is contained in:
ah 2024-05-25 12:56:11 +08:00
parent b5d0c45d2d
commit acfa8d38c6
4 changed files with 254 additions and 186 deletions

View File

@ -28,6 +28,36 @@ constexpr size_t BIN_BASE = 2;
constexpr size_t MAX_DWORD = 65536;
inline bool IsHexNumber(const std::string_view &token)
{
for (auto i : token) {
if (!((i >= '0' && i <= '9') || (i >= 'A' && i <= 'F') || (i >= 'a' && i <= 'f'))) {
return false;
}
}
return true;
}
inline bool IsBinaryNumber(const std::string_view &token)
{
for (auto i : token) {
if (!(i == '0' || i == '1')) {
return false;
}
}
return true;
}
inline bool IsOctalNumber(const std::string_view &token)
{
for (auto i : token) {
if (!(i >= '0' && i <= '7')) {
return false;
}
}
return true;
}
inline bool ValidateInteger(const std::string_view &p)
{
constexpr size_t GENERAL_SHIFT = 2;
@ -45,40 +75,17 @@ inline bool ValidateInteger(const std::string_view &p)
if (token[0] == '0' && token.size() > 1 && token.find('.') == std::string::npos) {
if (token[1] == 'x') {
token.remove_prefix(GENERAL_SHIFT);
for (auto i : token) {
if (!((i >= '0' && i <= '9') || (i >= 'A' && i <= 'F') || (i >= 'a' && i <= 'f'))) {
return false;
}
}
return true;
return IsHexNumber(token);
}
if (token[1] == 'b') {
token.remove_prefix(GENERAL_SHIFT);
if (token.empty()) {
return false;
}
for (auto i : token) {
if (!(i == '0' || i == '1')) {
return false;
}
}
return true;
return (!token.empty() && IsBinaryNumber(token));
}
if (token[1] >= '0' && token[1] <= '9' && token.find('e') == std::string::npos) {
token.remove_prefix(1);
for (auto i : token) {
if (!(i >= '0' && i <= '7')) {
return false;
}
}
return true;
return IsOctalNumber(token);
}
}

View File

@ -185,6 +185,110 @@ void RegAccAlloc::SetNeedLda(compiler::Inst *inst, bool need)
inst->SetSrcReg(AccReadIndex(inst), reg);
}
void RegAccAlloc::InitializeSourceRegisters()
{
for (auto block : GetGraph()->GetBlocksRPO()) {
for (auto inst : block->Insts()) {
if (inst->IsSaveState() || inst->IsCatchPhi()) {
continue;
}
if (inst->IsConst()) {
inst->SetFlag(compiler::inst_flags::ACC_WRITE);
}
for (size_t i = 0; i < inst->GetInputsCount(); ++i) {
inst->SetSrcReg(i, compiler::INVALID_REG);
}
if (inst->IsConst()) {
inst->SetDstReg(compiler::INVALID_REG);
}
}
}
}
void RegAccAlloc::MarkAccForPhiInstructions()
{
for (auto block : GetGraph()->GetBlocksRPO()) {
for (auto phi : block->PhiInsts()) {
if (IsPhiAccReady(phi)) {
phi->SetMarker(acc_marker_);
}
}
}
}
void RegAccAlloc::MarkAccForInstructions(compiler::BasicBlock *block)
{
for (auto inst : block->AllInsts()) {
if (inst->NoDest() || !IsAccWrite(inst)) {
continue;
}
bool use_acc_dst_reg = true;
for (auto &user : inst->GetUsers()) {
compiler::Inst *uinst = user.GetInst();
if (uinst->IsSaveState()) {
continue;
}
if (CanUserReadAcc(inst, uinst)) {
SetNeedLda(uinst, false);
} else {
use_acc_dst_reg = false;
}
}
if (use_acc_dst_reg) {
inst->SetDstReg(compiler::ACC_REG_ID);
continue;
}
if (!inst->IsConst()) {
continue;
}
inst->ClearFlag(compiler::inst_flags::ACC_WRITE);
for (auto &user : inst->GetUsers()) {
compiler::Inst *uinst = user.GetInst();
if (uinst->IsSaveState()) {
continue;
}
SetNeedLda(uinst, true);
}
}
}
void RegAccAlloc::UpdateInstructionsAfterMark(compiler::BasicBlock *block)
{
for (auto inst : block->Insts()) {
if (inst->GetInputsCount() == 0) {
continue;
}
if (inst->IsCall()) {
continue;
}
compiler::Inst *input = inst->GetInput(AccReadIndex(inst)).GetInst();
if (!IsAccWriteBetween(input, inst)) {
continue;
}
input->SetDstReg(compiler::INVALID_REG);
SetNeedLda(inst, true);
if (input->IsConst()) {
input->ClearFlag(compiler::inst_flags::ACC_WRITE);
for (auto &user : input->GetUsers()) {
compiler::Inst *uinst = user.GetInst();
SetNeedLda(uinst, true);
}
}
}
}
/**
* Determine the accumulator usage between instructions.
* Eliminate unnecessary register allocations by applying
@ -197,91 +301,13 @@ bool RegAccAlloc::RunImpl()
{
GetGraph()->InitDefaultLocations();
// Initialize all source register of all instructions.
for (auto block : GetGraph()->GetBlocksRPO()) {
for (auto inst : block->Insts()) {
if (inst->IsSaveState() || inst->IsCatchPhi()) {
continue;
}
if (inst->IsConst()) {
inst->SetFlag(compiler::inst_flags::ACC_WRITE);
}
for (size_t i = 0; i < inst->GetInputsCount(); ++i) {
inst->SetSrcReg(i, compiler::INVALID_REG);
if (inst->IsConst()) {
inst->SetDstReg(compiler::INVALID_REG);
}
}
}
}
InitializeSourceRegisters();
// Mark Phi instructions if they can be optimized for acc.
for (auto block : GetGraph()->GetBlocksRPO()) {
for (auto phi : block->PhiInsts()) {
if (IsPhiAccReady(phi)) {
phi->SetMarker(acc_marker_);
}
}
}
MarkAccForPhiInstructions();
// Mark instructions if they can be optimized for acc.
for (auto block : GetGraph()->GetBlocksRPO()) {
for (auto inst : block->AllInsts()) {
if (inst->NoDest() || !IsAccWrite(inst)) {
continue;
}
bool use_acc_dst_reg = true;
for (auto &user : inst->GetUsers()) {
compiler::Inst *uinst = user.GetInst();
if (uinst->IsSaveState()) {
continue;
}
if (CanUserReadAcc(inst, uinst)) {
SetNeedLda(uinst, false);
} else {
use_acc_dst_reg = false;
}
}
if (use_acc_dst_reg) {
inst->SetDstReg(compiler::ACC_REG_ID);
} else if (inst->IsConst()) {
inst->ClearFlag(compiler::inst_flags::ACC_WRITE);
for (auto &user : inst->GetUsers()) {
compiler::Inst *uinst = user.GetInst();
if (uinst->IsSaveState()) {
continue;
}
SetNeedLda(uinst, true);
}
}
}
for (auto inst : block->Insts()) {
if (inst->GetInputsCount() == 0) {
continue;
}
if (inst->IsCall()) {
continue;
}
compiler::Inst *input = inst->GetInput(AccReadIndex(inst)).GetInst();
if (IsAccWriteBetween(input, inst)) {
input->SetDstReg(compiler::INVALID_REG);
SetNeedLda(inst, true);
if (input->IsConst()) {
input->ClearFlag(compiler::inst_flags::ACC_WRITE);
for (auto &user : input->GetUsers()) {
compiler::Inst *uinst = user.GetInst();
SetNeedLda(uinst, true);
}
}
}
}
MarkAccForInstructions(block);
UpdateInstructionsAfterMark(block);
}
#ifndef NDEBUG

View File

@ -51,6 +51,11 @@ private:
bool IsPhiAccReady(compiler::Inst *phi) const;
void SetNeedLda(compiler::Inst *inst, bool need);
void InitializeSourceRegisters();
void MarkAccForPhiInstructions();
void MarkAccForInstructions(compiler::BasicBlock *block);
void UpdateInstructionsAfterMark(compiler::BasicBlock *block);
compiler::Marker acc_marker_ {0};
};

View File

@ -82,38 +82,16 @@ public:
if (inst1->GetOpcode() != inst2->GetOpcode() || inst1->GetType() != inst2->GetType() ||
inst1->GetInputsCount() != inst2->GetInputsCount()) {
inst_compare_map_.erase(inst1);
return false;
}
bool result = (inst1->GetOpcode() != Opcode::Phi) ?
CompareNonPhiInputs(inst1, inst2) : ComparePhiInputs(inst1, inst2);
if (!result) {
inst_compare_map_.erase(inst1);
return false;
}
if (inst1->GetOpcode() != Opcode::Phi) {
auto inst1_begin = inst1->GetInputs().begin();
auto inst1_end = inst1->GetInputs().end();
auto inst2_begin = inst2->GetInputs().begin();
auto eq_lambda = [this](Input input1, Input input2) { return Compare(input1.GetInst(), input2.GetInst()); };
if (!std::equal(inst1_begin, inst1_end, inst2_begin, eq_lambda)) {
inst_compare_map_.erase(inst1);
return false;
}
} else {
if (inst1->GetInputsCount() != inst2->GetInputsCount()) {
inst_compare_map_.erase(inst1);
return false;
}
for (size_t index1 = 0; index1 < inst1->GetInputsCount(); index1++) {
auto input1 = inst1->GetInput(index1).GetInst();
auto bb1 = inst1->CastToPhi()->GetPhiInputBb(index1);
if (bb_map_.count(bb1) == 0) {
inst_compare_map_.erase(inst1);
return false;
}
auto bb2 = bb_map_.at(bb1);
auto input2 = inst2->CastToPhi()->GetPhiInput(bb2);
if (!Compare(input1, input2)) {
inst_compare_map_.erase(inst1);
return false;
}
}
}
// NOLINTNEXTLINE(cppcoreguidelines-macro-usage
#define CAST(Opc) CastTo##Opc()
@ -218,24 +196,14 @@ public:
// CHECK(LoadType, GetTypeId)
#undef CHECK
#undef CAST
if (inst1->GetOpcode() == Opcode::Constant) {
auto c1 = inst1->CastToConstant();
auto c2 = inst2->CastToConstant();
bool same = false;
switch (inst1->GetType()) {
case DataType::FLOAT32:
case DataType::INT32:
same = static_cast<uint32_t>(c1->GetRawValue()) == static_cast<uint32_t>(c2->GetRawValue());
break;
default:
same = c1->GetRawValue() == c2->GetRawValue();
break;
}
if (!same) {
if (!CompareConstantInst(inst1, inst2)) {
inst_compare_map_.erase(inst1);
return false;
}
}
if (inst1->GetOpcode() == Opcode::Cmp && IsFloatType(inst1->GetInput(0).GetInst()->GetType())) {
auto cmp1 = static_cast<CmpInst *>(inst1);
auto cmp2 = static_cast<CmpInst *>(inst2);
@ -244,51 +212,113 @@ public:
return false;
}
}
for (size_t i = 0; i < inst2->GetInputsCount(); i++) {
if (inst1->GetInputType(i) != inst2->GetInputType(i)) {
inst_compare_map_.erase(inst1);
return false;
}
}
if (inst1->IsSaveState()) {
auto *sv_st1 = static_cast<SaveStateInst *>(inst1);
auto *sv_st2 = static_cast<SaveStateInst *>(inst2);
if (sv_st1->GetImmediatesCount() != sv_st2->GetImmediatesCount()) {
inst_compare_map_.erase(inst1);
return false;
}
std::vector<VirtualRegister::ValueType> regs1;
std::vector<VirtualRegister::ValueType> regs2;
regs1.reserve(sv_st1->GetInputsCount());
regs2.reserve(sv_st2->GetInputsCount());
for (size_t i {0}; i < sv_st1->GetInputsCount(); ++i) {
regs1.emplace_back(sv_st1->GetVirtualRegister(i).Value());
regs2.emplace_back(sv_st2->GetVirtualRegister(i).Value());
}
std::sort(regs1.begin(), regs1.end());
std::sort(regs2.begin(), regs2.end());
if (regs1 != regs2) {
if (!CompareInputTypes(inst1, inst2) || !CompareSaveStateInst(inst1, inst2)) {
inst_compare_map_.erase(inst1);
return false;
}
return true;
}
private:
std::unordered_map<Inst *, Inst *> inst_compare_map_;
std::unordered_map<BasicBlock *, BasicBlock *> bb_map_;
bool CompareNonPhiInputs(Inst *inst1, Inst *inst2)
{
auto inst1_begin = inst1->GetInputs().begin();
auto inst1_end = inst1->GetInputs().end();
auto inst2_begin = inst2->GetInputs().begin();
auto eq_lambda = [this](Input input1, Input input2) {
return Compare(input1.GetInst(), input2.GetInst());
};
return std::equal(inst1_begin, inst1_end, inst2_begin, eq_lambda);
}
bool ComparePhiInputs(Inst *inst1, Inst *inst2)
{
if (inst1->GetInputsCount() != inst2->GetInputsCount()) {
return false;
}
for (size_t index1 = 0; index1 < inst1->GetInputsCount(); index1++) {
auto input1 = inst1->GetInput(index1).GetInst();
auto bb1 = inst1->CastToPhi()->GetPhiInputBb(index1);
if (bb_map_.count(bb1) == 0) {
return false;
}
if (sv_st1->GetImmediatesCount() != 0) {
auto eq_lambda = [](SaveStateImm i1, SaveStateImm i2) {
return i1.value == i2.value && i1.vreg == i2.vreg && i1.is_acc == i2.is_acc;
};
if (!std::equal(sv_st1->GetImmediates()->begin(), sv_st1->GetImmediates()->end(),
sv_st2->GetImmediates()->begin(), eq_lambda)) {
inst_compare_map_.erase(inst1);
return false;
}
auto bb2 = bb_map_.at(bb1);
auto input2 = inst2->CastToPhi()->GetPhiInput(bb2);
if (!Compare(input1, input2)) {
return false;
}
}
return true;
}
private:
std::unordered_map<Inst *, Inst *> inst_compare_map_;
std::unordered_map<BasicBlock *, BasicBlock *> bb_map_;
bool CompareConstantInst(Inst *inst1, Inst *inst2)
{
auto c1 = inst1->CastToConstant();
auto c2 = inst2->CastToConstant();
bool same = false;
switch (inst1->GetType()) {
case DataType::FLOAT32:
case DataType::INT32:
same = static_cast<uint32_t>(c1->GetRawValue()) == static_cast<uint32_t>(c2->GetRawValue());
break;
default:
same = c1->GetRawValue() == c2->GetRawValue();
break;
}
return same;
}
bool CompareInputTypes(Inst *inst1, Inst *inst2)
{
for (size_t i = 0; i < inst2->GetInputsCount(); i++) {
if (inst1->GetInputType(i) != inst2->GetInputType(i)) {
return false;
}
}
return true;
}
bool CompareSaveStateInst(Inst *inst1, Inst *inst2)
{
if (!inst1->IsSaveState()) {
return true;
}
auto *sv_st1 = static_cast<SaveStateInst *>(inst1);
auto *sv_st2 = static_cast<SaveStateInst *>(inst2);
if (sv_st1->GetImmediatesCount() != sv_st2->GetImmediatesCount()) {
return false;
}
std::vector<VirtualRegister::ValueType> regs1;
std::vector<VirtualRegister::ValueType> regs2;
regs1.reserve(sv_st1->GetInputsCount());
regs2.reserve(sv_st2->GetInputsCount());
for (size_t i {0}; i < sv_st1->GetInputsCount(); ++i) {
regs1.emplace_back(sv_st1->GetVirtualRegister(i).Value());
regs2.emplace_back(sv_st2->GetVirtualRegister(i).Value());
}
std::sort(regs1.begin(), regs1.end());
std::sort(regs2.begin(), regs2.end());
if (regs1 != regs2) {
return false;
}
if (sv_st1->GetImmediatesCount() != 0) {
auto eq_lambda = [](SaveStateImm i1, SaveStateImm i2) {
return i1.value == i2.value && i1.vreg == i2.vreg && i1.is_acc == i2.is_acc;
};
if (!std::equal(sv_st1->GetImmediates()->begin(), sv_st1->GetImmediates()->end(),
sv_st2->GetImmediates()->begin(), eq_lambda)) {
return false;
}
}
return true;
}
};
} // namespace panda::compiler