mirror of
https://github.com/RPCSX/llvm.git
synced 2024-11-27 13:40:30 +00:00
[MVT][SVE] Scalable vector MVTs (3/3)
Adds MVT::ElementCount to represent the length of a vector which may be scalable, then adds helper functions that work with it. Patch by Graham Hunter. Differential Revision: https://reviews.llvm.org/D32019 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@300842 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
parent
0e3700625d
commit
0f69ba8243
@ -232,6 +232,42 @@ class MVT {
|
||||
|
||||
SimpleValueType SimpleTy;
|
||||
|
||||
|
||||
// A class to represent the number of elements in a vector
|
||||
//
|
||||
// For fixed-length vectors, the total number of elements is equal to 'Min'
|
||||
// For scalable vectors, the total number of elements is a multiple of 'Min'
|
||||
class ElementCount {
|
||||
public:
|
||||
unsigned Min;
|
||||
bool Scalable;
|
||||
|
||||
ElementCount(unsigned Min, bool Scalable)
|
||||
: Min(Min), Scalable(Scalable) {}
|
||||
|
||||
ElementCount operator*(unsigned RHS) {
|
||||
return { Min * RHS, Scalable };
|
||||
}
|
||||
|
||||
ElementCount& operator*=(unsigned RHS) {
|
||||
Min *= RHS;
|
||||
return *this;
|
||||
}
|
||||
|
||||
ElementCount operator/(unsigned RHS) {
|
||||
return { Min / RHS, Scalable };
|
||||
}
|
||||
|
||||
ElementCount& operator/=(unsigned RHS) {
|
||||
Min /= RHS;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool operator==(const ElementCount& RHS) {
|
||||
return Min == RHS.Min && Scalable == RHS.Scalable;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr MVT() : SimpleTy(INVALID_SIMPLE_VALUE_TYPE) {}
|
||||
constexpr MVT(SimpleValueType SVT) : SimpleTy(SVT) {}
|
||||
|
||||
@ -276,6 +312,15 @@ class MVT {
|
||||
SimpleTy <= MVT::LAST_VECTOR_VALUETYPE);
|
||||
}
|
||||
|
||||
/// Return true if this is a vector value type where the
|
||||
/// runtime length is machine dependent
|
||||
bool isScalableVector() const {
|
||||
return ((SimpleTy >= MVT::FIRST_INTEGER_SCALABLE_VALUETYPE &&
|
||||
SimpleTy <= MVT::LAST_INTEGER_SCALABLE_VALUETYPE) ||
|
||||
(SimpleTy >= MVT::FIRST_FP_SCALABLE_VALUETYPE &&
|
||||
SimpleTy <= MVT::LAST_FP_SCALABLE_VALUETYPE));
|
||||
}
|
||||
|
||||
/// Return true if this is a 16-bit vector type.
|
||||
bool is16BitVector() const {
|
||||
return (SimpleTy == MVT::v2i8 || SimpleTy == MVT::v1i16 ||
|
||||
@ -560,6 +605,10 @@ class MVT {
|
||||
}
|
||||
}
|
||||
|
||||
MVT::ElementCount getVectorElementCount() const {
|
||||
return { getVectorNumElements(), isScalableVector() };
|
||||
}
|
||||
|
||||
unsigned getSizeInBits() const {
|
||||
switch (SimpleTy) {
|
||||
default:
|
||||
@ -837,6 +886,83 @@ class MVT {
|
||||
return (MVT::SimpleValueType)(MVT::INVALID_SIMPLE_VALUE_TYPE);
|
||||
}
|
||||
|
||||
static MVT getScalableVectorVT(MVT VT, unsigned NumElements) {
|
||||
switch(VT.SimpleTy) {
|
||||
default:
|
||||
break;
|
||||
case MVT::i1:
|
||||
if (NumElements == 2) return MVT::nxv2i1;
|
||||
if (NumElements == 4) return MVT::nxv4i1;
|
||||
if (NumElements == 8) return MVT::nxv8i1;
|
||||
if (NumElements == 16) return MVT::nxv16i1;
|
||||
if (NumElements == 32) return MVT::nxv32i1;
|
||||
break;
|
||||
case MVT::i8:
|
||||
if (NumElements == 1) return MVT::nxv1i8;
|
||||
if (NumElements == 2) return MVT::nxv2i8;
|
||||
if (NumElements == 4) return MVT::nxv4i8;
|
||||
if (NumElements == 8) return MVT::nxv8i8;
|
||||
if (NumElements == 16) return MVT::nxv16i8;
|
||||
if (NumElements == 32) return MVT::nxv32i8;
|
||||
break;
|
||||
case MVT::i16:
|
||||
if (NumElements == 1) return MVT::nxv1i16;
|
||||
if (NumElements == 2) return MVT::nxv2i16;
|
||||
if (NumElements == 4) return MVT::nxv4i16;
|
||||
if (NumElements == 8) return MVT::nxv8i16;
|
||||
if (NumElements == 16) return MVT::nxv16i16;
|
||||
if (NumElements == 32) return MVT::nxv32i16;
|
||||
break;
|
||||
case MVT::i32:
|
||||
if (NumElements == 1) return MVT::nxv1i32;
|
||||
if (NumElements == 2) return MVT::nxv2i32;
|
||||
if (NumElements == 4) return MVT::nxv4i32;
|
||||
if (NumElements == 8) return MVT::nxv8i32;
|
||||
if (NumElements == 16) return MVT::nxv16i32;
|
||||
if (NumElements == 32) return MVT::nxv32i32;
|
||||
break;
|
||||
case MVT::i64:
|
||||
if (NumElements == 1) return MVT::nxv1i64;
|
||||
if (NumElements == 2) return MVT::nxv2i64;
|
||||
if (NumElements == 4) return MVT::nxv4i64;
|
||||
if (NumElements == 8) return MVT::nxv8i64;
|
||||
if (NumElements == 16) return MVT::nxv16i64;
|
||||
if (NumElements == 32) return MVT::nxv32i64;
|
||||
break;
|
||||
case MVT::f16:
|
||||
if (NumElements == 2) return MVT::nxv2f16;
|
||||
if (NumElements == 4) return MVT::nxv4f16;
|
||||
if (NumElements == 8) return MVT::nxv8f16;
|
||||
break;
|
||||
case MVT::f32:
|
||||
if (NumElements == 1) return MVT::nxv1f32;
|
||||
if (NumElements == 2) return MVT::nxv2f32;
|
||||
if (NumElements == 4) return MVT::nxv4f32;
|
||||
if (NumElements == 8) return MVT::nxv8f32;
|
||||
if (NumElements == 16) return MVT::nxv16f32;
|
||||
break;
|
||||
case MVT::f64:
|
||||
if (NumElements == 1) return MVT::nxv1f64;
|
||||
if (NumElements == 2) return MVT::nxv2f64;
|
||||
if (NumElements == 4) return MVT::nxv4f64;
|
||||
if (NumElements == 8) return MVT::nxv8f64;
|
||||
break;
|
||||
}
|
||||
return (MVT::SimpleValueType)(MVT::INVALID_SIMPLE_VALUE_TYPE);
|
||||
}
|
||||
|
||||
static MVT getVectorVT(MVT VT, unsigned NumElements, bool IsScalable) {
|
||||
if (IsScalable)
|
||||
return getScalableVectorVT(VT, NumElements);
|
||||
return getVectorVT(VT, NumElements);
|
||||
}
|
||||
|
||||
static MVT getVectorVT(MVT VT, MVT::ElementCount EC) {
|
||||
if (EC.Scalable)
|
||||
return getScalableVectorVT(VT, EC.Min);
|
||||
return getVectorVT(VT, EC.Min);
|
||||
}
|
||||
|
||||
/// Return the value type corresponding to the specified type. This returns
|
||||
/// all pointers as iPTR. If HandleUnknown is true, unknown types are
|
||||
/// returned as Other, otherwise they are invalid.
|
||||
@ -887,6 +1013,14 @@ class MVT {
|
||||
MVT::FIRST_FP_VECTOR_VALUETYPE,
|
||||
(MVT::SimpleValueType)(MVT::LAST_FP_VECTOR_VALUETYPE + 1));
|
||||
}
|
||||
static mvt_range integer_scalable_vector_valuetypes() {
|
||||
return mvt_range(MVT::FIRST_INTEGER_SCALABLE_VALUETYPE,
|
||||
(MVT::SimpleValueType)(MVT::LAST_INTEGER_SCALABLE_VALUETYPE + 1));
|
||||
}
|
||||
static mvt_range fp_scalable_vector_valuetypes() {
|
||||
return mvt_range(MVT::FIRST_FP_SCALABLE_VALUETYPE,
|
||||
(MVT::SimpleValueType)(MVT::LAST_FP_SCALABLE_VALUETYPE + 1));
|
||||
}
|
||||
/// @}
|
||||
};
|
||||
|
||||
|
@ -67,24 +67,41 @@ namespace llvm {
|
||||
|
||||
/// Returns the EVT that represents a vector NumElements in length, where
|
||||
/// each element is of type VT.
|
||||
static EVT getVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements) {
|
||||
MVT M = MVT::getVectorVT(VT.V, NumElements);
|
||||
static EVT getVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements,
|
||||
bool IsScalable = false) {
|
||||
MVT M = MVT::getVectorVT(VT.V, NumElements, IsScalable);
|
||||
if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE)
|
||||
return M;
|
||||
|
||||
assert(!IsScalable && "We don't support extended scalable types yet");
|
||||
return getExtendedVectorVT(Context, VT, NumElements);
|
||||
}
|
||||
|
||||
/// Returns the EVT that represents a vector EC.Min elements in length,
|
||||
/// where each element is of type VT.
|
||||
static EVT getVectorVT(LLVMContext &Context, EVT VT, MVT::ElementCount EC) {
|
||||
MVT M = MVT::getVectorVT(VT.V, EC);
|
||||
if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE)
|
||||
return M;
|
||||
assert (!EC.Scalable && "We don't support extended scalable types yet");
|
||||
return getExtendedVectorVT(Context, VT, EC.Min);
|
||||
}
|
||||
|
||||
/// Return a vector with the same number of elements as this vector, but
|
||||
/// with the element type converted to an integer type with the same
|
||||
/// bitwidth.
|
||||
EVT changeVectorElementTypeToInteger() const {
|
||||
if (!isSimple())
|
||||
if (!isSimple()) {
|
||||
assert (!isScalableVector() &&
|
||||
"We don't support extended scalable types yet");
|
||||
return changeExtendedVectorElementTypeToInteger();
|
||||
}
|
||||
MVT EltTy = getSimpleVT().getVectorElementType();
|
||||
unsigned BitWidth = EltTy.getSizeInBits();
|
||||
MVT IntTy = MVT::getIntegerVT(BitWidth);
|
||||
MVT VecTy = MVT::getVectorVT(IntTy, getVectorNumElements());
|
||||
assert(VecTy.SimpleTy >= 0 &&
|
||||
MVT VecTy = MVT::getVectorVT(IntTy, getVectorNumElements(),
|
||||
isScalableVector());
|
||||
assert(VecTy.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE &&
|
||||
"Simple vector VT not representable by simple integer vector VT!");
|
||||
return VecTy;
|
||||
}
|
||||
@ -132,6 +149,17 @@ namespace llvm {
|
||||
return isSimple() ? V.isVector() : isExtendedVector();
|
||||
}
|
||||
|
||||
/// Return true if this is a vector type where the runtime
|
||||
/// length is machine dependent
|
||||
bool isScalableVector() const {
|
||||
// FIXME: We don't support extended scalable types yet, because the
|
||||
// matching IR type doesn't exist. Once it has been added, this can
|
||||
// be changed to call isExtendedScalableVector.
|
||||
if (!isSimple())
|
||||
return false;
|
||||
return V.isScalableVector();
|
||||
}
|
||||
|
||||
/// Return true if this is a 16-bit vector type.
|
||||
bool is16BitVector() const {
|
||||
return isSimple() ? V.is16BitVector() : isExtended16BitVector();
|
||||
@ -247,6 +275,17 @@ namespace llvm {
|
||||
return getExtendedVectorNumElements();
|
||||
}
|
||||
|
||||
// Given a (possibly scalable) vector type, return the ElementCount
|
||||
MVT::ElementCount getVectorElementCount() const {
|
||||
assert((isVector()) && "Invalid vector type!");
|
||||
if (isSimple())
|
||||
return V.getVectorElementCount();
|
||||
|
||||
assert(!isScalableVector() &&
|
||||
"We don't support extended scalable types yet");
|
||||
return {getExtendedVectorNumElements(), false};
|
||||
}
|
||||
|
||||
/// Return the size of the specified value type in bits.
|
||||
unsigned getSizeInBits() const {
|
||||
if (isSimple())
|
||||
@ -301,7 +340,7 @@ namespace llvm {
|
||||
EVT widenIntegerVectorElementType(LLVMContext &Context) const {
|
||||
EVT EltVT = getVectorElementType();
|
||||
EltVT = EVT::getIntegerVT(Context, 2 * EltVT.getSizeInBits());
|
||||
return EVT::getVectorVT(Context, EltVT, getVectorNumElements());
|
||||
return EVT::getVectorVT(Context, EltVT, getVectorElementCount());
|
||||
}
|
||||
|
||||
// Return a VT for a vector type with the same element type but
|
||||
@ -309,9 +348,8 @@ namespace llvm {
|
||||
// extended type.
|
||||
EVT getHalfNumVectorElementsVT(LLVMContext &Context) const {
|
||||
EVT EltVT = getVectorElementType();
|
||||
auto EltCnt = getVectorNumElements();
|
||||
assert(!(getVectorNumElements() & 1) &&
|
||||
"Splitting vector, but not in half!");
|
||||
auto EltCnt = getVectorElementCount();
|
||||
assert(!(EltCnt.Min & 1) && "Splitting vector, but not in half!");
|
||||
return EVT::getVectorVT(Context, EltVT, EltCnt / 2);
|
||||
}
|
||||
|
||||
@ -327,7 +365,8 @@ namespace llvm {
|
||||
if (!isPow2VectorType()) {
|
||||
unsigned NElts = getVectorNumElements();
|
||||
unsigned Pow2NElts = 1 << Log2_32_Ceil(NElts);
|
||||
return EVT::getVectorVT(Context, getVectorElementType(), Pow2NElts);
|
||||
return EVT::getVectorVT(Context, getVectorElementType(), Pow2NElts,
|
||||
isScalableVector());
|
||||
}
|
||||
else {
|
||||
return *this;
|
||||
|
@ -925,9 +925,9 @@ SDValue DAGTypeLegalizer::BitConvertVectorToIntegerVector(SDValue Op) {
|
||||
assert(Op.getValueType().isVector() && "Only applies to vectors!");
|
||||
unsigned EltWidth = Op.getScalarValueSizeInBits();
|
||||
EVT EltNVT = EVT::getIntegerVT(*DAG.getContext(), EltWidth);
|
||||
unsigned NumElts = Op.getValueType().getVectorNumElements();
|
||||
auto EltCnt = Op.getValueType().getVectorElementCount();
|
||||
return DAG.getNode(ISD::BITCAST, SDLoc(Op),
|
||||
EVT::getVectorVT(*DAG.getContext(), EltNVT, NumElts), Op);
|
||||
EVT::getVectorVT(*DAG.getContext(), EltNVT, EltCnt), Op);
|
||||
}
|
||||
|
||||
SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
|
||||
|
@ -9,6 +9,7 @@ set(CodeGenSources
|
||||
DIEHashTest.cpp
|
||||
LowLevelTypeTest.cpp
|
||||
MachineInstrBundleIteratorTest.cpp
|
||||
ScalableVectorMVTsTest.cpp
|
||||
)
|
||||
|
||||
add_llvm_unittest(CodeGenTests
|
||||
|
88
unittests/CodeGen/ScalableVectorMVTsTest.cpp
Normal file
88
unittests/CodeGen/ScalableVectorMVTsTest.cpp
Normal file
@ -0,0 +1,88 @@
|
||||
//===-------- llvm/unittest/CodeGen/ScalableVectorMVTsTest.cpp ------------===//
|
||||
//
|
||||
// The LLVM Compiler Infrastructure
|
||||
//
|
||||
// This file is distributed under the University of Illinois Open Source
|
||||
// License. See LICENSE.TXT for details.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/CodeGen/MachineValueType.h"
|
||||
#include "llvm/CodeGen/ValueTypes.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(ScalableVectorMVTsTest, IntegerMVTs) {
|
||||
for (auto VecTy : MVT::integer_scalable_vector_valuetypes()) {
|
||||
ASSERT_TRUE(VecTy.isValid());
|
||||
ASSERT_TRUE(VecTy.isInteger());
|
||||
ASSERT_TRUE(VecTy.isVector());
|
||||
ASSERT_TRUE(VecTy.isScalableVector());
|
||||
ASSERT_TRUE(VecTy.getScalarType().isValid());
|
||||
|
||||
ASSERT_FALSE(VecTy.isFloatingPoint());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ScalableVectorMVTsTest, FloatMVTs) {
|
||||
for (auto VecTy : MVT::fp_scalable_vector_valuetypes()) {
|
||||
ASSERT_TRUE(VecTy.isValid());
|
||||
ASSERT_TRUE(VecTy.isFloatingPoint());
|
||||
ASSERT_TRUE(VecTy.isVector());
|
||||
ASSERT_TRUE(VecTy.isScalableVector());
|
||||
ASSERT_TRUE(VecTy.getScalarType().isValid());
|
||||
|
||||
ASSERT_FALSE(VecTy.isInteger());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ScalableVectorMVTsTest, HelperFuncs) {
|
||||
LLVMContext Ctx;
|
||||
|
||||
// Create with scalable flag
|
||||
EVT Vnx4i32 = EVT::getVectorVT(Ctx, MVT::i32, 4, /*Scalable=*/true);
|
||||
ASSERT_TRUE(Vnx4i32.isScalableVector());
|
||||
|
||||
// Create with separate MVT::ElementCount
|
||||
auto EltCnt = MVT::ElementCount(2, true);
|
||||
EVT Vnx2i32 = EVT::getVectorVT(Ctx, MVT::i32, EltCnt);
|
||||
ASSERT_TRUE(Vnx2i32.isScalableVector());
|
||||
|
||||
// Create with inline MVT::ElementCount
|
||||
EVT Vnx2i64 = EVT::getVectorVT(Ctx, MVT::i64, {2, true});
|
||||
ASSERT_TRUE(Vnx2i64.isScalableVector());
|
||||
|
||||
// Check that changing scalar types/element count works
|
||||
EXPECT_EQ(Vnx2i32.widenIntegerVectorElementType(Ctx), Vnx2i64);
|
||||
EXPECT_EQ(Vnx4i32.getHalfNumVectorElementsVT(Ctx), Vnx2i32);
|
||||
|
||||
// Check that overloaded '*' and '/' operators work
|
||||
EXPECT_EQ(EVT::getVectorVT(Ctx, MVT::i64, EltCnt * 2), MVT::nxv4i64);
|
||||
EXPECT_EQ(EVT::getVectorVT(Ctx, MVT::i64, EltCnt / 2), MVT::nxv1i64);
|
||||
|
||||
// Check that float->int conversion works
|
||||
EVT Vnx2f64 = EVT::getVectorVT(Ctx, MVT::f64, {2, true});
|
||||
EXPECT_EQ(Vnx2f64.changeTypeToInteger(), Vnx2i64);
|
||||
|
||||
// Check fields inside MVT::ElementCount
|
||||
EltCnt = Vnx4i32.getVectorElementCount();
|
||||
EXPECT_EQ(EltCnt.Min, 4);
|
||||
ASSERT_TRUE(EltCnt.Scalable);
|
||||
|
||||
// Check that fixed-length vector types aren't scalable.
|
||||
EVT V8i32 = EVT::getVectorVT(Ctx, MVT::i32, 8);
|
||||
ASSERT_FALSE(V8i32.isScalableVector());
|
||||
EVT V4f64 = EVT::getVectorVT(Ctx, MVT::f64, {4, false});
|
||||
ASSERT_FALSE(V4f64.isScalableVector());
|
||||
|
||||
// Check that MVT::ElementCount works for fixed-length types.
|
||||
EltCnt = V8i32.getVectorElementCount();
|
||||
EXPECT_EQ(EltCnt.Min, 8);
|
||||
ASSERT_FALSE(EltCnt.Scalable);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user