llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Brian Gesiak b11ac3d4e1 [coroutines][PR40979] Ignore unreachable uses across suspend points
Summary:
Depends on https://reviews.llvm.org/D59069.

https://bugs.llvm.org/show_bug.cgi?id=40979 describes a bug in which the
-coro-split pass would assert that a use was across a suspend point from
a definition. Normally this would mean that a value would "spill" across
a suspend point and thus need to be stored in the coroutine frame. However,
in this case the use was unreachable, and so it would not be necessary
to store the definition on the frame.

To prevent the assert, simply remove unreachable basic blocks from a
coroutine function before computing spills. This avoids the assert
reported in PR40979.

Reviewers: GorNishanov, tks2103

Reviewed By: GorNishanov

Subscribers: EricWF, jdoerfert, llvm-commits, lewissbaker

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D59068

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@355852 91177308-0d34-0410-b5e6-96231b3b80d8
2019-03-11 18:31:28 +00:00

955 lines
34 KiB
C++

//===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This pass builds the coroutine frame and outlines resume and destroy parts
// of the coroutine into separate functions.
//
// We present a coroutine to an LLVM as an ordinary function with suspension
// points marked up with intrinsics. We let the optimizer party on the coroutine
// as a single function for as long as possible. Shortly before the coroutine is
// eligible to be inlined into its callers, we split up the coroutine into parts
// corresponding to an initial, resume and destroy invocations of the coroutine,
// add them to the current SCC and restart the IPO pipeline to optimize the
// coroutine subfunctions we extracted before proceeding to the caller of the
// coroutine.
//===----------------------------------------------------------------------===//
#include "CoroInstr.h"
#include "CoroInternal.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ValueMapper.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <iterator>
using namespace llvm;
#define DEBUG_TYPE "coro-split"
// Create an entry block for a resume function with a switch that will jump to
// suspend points.
static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) {
LLVMContext &C = F.getContext();
// resume.entry:
// %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
// i32 2
// % index = load i32, i32* %index.addr
// switch i32 %index, label %unreachable [
// i32 0, label %resume.0
// i32 1, label %resume.1
// ...
// ]
auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
IRBuilder<> Builder(NewEntry);
auto *FramePtr = Shape.FramePtr;
auto *FrameTy = Shape.FrameTy;
auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
auto *Switch =
Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
Shape.ResumeSwitch = Switch;
size_t SuspendIndex = 0;
for (CoroSuspendInst *S : Shape.CoroSuspends) {
ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
// Replace CoroSave with a store to Index:
// %index.addr = getelementptr %f.frame... (index field number)
// store i32 0, i32* %index.addr1
auto *Save = S->getCoroSave();
Builder.SetInsertPoint(Save);
if (S->isFinal()) {
// Final suspend point is represented by storing zero in ResumeFnAddr.
auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0,
0, "ResumeFn.addr");
auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
cast<PointerType>(GepIndex->getType())->getElementType()));
Builder.CreateStore(NullPtr, GepIndex);
} else {
auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(
FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr");
Builder.CreateStore(IndexVal, GepIndex);
}
Save->replaceAllUsesWith(ConstantTokenNone::get(C));
Save->eraseFromParent();
// Split block before and after coro.suspend and add a jump from an entry
// switch:
//
// whateverBB:
// whatever
// %0 = call i8 @llvm.coro.suspend(token none, i1 false)
// switch i8 %0, label %suspend[i8 0, label %resume
// i8 1, label %cleanup]
// becomes:
//
// whateverBB:
// whatever
// br label %resume.0.landing
//
// resume.0: ; <--- jump from the switch in the resume.entry
// %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
// br label %resume.0.landing
//
// resume.0.landing:
// %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
// switch i8 % 1, label %suspend [i8 0, label %resume
// i8 1, label %cleanup]
auto *SuspendBB = S->getParent();
auto *ResumeBB =
SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
auto *LandingBB = ResumeBB->splitBasicBlock(
S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
Switch->addCase(IndexVal, ResumeBB);
cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
S->replaceAllUsesWith(PN);
PN->addIncoming(Builder.getInt8(-1), SuspendBB);
PN->addIncoming(S, ResumeBB);
++SuspendIndex;
}
Builder.SetInsertPoint(UnreachBB);
Builder.CreateUnreachable();
return NewEntry;
}
// In Resumers, we replace fallthrough coro.end with ret void and delete the
// rest of the block.
static void replaceFallthroughCoroEnd(IntrinsicInst *End,
ValueToValueMapTy &VMap) {
auto *NewE = cast<IntrinsicInst>(VMap[End]);
ReturnInst::Create(NewE->getContext(), nullptr, NewE);
// Remove the rest of the block, by splitting it into an unreachable block.
auto *BB = NewE->getParent();
BB->splitBasicBlock(NewE);
BB->getTerminator()->eraseFromParent();
}
// In Resumers, we replace unwind coro.end with True to force the immediate
// unwind to caller.
static void replaceUnwindCoroEnds(coro::Shape &Shape, ValueToValueMapTy &VMap) {
if (Shape.CoroEnds.empty())
return;
LLVMContext &Context = Shape.CoroEnds.front()->getContext();
auto *True = ConstantInt::getTrue(Context);
for (CoroEndInst *CE : Shape.CoroEnds) {
if (!CE->isUnwind())
continue;
auto *NewCE = cast<IntrinsicInst>(VMap[CE]);
// If coro.end has an associated bundle, add cleanupret instruction.
if (auto Bundle = NewCE->getOperandBundle(LLVMContext::OB_funclet)) {
Value *FromPad = Bundle->Inputs[0];
auto *CleanupRet = CleanupReturnInst::Create(FromPad, nullptr, NewCE);
NewCE->getParent()->splitBasicBlock(NewCE);
CleanupRet->getParent()->getTerminator()->eraseFromParent();
}
NewCE->replaceAllUsesWith(True);
NewCE->eraseFromParent();
}
}
// Rewrite final suspend point handling. We do not use suspend index to
// represent the final suspend point. Instead we zero-out ResumeFnAddr in the
// coroutine frame, since it is undefined behavior to resume a coroutine
// suspended at the final suspend point. Thus, in the resume function, we can
// simply remove the last case (when coro::Shape is built, the final suspend
// point (if present) is always the last element of CoroSuspends array).
// In the destroy function, we add a code sequence to check if ResumeFnAddress
// is Null, and if so, jump to the appropriate label to handle cleanup from the
// final suspend point.
static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr,
coro::Shape &Shape, SwitchInst *Switch,
bool IsDestroy) {
assert(Shape.HasFinalSuspend);
auto FinalCaseIt = std::prev(Switch->case_end());
BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
Switch->removeCase(FinalCaseIt);
if (IsDestroy) {
BasicBlock *OldSwitchBB = Switch->getParent();
auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
Builder.SetInsertPoint(OldSwitchBB->getTerminator());
auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr,
0, 0, "ResumeFn.addr");
auto *Load = Builder.CreateLoad(
Shape.FrameTy->getElementType(coro::Shape::ResumeField), GepIndex);
auto *NullPtr =
ConstantPointerNull::get(cast<PointerType>(Load->getType()));
auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
OldSwitchBB->getTerminator()->eraseFromParent();
}
}
// Create a resume clone by cloning the body of the original function, setting
// new entry block and replacing coro.suspend an appropriate value to force
// resume or cleanup pass for every suspend point.
static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape,
BasicBlock *ResumeEntry, int8_t FnIndex) {
Module *M = F.getParent();
auto *FrameTy = Shape.FrameTy;
auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0));
auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType());
Function *NewF =
Function::Create(FnTy, GlobalValue::LinkageTypes::ExternalLinkage,
F.getName() + Suffix, M);
NewF->addParamAttr(0, Attribute::NonNull);
NewF->addParamAttr(0, Attribute::NoAlias);
ValueToValueMapTy VMap;
// Replace all args with undefs. The buildCoroutineFrame algorithm already
// rewritten access to the args that occurs after suspend points with loads
// and stores to/from the coroutine frame.
for (Argument &A : F.args())
VMap[&A] = UndefValue::get(A.getType());
SmallVector<ReturnInst *, 4> Returns;
CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns);
NewF->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
// Remove old returns.
for (ReturnInst *Return : Returns)
changeToUnreachable(Return, /*UseLLVMTrap=*/false);
// Remove old return attributes.
NewF->removeAttributes(
AttributeList::ReturnIndex,
AttributeFuncs::typeIncompatible(NewF->getReturnType()));
// Make AllocaSpillBlock the new entry block.
auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]);
auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
Entry->moveBefore(&NewF->getEntryBlock());
Entry->getTerminator()->eraseFromParent();
BranchInst::Create(SwitchBB, Entry);
Entry->setName("entry" + Suffix);
// Clear all predecessors of the new entry block.
auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
Entry->replaceAllUsesWith(Switch->getDefaultDest());
IRBuilder<> Builder(&NewF->getEntryBlock().front());
// Remap frame pointer.
Argument *NewFramePtr = &*NewF->arg_begin();
Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]);
NewFramePtr->takeName(OldFramePtr);
OldFramePtr->replaceAllUsesWith(NewFramePtr);
// Remap vFrame pointer.
auto *NewVFrame = Builder.CreateBitCast(
NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
OldVFrame->replaceAllUsesWith(NewVFrame);
// Rewrite final suspend handling as it is not done via switch (allows to
// remove final case from the switch, since it is undefined behavior to resume
// the coroutine suspended at the final suspend point.
if (Shape.HasFinalSuspend) {
auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]);
bool IsDestroy = FnIndex != 0;
handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy);
}
// Replace coro suspend with the appropriate resume index.
// Replacing coro.suspend with (0) will result in control flow proceeding to
// a resume label associated with a suspend point, replacing it with (1) will
// result in control flow proceeding to a cleanup label associated with this
// suspend point.
auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0);
for (CoroSuspendInst *CS : Shape.CoroSuspends) {
auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]);
MappedCS->replaceAllUsesWith(NewValue);
MappedCS->eraseFromParent();
}
// Remove coro.end intrinsics.
replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap);
replaceUnwindCoroEnds(Shape, VMap);
// Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
// to suppress deallocation code.
coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
/*Elide=*/FnIndex == 2);
NewF->setCallingConv(CallingConv::Fast);
return NewF;
}
static void removeCoroEnds(coro::Shape &Shape) {
if (Shape.CoroEnds.empty())
return;
LLVMContext &Context = Shape.CoroEnds.front()->getContext();
auto *False = ConstantInt::getFalse(Context);
for (CoroEndInst *CE : Shape.CoroEnds) {
CE->replaceAllUsesWith(False);
CE->eraseFromParent();
}
}
static void replaceFrameSize(coro::Shape &Shape) {
if (Shape.CoroSizes.empty())
return;
// In the same function all coro.sizes should have the same result type.
auto *SizeIntrin = Shape.CoroSizes.back();
Module *M = SizeIntrin->getModule();
const DataLayout &DL = M->getDataLayout();
auto Size = DL.getTypeAllocSize(Shape.FrameTy);
auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size);
for (CoroSizeInst *CS : Shape.CoroSizes) {
CS->replaceAllUsesWith(SizeConstant);
CS->eraseFromParent();
}
}
// Create a global constant array containing pointers to functions provided and
// set Info parameter of CoroBegin to point at this constant. Example:
//
// @f.resumers = internal constant [2 x void(%f.frame*)*]
// [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
// define void @f() {
// ...
// call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
// i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
//
// Assumes that all the functions have the same signature.
static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin,
std::initializer_list<Function *> Fns) {
SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
assert(!Args.empty());
Function *Part = *Fns.begin();
Module *M = Part->getParent();
auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
auto *ConstVal = ConstantArray::get(ArrTy, Args);
auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
GlobalVariable::PrivateLinkage, ConstVal,
F.getName() + Twine(".resumers"));
// Update coro.begin instruction to refer to this constant.
LLVMContext &C = F.getContext();
auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
CoroBegin->getId()->setInfo(BC);
}
// Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
Function *DestroyFn, Function *CleanupFn) {
IRBuilder<> Builder(Shape.FramePtr->getNextNode());
auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32(
Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField,
"resume.addr");
Builder.CreateStore(ResumeFn, ResumeAddr);
Value *DestroyOrCleanupFn = DestroyFn;
CoroIdInst *CoroId = Shape.CoroBegin->getId();
if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
// If there is a CoroAlloc and it returns false (meaning we elide the
// allocation, use CleanupFn instead of DestroyFn).
DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
}
auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32(
Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField,
"destroy.addr");
Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
}
static void postSplitCleanup(Function &F) {
removeUnreachableBlocks(F);
legacy::FunctionPassManager FPM(F.getParent());
FPM.add(createVerifierPass());
FPM.add(createSCCPPass());
FPM.add(createCFGSimplificationPass());
FPM.add(createEarlyCSEPass());
FPM.add(createCFGSimplificationPass());
FPM.doInitialization();
FPM.run(F);
FPM.doFinalization();
}
// Assuming we arrived at the block NewBlock from Prev instruction, store
// PHI's incoming values in the ResolvedValues map.
static void
scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
DenseMap<Value *, Value *> &ResolvedValues) {
auto *PrevBB = Prev->getParent();
for (PHINode &PN : NewBlock->phis()) {
auto V = PN.getIncomingValueForBlock(PrevBB);
// See if we already resolved it.
auto VI = ResolvedValues.find(V);
if (VI != ResolvedValues.end())
V = VI->second;
// Remember the value.
ResolvedValues[&PN] = V;
}
}
// Replace a sequence of branches leading to a ret, with a clone of a ret
// instruction. Suspend instruction represented by a switch, track the PHI
// values and select the correct case successor when possible.
static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
DenseMap<Value *, Value *> ResolvedValues;
Instruction *I = InitialInst;
while (I->isTerminator()) {
if (isa<ReturnInst>(I)) {
if (I != InitialInst)
ReplaceInstWithInst(InitialInst, I->clone());
return true;
}
if (auto *BR = dyn_cast<BranchInst>(I)) {
if (BR->isUnconditional()) {
BasicBlock *BB = BR->getSuccessor(0);
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
I = BB->getFirstNonPHIOrDbgOrLifetime();
continue;
}
} else if (auto *SI = dyn_cast<SwitchInst>(I)) {
Value *V = SI->getCondition();
auto it = ResolvedValues.find(V);
if (it != ResolvedValues.end())
V = it->second;
if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
I = BB->getFirstNonPHIOrDbgOrLifetime();
continue;
}
}
return false;
}
return false;
}
// Add musttail to any resume instructions that is immediately followed by a
// suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
// for symmetrical coroutine control transfer (C++ Coroutines TS extension).
// This transformation is done only in the resume part of the coroutine that has
// identical signature and calling convention as the coro.resume call.
static void addMustTailToCoroResumes(Function &F) {
bool changed = false;
// Collect potential resume instructions.
SmallVector<CallInst *, 4> Resumes;
for (auto &I : instructions(F))
if (auto *Call = dyn_cast<CallInst>(&I))
if (auto *CalledValue = Call->getCalledValue())
// CoroEarly pass replaced coro resumes with indirect calls to an
// address return by CoroSubFnInst intrinsic. See if it is one of those.
if (isa<CoroSubFnInst>(CalledValue->stripPointerCasts()))
Resumes.push_back(Call);
// Set musttail on those that are followed by a ret instruction.
for (CallInst *Call : Resumes)
if (simplifyTerminatorLeadingToRet(Call->getNextNode())) {
Call->setTailCallKind(CallInst::TCK_MustTail);
changed = true;
}
if (changed)
removeUnreachableBlocks(F);
}
// Coroutine has no suspend points. Remove heap allocation for the coroutine
// frame if possible.
static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) {
auto *CoroId = CoroBegin->getId();
auto *AllocInst = CoroId->getCoroAlloc();
coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr);
if (AllocInst) {
IRBuilder<> Builder(AllocInst);
// FIXME: Need to handle overaligned members.
auto *Frame = Builder.CreateAlloca(FrameTy);
auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
AllocInst->replaceAllUsesWith(Builder.getFalse());
AllocInst->eraseFromParent();
CoroBegin->replaceAllUsesWith(VFrame);
} else {
CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
}
CoroBegin->eraseFromParent();
}
// SimplifySuspendPoint needs to check that there is no calls between
// coro_save and coro_suspend, since any of the calls may potentially resume
// the coroutine and if that is the case we cannot eliminate the suspend point.
static bool hasCallsInBlockBetween(Instruction *From, Instruction *To) {
for (Instruction *I = From; I != To; I = I->getNextNode()) {
// Assume that no intrinsic can resume the coroutine.
if (isa<IntrinsicInst>(I))
continue;
if (CallSite(I))
return true;
}
return false;
}
static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) {
SmallPtrSet<BasicBlock *, 8> Set;
SmallVector<BasicBlock *, 8> Worklist;
Set.insert(SaveBB);
Worklist.push_back(ResDesBB);
// Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr
// returns a token consumed by suspend instruction, all blocks in between
// will have to eventually hit SaveBB when going backwards from ResDesBB.
while (!Worklist.empty()) {
auto *BB = Worklist.pop_back_val();
Set.insert(BB);
for (auto *Pred : predecessors(BB))
if (Set.count(Pred) == 0)
Worklist.push_back(Pred);
}
// SaveBB and ResDesBB are checked separately in hasCallsBetween.
Set.erase(SaveBB);
Set.erase(ResDesBB);
for (auto *BB : Set)
if (hasCallsInBlockBetween(BB->getFirstNonPHI(), nullptr))
return true;
return false;
}
static bool hasCallsBetween(Instruction *Save, Instruction *ResumeOrDestroy) {
auto *SaveBB = Save->getParent();
auto *ResumeOrDestroyBB = ResumeOrDestroy->getParent();
if (SaveBB == ResumeOrDestroyBB)
return hasCallsInBlockBetween(Save->getNextNode(), ResumeOrDestroy);
// Any calls from Save to the end of the block?
if (hasCallsInBlockBetween(Save->getNextNode(), nullptr))
return true;
// Any calls from begging of the block up to ResumeOrDestroy?
if (hasCallsInBlockBetween(ResumeOrDestroyBB->getFirstNonPHI(),
ResumeOrDestroy))
return true;
// Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB?
if (hasCallsInBlocksBetween(SaveBB, ResumeOrDestroyBB))
return true;
return false;
}
// If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the
// suspend point and replace it with nornal control flow.
static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
CoroBeginInst *CoroBegin) {
Instruction *Prev = Suspend->getPrevNode();
if (!Prev) {
auto *Pred = Suspend->getParent()->getSinglePredecessor();
if (!Pred)
return false;
Prev = Pred->getTerminator();
}
CallSite CS{Prev};
if (!CS)
return false;
auto *CallInstr = CS.getInstruction();
auto *Callee = CS.getCalledValue()->stripPointerCasts();
// See if the callsite is for resumption or destruction of the coroutine.
auto *SubFn = dyn_cast<CoroSubFnInst>(Callee);
if (!SubFn)
return false;
// Does not refer to the current coroutine, we cannot do anything with it.
if (SubFn->getFrame() != CoroBegin)
return false;
// See if the transformation is safe. Specifically, see if there are any
// calls in between Save and CallInstr. They can potenitally resume the
// coroutine rendering this optimization unsafe.
auto *Save = Suspend->getCoroSave();
if (hasCallsBetween(Save, CallInstr))
return false;
// Replace llvm.coro.suspend with the value that results in resumption over
// the resume or cleanup path.
Suspend->replaceAllUsesWith(SubFn->getRawIndex());
Suspend->eraseFromParent();
Save->eraseFromParent();
// No longer need a call to coro.resume or coro.destroy.
if (auto *Invoke = dyn_cast<InvokeInst>(CallInstr)) {
BranchInst::Create(Invoke->getNormalDest(), Invoke);
}
// Grab the CalledValue from CS before erasing the CallInstr.
auto *CalledValue = CS.getCalledValue();
CallInstr->eraseFromParent();
// If no more users remove it. Usually it is a bitcast of SubFn.
if (CalledValue != SubFn && CalledValue->user_empty())
if (auto *I = dyn_cast<Instruction>(CalledValue))
I->eraseFromParent();
// Now we are good to remove SubFn.
if (SubFn->user_empty())
SubFn->eraseFromParent();
return true;
}
// Remove suspend points that are simplified.
static void simplifySuspendPoints(coro::Shape &Shape) {
auto &S = Shape.CoroSuspends;
size_t I = 0, N = S.size();
if (N == 0)
return;
while (true) {
if (simplifySuspendPoint(S[I], Shape.CoroBegin)) {
if (--N == I)
break;
std::swap(S[I], S[N]);
continue;
}
if (++I == N)
break;
}
S.resize(N);
}
static SmallPtrSet<BasicBlock *, 4> getCoroBeginPredBlocks(CoroBeginInst *CB) {
// Collect all blocks that we need to look for instructions to relocate.
SmallPtrSet<BasicBlock *, 4> RelocBlocks;
SmallVector<BasicBlock *, 4> Work;
Work.push_back(CB->getParent());
do {
BasicBlock *Current = Work.pop_back_val();
for (BasicBlock *BB : predecessors(Current))
if (RelocBlocks.count(BB) == 0) {
RelocBlocks.insert(BB);
Work.push_back(BB);
}
} while (!Work.empty());
return RelocBlocks;
}
static SmallPtrSet<Instruction *, 8>
getNotRelocatableInstructions(CoroBeginInst *CoroBegin,
SmallPtrSetImpl<BasicBlock *> &RelocBlocks) {
SmallPtrSet<Instruction *, 8> DoNotRelocate;
// Collect all instructions that we should not relocate
SmallVector<Instruction *, 8> Work;
// Start with CoroBegin and terminators of all preceding blocks.
Work.push_back(CoroBegin);
BasicBlock *CoroBeginBB = CoroBegin->getParent();
for (BasicBlock *BB : RelocBlocks)
if (BB != CoroBeginBB)
Work.push_back(BB->getTerminator());
// For every instruction in the Work list, place its operands in DoNotRelocate
// set.
do {
Instruction *Current = Work.pop_back_val();
LLVM_DEBUG(dbgs() << "CoroSplit: Will not relocate: " << *Current << "\n");
DoNotRelocate.insert(Current);
for (Value *U : Current->operands()) {
auto *I = dyn_cast<Instruction>(U);
if (!I)
continue;
if (auto *A = dyn_cast<AllocaInst>(I)) {
// Stores to alloca instructions that occur before the coroutine frame
// is allocated should not be moved; the stored values may be used by
// the coroutine frame allocator. The operands to those stores must also
// remain in place.
for (const auto &User : A->users())
if (auto *SI = dyn_cast<llvm::StoreInst>(User))
if (RelocBlocks.count(SI->getParent()) != 0 &&
DoNotRelocate.count(SI) == 0) {
Work.push_back(SI);
DoNotRelocate.insert(SI);
}
continue;
}
if (DoNotRelocate.count(I) == 0) {
Work.push_back(I);
DoNotRelocate.insert(I);
}
}
} while (!Work.empty());
return DoNotRelocate;
}
static void relocateInstructionBefore(CoroBeginInst *CoroBegin, Function &F) {
// Analyze which non-alloca instructions are needed for allocation and
// relocate the rest to after coro.begin. We need to do it, since some of the
// targets of those instructions may be placed into coroutine frame memory
// for which becomes available after coro.begin intrinsic.
auto BlockSet = getCoroBeginPredBlocks(CoroBegin);
auto DoNotRelocateSet = getNotRelocatableInstructions(CoroBegin, BlockSet);
Instruction *InsertPt = CoroBegin->getNextNode();
BasicBlock &BB = F.getEntryBlock(); // TODO: Look at other blocks as well.
for (auto B = BB.begin(), E = BB.end(); B != E;) {
Instruction &I = *B++;
if (isa<AllocaInst>(&I))
continue;
if (&I == CoroBegin)
break;
if (DoNotRelocateSet.count(&I))
continue;
I.moveBefore(InsertPt);
}
}
static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
EliminateUnreachableBlocks(F);
coro::Shape Shape(F);
if (!Shape.CoroBegin)
return;
simplifySuspendPoints(Shape);
relocateInstructionBefore(Shape.CoroBegin, F);
buildCoroutineFrame(F, Shape);
replaceFrameSize(Shape);
// If there are no suspend points, no split required, just remove
// the allocation and deallocation blocks, they are not needed.
if (Shape.CoroSuspends.empty()) {
handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy);
removeCoroEnds(Shape);
postSplitCleanup(F);
coro::updateCallGraph(F, {}, CG, SCC);
return;
}
auto *ResumeEntry = createResumeEntryBlock(F, Shape);
auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0);
auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1);
auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2);
// We no longer need coro.end in F.
removeCoroEnds(Shape);
postSplitCleanup(F);
postSplitCleanup(*ResumeClone);
postSplitCleanup(*DestroyClone);
postSplitCleanup(*CleanupClone);
addMustTailToCoroResumes(*ResumeClone);
// Store addresses resume/destroy/cleanup functions in the coroutine frame.
updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
// Create a constant array referring to resume/destroy/clone functions pointed
// by the last argument of @llvm.coro.info, so that CoroElide pass can
// determined correct function to call.
setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone});
// Update call graph and add the functions we created to the SCC.
coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC);
}
// When we see the coroutine the first time, we insert an indirect call to a
// devirt trigger function and mark the coroutine that it is now ready for
// split.
static void prepareForSplit(Function &F, CallGraph &CG) {
Module &M = *F.getParent();
LLVMContext &Context = F.getContext();
#ifndef NDEBUG
Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN);
assert(DevirtFn && "coro.devirt.trigger function not found");
#endif
F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
// Insert an indirect call sequence that will be devirtualized by CoroElide
// pass:
// %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
// %1 = bitcast i8* %0 to void(i8*)*
// call void %1(i8* null)
coro::LowererBase Lowerer(M);
Instruction *InsertPt = F.getEntryBlock().getTerminator();
auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(Context));
auto *DevirtFnAddr =
Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt);
FunctionType *FnTy = FunctionType::get(Type::getVoidTy(Context),
{Type::getInt8PtrTy(Context)}, false);
auto *IndirectCall = CallInst::Create(FnTy, DevirtFnAddr, Null, "", InsertPt);
// Update CG graph with an indirect call we just added.
CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode());
}
// Make sure that there is a devirtualization trigger function that CoroSplit
// pass uses the force restart CGSCC pipeline. If devirt trigger function is not
// found, we will create one and add it to the current SCC.
static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
Module &M = CG.getModule();
if (M.getFunction(CORO_DEVIRT_TRIGGER_FN))
return;
LLVMContext &C = M.getContext();
auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C),
/*IsVarArgs=*/false);
Function *DevirtFn =
Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
CORO_DEVIRT_TRIGGER_FN, &M);
DevirtFn->addFnAttr(Attribute::AlwaysInline);
auto *Entry = BasicBlock::Create(C, "entry", DevirtFn);
ReturnInst::Create(C, Entry);
auto *Node = CG.getOrInsertFunction(DevirtFn);
SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
Nodes.push_back(Node);
SCC.initialize(Nodes);
}
//===----------------------------------------------------------------------===//
// Top Level Driver
//===----------------------------------------------------------------------===//
namespace {
struct CoroSplit : public CallGraphSCCPass {
static char ID; // Pass identification, replacement for typeid
CoroSplit() : CallGraphSCCPass(ID) {
initializeCoroSplitPass(*PassRegistry::getPassRegistry());
}
bool Run = false;
// A coroutine is identified by the presence of coro.begin intrinsic, if
// we don't have any, this pass has nothing to do.
bool doInitialization(CallGraph &CG) override {
Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"});
return CallGraphSCCPass::doInitialization(CG);
}
bool runOnSCC(CallGraphSCC &SCC) override {
if (!Run)
return false;
// Find coroutines for processing.
SmallVector<Function *, 4> Coroutines;
for (CallGraphNode *CGN : SCC)
if (auto *F = CGN->getFunction())
if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
Coroutines.push_back(F);
if (Coroutines.empty())
return false;
CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
createDevirtTriggerFunc(CG, SCC);
for (Function *F : Coroutines) {
Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
StringRef Value = Attr.getValueAsString();
LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName()
<< "' state: " << Value << "\n");
if (Value == UNPREPARED_FOR_SPLIT) {
prepareForSplit(*F, CG);
continue;
}
F->removeFnAttr(CORO_PRESPLIT_ATTR);
splitCoroutine(*F, CG, SCC);
}
return true;
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
CallGraphSCCPass::getAnalysisUsage(AU);
}
StringRef getPassName() const override { return "Coroutine Splitting"; }
};
} // end anonymous namespace
char CoroSplit::ID = 0;
INITIALIZE_PASS(
CoroSplit, "coro-split",
"Split coroutine into a set of functions driving its state machine", false,
false)
Pass *llvm::createCoroSplitPass() { return new CoroSplit(); }