diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp new file mode 100644 index 00000000000..f0c5ba73b91 --- /dev/null +++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -0,0 +1,96 @@ +//===- TailRecursionElimination.cpp - Eliminate Tail Calls ----------------===// +// +// This file implements tail recursion elimination. +// +// Caveats: The algorithm implemented is trivially simple. There are several +// improvements that could be made: +// +// 1. If the function has any alloca instructions, these instructions will not +// remain in the entry block of the function. Doing this requires analysis +// to prove that the alloca is not reachable by the recursively invoked +// function call. +// 2. Tail recursion is only performed if the call immediately preceeds the +// return instruction. Would it be useful to generalize this somehow? +// 3. TRE is only performed if the function returns void or if the return +// returns the result returned by the call. It is possible, but unlikely, +// that the return returns something else (like constant 0), and can still +// be TRE'd. It can be TRE'd if ALL OTHER return instructions in the +// function return the exact same value. +// +//===----------------------------------------------------------------------===// + +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "Support/Statistic.h" + +namespace { + Statistic<> NumEliminated("tailcallelim", "Number of tail calls removed"); + + struct TailCallElim : public FunctionPass { + virtual bool runOnFunction(Function &F); + }; + RegisterOpt X("tailcallelim", "Tail Call Elimination"); +} + + +bool TailCallElim::runOnFunction(Function &F) { + // If this function is a varargs function, we won't be able to PHI the args + // right, so don't even try to convert it... + if (F.getFunctionType()->isVarArg()) return false; + + BasicBlock *OldEntry = 0; + std::vector ArgumentPHIs; + bool MadeChange = false; + + // Loop over the function, looking for any returning blocks... + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) + if (ReturnInst *Ret = dyn_cast(BB->getTerminator())) + if (Ret != BB->begin()) + if (CallInst *CI = dyn_cast(Ret->getPrev())) + // Make sure the tail call is to the current function, and that the + // return either returns void or returns the value computed by the + // call. + if (CI->getCalledFunction() == &F && + (Ret->getNumOperands() == 0 || Ret->getReturnValue() == CI)) { + // Ohh, it looks like we found a tail call, is this the first? + if (!OldEntry) { + // Ok, so this is the first tail call we have found in this + // function. Insert a new entry block into the function, allowing + // us to branch back to the old entry block. + OldEntry = &F.getEntryNode(); + BasicBlock *NewEntry = new BasicBlock("tailrecurse", OldEntry); + NewEntry->getInstList().push_back(new BranchInst(OldEntry)); + + // Now that we have created a new block, which jumps to the entry + // block, insert a PHI node for each argument of the function. + // For now, we initialize each PHI to only have the real arguments + // which are passed in. + Instruction *InsertPos = OldEntry->begin(); + for (Function::aiterator I = F.abegin(), E = F.aend(); I!=E; ++I){ + PHINode *PN = new PHINode(I->getType(), I->getName()+".tr", + InsertPos); + PN->addIncoming(I, NewEntry); + ArgumentPHIs.push_back(PN); + } + } + + // Ok, now that we know we have a pseudo-entry block WITH all of the + // required PHI nodes, add entries into the PHI node for the actual + // parameters passed into the tail-recursive call. + for (unsigned i = 0, e = CI->getNumOperands()-1; i != e; ++i) + ArgumentPHIs[i]->addIncoming(CI->getOperand(i+1), BB); + + // Now that all of the PHI nodes are in place, remove the call and + // ret instructions, replacing them with an unconditional branch. + new BranchInst(OldEntry, CI); + BB->getInstList().pop_back(); // Remove return. + BB->getInstList().pop_back(); // Remove call. + MadeChange = true; + NumEliminated++; + } + + return MadeChange; +} +