[decompiler] Fix deref bug and add some more new type pass stuff (#606)

* add copy on write and clean up some register stuff

* fix bug in multiple field lookup

* format
This commit is contained in:
water111 2021-06-18 21:10:00 -04:00 committed by GitHub
parent 409c1f5a7d
commit bc87c4426f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 256 additions and 78 deletions

View File

@ -196,7 +196,10 @@ void try_reverse_lookup_array_like(const FieldReverseLookupInput& input,
vec.push_back(tok);
output->results.emplace_back(false, array_data_type, vec);
} else {
output->results.emplace_back(false, input.base_type, parent->to_vector());
auto parent_vector = parent->to_vector();
if (!parent_vector.empty()) {
output->results.emplace_back(false, input.base_type, parent_vector);
}
}
}
@ -277,7 +280,11 @@ void try_reverse_lookup_inline_array(const FieldReverseLookupInput& input,
// can we just return the array?
if (expected_offset_into_elt == offset_into_elt && !input.deref.has_value() && elt_idx == 0) {
output->results.emplace_back(false, input.base_type, parent->to_vector());
auto parent_vec = parent->to_vector();
if (!parent_vec.empty()) {
output->results.emplace_back(false, input.base_type, parent->to_vector());
}
if ((int)output->results.size() >= max_count) {
return;
}

122
common/util/CopyOnWrite.h Normal file
View File

@ -0,0 +1,122 @@
#include <utility>
#include "common/util/assert.h"
/*
template<typename T, typename... Args>
std::unique_ptr<T> make_unique(Args&&... args)
{
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
*/
/*!
* The CopyOnWrite class acts like a value, but internally uses references to avoid copying
* when it is possible to avoid it.
*
* It is used like a shared pointer.
* But, if you try to modify an existing object with multiple owners, it will make a copy
* so the other owners don't see any changes. In this way, it does not act like a reference.
*
* To construct a new object, use CopyOnWrite<T>(args...). This is different from the usual smart
* pointer pattern.
*
* Like shared pointers, a CopyOnWrite can be null. Doing mut() just gives you a null pointer.
*
* The default .get(), ->, and * operators give you const references. If you need to modify,
* use .mut(). It will create a copy if needed, then give you a mutable reference.
*/
template <typename T>
class CopyOnWrite {
private:
// we store the object and its reference count in the same heap allocation.
struct ObjectAndCount {
T object;
// construct the object in-place, or copy construct from an existing.
template <typename... Args>
explicit ObjectAndCount(Args&&... args) : object(std::forward<Args>(args)...) {}
explicit ObjectAndCount(const T& existing) : object(existing) {}
// in case we ever want this to have locks.
void add_ref() { m_count++; }
void remove_ref() { m_count--; }
bool unique() { return m_count == 1; }
bool dead() { return m_count == 0; }
private:
int m_count = 0;
};
public:
CopyOnWrite() = default; // allow nulls.
/*!
* Construct a new object.
*/
template <typename... Args>
explicit CopyOnWrite(Args&&... args) {
auto obj = new ObjectAndCount(std::forward<Args>(args)...);
acquire_object(obj);
}
/*!
* Copy an object.
*/
CopyOnWrite(const CopyOnWrite<T>& other) { acquire_object(other.m_data); }
CopyOnWrite<T>& operator=(const CopyOnWrite<T>& other) {
if (this == &other) {
return *this;
}
if (m_data != other.m_data) {
clear_my_object();
acquire_object(other.m_data);
}
return *this;
}
~CopyOnWrite() { clear_my_object(); }
// constant access
const T* get() const { return &m_data->object; }
const T* operator->() const { return &m_data->object; }
const T& operator*() const { return m_data->object; }
explicit operator bool() const { return m_data; }
T* mut() {
if (!m_data) {
return nullptr;
}
if (!m_data->unique()) {
assert(!m_data->dead());
m_data->remove_ref(); // don't need to check for dead here, there's another ref somewhere.
assert(!m_data->dead());
m_data = new ObjectAndCount(m_data->object);
m_data->add_ref();
}
return &m_data->object;
}
private:
void clear_my_object() {
if (m_data) {
m_data->remove_ref();
if (m_data->dead()) {
delete m_data;
}
}
m_data = nullptr;
}
void acquire_object(ObjectAndCount* obj) {
assert(!m_data);
m_data = obj;
if (obj) {
m_data->add_ref();
}
}
ObjectAndCount* m_data = nullptr;
};

View File

@ -1086,7 +1086,7 @@ Instruction decode_instruction(LinkedWord& word, LinkedObjectFile& file, int seg
atom.set_reg(Register(Reg::COP0, value));
break;
case DecodeType::PCR:
atom.set_reg(Register(Reg::PCR, value));
atom.set_reg(Register(Reg::SPECIAL, Reg::PCR0 + value));
break;
case DecodeType::IMM:
atom.set_imm(value);

View File

@ -56,9 +56,7 @@ const static char* vi_names[32] = {
"Status", "MAC", "Clipping", "INVALID3", "vi_R", "vi_I", "vi_Q", "INVALID7",
"INVALID8", "INVALID9", "TPC", "CMSAR0", "FBRST", "VPU-STAT", "INVALID14", "CMSAR1"};
const static char* pcr_names[2] = {"pcr0", "pcr1"};
const static char* cop2_macro_special[2] = {"Q", "ACC"};
const static char* special_names[Reg::MAX_SPECIAL] = {"pcr0", "pcr1", "Q", "ACC"};
/////////////////////////////
// Register Names Conversion
@ -90,14 +88,9 @@ const char* vi_to_charp(uint32_t vi) {
return vi_names[vi];
}
const char* pcr_to_charp(uint32_t pcr) {
assert(pcr < 2);
return pcr_names[pcr];
}
const char* cop2_macro_special_to_charp(uint32_t reg) {
assert(reg < 2);
return cop2_macro_special[reg];
const char* special_to_charp(uint32_t special) {
assert(special < Reg::MAX_SPECIAL);
return special_names[special];
}
} // namespace
@ -111,11 +104,17 @@ const char* cop2_macro_special_to_charp(uint32_t reg) {
// Note: VI / COP2 are separate "kinds" of registers, each with 16 registers.
// It might make sense to make this a single "kind" instead?
namespace {
constexpr int REG_CATEGORY_SHIFT = 5;
constexpr int REG_IDX_MASK = 0b11111;
} // namespace
/*!
* Create a register. The kind and num must both be valid.
*/
Register::Register(Reg::RegisterKind kind, uint32_t num) {
id = (kind << 8) | num;
// 32 regs/category at most.
id = (kind << REG_CATEGORY_SHIFT) | num;
// check range:
switch (kind) {
@ -126,9 +125,8 @@ Register::Register(Reg::RegisterKind kind, uint32_t num) {
case Reg::VI:
assert(num < 32);
break;
case Reg::PCR:
case Reg::COP2_MACRO_SPECIAL:
assert(num < 2);
case Reg::SPECIAL:
assert(num < Reg::MAX_SPECIAL);
break;
default:
assert(false);
@ -139,7 +137,7 @@ Register::Register(const std::string& name) {
// first try gprs,
for (int i = 0; i < Reg::MAX_GPR; i++) {
if (name == gpr_names[i]) {
id = (Reg::GPR << 8) | i;
id = (Reg::GPR << REG_CATEGORY_SHIFT) | i;
return;
}
}
@ -147,7 +145,7 @@ Register::Register(const std::string& name) {
// next fprs
for (int i = 0; i < 32; i++) {
if (name == fpr_names[i]) {
id = (Reg::FPR << 8) | i;
id = (Reg::FPR << REG_CATEGORY_SHIFT) | i;
return;
}
}
@ -170,10 +168,8 @@ const char* Register::to_charp() const {
return vi_to_charp(get_vi());
case Reg::COP0:
return cop0_to_charp(get_cop0());
case Reg::PCR:
return pcr_to_charp(get_pcr());
case Reg::COP2_MACRO_SPECIAL:
return cop2_macro_special_to_charp(get_cop2_macro_special());
case Reg::SPECIAL:
return special_to_charp(get_special());
default:
throw std::runtime_error("Unsupported Register");
}
@ -190,7 +186,7 @@ std::string Register::to_string() const {
* Get the register kind.
*/
Reg::RegisterKind Register::get_kind() const {
uint16_t kind = id >> 8;
uint16_t kind = id >> REG_CATEGORY_SHIFT;
assert(kind < Reg::MAX_KIND);
return (Reg::RegisterKind)kind;
}
@ -200,7 +196,7 @@ Reg::RegisterKind Register::get_kind() const {
*/
Reg::Gpr Register::get_gpr() const {
assert(get_kind() == Reg::GPR);
uint16_t kind = id & 0xff;
uint16_t kind = id & REG_IDX_MASK;
assert(kind < Reg::MAX_GPR);
return (Reg::Gpr)(kind);
}
@ -210,7 +206,7 @@ Reg::Gpr Register::get_gpr() const {
*/
uint32_t Register::get_fpr() const {
assert(get_kind() == Reg::FPR);
uint16_t kind = id & 0xff;
uint16_t kind = id & REG_IDX_MASK;
assert(kind < 32);
return kind;
}
@ -220,7 +216,7 @@ uint32_t Register::get_fpr() const {
*/
uint32_t Register::get_vf() const {
assert(get_kind() == Reg::VF);
uint16_t kind = id & 0xff;
uint16_t kind = id & REG_IDX_MASK;
assert(kind < 32);
return kind;
}
@ -230,7 +226,7 @@ uint32_t Register::get_vf() const {
*/
uint32_t Register::get_vi() const {
assert(get_kind() == Reg::VI);
uint16_t kind = id & 0xff;
uint16_t kind = id & REG_IDX_MASK;
assert(kind < 32);
return kind;
}
@ -240,7 +236,7 @@ uint32_t Register::get_vi() const {
*/
Reg::Cop0 Register::get_cop0() const {
assert(get_kind() == Reg::COP0);
uint16_t kind = id & 0xff;
uint16_t kind = id & REG_IDX_MASK;
assert(kind < Reg::MAX_COP0);
return (Reg::Cop0)(kind);
}
@ -248,20 +244,13 @@ Reg::Cop0 Register::get_cop0() const {
/*!
* Get the PCR number. Must be a PCR.
*/
uint32_t Register::get_pcr() const {
assert(get_kind() == Reg::PCR);
uint16_t kind = id & 0xff;
assert(kind < 2);
uint32_t Register::get_special() const {
assert(get_kind() == Reg::SPECIAL);
uint16_t kind = id & REG_IDX_MASK;
assert(kind < Reg::MAX_SPECIAL);
return kind;
}
Reg::Cop2MacroSpecial Register::get_cop2_macro_special() const {
assert(get_kind() == Reg::COP2_MACRO_SPECIAL);
uint16_t k = id & 0xff;
assert(k < 2);
return (Reg::Cop2MacroSpecial)k;
}
bool Register::operator==(const Register& other) const {
return id == other.id;
}

View File

@ -11,17 +11,22 @@
namespace decompiler {
// Namespace for register name constants
// Note on registers:
// Registers are assigned a unique Register ID as an integer from 0 to 164 (not including 164).
// Don't change these enums without updating the indexing scheme.
// It is important that each register is a unique register ID, and that we don't have gaps.
namespace Reg {
enum RegisterKind {
GPR = 0, // EE General purpose registers, these have nicknames.
FPR = 1, // EE Floating point registers, just called f0 - f31
VF = 2, // VU0 Floating point vector registers from EE, just called vf0 - vf31
VI =
3, // VU0 Integer registers from EE, the first 16 are vi00 - vi15, the rest are control regs.
COP0 = 4, // EE COP0 Control Registers: full of fancy names (there are 32 of them)
PCR = 5, // Performance Counter registers (PCR0, PCR1)
COP2_MACRO_SPECIAL = 6, // COP2 Q, ACC accessed from macro mode instructions.
MAX_KIND = 7
GPR = 0, // EE General purpose registers, these have nicknames (32 regs)
FPR = 1, // EE Floating point registers, just called f0 - f31 (32 regs)
VF = 2, // VU0 Floating point vector registers from EE, just called vf0 - vf31 (32 regs)
VI = 3, // VU0 Integer registers from EE, the first 16 are vi00 - vi15, the rest are control
// regs. (32 regs)
COP0 = 4, // EE COP0 Control Registers: full of fancy names (there are 32 of them) (32 regs)
SPECIAL = 5, // COP2 Q, ACC accessed from macro mode instructions and PCR
MAX_KIND = 6
};
// nicknames for GPRs
@ -121,35 +126,52 @@ enum Vi {
MAX_COP2 = 32
};
enum Cop2MacroSpecial {
MACRO_Q = 0,
MACRO_ACC = 1,
enum SpecialRegisters {
PCR0 = 0,
PCR1 = 1,
MACRO_Q = 2,
MACRO_ACC = 3,
MAX_SPECIAL = 4,
};
const extern bool allowed_local_gprs[Reg::MAX_GPR];
constexpr int MAX_REG_ID = 32 * 5 + MAX_SPECIAL;
constexpr int MAX_VAR_REG_ID = 32 * 2; // gprs/fprs.
} // namespace Reg
// Representation of a register. Uses a 32-bit integer internally.
// Representation of a register. Uses a 16-bit integer internally.
class Register {
public:
Register() = default;
Register(Reg::RegisterKind kind, uint32_t num);
explicit Register(int reg_id) {
assert(reg_id < Reg::MAX_REG_ID);
id = reg_id;
}
Register(const std::string& name);
static Register get_arg_reg(int idx) {
assert(idx >= 0 && idx < 8);
return Register(Reg::GPR, Reg::A0 + idx);
}
uint16_t reg_id() const { return id; }
const char* to_charp() const;
std::string to_string() const;
Reg::RegisterKind get_kind() const;
bool is_vu_float() const {
return get_kind() == Reg::VF ||
(get_kind() == Reg::SPECIAL &&
(get_special() == Reg::MACRO_Q || get_special() == Reg::MACRO_ACC));
}
Reg::Gpr get_gpr() const;
uint32_t get_fpr() const;
uint32_t get_vf() const;
uint32_t get_vi() const;
Reg::Cop0 get_cop0() const;
uint32_t get_pcr() const;
Reg::Cop2MacroSpecial get_cop2_macro_special() const;
uint32_t get_special() const;
bool allowed_local_gpr() const;
bool operator==(const Register& other) const;

View File

@ -573,20 +573,20 @@ void AsmOp::update_register_info() {
if (m_instr.kind >= FIRST_COP2_MACRO && m_instr.kind <= LAST_COP2_MACRO) {
switch (m_instr.kind) {
case InstructionKind::VMSUBQ:
m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_Q));
m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC));
m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_Q));
m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC));
break;
case InstructionKind::VMULAQ:
m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_Q));
m_write_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC));
m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_Q));
m_write_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC));
break;
// Read Q register
case InstructionKind::VADDQ:
case InstructionKind::VSUBQ:
case InstructionKind::VMULQ:
m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_Q));
m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_Q));
break;
// Write ACC register
@ -595,14 +595,14 @@ void AsmOp::update_register_info() {
case InstructionKind::VMULA:
case InstructionKind::VMULA_BC:
case InstructionKind::VOPMULA:
m_write_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC));
m_write_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC));
break;
// Write Q register
case InstructionKind::VDIV:
case InstructionKind::VSQRT:
case InstructionKind::VRSQRT:
m_write_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_Q));
m_write_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_Q));
break;
// Read acc register
@ -610,18 +610,18 @@ void AsmOp::update_register_info() {
case InstructionKind::VMADD_BC:
case InstructionKind::VMSUB:
case InstructionKind::VMSUB_BC:
m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC));
m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC));
break;
case InstructionKind::VOPMSUB:
m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC));
m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC));
break;
// Read/Write acc register
case InstructionKind::VMADDA:
case InstructionKind::VMADDA_BC:
case InstructionKind::VMSUBA_BC:
m_write_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC));
m_read_regs.push_back(Register(Reg::COP2_MACRO_SPECIAL, Reg::MACRO_ACC));
m_write_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC));
m_read_regs.push_back(Register(Reg::SPECIAL, Reg::MACRO_ACC));
break;
case InstructionKind::VMOVE:

