diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h index 66a722ef8e15..59791738fdf6 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h @@ -215,12 +215,27 @@ LegalityPredicate isPointer(unsigned TypeIdx, unsigned AddrSpace); /// True iff the specified type index is a scalar that's narrower than the given /// size. LegalityPredicate narrowerThan(unsigned TypeIdx, unsigned Size); + /// True iff the specified type index is a scalar that's wider than the given /// size. LegalityPredicate widerThan(unsigned TypeIdx, unsigned Size); + +/// True iff the specified type index is a scalar or vector with an element type +/// that's narrower than the given size. +LegalityPredicate scalarOrEltNarrowerThan(unsigned TypeIdx, unsigned Size); + +/// True iff the specified type index is a scalar or a vector with an element +/// type that's wider than the given size. +LegalityPredicate scalarOrEltWiderThan(unsigned TypeIdx, unsigned Size); + /// True iff the specified type index is a scalar whose size is not a power of /// 2. LegalityPredicate sizeNotPow2(unsigned TypeIdx); + +/// True iff the specified type index is a scalar or vector whose element size +/// is not a power of 2. +LegalityPredicate scalarOrEltSizeNotPow2(unsigned TypeIdx); + /// True iff the specified type indices are both the same bit size. LegalityPredicate sameSize(unsigned TypeIdx0, unsigned TypeIdx1); /// True iff the specified MMO index has a size that is not a power of 2 @@ -237,10 +252,20 @@ LegalityPredicate atomicOrderingAtLeastOrStrongerThan(unsigned MMOIdx, namespace LegalizeMutations { /// Select this specific type for the given type index. LegalizeMutation changeTo(unsigned TypeIdx, LLT Ty); + /// Keep the same type as the given type index. LegalizeMutation changeTo(unsigned TypeIdx, unsigned FromTypeIdx); -/// Widen the type for the given type index to the next power of 2. -LegalizeMutation widenScalarToNextPow2(unsigned TypeIdx, unsigned Min = 0); + +/// Keep the same scalar or element type as the given type index. +LegalizeMutation changeElementTo(unsigned TypeIdx, unsigned FromTypeIdx); + +/// Keep the same scalar or element type as the given type. +LegalizeMutation changeElementTo(unsigned TypeIdx, LLT Ty); + +/// Widen the scalar type or vector element type for the given type index to the +/// next power of 2. +LegalizeMutation widenScalarOrEltToNextPow2(unsigned TypeIdx, unsigned Min = 0); + /// Add more elements to the type for the given type index to the next power of /// 2. LegalizeMutation moreElementsToNextPow2(unsigned TypeIdx, unsigned Min = 0); @@ -618,8 +643,19 @@ public: LegalizeRuleSet &widenScalarToNextPow2(unsigned TypeIdx, unsigned MinSize = 0) { using namespace LegalityPredicates; - return actionIf(LegalizeAction::WidenScalar, sizeNotPow2(typeIdx(TypeIdx)), - LegalizeMutations::widenScalarToNextPow2(TypeIdx, MinSize)); + return actionIf( + LegalizeAction::WidenScalar, sizeNotPow2(typeIdx(TypeIdx)), + LegalizeMutations::widenScalarOrEltToNextPow2(TypeIdx, MinSize)); + } + + /// Widen the scalar or vector element type to the next power of two that is + /// at least MinSize. No effect if the scalar size is a power of two. + LegalizeRuleSet &widenScalarOrEltToNextPow2(unsigned TypeIdx, + unsigned MinSize = 0) { + using namespace LegalityPredicates; + return actionIf( + LegalizeAction::WidenScalar, scalarOrEltSizeNotPow2(typeIdx(TypeIdx)), + LegalizeMutations::widenScalarOrEltToNextPow2(TypeIdx, MinSize)); } LegalizeRuleSet &narrowScalar(unsigned TypeIdx, LegalizeMutation Mutation) { @@ -634,6 +670,15 @@ public: LegalizeMutations::scalarize(TypeIdx)); } + /// Ensure the scalar is at least as wide as Ty. + LegalizeRuleSet &minScalarOrElt(unsigned TypeIdx, const LLT &Ty) { + using namespace LegalityPredicates; + using namespace LegalizeMutations; + return actionIf(LegalizeAction::WidenScalar, + scalarOrEltNarrowerThan(TypeIdx, Ty.getScalarSizeInBits()), + changeElementTo(typeIdx(TypeIdx), Ty)); + } + /// Ensure the scalar is at least as wide as Ty. LegalizeRuleSet &minScalar(unsigned TypeIdx, const LLT &Ty) { using namespace LegalityPredicates; @@ -643,6 +688,15 @@ public: changeTo(typeIdx(TypeIdx), Ty)); } + /// Ensure the scalar is at most as wide as Ty. + LegalizeRuleSet &maxScalarOrElt(unsigned TypeIdx, const LLT &Ty) { + using namespace LegalityPredicates; + using namespace LegalizeMutations; + return actionIf(LegalizeAction::NarrowScalar, + scalarOrEltWiderThan(TypeIdx, Ty.getScalarSizeInBits()), + changeElementTo(typeIdx(TypeIdx), Ty)); + } + /// Ensure the scalar is at most as wide as Ty. LegalizeRuleSet &maxScalar(unsigned TypeIdx, const LLT &Ty) { using namespace LegalityPredicates; @@ -659,12 +713,12 @@ public: const LLT &Ty) { using namespace LegalityPredicates; using namespace LegalizeMutations; - return actionIf(LegalizeAction::NarrowScalar, - [=](const LegalityQuery &Query) { - return widerThan(TypeIdx, Ty.getSizeInBits()) && - Predicate(Query); - }, - changeTo(typeIdx(TypeIdx), Ty)); + return actionIf( + LegalizeAction::NarrowScalar, + [=](const LegalityQuery &Query) { + return widerThan(TypeIdx, Ty.getSizeInBits()) && Predicate(Query); + }, + changeElementTo(typeIdx(TypeIdx), Ty)); } /// Limit the range of scalar sizes to MinTy and MaxTy. @@ -674,6 +728,12 @@ public: return minScalar(TypeIdx, MinTy).maxScalar(TypeIdx, MaxTy); } + /// Limit the range of scalar sizes to MinTy and MaxTy. + LegalizeRuleSet &clampScalarOrElt(unsigned TypeIdx, const LLT &MinTy, + const LLT &MaxTy) { + return minScalarOrElt(TypeIdx, MinTy).maxScalarOrElt(TypeIdx, MaxTy); + } + /// Widen the scalar to match the size of another. LegalizeRuleSet &minScalarSameAs(unsigned TypeIdx, unsigned LargeTypeIdx) { typeIdx(TypeIdx); diff --git a/llvm/include/llvm/Support/LowLevelTypeImpl.h b/llvm/include/llvm/Support/LowLevelTypeImpl.h index 86422500ac51..efe5c51d1a55 100644 --- a/llvm/include/llvm/Support/LowLevelTypeImpl.h +++ b/llvm/include/llvm/Support/LowLevelTypeImpl.h @@ -115,6 +115,22 @@ public: return isVector() ? getElementType() : *this; } + /// If this type is a vector, return a vector with the same number of elements + /// but the new element type. Otherwise, return the new element type. + LLT changeElementType(LLT NewEltTy) const { + return isVector() ? LLT::vector(getNumElements(), NewEltTy) : NewEltTy; + } + + /// If this type is a vector, return a vector with the same number of elements + /// but the new element size. Otherwise, return the new element type. Invalid + /// for pointer types. For pointer types, use changeElementType. + LLT changeElementSize(unsigned NewEltSize) const { + assert(!getScalarType().isPointer() && + "invalid to directly change element size for pointers"); + return isVector() ? LLT::vector(getNumElements(), NewEltSize) + : LLT::scalar(NewEltSize); + } + unsigned getScalarSizeInBits() const { assert(RawData != 0 && "Invalid Type"); if (!IsVector) { diff --git a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp index c2817a12352a..07e0cb662b58 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp @@ -79,7 +79,7 @@ LegalityPredicate LegalityPredicates::isPointer(unsigned TypeIdx, LegalityPredicate LegalityPredicates::narrowerThan(unsigned TypeIdx, unsigned Size) { return [=](const LegalityQuery &Query) { - const LLT &QueryTy = Query.Types[TypeIdx]; + const LLT QueryTy = Query.Types[TypeIdx]; return QueryTy.isScalar() && QueryTy.getSizeInBits() < Size; }; } @@ -87,14 +87,37 @@ LegalityPredicate LegalityPredicates::narrowerThan(unsigned TypeIdx, LegalityPredicate LegalityPredicates::widerThan(unsigned TypeIdx, unsigned Size) { return [=](const LegalityQuery &Query) { - const LLT &QueryTy = Query.Types[TypeIdx]; + const LLT QueryTy = Query.Types[TypeIdx]; return QueryTy.isScalar() && QueryTy.getSizeInBits() > Size; }; } +LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx, + unsigned Size) { + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return QueryTy.getScalarSizeInBits() < Size; + }; +} + +LegalityPredicate LegalityPredicates::scalarOrEltWiderThan(unsigned TypeIdx, + unsigned Size) { + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return QueryTy.getScalarSizeInBits() > Size; + }; +} + +LegalityPredicate LegalityPredicates::scalarOrEltSizeNotPow2(unsigned TypeIdx) { + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return !isPowerOf2_32(QueryTy.getScalarSizeInBits()); + }; +} + LegalityPredicate LegalityPredicates::sizeNotPow2(unsigned TypeIdx) { return [=](const LegalityQuery &Query) { - const LLT &QueryTy = Query.Types[TypeIdx]; + const LLT QueryTy = Query.Types[TypeIdx]; return QueryTy.isScalar() && !isPowerOf2_32(QueryTy.getSizeInBits()); }; } diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizeMutations.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizeMutations.cpp index 33228abcfb85..fcbecf90a845 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalizeMutations.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizeMutations.cpp @@ -26,14 +26,30 @@ LegalizeMutation LegalizeMutations::changeTo(unsigned TypeIdx, }; } -LegalizeMutation LegalizeMutations::widenScalarToNextPow2(unsigned TypeIdx, - unsigned Min) { +LegalizeMutation LegalizeMutations::changeElementTo(unsigned TypeIdx, + unsigned FromTypeIdx) { return [=](const LegalityQuery &Query) { - unsigned NewSizeInBits = - 1 << Log2_32_Ceil(Query.Types[TypeIdx].getSizeInBits()); - if (NewSizeInBits < Min) - NewSizeInBits = Min; - return std::make_pair(TypeIdx, LLT::scalar(NewSizeInBits)); + const LLT OldTy = Query.Types[TypeIdx]; + const LLT NewTy = Query.Types[FromTypeIdx]; + return std::make_pair(TypeIdx, OldTy.changeElementType(NewTy)); + }; +} + +LegalizeMutation LegalizeMutations::changeElementTo(unsigned TypeIdx, + LLT NewEltTy) { + return [=](const LegalityQuery &Query) { + const LLT OldTy = Query.Types[TypeIdx]; + return std::make_pair(TypeIdx, OldTy.changeElementType(NewEltTy)); + }; +} + +LegalizeMutation LegalizeMutations::widenScalarOrEltToNextPow2(unsigned TypeIdx, + unsigned Min) { + return [=](const LegalityQuery &Query) { + const LLT Ty = Query.Types[TypeIdx]; + unsigned NewEltSizeInBits = + std::max(1u << Log2_32_Ceil(Ty.getScalarSizeInBits()), Min); + return std::make_pair(TypeIdx, Ty.changeElementSize(NewEltSizeInBits)); }; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp index d25ed987526a..5ffab5ca96e4 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp @@ -435,29 +435,18 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST, // FIXME: Doesn't handle extract of illegal sizes. getActionDefinitionsBuilder({G_EXTRACT, G_INSERT}) - .legalIf([=](const LegalityQuery &Query) { + .legalIf([=](const LegalityQuery &Query) { const LLT &Ty0 = Query.Types[0]; const LLT &Ty1 = Query.Types[1]; return (Ty0.getSizeInBits() % 16 == 0) && (Ty1.getSizeInBits() % 16 == 0); }) - .widenScalarIf( - [=](const LegalityQuery &Query) { - const LLT &Ty1 = Query.Types[1]; - return (Ty1.getScalarSizeInBits() < 16); - }, - // TODO Use generic LegalizeMutation - [](const LegalityQuery &Query) { - LLT Ty1 = Query.Types[1]; - unsigned NewEltSizeInBits = - std::max(1 << Log2_32_Ceil(Ty1.getScalarSizeInBits()), 16); - if (Ty1.isVector()) { - return std::make_pair(1, LLT::vector(Ty1.getNumElements(), - NewEltSizeInBits)); - } - - return std::make_pair(1, LLT::scalar(NewEltSizeInBits)); - }); + .widenScalarIf( + [=](const LegalityQuery &Query) { + const LLT Ty1 = Query.Types[1]; + return (Ty1.getScalarSizeInBits() < 16); + }, + LegalizeMutations::widenScalarOrEltToNextPow2(1, 16)); // TODO: vectors of pointers getActionDefinitionsBuilder(G_BUILD_VECTOR) diff --git a/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp b/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp index 4578b95c203e..3617388c04e9 100644 --- a/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp @@ -206,11 +206,21 @@ TEST(LegalizerInfoTest, SizeChangeStrategy) { TEST(LegalizerInfoTest, RuleSets) { using namespace TargetOpcode; + const LLT s5 = LLT::scalar(5); + const LLT s8 = LLT::scalar(8); + const LLT s16 = LLT::scalar(16); const LLT s32 = LLT::scalar(32); + const LLT s33 = LLT::scalar(33); + const LLT s64 = LLT::scalar(64); + const LLT v2s5 = LLT::vector(2, 5); + const LLT v2s8 = LLT::vector(2, 8); + const LLT v2s16 = LLT::vector(2, 16); const LLT v2s32 = LLT::vector(2, 32); const LLT v3s32 = LLT::vector(3, 32); const LLT v4s32 = LLT::vector(4, 32); + const LLT v2s33 = LLT::vector(2, 33); + const LLT v2s64 = LLT::vector(2, 64); const LLT p0 = LLT::pointer(0, 32); const LLT v3p0 = LLT::vector(3, p0); @@ -229,4 +239,120 @@ TEST(LegalizerInfoTest, RuleSets) { EXPECT_ACTION(MoreElements, 0, v4p0, LegalityQuery(G_IMPLICIT_DEF, {v3p0})); EXPECT_ACTION(MoreElements, 0, v4s32, LegalityQuery(G_IMPLICIT_DEF, {v3s32})); } + + // Test minScalarOrElt + { + LegalizerInfo LI; + LI.getActionDefinitionsBuilder(G_OR) + .legalFor({s32}) + .minScalarOrElt(0, s32); + LI.computeTables(); + + EXPECT_ACTION(WidenScalar, 0, s32, LegalityQuery(G_OR, {s16})); + EXPECT_ACTION(WidenScalar, 0, v2s32, LegalityQuery(G_OR, {v2s16})); + } + + // Test maxScalarOrELt + { + LegalizerInfo LI; + LI.getActionDefinitionsBuilder(G_AND) + .legalFor({s16}) + .maxScalarOrElt(0, s16); + LI.computeTables(); + + EXPECT_ACTION(NarrowScalar, 0, s16, LegalityQuery(G_AND, {s32})); + EXPECT_ACTION(NarrowScalar, 0, v2s16, LegalityQuery(G_AND, {v2s32})); + } + + // Test clampScalarOrElt + { + LegalizerInfo LI; + LI.getActionDefinitionsBuilder(G_XOR) + .legalFor({s16}) + .clampScalarOrElt(0, s16, s32); + LI.computeTables(); + + EXPECT_ACTION(NarrowScalar, 0, s32, LegalityQuery(G_XOR, {s64})); + EXPECT_ACTION(WidenScalar, 0, s16, LegalityQuery(G_XOR, {s8})); + + // Make sure the number of elements is preserved. + EXPECT_ACTION(NarrowScalar, 0, v2s32, LegalityQuery(G_XOR, {v2s64})); + EXPECT_ACTION(WidenScalar, 0, v2s16, LegalityQuery(G_XOR, {v2s8})); + } + + // Test minScalar + { + LegalizerInfo LI; + LI.getActionDefinitionsBuilder(G_OR) + .legalFor({s32}) + .minScalar(0, s32); + LI.computeTables(); + + // Only handle scalars, ignore vectors. + EXPECT_ACTION(WidenScalar, 0, s32, LegalityQuery(G_OR, {s16})); + EXPECT_ACTION(Unsupported, 0, LLT(), LegalityQuery(G_OR, {v2s16})); + } + + // Test maxScalar + { + LegalizerInfo LI; + LI.getActionDefinitionsBuilder(G_AND) + .legalFor({s16}) + .maxScalar(0, s16); + LI.computeTables(); + + // Only handle scalars, ignore vectors. + EXPECT_ACTION(NarrowScalar, 0, s16, LegalityQuery(G_AND, {s32})); + EXPECT_ACTION(Unsupported, 0, LLT(), LegalityQuery(G_AND, {v2s32})); + } + + // Test clampScalar + { + LegalizerInfo LI; + + LI.getActionDefinitionsBuilder(G_XOR) + .legalFor({s16}) + .clampScalar(0, s16, s32); + LI.computeTables(); + + EXPECT_ACTION(NarrowScalar, 0, s32, LegalityQuery(G_XOR, {s64})); + EXPECT_ACTION(WidenScalar, 0, s16, LegalityQuery(G_XOR, {s8})); + + // Only handle scalars, ignore vectors. + EXPECT_ACTION(Unsupported, 0, LLT(), LegalityQuery(G_XOR, {v2s64})); + EXPECT_ACTION(Unsupported, 0, LLT(), LegalityQuery(G_XOR, {v2s8})); + } + + // Test widenScalarOrEltToNextPow2 + { + LegalizerInfo LI; + + LI.getActionDefinitionsBuilder(G_AND) + .legalFor({s32}) + .widenScalarOrEltToNextPow2(0, 32); + LI.computeTables(); + + // Handle scalars and vectors + EXPECT_ACTION(WidenScalar, 0, s32, LegalityQuery(G_AND, {s5})); + EXPECT_ACTION(WidenScalar, 0, v2s32, LegalityQuery(G_AND, {v2s5})); + EXPECT_ACTION(WidenScalar, 0, s64, LegalityQuery(G_AND, {s33})); + EXPECT_ACTION(WidenScalar, 0, v2s64, LegalityQuery(G_AND, {v2s33})); + } + + // Test widenScalarToNextPow2 + { + LegalizerInfo LI; + + LI.getActionDefinitionsBuilder(G_AND) + .legalFor({s32}) + .widenScalarToNextPow2(0, 32); + LI.computeTables(); + + EXPECT_ACTION(WidenScalar, 0, s32, LegalityQuery(G_AND, {s5})); + EXPECT_ACTION(WidenScalar, 0, s64, LegalityQuery(G_AND, {s33})); + + // Do nothing for vectors. + EXPECT_ACTION(Unsupported, 0, LLT(), LegalityQuery(G_AND, {v2s5})); + EXPECT_ACTION(Unsupported, 0, LLT(), LegalityQuery(G_AND, {v2s33})); + } } diff --git a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp index 52df852b4032..bf4277d82fd2 100644 --- a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp +++ b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp @@ -106,6 +106,61 @@ TEST(LowLevelTypeTest, ScalarOrVector) { LLT::scalarOrVector(2, LLT::pointer(1, 32))); } +TEST(LowLevelTypeTest, ChangeElementType) { + const LLT P0 = LLT::pointer(0, 32); + const LLT P1 = LLT::pointer(1, 64); + + const LLT S32 = LLT::scalar(32); + const LLT S64 = LLT::scalar(64); + + const LLT V2S32 = LLT::vector(2, 32); + const LLT V2S64 = LLT::vector(2, 64); + + const LLT V2P0 = LLT::vector(2, P0); + const LLT V2P1 = LLT::vector(2, P1); + + EXPECT_EQ(S64, S32.changeElementType(S64)); + EXPECT_EQ(S32, S32.changeElementType(S32)); + + EXPECT_EQ(S32, S64.changeElementSize(32)); + EXPECT_EQ(S32, S32.changeElementSize(32)); + + EXPECT_EQ(V2S64, V2S32.changeElementType(S64)); + EXPECT_EQ(V2S32, V2S64.changeElementType(S32)); + + EXPECT_EQ(V2S64, V2S32.changeElementSize(64)); + EXPECT_EQ(V2S32, V2S64.changeElementSize(32)); + + EXPECT_EQ(P0, S32.changeElementType(P0)); + EXPECT_EQ(S32, P0.changeElementType(S32)); + + EXPECT_EQ(V2P1, V2P0.changeElementType(P1)); + EXPECT_EQ(V2S32, V2P0.changeElementType(S32)); +} + +#ifdef GTEST_HAS_DEATH_TEST +#ifndef NDEBUG + +// Invalid to directly change the element size for pointers. +TEST(LowLevelTypeTest, ChangeElementTypeDeath) { + const LLT P0 = LLT::pointer(0, 32); + const LLT V2P0 = LLT::vector(2, P0); + + EXPECT_DEATH(P0.changeElementSize(64), + "invalid to directly change element size for pointers"); + EXPECT_DEATH(V2P0.changeElementSize(64), + "invalid to directly change element size for pointers"); + + // Make sure this still fails even without a change in size. + EXPECT_DEATH(P0.changeElementSize(32), + "invalid to directly change element size for pointers"); + EXPECT_DEATH(V2P0.changeElementSize(32), + "invalid to directly change element size for pointers"); +} + +#endif +#endif + TEST(LowLevelTypeTest, Pointer) { LLVMContext C; DataLayout DL("p64:64:64-p127:512:512:512-p16777215:65528:8");