diff --git a/include/llvm/Analysis/ScalarEvolution.h b/include/llvm/Analysis/ScalarEvolution.h index c180ce37e39..1bd7fd0db55 100644 --- a/include/llvm/Analysis/ScalarEvolution.h +++ b/include/llvm/Analysis/ScalarEvolution.h @@ -193,7 +193,7 @@ namespace llvm { /// \brief Returns the estimated complexity of this predicate. /// This is roughly measured in the number of run-time checks required. - virtual unsigned getComplexity() { return 1; } + virtual unsigned getComplexity() const { return 1; } /// \brief Returns true if the predicate is always true. This means that no /// assumptions were made and nothing needs to be checked at run-time. @@ -303,7 +303,7 @@ namespace llvm { /// \brief We estimate the complexity of a union predicate as the size /// number of predicates in the union. - unsigned getComplexity() override { return Preds.size(); } + unsigned getComplexity() const override { return Preds.size(); } /// Methods for support type inquiry through isa, cast, and dyn_cast: static inline bool classof(const SCEVPredicate *P) { diff --git a/include/llvm/Transforms/Utils/LoopVersioning.h b/include/llvm/Transforms/Utils/LoopVersioning.h index 41eb50c7662..3b70594e0b6 100644 --- a/include/llvm/Transforms/Utils/LoopVersioning.h +++ b/include/llvm/Transforms/Utils/LoopVersioning.h @@ -17,6 +17,7 @@ #define LLVM_TRANSFORMS_UTILS_LOOPVERSIONING_H #include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -25,6 +26,7 @@ namespace llvm { class Loop; class LoopAccessInfo; class LoopInfo; +class ScalarEvolution; /// \brief This class emits a version of the loop where run-time checks ensure /// that may-alias pointers can't overlap. @@ -33,16 +35,13 @@ class LoopInfo; /// already has a preheader. class LoopVersioning { public: - /// \brief Expects MemCheck, LoopAccessInfo, Loop, LoopInfo, DominatorTree - /// as input. It uses runtime check provided by user. - LoopVersioning(SmallVector Checks, - const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, - DominatorTree *DT); - /// \brief Expects LoopAccessInfo, Loop, LoopInfo, DominatorTree as input. - /// It uses default runtime check provided by LoopAccessInfo. - LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L, LoopInfo *LI, - DominatorTree *DT); + /// It uses runtime check provided by the user. If \p UseLAIChecks is true, + /// we will retain the default checks made by LAI. Otherwise, construct an + /// object having no checks and we expect the user to add them. + LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, + DominatorTree *DT, ScalarEvolution *SE, + bool UseLAIChecks = true); /// \brief Performs the CFG manipulation part of versioning the loop including /// the DominatorTree and LoopInfo updates. @@ -72,6 +71,13 @@ public: /// loop may alias (i.e. one of the memchecks failed). Loop *getNonVersionedLoop() { return NonVersionedLoop; } + /// \brief Sets the runtime alias checks for versioning the loop. + void setAliasChecks( + const SmallVector Checks); + + /// \brief Sets the runtime SCEV checks for versioning the loop. + void setSCEVChecks(SCEVUnionPredicate Check); + private: /// \brief Adds the necessary PHI nodes for the versioned loops based on the /// loop-defined values used outside of the loop. @@ -91,13 +97,17 @@ private: /// in NonVersionedLoop. ValueToValueMapTy VMap; - /// \brief The set of checks that we are versioning for. - SmallVector Checks; + /// \brief The set of alias checks that we are versioning for. + SmallVector AliasChecks; + + /// \brief The set of SCEV checks that we are versioning for. + SCEVUnionPredicate Preds; /// \brief Analyses used. const LoopAccessInfo &LAI; LoopInfo *LI; DominatorTree *DT; + ScalarEvolution *SE; }; } diff --git a/lib/Transforms/Scalar/LoopDistribute.cpp b/lib/Transforms/Scalar/LoopDistribute.cpp index 1584f0fa3eb..67ebd2532b1 100644 --- a/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/lib/Transforms/Scalar/LoopDistribute.cpp @@ -55,6 +55,11 @@ static cl::opt DistributeNonIfConvertible( "if-convertible by the loop vectorizer"), cl::init(false)); +static cl::opt DistributeSCEVCheckThreshold( + "loop-distribute-scev-check-threshold", cl::init(8), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed for Loop " + "Distribution")); + STATISTIC(NumLoopsDistributed, "Number of loops distributed"); namespace { @@ -577,6 +582,7 @@ public: LI = &getAnalysis().getLoopInfo(); LAA = &getAnalysis(); DT = &getAnalysis().getDomTree(); + SE = &getAnalysis().getSE(); // Build up a worklist of inner-loops to vectorize. This is necessary as the // act of distributing a loop creates new loops and can invalidate iterators @@ -599,6 +605,7 @@ public: } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); AU.addRequired(); AU.addPreserved(); AU.addRequired(); @@ -753,6 +760,13 @@ private: return false; } + // Don't distribute the loop if we need too many SCEV run-time checks. + const SCEVUnionPredicate &Pred = LAI.Preds; + if (Pred.getComplexity() > DistributeSCEVCheckThreshold) { + DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n"); + return false; + } + DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n"); // We're done forming the partitions set up the reverse mapping from // instructions to partitions. @@ -764,17 +778,19 @@ private: if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator()) SplitBlock(PH, PH->getTerminator(), DT, LI); - // If we need run-time checks to disambiguate pointers are run-time, version - // the loop now. + // If we need run-time checks, version the loop now. auto PtrToPartition = Partitions.computePartitionSetForPointers(LAI); const auto *RtPtrChecking = LAI.getRuntimePointerChecking(); const auto &AllChecks = RtPtrChecking->getChecks(); auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition, RtPtrChecking); - if (!Checks.empty()) { + + if (!Pred.isAlwaysTrue() || !Checks.empty()) { DEBUG(dbgs() << "\nPointers:\n"); DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); - LoopVersioning LVer(std::move(Checks), LAI, L, LI, DT); + LoopVersioning LVer(LAI, L, LI, DT, SE, false); + LVer.setAliasChecks(std::move(Checks)); + LVer.setSCEVChecks(LAI.Preds); LVer.versionLoop(DefsUsedOutside); } @@ -801,6 +817,7 @@ private: LoopInfo *LI; LoopAccessAnalysis *LAA; DominatorTree *DT; + ScalarEvolution *SE; }; } // anonymous namespace @@ -811,6 +828,7 @@ INITIALIZE_PASS_BEGIN(LoopDistribute, LDIST_NAME, ldist_name, false, false) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_END(LoopDistribute, LDIST_NAME, ldist_name, false, false) namespace llvm { diff --git a/lib/Transforms/Scalar/LoopLoadElimination.cpp b/lib/Transforms/Scalar/LoopLoadElimination.cpp index e0456a2110d..7c7bf64ba79 100644 --- a/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -41,6 +41,12 @@ static cl::opt CheckPerElim( cl::desc("Max number of memchecks allowed per eliminated load on average"), cl::init(1)); +static cl::opt LoadElimSCEVCheckThreshold( + "loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed for Loop " + "Load Elimination")); + + STATISTIC(NumLoopLoadEliminted, "Number of loads eliminated by LLE"); namespace { @@ -453,10 +459,17 @@ public: return false; } + if (LAI.Preds.getComplexity() > LoadElimSCEVCheckThreshold) { + DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n"); + return false; + } + // Point of no-return, start the transformation. First, version the loop if // necessary. - if (!Checks.empty()) { - LoopVersioning LV(std::move(Checks), LAI, L, LI, DT); + if (!Checks.empty() || !LAI.Preds.isAlwaysTrue()) { + LoopVersioning LV(LAI, L, LI, DT, SE, false); + LV.setAliasChecks(std::move(Checks)); + LV.setSCEVChecks(LAI.Preds); LV.versionLoop(); } diff --git a/lib/Transforms/Utils/LoopVersioning.cpp b/lib/Transforms/Utils/LoopVersioning.cpp index bf7ed73ff01..a77c3642a56 100644 --- a/lib/Transforms/Utils/LoopVersioning.cpp +++ b/lib/Transforms/Utils/LoopVersioning.cpp @@ -17,46 +17,78 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/IR/Dominators.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" using namespace llvm; -LoopVersioning::LoopVersioning( - SmallVector Checks, - const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, DominatorTree *DT) - : VersionedLoop(L), NonVersionedLoop(nullptr), Checks(std::move(Checks)), - LAI(LAI), LI(LI), DT(DT) { +LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, + DominatorTree *DT, ScalarEvolution *SE, + bool UseLAIChecks) + : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT), + SE(SE) { assert(L->getExitBlock() && "No single exit block"); assert(L->getLoopPreheader() && "No preheader"); + if (UseLAIChecks) { + setAliasChecks(LAI.getRuntimePointerChecking()->getChecks()); + setSCEVChecks(LAI.Preds); + } } -LoopVersioning::LoopVersioning(const LoopAccessInfo &LAInfo, Loop *L, - LoopInfo *LI, DominatorTree *DT) - : VersionedLoop(L), NonVersionedLoop(nullptr), - Checks(LAInfo.getRuntimePointerChecking()->getChecks()), LAI(LAInfo), - LI(LI), DT(DT) { - assert(L->getExitBlock() && "No single exit block"); - assert(L->getLoopPreheader() && "No preheader"); +void LoopVersioning::setAliasChecks( + const SmallVector Checks) { + AliasChecks = std::move(Checks); +} + +void LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) { + Preds = std::move(Check); } void LoopVersioning::versionLoop( const SmallVectorImpl &DefsUsedOutside) { Instruction *FirstCheckInst; Instruction *MemRuntimeCheck; + Value *SCEVRuntimeCheck; + Value *RuntimeCheck = nullptr; + // Add the memcheck in the original preheader (this is empty initially). - BasicBlock *MemCheckBB = VersionedLoop->getLoopPreheader(); + BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader(); std::tie(FirstCheckInst, MemRuntimeCheck) = - LAI.addRuntimeChecks(MemCheckBB->getTerminator(), Checks); + LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks); assert(MemRuntimeCheck && "called even though needsAnyChecking = false"); + const SCEVUnionPredicate &Pred = LAI.Preds; + SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), + "scev.check"); + SCEVRuntimeCheck = + Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator()); + auto *CI = dyn_cast(SCEVRuntimeCheck); + + // Discard the SCEV runtime check if it is always true. + if (CI && CI->isZero()) + SCEVRuntimeCheck = nullptr; + + if (MemRuntimeCheck && SCEVRuntimeCheck) { + RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck, + SCEVRuntimeCheck, "ldist.safe"); + if (auto *I = dyn_cast(RuntimeCheck)) + I->insertBefore(RuntimeCheckBB->getTerminator()); + } else + RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck; + + assert(RuntimeCheck && "called even though we don't need " + "any runtime checks"); + // Rename the block to make the IR more readable. - MemCheckBB->setName(VersionedLoop->getHeader()->getName() + ".lver.memcheck"); + RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() + + ".lver.check"); // Create empty preheader for the loop (and after cloning for the // non-versioned loop). - BasicBlock *PH = SplitBlock(MemCheckBB, MemCheckBB->getTerminator(), DT, LI); + BasicBlock *PH = + SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI); PH->setName(VersionedLoop->getHeader()->getName() + ".ph"); // Clone the loop including the preheader. @@ -65,20 +97,19 @@ void LoopVersioning::versionLoop( // block is a join between the two loops. SmallVector NonVersionedLoopBlocks; NonVersionedLoop = - cloneLoopWithPreheader(PH, MemCheckBB, VersionedLoop, VMap, ".lver.orig", - LI, DT, NonVersionedLoopBlocks); + cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap, + ".lver.orig", LI, DT, NonVersionedLoopBlocks); remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap); // Insert the conditional branch based on the result of the memchecks. - Instruction *OrigTerm = MemCheckBB->getTerminator(); + Instruction *OrigTerm = RuntimeCheckBB->getTerminator(); BranchInst::Create(NonVersionedLoop->getLoopPreheader(), - VersionedLoop->getLoopPreheader(), MemRuntimeCheck, - OrigTerm); + VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm); OrigTerm->eraseFromParent(); // The loops merge in the original exit block. This is now dominated by the // memchecking block. - DT->changeImmediateDominator(VersionedLoop->getExitBlock(), MemCheckBB); + DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB); // Adds the necessary PHI nodes for the versioned loops based on the // loop-defined values used outside of the loop. diff --git a/test/Transforms/LoopDistribute/basic-with-memchecks.ll b/test/Transforms/LoopDistribute/basic-with-memchecks.ll index 3aced485041..dce5698595a 100644 --- a/test/Transforms/LoopDistribute/basic-with-memchecks.ll +++ b/test/Transforms/LoopDistribute/basic-with-memchecks.ll @@ -36,7 +36,7 @@ entry: ; Since the checks to A and A + 4 get merged, this will give us a ; total of 8 compares. ; -; CHECK: for.body.lver.memcheck: +; CHECK: for.body.lver.check: ; CHECK: = icmp ; CHECK: = icmp diff --git a/test/Transforms/LoopLoadElim/forward.ll b/test/Transforms/LoopLoadElim/forward.ll index 1a77297a064..c2b1816530c 100644 --- a/test/Transforms/LoopLoadElim/forward.ll +++ b/test/Transforms/LoopLoadElim/forward.ll @@ -11,7 +11,7 @@ target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" define void @f(i32* %A, i32* %B, i32* %C, i64 %N) { -; CHECK: for.body.lver.memcheck: +; CHECK: for.body.lver.check: ; CHECK: %found.conflict{{.*}} = ; CHECK-NOT: %found.conflict{{.*}} = diff --git a/test/Transforms/LoopLoadElim/memcheck.ll b/test/Transforms/LoopLoadElim/memcheck.ll index ebb52825754..8eadd437a5a 100644 --- a/test/Transforms/LoopLoadElim/memcheck.ll +++ b/test/Transforms/LoopLoadElim/memcheck.ll @@ -16,7 +16,7 @@ define void @f(i32* %A, i32* %B, i32* %C, i64 %N, i32* %D) { entry: br label %for.body -; AGGRESSIVE: for.body.lver.memcheck: +; AGGRESSIVE: for.body.lver.check: ; AGGRESSIVE: %found.conflict{{.*}} = ; AGGRESSIVE: %found.conflict{{.*}} = ; AGGRESSIVE-NOT: %found.conflict{{.*}} =