View File

@ -571,22 +571,19 @@ void OpenGoalAsmOpElement::collect_vars(RegAccessSet& vars, bool) const {
void OpenGoalAsmOpElement::collect_vf_regs(RegSet& regs) const {
for (auto r : m_op->read_regs()) {
if (r.get_kind() == Reg::RegisterKind::VF ||
r.get_kind() == Reg::RegisterKind::COP2_MACRO_SPECIAL) {
if (r.is_vu_float()) {
regs.insert(r);
}
}
for (auto r : m_op->write_regs()) {
if (r.get_kind() == Reg::RegisterKind::VF ||
r.get_kind() == Reg::RegisterKind::COP2_MACRO_SPECIAL) {
if (r.is_vu_float()) {
regs.insert(r);
}
}
for (auto r : m_op->clobber_regs()) {
if (r.get_kind() == Reg::RegisterKind::VF ||
r.get_kind() == Reg::RegisterKind::COP2_MACRO_SPECIAL) {
if (r.is_vu_float()) {
regs.insert(r);
}
}
@ -947,8 +944,7 @@ void RLetElement::apply(const std::function<void(FormElement*)>& f) {
goos::Object RLetElement::reg_list() const {
std::vector<goos::Object> regs;
for (auto& reg : sorted_regs) {
if (reg.get_kind() == Reg::RegisterKind::VF ||
reg.get_kind() == Reg::RegisterKind::COP2_MACRO_SPECIAL) {
if (reg.is_vu_float()) {
std::string reg_name = reg.to_string() == "ACC" ? "acc" : reg.to_string();
regs.push_back(
pretty_print::build_list(pretty_print::to_symbol(fmt::format("{} :class vf", reg_name))));

View File

@ -416,7 +416,7 @@ void ObjectFileDB::ir2_register_usage_pass() {
func.warnings.bad_vf_dependency("{}", x.to_string());
}
if (x.get_kind() == Reg::COP2_MACRO_SPECIAL) {
if (x.get_kind() == Reg::SPECIAL) {
lg::error("Bad vf dependency on {} in {}", x.to_charp(), func.guessed_name.to_string());
func.warnings.bad_vf_dependency("{}", x.to_string());
}

View File

@ -79,7 +79,7 @@ Config read_config_file(const std::string& path_to_config_file) {
for (auto idx : idx_range) {
RegisterTypeCast type_cast;
type_cast.atomic_op_idx = idx;
type_cast.reg = Register(cast.at(1));
type_cast.reg = Register(cast.at(1).get<std::string>());
type_cast.type_name = cast.at(2).get<std::string>();
config.register_type_casts_by_function_by_atomic_op_idx[function_name][idx].push_back(
type_cast);

View File

@ -40,7 +40,7 @@ std::unordered_map<int, std::vector<decompiler::RegisterTypeCast>> parse_cast_hi
for (auto idx : idx_range) {
RegisterTypeCast type_cast;
type_cast.atomic_op_idx = idx;
type_cast.reg = Register(cast.at(1));
type_cast.reg = Register(cast.at(1).get<std::string>());
type_cast.type_name = cast.at(2).get<std::string>();
out[idx].push_back(type_cast);
}

View File

@ -11,6 +11,7 @@
#include "common/util/Range.h"
#include "third-party/fmt/core.h"
#include "common/util/print_float.h"
#include "common/util/CopyOnWrite.h"
TEST(CommonUtil, get_file_path) {
std::vector<std::string> test = {"cabbage", "banana", "apple"};
@ -140,4 +141,45 @@ TEST(CommonUtil, PowerOfTwo) {
EXPECT_EQ(get_power_of_two(3), std::nullopt);
EXPECT_EQ(get_power_of_two(4), 2);
EXPECT_EQ(get_power_of_two(u64(1) << 63), 63);
}
TEST(CommonUtil, CopyOnWrite) {
CopyOnWrite<int> x(2);
EXPECT_EQ(*x, 2);
*x.mut() = 3;
EXPECT_EQ(*x, 3);
CopyOnWrite<int> y = x;
EXPECT_EQ(*x, 3);
EXPECT_EQ(*y, 3);
EXPECT_EQ(x.get(), y.get());
*x.mut() = 12;
EXPECT_EQ(*x, 12);
EXPECT_EQ(*y, 3);
x = y;
EXPECT_EQ(*x, 3);
EXPECT_EQ(*y, 3);
EXPECT_EQ(x.get(), y.get());
y = x;
EXPECT_EQ(*x, 3);
EXPECT_EQ(*y, 3);
EXPECT_EQ(x.get(), y.get());
EXPECT_TRUE(x);
EXPECT_TRUE(y);
CopyOnWrite<int> z;
EXPECT_FALSE(z);
z = x;
EXPECT_TRUE(z);
EXPECT_EQ(x.get(), z.get());
*z.mut() = 15;
EXPECT_EQ(*x, 3);
EXPECT_EQ(*y, 3);
EXPECT_EQ(*z, 15);
}