From 14b16e7ee1320933fd01f4ec6d59e516374d4423 Mon Sep 17 00:00:00 2001 From: Matthew Simpson Date: Thu, 21 Jan 2016 16:31:55 +0000 Subject: [PATCH] [SLP] Truncate expressions to minimum required bit width This change attempts to produce vectorized integer expressions in bit widths that are narrower than their scalar counterparts. The need for demotion arises especially on architectures in which the small integer types (e.g., i8 and i16) are not legal for scalar operations but can still be used in vectors. Like similar work done within the loop vectorizer, we rely on InstCombine to perform the actual type-shrinking. We use the DemandedBits analysis and ComputeNumSignBits from ValueTracking to determine the minimum required bit width of an expression. Differential revision: http://reviews.llvm.org/D15815 llvm-svn: 258404 --- lib/Transforms/Vectorize/SLPVectorizer.cpp | 154 ++++++++++++++++-- .../SLPVectorizer/AArch64/gather-reduce.ll | 31 ++-- 2 files changed, 161 insertions(+), 24 deletions(-) diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index 2520c78b538..e65429c403c 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -15,21 +15,22 @@ // "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks. // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Vectorize.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" @@ -44,7 +45,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Analysis/VectorUtils.h" +#include "llvm/Transforms/Vectorize.h" #include #include #include @@ -363,11 +364,12 @@ public: BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti, TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li, - DominatorTree *Dt, AssumptionCache *AC) + DominatorTree *Dt, AssumptionCache *AC, DemandedBits *DB) : NumLoadsWantToKeepOrder(0), NumLoadsWantToChangeOrder(0), F(Func), - SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), + SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), AC(AC), DB(DB), Builder(Se->getContext()) { CodeMetrics::collectEphemeralValues(F, AC, EphValues); + MaxRequiredIntegerTy = nullptr; } /// \brief Vectorize the tree that starts with the elements in \p VL. @@ -399,6 +401,7 @@ public: BlockScheduling *BS = Iter.second.get(); BS->clear(); } + MaxRequiredIntegerTy = nullptr; } /// \returns true if the memory operations A and B are consecutive. @@ -419,6 +422,10 @@ public: /// vectorization factors. unsigned getVectorElementSize(Value *V); + /// Compute the maximum width integer type required to represent the result + /// of a scalar expression, if such a type exists. + void computeMaxRequiredIntegerTy(); + private: struct TreeEntry; @@ -924,8 +931,13 @@ private: AliasAnalysis *AA; LoopInfo *LI; DominatorTree *DT; + AssumptionCache *AC; + DemandedBits *DB; /// Instruction builder to construct the vectorized tree. IRBuilder<> Builder; + + // The maximum width integer type required to represent a scalar expression. + IntegerType *MaxRequiredIntegerTy; }; #ifndef NDEBUG @@ -1481,6 +1493,15 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { ScalarTy = SI->getValueOperand()->getType(); VectorType *VecTy = VectorType::get(ScalarTy, VL.size()); + // If we have computed a smaller type for the expression, update VecTy so + // that the costs will be accurate. + if (MaxRequiredIntegerTy) { + auto *IT = dyn_cast(ScalarTy); + assert(IT && "Computed smaller type for non-integer value?"); + if (MaxRequiredIntegerTy->getBitWidth() < IT->getBitWidth()) + VecTy = VectorType::get(MaxRequiredIntegerTy, VL.size()); + } + if (E->NeedToGather) { if (allConstant(VL)) return 0; @@ -1809,9 +1830,17 @@ int BoUpSLP::getTreeCost() { if (EphValues.count(EU.User)) continue; - VectorType *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth); - ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, - EU.Lane); + // If we plan to rewrite the tree in a smaller type, we will need to sign + // extend the extracted value back to the original type. Here, we account + // for the extract and the added cost of the sign extend if needed. + auto *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth); + if (MaxRequiredIntegerTy) { + VecTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth); + ExtractCost += TTI->getCastInstrCost( + Instruction::SExt, EU.Scalar->getType(), MaxRequiredIntegerTy); + } + ExtractCost += + TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane); } Cost += getSpillCost(); @@ -2566,7 +2595,19 @@ Value *BoUpSLP::vectorizeTree() { } Builder.SetInsertPoint(&F->getEntryBlock().front()); - vectorizeTree(&VectorizableTree[0]); + auto *VectorRoot = vectorizeTree(&VectorizableTree[0]); + + // If the vectorized tree can be rewritten in a smaller type, we truncate the + // vectorized root. InstCombine will then rewrite the entire expression. We + // sign extend the extracted values below. + if (MaxRequiredIntegerTy) { + BasicBlock::iterator I(cast(VectorRoot)); + Builder.SetInsertPoint(&*++I); + auto BundleWidth = VectorizableTree[0].Scalars.size(); + auto *SmallerTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth); + auto *Trunc = Builder.CreateTrunc(VectorRoot, SmallerTy); + VectorizableTree[0].VectorizedValue = Trunc; + } DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n"); @@ -2599,6 +2640,8 @@ Value *BoUpSLP::vectorizeTree() { if (PH->getIncomingValue(i) == Scalar) { Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator()); Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MaxRequiredIntegerTy) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(PH->getIncomingBlock(i)); PH->setOperand(i, Ex); } @@ -2606,12 +2649,16 @@ Value *BoUpSLP::vectorizeTree() { } else { Builder.SetInsertPoint(cast(User)); Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MaxRequiredIntegerTy) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(cast(User)->getParent()); User->replaceUsesOfWith(Scalar, Ex); } } else { Builder.SetInsertPoint(&F->getEntryBlock().front()); Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MaxRequiredIntegerTy) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(&F->getEntryBlock()); User->replaceUsesOfWith(Scalar, Ex); } @@ -3180,7 +3227,7 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) { // If the current instruction is a load, update MaxWidth to reflect the // width of the loaded value. else if (isa(I)) - MaxWidth = std::max(MaxWidth, (unsigned)DL.getTypeSizeInBits(Ty)); + MaxWidth = std::max(MaxWidth, DL.getTypeSizeInBits(Ty)); // Otherwise, we need to visit the operands of the instruction. We only // handle the interesting cases from buildTree here. If an operand is an @@ -3207,6 +3254,85 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) { return MaxWidth; } +void BoUpSLP::computeMaxRequiredIntegerTy() { + + // If there are no external uses, the expression tree must be rooted by a + // store. We can't demote in-memory values, so there is nothing to do here. + if (ExternalUses.empty()) + return; + + // If the expression is not rooted by a store, these roots should have + // external uses. We will rely on InstCombine to rewrite the expression in + // the narrower type. However, InstCombine only rewrites single-use values. + // This means that if a tree entry other than a root is used externally, it + // must have multiple uses and InstCombine will not rewrite it. The code + // below ensures that only the roots are used externally. + auto &TreeRoot = VectorizableTree[0].Scalars; + SmallPtrSet ScalarRoots(TreeRoot.begin(), TreeRoot.end()); + for (auto &EU : ExternalUses) + if (!ScalarRoots.erase(EU.Scalar)) + return; + if (!ScalarRoots.empty()) + return; + + // The maximum bit width required to represent all the instructions in the + // tree without loss of precision. It would be safe to truncate the + // expression to this width. + auto MaxBitWidth = 8u; + + // We first check if all the bits of the root are demanded. If they're not, + // we can truncate the root to this narrower type. + auto *Root = dyn_cast(TreeRoot[0]); + if (!Root || !isa(Root->getType()) || !Root->hasOneUse()) + return; + auto Mask = DB->getDemandedBits(Root); + if (Mask.countLeadingZeros() > 0) + MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros(); + + // If all the bits of the root are demanded, we can try a little harder to + // compute a narrower type. This can happen, for example, if the roots are + // getelementptr indices. InstCombine promotes these indices to the pointer + // width. Thus, all their bits are technically demanded even though the + // address computation might be vectorized in a smaller type. We start by + // looking at each entry in the tree. + else + for (auto &Entry : VectorizableTree) { + + // Get a representative value for the vectorizable bundle. All values in + // Entry.Scalars should be isomorphic. + auto *Scalar = Entry.Scalars[0]; + + // If the scalar is used more than once, InstCombine will not rewrite it, + // so we should give up. + if (!Scalar->hasOneUse()) + return; + + // We only compute smaller integer types. If the scalar has a different + // type, give up. + auto *IT = dyn_cast(Scalar->getType()); + if (!IT) + return; + + // Compute the maximum bit width required to store the scalar. We use + // ValueTracking to compute the number of high-order bits we can + // truncate. We then round up to the next power-of-two. + auto &DL = F->getParent()->getDataLayout(); + auto NumSignBits = ComputeNumSignBits(Scalar, DL, 0, AC, 0, DT); + auto NumTypeBits = IT->getBitWidth(); + MaxBitWidth = std::max(NumTypeBits - NumSignBits, MaxBitWidth); + } + + // Round up to the next power-of-two. + if (!isPowerOf2_64(MaxBitWidth)) + MaxBitWidth = NextPowerOf2(MaxBitWidth); + + // If the maximum bit width we compute is less than the with of the roots' + // type, we can proceed with the narrowing. Otherwise, do nothing. + auto *RootIT = cast(TreeRoot[0]->getType()); + if (MaxBitWidth > 0 && MaxBitWidth < RootIT->getBitWidth()) + MaxRequiredIntegerTy = IntegerType::get(F->getContext(), MaxBitWidth); +} + /// The SLPVectorizer Pass. struct SLPVectorizer : public FunctionPass { typedef SmallVector StoreList; @@ -3228,6 +3354,7 @@ struct SLPVectorizer : public FunctionPass { LoopInfo *LI; DominatorTree *DT; AssumptionCache *AC; + DemandedBits *DB; bool runOnFunction(Function &F) override { if (skipOptnoneFunction(F)) @@ -3241,6 +3368,7 @@ struct SLPVectorizer : public FunctionPass { LI = &getAnalysis().getLoopInfo(); DT = &getAnalysis().getDomTree(); AC = &getAnalysis().getAssumptionCache(F); + DB = &getAnalysis(); Stores.clear(); GEPs.clear(); @@ -3270,7 +3398,7 @@ struct SLPVectorizer : public FunctionPass { // Use the bottom up slp vectorizer to construct chains that start with // store instructions. - BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC); + BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC, DB); // A general note: the vectorizer must use BoUpSLP::eraseInstruction() to // delete instructions. @@ -3313,6 +3441,7 @@ struct SLPVectorizer : public FunctionPass { AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); @@ -3417,6 +3546,7 @@ bool SLPVectorizer::vectorizeStoreChain(ArrayRef Chain, ArrayRef Operands = Chain.slice(i, VF); R.buildTree(Operands); + R.computeMaxRequiredIntegerTy(); int Cost = R.getTreeCost(); @@ -3616,6 +3746,7 @@ bool SLPVectorizer::tryToVectorizeList(ArrayRef VL, BoUpSLP &R, Value *ReorderedOps[] = { Ops[1], Ops[0] }; R.buildTree(ReorderedOps, None); } + R.computeMaxRequiredIntegerTy(); int Cost = R.getTreeCost(); if (Cost < -SLPCostThreshold) { @@ -3882,6 +4013,7 @@ public: for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) { V.buildTree(makeArrayRef(&ReducedVals[i], ReduxWidth), ReductionOps); + V.computeMaxRequiredIntegerTy(); // Estimate cost. int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]); diff --git a/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll b/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll index 59ceba1717a..9c06b24163a 100644 --- a/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll +++ b/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll @@ -1,4 +1,5 @@ -; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s +; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s --check-prefix=PROFITABLE +; RUN: opt -S -slp-vectorizer -slp-threshold=-12 -dce -instcombine < %s | FileCheck %s --check-prefix=UNPROFITABLE target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" target triple = "aarch64--linux-gnu" @@ -18,13 +19,13 @@ target triple = "aarch64--linux-gnu" ; return sum; ; } -; CHECK-LABEL: @gather_reduce_8x16_i32 +; PROFITABLE-LABEL: @gather_reduce_8x16_i32 ; -; CHECK: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> -; CHECK: zext <8 x i16> [[L]] to <8 x i32> -; CHECK: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> -; CHECK: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] -; CHECK: sext i32 [[X]] to i64 +; PROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> +; PROFITABLE: zext <8 x i16> [[L]] to <8 x i32> +; PROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> +; PROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] +; PROFITABLE: sext i32 [[X]] to i64 ; define i32 @gather_reduce_8x16_i32(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { entry: @@ -137,14 +138,18 @@ for.body: br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body } -; CHECK-LABEL: @gather_reduce_8x16_i64 +; UNPROFITABLE-LABEL: @gather_reduce_8x16_i64 ; -; CHECK-NOT: load <8 x i16> +; UNPROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> +; UNPROFITABLE: zext <8 x i16> [[L]] to <8 x i32> +; UNPROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> +; UNPROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] +; UNPROFITABLE: sext i32 [[X]] to i64 ; -; FIXME: We are currently unable to vectorize the case with i64 subtraction -; because the zero extensions are too expensive. The solution here is to -; convert the i64 subtractions to i32 subtractions during vectorization. -; This would then match the case above. +; TODO: Although we can now vectorize this case while converting the i64 +; subtractions to i32, the cost model currently finds vectorization to be +; unprofitable. The cost model is penalizing the sign and zero +; extensions in the vectorized version, but they are actually free. ; define i32 @gather_reduce_8x16_i64(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { entry: