diff --git a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index acec94ecd05..0d6ae4c72fd 100644 --- a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -43,8 +43,8 @@ namespace { // TODO: Remove this static const unsigned TargetBaseAlign = 4; -typedef SmallVector ValueList; -typedef MapVector ValueListMap; +typedef SmallVector InstrList; +typedef MapVector InstrListMap; class Vectorizer { Function &F; @@ -92,17 +92,17 @@ private: /// Returns the first and the last instructions in Chain. std::pair - getBoundaryInstrs(ArrayRef Chain); + getBoundaryInstrs(ArrayRef Chain); /// Erases the original instructions after vectorizing. - void eraseInstructions(ArrayRef Chain); + void eraseInstructions(ArrayRef Chain); /// "Legalize" the vector type that would be produced by combining \p /// ElementSizeBits elements in \p Chain. Break into two pieces such that the /// total size of each piece is 1, 2 or a multiple of 4 bytes. \p Chain is /// expected to have more than 4 elements. - std::pair, ArrayRef> - splitOddVectorElts(ArrayRef Chain, unsigned ElementSizeBits); + std::pair, ArrayRef> + splitOddVectorElts(ArrayRef Chain, unsigned ElementSizeBits); /// Finds the largest prefix of Chain that's vectorizable, checking for /// intervening instructions which may affect the memory accessed by the @@ -110,25 +110,27 @@ private: /// /// The elements of \p Chain must be all loads or all stores and must be in /// address order. - ArrayRef getVectorizablePrefix(ArrayRef Chain); + ArrayRef getVectorizablePrefix(ArrayRef Chain); /// Collects load and store instructions to vectorize. - std::pair collectInstructions(BasicBlock *BB); + std::pair collectInstructions(BasicBlock *BB); - /// Processes the collected instructions, the \p Map. The elements of \p Map + /// Processes the collected instructions, the \p Map. The values of \p Map /// should be all loads or all stores. - bool vectorizeChains(ValueListMap &Map); + bool vectorizeChains(InstrListMap &Map); /// Finds the load/stores to consecutive memory addresses and vectorizes them. - bool vectorizeInstructions(ArrayRef Instrs); + bool vectorizeInstructions(ArrayRef Instrs); /// Vectorizes the load instructions in Chain. - bool vectorizeLoadChain(ArrayRef Chain, - SmallPtrSet *InstructionsProcessed); + bool + vectorizeLoadChain(ArrayRef Chain, + SmallPtrSet *InstructionsProcessed); /// Vectorizes the store instructions in Chain. - bool vectorizeStoreChain(ArrayRef Chain, - SmallPtrSet *InstructionsProcessed); + bool + vectorizeStoreChain(ArrayRef Chain, + SmallPtrSet *InstructionsProcessed); /// Check if this load/store access is misaligned accesses bool accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace, @@ -175,6 +177,13 @@ Pass *llvm::createLoadStoreVectorizerPass() { return new LoadStoreVectorizer(); } +// The real propagateMetadata expects a SmallVector, but we deal in +// vectors of Instructions. +static void propagateMetadata(Instruction *I, ArrayRef IL) { + SmallVector VL(IL.begin(), IL.end()); + propagateMetadata(I, VL); +} + bool LoadStoreVectorizer::runOnFunction(Function &F) { // Don't vectorize when the attribute NoImplicitFloat is used. if (skipFunction(F) || F.hasFnAttribute(Attribute::NoImplicitFloat)) @@ -196,7 +205,7 @@ bool Vectorizer::run() { // Scan the blocks in the function in post order. for (BasicBlock *BB : post_order(&F)) { - ValueListMap LoadRefs, StoreRefs; + InstrListMap LoadRefs, StoreRefs; std::tie(LoadRefs, StoreRefs) = collectInstructions(BB); Changed |= vectorizeChains(LoadRefs); Changed |= vectorizeChains(StoreRefs); @@ -371,8 +380,8 @@ void Vectorizer::reorder(Instruction *I) { } std::pair -Vectorizer::getBoundaryInstrs(ArrayRef Chain) { - Instruction *C0 = cast(Chain[0]); +Vectorizer::getBoundaryInstrs(ArrayRef Chain) { + Instruction *C0 = Chain[0]; BasicBlock::iterator FirstInstr = C0->getIterator(); BasicBlock::iterator LastInstr = C0->getIterator(); @@ -396,26 +405,24 @@ Vectorizer::getBoundaryInstrs(ArrayRef Chain) { return std::make_pair(FirstInstr, ++LastInstr); } -void Vectorizer::eraseInstructions(ArrayRef Chain) { +void Vectorizer::eraseInstructions(ArrayRef Chain) { SmallVector Instrs; - for (Value *V : Chain) { - Value *PtrOperand = getPointerOperand(V); + for (Instruction *I : Chain) { + Value *PtrOperand = getPointerOperand(I); assert(PtrOperand && "Instruction must have a pointer operand."); - Instrs.push_back(cast(V)); + Instrs.push_back(I); if (GetElementPtrInst *GEP = dyn_cast(PtrOperand)) Instrs.push_back(GEP); } // Erase instructions. - for (Value *V : Instrs) { - Instruction *Instr = cast(V); - if (Instr->use_empty()) - Instr->eraseFromParent(); - } + for (Instruction *I : Instrs) + if (I->use_empty()) + I->eraseFromParent(); } -std::pair, ArrayRef> -Vectorizer::splitOddVectorElts(ArrayRef Chain, +std::pair, ArrayRef> +Vectorizer::splitOddVectorElts(ArrayRef Chain, unsigned ElementSizeBits) { unsigned ElemSizeInBytes = ElementSizeBits / 8; unsigned SizeInBytes = ElemSizeInBytes * Chain.size(); @@ -424,19 +431,20 @@ Vectorizer::splitOddVectorElts(ArrayRef Chain, return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft)); } -ArrayRef Vectorizer::getVectorizablePrefix(ArrayRef Chain) { +ArrayRef +Vectorizer::getVectorizablePrefix(ArrayRef Chain) { // These are in BB order, unlike Chain, which is in address order. - SmallVector, 16> MemoryInstrs; - SmallVector, 16> ChainInstrs; + SmallVector, 16> MemoryInstrs; + SmallVector, 16> ChainInstrs; bool IsLoadChain = isa(Chain[0]); DEBUG({ - for (Value *V : Chain) { + for (Instruction *I : Chain) { if (IsLoadChain) - assert(isa(V) && + assert(isa(I) && "All elements of Chain must be loads, or all must be stores."); else - assert(isa(V) && + assert(isa(I) && "All elements of Chain must be loads, or all must be stores."); } }); @@ -463,11 +471,11 @@ ArrayRef Vectorizer::getVectorizablePrefix(ArrayRef Chain) { unsigned ChainInstrIdx, ChainInstrsLen; for (ChainInstrIdx = 0, ChainInstrsLen = ChainInstrs.size(); ChainInstrIdx < ChainInstrsLen; ++ChainInstrIdx) { - Value *ChainInstr = ChainInstrs[ChainInstrIdx].first; + Instruction *ChainInstr = ChainInstrs[ChainInstrIdx].first; unsigned ChainInstrLoc = ChainInstrs[ChainInstrIdx].second; bool AliasFound = false; for (auto EntryMem : MemoryInstrs) { - Value *MemInstr = EntryMem.first; + Instruction *MemInstr = EntryMem.first; unsigned MemInstrLoc = EntryMem.second; if (isa(MemInstr) && isa(ChainInstr)) continue; @@ -485,20 +493,16 @@ ArrayRef Vectorizer::getVectorizablePrefix(ArrayRef Chain) { ChainInstrLoc > MemInstrLoc) continue; - Instruction *M0 = cast(MemInstr); - Instruction *M1 = cast(ChainInstr); - - if (!AA.isNoAlias(MemoryLocation::get(M0), MemoryLocation::get(M1))) { + if (!AA.isNoAlias(MemoryLocation::get(MemInstr), + MemoryLocation::get(ChainInstr))) { DEBUG({ - Value *Ptr0 = getPointerOperand(M0); - Value *Ptr1 = getPointerOperand(M1); dbgs() << "LSV: Found alias:\n" " Aliasing instruction and pointer:\n" << " " << *MemInstr << '\n' - << " " << *Ptr0 << '\n' + << " " << *getPointerOperand(MemInstr) << '\n' << " Aliased instruction and pointer:\n" << " " << *ChainInstr << '\n' - << " " << *Ptr1 << '\n'; + << " " << *getPointerOperand(ChainInstr) << '\n'; }); AliasFound = true; break; @@ -516,18 +520,20 @@ ArrayRef Vectorizer::getVectorizablePrefix(ArrayRef Chain) { makeArrayRef(ChainInstrs.data(), ChainInstrIdx); unsigned ChainIdx, ChainLen; for (ChainIdx = 0, ChainLen = Chain.size(); ChainIdx < ChainLen; ++ChainIdx) { - Value *V = Chain[ChainIdx]; + Instruction *I = Chain[ChainIdx]; if (!any_of(VectorizableChainInstrs, - [V](std::pair CI) { return V == CI.first; })) + [I](std::pair CI) { + return I == CI.first; + })) break; } return Chain.slice(0, ChainIdx); } -std::pair +std::pair Vectorizer::collectInstructions(BasicBlock *BB) { - ValueListMap LoadRefs; - ValueListMap StoreRefs; + InstrListMap LoadRefs; + InstrListMap StoreRefs; for (Instruction &I : *BB) { if (!I.mayReadOrWriteMemory()) @@ -557,9 +563,8 @@ Vectorizer::collectInstructions(BasicBlock *BB) { // Make sure all the users of a vector are constant-index extracts. if (isa(Ty) && !all_of(LI->users(), [LI](const User *U) { - const Instruction *UI = cast(U); - return isa(UI) && - isa(UI->getOperand(1)); + const ExtractElementInst *EEI = dyn_cast(U); + return EEI && isa(EEI->getOperand(1)); })) continue; @@ -590,9 +595,8 @@ Vectorizer::collectInstructions(BasicBlock *BB) { continue; if (isa(Ty) && !all_of(SI->users(), [SI](const User *U) { - const Instruction *UI = cast(U); - return isa(UI) && - isa(UI->getOperand(1)); + const ExtractElementInst *EEI = dyn_cast(U); + return EEI && isa(EEI->getOperand(1)); })) continue; @@ -605,10 +609,10 @@ Vectorizer::collectInstructions(BasicBlock *BB) { return {LoadRefs, StoreRefs}; } -bool Vectorizer::vectorizeChains(ValueListMap &Map) { +bool Vectorizer::vectorizeChains(InstrListMap &Map) { bool Changed = false; - for (const std::pair &Chain : Map) { + for (const std::pair &Chain : Map) { unsigned Size = Chain.second.size(); if (Size < 2) continue; @@ -618,7 +622,7 @@ bool Vectorizer::vectorizeChains(ValueListMap &Map) { // Process the stores in chunks of 64. for (unsigned CI = 0, CE = Size; CI < CE; CI += 64) { unsigned Len = std::min(CE - CI, 64); - ArrayRef Chunk(&Chain.second[CI], Len); + ArrayRef Chunk(&Chain.second[CI], Len); Changed |= vectorizeInstructions(Chunk); } } @@ -626,7 +630,7 @@ bool Vectorizer::vectorizeChains(ValueListMap &Map) { return Changed; } -bool Vectorizer::vectorizeInstructions(ArrayRef Instrs) { +bool Vectorizer::vectorizeInstructions(ArrayRef Instrs) { DEBUG(dbgs() << "LSV: Vectorizing " << Instrs.size() << " instructions.\n"); SmallSetVector Heads, Tails; int ConsecutiveChain[64]; @@ -655,7 +659,7 @@ bool Vectorizer::vectorizeInstructions(ArrayRef Instrs) { } bool Changed = false; - SmallPtrSet InstructionsProcessed; + SmallPtrSet InstructionsProcessed; for (int Head : Heads) { if (InstructionsProcessed.count(Instrs[Head])) @@ -672,7 +676,7 @@ bool Vectorizer::vectorizeInstructions(ArrayRef Instrs) { // We found an instr that starts a chain. Now follow the chain and try to // vectorize it. - SmallVector Operands; + SmallVector Operands; int I = Head; while (I != -1 && (Tails.count(I) || Heads.count(I))) { if (InstructionsProcessed.count(Instrs[I])) @@ -695,13 +699,14 @@ bool Vectorizer::vectorizeInstructions(ArrayRef Instrs) { } bool Vectorizer::vectorizeStoreChain( - ArrayRef Chain, SmallPtrSet *InstructionsProcessed) { + ArrayRef Chain, + SmallPtrSet *InstructionsProcessed) { StoreInst *S0 = cast(Chain[0]); // If the vector has an int element, default to int for the whole load. Type *StoreTy; - for (const auto &V : Chain) { - StoreTy = cast(V)->getValueOperand()->getType(); + for (Instruction *I : Chain) { + StoreTy = cast(I)->getValueOperand()->getType(); if (StoreTy->isIntOrIntVectorTy()) break; @@ -723,7 +728,7 @@ bool Vectorizer::vectorizeStoreChain( return false; } - ArrayRef NewChain = getVectorizablePrefix(Chain); + ArrayRef NewChain = getVectorizablePrefix(Chain); if (NewChain.empty()) { // No vectorization possible. InstructionsProcessed->insert(Chain.begin(), Chain.end()); @@ -773,8 +778,8 @@ bool Vectorizer::vectorizeStoreChain( DEBUG({ dbgs() << "LSV: Stores to vectorize:\n"; - for (Value *V : Chain) - dbgs() << " " << *V << "\n"; + for (Instruction *I : Chain) + dbgs() << " " << *I << "\n"; }); // We won't try again to vectorize the elements of the chain, regardless of @@ -836,9 +841,11 @@ bool Vectorizer::vectorizeStoreChain( } } - Value *Bitcast = - Builder.CreateBitCast(S0->getPointerOperand(), VecTy->getPointerTo(AS)); - StoreInst *SI = cast(Builder.CreateStore(Vec, Bitcast)); + // This cast is safe because Builder.CreateStore() always creates a bona fide + // StoreInst. + StoreInst *SI = cast( + Builder.CreateStore(Vec, Builder.CreateBitCast(S0->getPointerOperand(), + VecTy->getPointerTo(AS)))); propagateMetadata(SI, Chain); SI->setAlignment(Alignment); @@ -849,7 +856,8 @@ bool Vectorizer::vectorizeStoreChain( } bool Vectorizer::vectorizeLoadChain( - ArrayRef Chain, SmallPtrSet *InstructionsProcessed) { + ArrayRef Chain, + SmallPtrSet *InstructionsProcessed) { LoadInst *L0 = cast(Chain[0]); // If the vector has an int element, default to int for the whole load. @@ -877,7 +885,7 @@ bool Vectorizer::vectorizeLoadChain( return false; } - ArrayRef NewChain = getVectorizablePrefix(Chain); + ArrayRef NewChain = getVectorizablePrefix(Chain); if (NewChain.empty()) { // No vectorization possible. InstructionsProcessed->insert(Chain.begin(), Chain.end()); @@ -949,8 +957,8 @@ bool Vectorizer::vectorizeLoadChain( DEBUG({ dbgs() << "LSV: Loads to vectorize:\n"; - for (Value *V : Chain) - V->dump(); + for (Instruction *I : Chain) + I->dump(); }); // getVectorizablePrefix already computed getBoundaryInstrs. The value of @@ -962,7 +970,8 @@ bool Vectorizer::vectorizeLoadChain( Value *Bitcast = Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS)); - + // This cast is safe because Builder.CreateLoad always creates a bona fide + // LoadInst. LoadInst *LI = cast(Builder.CreateLoad(Bitcast)); propagateMetadata(LI, Chain); LI->setAlignment(Alignment); @@ -973,17 +982,17 @@ bool Vectorizer::vectorizeLoadChain( unsigned VecWidth = VecLoadTy->getNumElements(); for (unsigned I = 0, E = Chain.size(); I != E; ++I) { for (auto Use : Chain[I]->users()) { + // All users of vector loads are ExtractElement instructions with + // constant indices, otherwise we would have bailed before now. Instruction *UI = cast(Use); unsigned Idx = cast(UI->getOperand(1))->getZExtValue(); unsigned NewIdx = Idx + I * VecWidth; Value *V = Builder.CreateExtractElement(LI, Builder.getInt32(NewIdx)); - Instruction *Extracted = cast(V); - if (Extracted->getType() != UI->getType()) - Extracted = cast( - Builder.CreateBitCast(Extracted, UI->getType())); + if (V->getType() != UI->getType()) + V = Builder.CreateBitCast(V, UI->getType()); // Replace the old instruction. - UI->replaceAllUsesWith(Extracted); + UI->replaceAllUsesWith(V); InstrsToErase.push_back(UI); } } @@ -998,15 +1007,13 @@ bool Vectorizer::vectorizeLoadChain( } else { for (unsigned I = 0, E = Chain.size(); I != E; ++I) { Value *V = Builder.CreateExtractElement(LI, Builder.getInt32(I)); - Instruction *Extracted = cast(V); - Instruction *UI = cast(Chain[I]); - if (Extracted->getType() != UI->getType()) { - Extracted = cast( - Builder.CreateBitOrPointerCast(Extracted, UI->getType())); + Value *CV = Chain[I]; + if (V->getType() != CV->getType()) { + V = Builder.CreateBitOrPointerCast(V, CV->getType()); } // Replace the old instruction. - UI->replaceAllUsesWith(Extracted); + CV->replaceAllUsesWith(V); } if (Instruction *BitcastInst = dyn_cast(Bitcast))