diff --git a/include/llvm/Transforms/Utils/MemorySSA.h b/include/llvm/Transforms/Utils/MemorySSA.h index 84892e5d3ad..408c6a157cd 100644 --- a/include/llvm/Transforms/Utils/MemorySSA.h +++ b/include/llvm/Transforms/Utils/MemorySSA.h @@ -578,6 +578,15 @@ public: MemoryAccess *Definition, MemoryAccess *InsertPt); + // \brief Splice \p What to just before \p Where. + // + // In order to be efficient, the following conditions must be met: + // - \p Where dominates \p What, + // - All memory accesses in [\p Where, \p What) are no-alias with \p What. + // + // TODO: relax the MemoryDef requirement on Where. + void spliceMemoryAccessAbove(MemoryDef *Where, MemoryUseOrDef *What); + /// \brief Remove a MemoryAccess from MemorySSA, including updating all /// definitions and uses. /// This should be called when a memory instruction that has a MemoryAccess diff --git a/lib/Transforms/Utils/MemorySSA.cpp b/lib/Transforms/Utils/MemorySSA.cpp index c887f86ca17..1ce4225f09c 100644 --- a/lib/Transforms/Utils/MemorySSA.cpp +++ b/lib/Transforms/Utils/MemorySSA.cpp @@ -1623,6 +1623,29 @@ MemoryUseOrDef *MemorySSA::createMemoryAccessAfter(Instruction *I, return NewAccess; } +void MemorySSA::spliceMemoryAccessAbove(MemoryDef *Where, + MemoryUseOrDef *What) { + assert(What != getLiveOnEntryDef() && + Where != getLiveOnEntryDef() && "Can't splice (above) LOE."); + assert(dominates(Where, What) && "Only upwards splices are permitted."); + + if (Where == What) + return; + if (isa(What)) { + // TODO: possibly use removeMemoryAccess' more efficient RAUW + What->replaceAllUsesWith(What->getDefiningAccess()); + What->setDefiningAccess(Where->getDefiningAccess()); + Where->setDefiningAccess(What); + } + AccessList *Src = getWritableBlockAccesses(What->getBlock()); + AccessList *Dest = getWritableBlockAccesses(Where->getBlock()); + Dest->splice(AccessList::iterator(Where), *Src, What); + + BlockNumberingValid.erase(What->getBlock()); + if (What->getBlock() != Where->getBlock()) + BlockNumberingValid.erase(Where->getBlock()); +} + /// \brief Helper function to create new memory accesses MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) { // The assume intrinsic has a control dependency which we model by claiming diff --git a/unittests/Transforms/Utils/MemorySSA.cpp b/unittests/Transforms/Utils/MemorySSA.cpp index c290e5f4073..945fe32c316 100644 --- a/unittests/Transforms/Utils/MemorySSA.cpp +++ b/unittests/Transforms/Utils/MemorySSA.cpp @@ -484,3 +484,51 @@ TEST_F(MemorySSATest, WalkerReopt) { EXPECT_EQ(Walker->getClobberingMemoryAccess(NewLoadAccess), LoadClobber); EXPECT_EQ(NewLoadAccess->getDefiningAccess(), LoadClobber); } + +// Test out MemorySSA::spliceMemoryAccessAbove. +TEST_F(MemorySSATest, SpliceAboveMemoryDef) { + F = Function::Create(FunctionType::get(B.getVoidTy(), {}, false), + GlobalValue::ExternalLinkage, "F", &M); + B.SetInsertPoint(BasicBlock::Create(C, "", F)); + + Type *Int8 = Type::getInt8Ty(C); + Value *A = B.CreateAlloca(Int8, ConstantInt::get(Int8, 1), "A"); + Value *B_ = B.CreateAlloca(Int8, ConstantInt::get(Int8, 1), "B"); + Value *C = B.CreateAlloca(Int8, ConstantInt::get(Int8, 1), "C"); + + StoreInst *StoreA0 = B.CreateStore(ConstantInt::get(Int8, 0), A); + StoreInst *StoreB = B.CreateStore(ConstantInt::get(Int8, 0), B_); + LoadInst *LoadB = B.CreateLoad(B_); + StoreInst *StoreA1 = B.CreateStore(ConstantInt::get(Int8, 4), A); + // splice this above StoreB + StoreInst *StoreC = B.CreateStore(ConstantInt::get(Int8, 4), C); + StoreInst *StoreA2 = B.CreateStore(ConstantInt::get(Int8, 4), A); + LoadInst *LoadC = B.CreateLoad(C); + + setupAnalyses(); + MemorySSA &MSSA = *Analyses->MSSA; + MemorySSAWalker &Walker = *Analyses->Walker; + + StoreC->moveBefore(StoreB); + MSSA.spliceMemoryAccessAbove(cast(MSSA.getMemoryAccess(StoreB)), + MSSA.getMemoryAccess(StoreC)); + + MSSA.verifyMemorySSA(); + + EXPECT_EQ(MSSA.getMemoryAccess(StoreB)->getDefiningAccess(), + MSSA.getMemoryAccess(StoreC)); + EXPECT_EQ(MSSA.getMemoryAccess(StoreC)->getDefiningAccess(), + MSSA.getMemoryAccess(StoreA0)); + EXPECT_EQ(MSSA.getMemoryAccess(StoreA2)->getDefiningAccess(), + MSSA.getMemoryAccess(StoreA1)); + EXPECT_EQ(Walker.getClobberingMemoryAccess(LoadB), + MSSA.getMemoryAccess(StoreB)); + EXPECT_EQ(Walker.getClobberingMemoryAccess(LoadC), + MSSA.getMemoryAccess(StoreC)); + + // exercise block numbering + EXPECT_TRUE(MSSA.locallyDominates(MSSA.getMemoryAccess(StoreC), + MSSA.getMemoryAccess(StoreB))); + EXPECT_TRUE(MSSA.locallyDominates(MSSA.getMemoryAccess(StoreA1), + MSSA.getMemoryAccess(StoreA2))); +}