[RISCV] Group the legal vector types into lists we can iterator over in the RISCVISelLowering constructor

Remove the RISCVVMVTs namespace because I don't think it provides
a lot of value. If we change the mappings we'd likely have to add
or remove things from the list anyway.

Add a wrapper around addRegisterClass that can determine the
register class from the fixed size of the type.

Reviewed By: frasercrmck, rogfer01

Differential Revision: https://reviews.llvm.org/D95491
This commit is contained in:
Craig Topper 2021-01-27 09:48:27 -08:00
parent f30c523660
commit 04570e98c8
2 changed files with 74 additions and 146 deletions

View File

@ -18,7 +18,6 @@
#include "llvm/ADT/StringSwitch.h"
#include "llvm/MC/MCInstrDesc.h"
#include "llvm/MC/SubtargetFeature.h"
#include "llvm/Support/MachineValueType.h"
namespace llvm {
@ -257,62 +256,6 @@ void validate(const Triple &TT, const FeatureBitset &FeatureBits);
} // namespace RISCVFeatures
namespace RISCVVMVTs {
constexpr MVT vint8mf8_t = MVT::nxv1i8;
constexpr MVT vint8mf4_t = MVT::nxv2i8;
constexpr MVT vint8mf2_t = MVT::nxv4i8;
constexpr MVT vint8m1_t = MVT::nxv8i8;
constexpr MVT vint8m2_t = MVT::nxv16i8;
constexpr MVT vint8m4_t = MVT::nxv32i8;
constexpr MVT vint8m8_t = MVT::nxv64i8;
constexpr MVT vint16mf4_t = MVT::nxv1i16;
constexpr MVT vint16mf2_t = MVT::nxv2i16;
constexpr MVT vint16m1_t = MVT::nxv4i16;
constexpr MVT vint16m2_t = MVT::nxv8i16;
constexpr MVT vint16m4_t = MVT::nxv16i16;
constexpr MVT vint16m8_t = MVT::nxv32i16;
constexpr MVT vint32mf2_t = MVT::nxv1i32;
constexpr MVT vint32m1_t = MVT::nxv2i32;
constexpr MVT vint32m2_t = MVT::nxv4i32;
constexpr MVT vint32m4_t = MVT::nxv8i32;
constexpr MVT vint32m8_t = MVT::nxv16i32;
constexpr MVT vint64m1_t = MVT::nxv1i64;
constexpr MVT vint64m2_t = MVT::nxv2i64;
constexpr MVT vint64m4_t = MVT::nxv4i64;
constexpr MVT vint64m8_t = MVT::nxv8i64;
constexpr MVT vfloat16mf4_t = MVT::nxv1f16;
constexpr MVT vfloat16mf2_t = MVT::nxv2f16;
constexpr MVT vfloat16m1_t = MVT::nxv4f16;
constexpr MVT vfloat16m2_t = MVT::nxv8f16;
constexpr MVT vfloat16m4_t = MVT::nxv16f16;
constexpr MVT vfloat16m8_t = MVT::nxv32f16;
constexpr MVT vfloat32mf2_t = MVT::nxv1f32;
constexpr MVT vfloat32m1_t = MVT::nxv2f32;
constexpr MVT vfloat32m2_t = MVT::nxv4f32;
constexpr MVT vfloat32m4_t = MVT::nxv8f32;
constexpr MVT vfloat32m8_t = MVT::nxv16f32;
constexpr MVT vfloat64m1_t = MVT::nxv1f64;
constexpr MVT vfloat64m2_t = MVT::nxv2f64;
constexpr MVT vfloat64m4_t = MVT::nxv4f64;
constexpr MVT vfloat64m8_t = MVT::nxv8f64;
constexpr MVT vbool1_t = MVT::nxv64i1;
constexpr MVT vbool2_t = MVT::nxv32i1;
constexpr MVT vbool4_t = MVT::nxv16i1;
constexpr MVT vbool8_t = MVT::nxv8i1;
constexpr MVT vbool16_t = MVT::nxv4i1;
constexpr MVT vbool32_t = MVT::nxv2i1;
constexpr MVT vbool64_t = MVT::nxv1i1;
} // namespace RISCVVMVTs
enum class RISCVVSEW {
SEW_8 = 0,
SEW_16,

View File

@ -90,64 +90,56 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.hasStdExtD())
addRegisterClass(MVT::f64, &RISCV::FPR64RegClass);
static const MVT::SimpleValueType BoolVecVTs[] = {
MVT::nxv1i1, MVT::nxv2i1, MVT::nxv4i1, MVT::nxv8i1,
MVT::nxv16i1, MVT::nxv32i1, MVT::nxv64i1};
static const MVT::SimpleValueType IntVecVTs[] = {
MVT::nxv1i8, MVT::nxv2i8, MVT::nxv4i8, MVT::nxv8i8, MVT::nxv16i8,
MVT::nxv32i8, MVT::nxv64i8, MVT::nxv1i16, MVT::nxv2i16, MVT::nxv4i16,
MVT::nxv8i16, MVT::nxv16i16, MVT::nxv32i16, MVT::nxv1i32, MVT::nxv2i32,
MVT::nxv4i32, MVT::nxv8i32, MVT::nxv16i32, MVT::nxv1i64, MVT::nxv2i64,
MVT::nxv4i64, MVT::nxv8i64};
static const MVT::SimpleValueType F16VecVTs[] = {
MVT::nxv1f16, MVT::nxv2f16, MVT::nxv4f16,
MVT::nxv8f16, MVT::nxv16f16, MVT::nxv32f16};
static const MVT::SimpleValueType F32VecVTs[] = {
MVT::nxv1f32, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv8f32, MVT::nxv16f32};
static const MVT::SimpleValueType F64VecVTs[] = {
MVT::nxv1f64, MVT::nxv2f64, MVT::nxv4f64, MVT::nxv8f64};
if (Subtarget.hasStdExtV()) {
addRegisterClass(RISCVVMVTs::vbool64_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vbool32_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vbool16_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vbool8_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vbool4_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vbool2_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vbool1_t, &RISCV::VRRegClass);
auto addRegClassForRVV = [this](MVT VT) {
unsigned Size = VT.getSizeInBits().getKnownMinValue();
assert(Size <= 512 && isPowerOf2_32(Size));
const TargetRegisterClass *RC;
if (Size <= 64)
RC = &RISCV::VRRegClass;
else if (Size == 128)
RC = &RISCV::VRM2RegClass;
else if (Size == 256)
RC = &RISCV::VRM4RegClass;
else
RC = &RISCV::VRM8RegClass;
addRegisterClass(RISCVVMVTs::vint8mf8_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vint8mf4_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vint8mf2_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vint8m1_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vint8m2_t, &RISCV::VRM2RegClass);
addRegisterClass(RISCVVMVTs::vint8m4_t, &RISCV::VRM4RegClass);
addRegisterClass(RISCVVMVTs::vint8m8_t, &RISCV::VRM8RegClass);
addRegisterClass(VT, RC);
};
addRegisterClass(RISCVVMVTs::vint16mf4_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vint16mf2_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vint16m1_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vint16m2_t, &RISCV::VRM2RegClass);
addRegisterClass(RISCVVMVTs::vint16m4_t, &RISCV::VRM4RegClass);
addRegisterClass(RISCVVMVTs::vint16m8_t, &RISCV::VRM8RegClass);
for (MVT VT : BoolVecVTs)
addRegClassForRVV(VT);
for (MVT VT : IntVecVTs)
addRegClassForRVV(VT);
addRegisterClass(RISCVVMVTs::vint32mf2_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vint32m1_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vint32m2_t, &RISCV::VRM2RegClass);
addRegisterClass(RISCVVMVTs::vint32m4_t, &RISCV::VRM4RegClass);
addRegisterClass(RISCVVMVTs::vint32m8_t, &RISCV::VRM8RegClass);
if (Subtarget.hasStdExtZfh())
for (MVT VT : F16VecVTs)
addRegClassForRVV(VT);
addRegisterClass(RISCVVMVTs::vint64m1_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vint64m2_t, &RISCV::VRM2RegClass);
addRegisterClass(RISCVVMVTs::vint64m4_t, &RISCV::VRM4RegClass);
addRegisterClass(RISCVVMVTs::vint64m8_t, &RISCV::VRM8RegClass);
if (Subtarget.hasStdExtF())
for (MVT VT : F32VecVTs)
addRegClassForRVV(VT);
if (Subtarget.hasStdExtZfh()) {
addRegisterClass(RISCVVMVTs::vfloat16mf4_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vfloat16mf2_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vfloat16m1_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vfloat16m2_t, &RISCV::VRM2RegClass);
addRegisterClass(RISCVVMVTs::vfloat16m4_t, &RISCV::VRM4RegClass);
addRegisterClass(RISCVVMVTs::vfloat16m8_t, &RISCV::VRM8RegClass);
}
if (Subtarget.hasStdExtF()) {
addRegisterClass(RISCVVMVTs::vfloat32mf2_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vfloat32m1_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vfloat32m2_t, &RISCV::VRM2RegClass);
addRegisterClass(RISCVVMVTs::vfloat32m4_t, &RISCV::VRM4RegClass);
addRegisterClass(RISCVVMVTs::vfloat32m8_t, &RISCV::VRM8RegClass);
}
if (Subtarget.hasStdExtD()) {
addRegisterClass(RISCVVMVTs::vfloat64m1_t, &RISCV::VRRegClass);
addRegisterClass(RISCVVMVTs::vfloat64m2_t, &RISCV::VRM2RegClass);
addRegisterClass(RISCVVMVTs::vfloat64m4_t, &RISCV::VRM4RegClass);
addRegisterClass(RISCVVMVTs::vfloat64m8_t, &RISCV::VRM8RegClass);
}
if (Subtarget.hasStdExtD())
for (MVT VT : F64VecVTs)
addRegClassForRVV(VT);
}
// Compute derived properties from the register classes.
@ -379,9 +371,22 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.is64Bit()) {
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i64, Custom);
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i64, Custom);
} else {
// We must custom-lower certain vXi64 operations on RV32 due to the vector
// element type being illegal.
setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom);
}
for (auto VT : MVT::integer_scalable_vector_valuetypes()) {
for (MVT VT : BoolVecVTs) {
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
// Mask VTs are custom-expanded into a series of standard nodes
setOperationAction(ISD::TRUNCATE, VT, Custom);
}
for (MVT VT : IntVecVTs) {
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
setOperationAction(ISD::SMIN, VT, Legal);
@ -392,30 +397,18 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::ROTL, VT, Expand);
setOperationAction(ISD::ROTR, VT, Expand);
if (isTypeLegal(VT)) {
// Custom-lower extensions and truncations from/to mask types.
setOperationAction(ISD::ANY_EXTEND, VT, Custom);
setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
// Custom-lower extensions and truncations from/to mask types.
setOperationAction(ISD::ANY_EXTEND, VT, Custom);
setOperationAction(ISD::SIGN_EXTEND, VT, Custom);
setOperationAction(ISD::ZERO_EXTEND, VT, Custom);
// We custom-lower all legally-typed vector truncates:
// 1. Mask VTs are custom-expanded into a series of standard nodes
// 2. Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
// nodes which truncate by one power of two at a time.
setOperationAction(ISD::TRUNCATE, VT, Custom);
// Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
// nodes which truncate by one power of two at a time.
setOperationAction(ISD::TRUNCATE, VT, Custom);
// Custom-lower insert/extract operations to simplify patterns.
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
}
}
// We must custom-lower certain vXi64 operations on RV32 due to the vector
// element type being illegal.
if (!Subtarget.is64Bit()) {
setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom);
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom);
// Custom-lower insert/extract operations to simplify patterns.
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
}
// Expand various CCs to best match the RVV ISA, which natively supports UNE
@ -441,25 +434,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setCondCodeAction(CC, VT, Expand);
};
if (Subtarget.hasStdExtZfh()) {
for (auto VT : {RISCVVMVTs::vfloat16mf4_t, RISCVVMVTs::vfloat16mf2_t,
RISCVVMVTs::vfloat16m1_t, RISCVVMVTs::vfloat16m2_t,
RISCVVMVTs::vfloat16m4_t, RISCVVMVTs::vfloat16m8_t})
if (Subtarget.hasStdExtZfh())
for (MVT VT : F16VecVTs)
SetCommonVFPActions(VT);
}
if (Subtarget.hasStdExtF()) {
for (auto VT : {RISCVVMVTs::vfloat32mf2_t, RISCVVMVTs::vfloat32m1_t,
RISCVVMVTs::vfloat32m2_t, RISCVVMVTs::vfloat32m4_t,
RISCVVMVTs::vfloat32m8_t})
if (Subtarget.hasStdExtF())
for (MVT VT : F32VecVTs)
SetCommonVFPActions(VT);
}
if (Subtarget.hasStdExtD()) {
for (auto VT : {RISCVVMVTs::vfloat64m1_t, RISCVVMVTs::vfloat64m2_t,
RISCVVMVTs::vfloat64m4_t, RISCVVMVTs::vfloat64m8_t})
if (Subtarget.hasStdExtD())
for (MVT VT : F64VecVTs)
SetCommonVFPActions(VT);
}
}
// Function alignments.