Generate substantially better code when there are a limited number of exits

from the extracted region.  If the return has 0 or 1 exit blocks, the new
function returns void.  If it has 2 exits, it returns bool, otherwise it
returns a ushort as before.

This allows us to use a conditional branch instruction when there are two
exit blocks, as often happens during block extraction.

llvm-svn: 13481
This commit is contained in:
Chris Lattner 2004-05-12 04:14:24 +00:00
parent 7f4cd3b0be
commit 67e58adb41

View File

@ -44,9 +44,11 @@ namespace {
std::set<BasicBlock*> BlocksToExtract; std::set<BasicBlock*> BlocksToExtract;
DominatorSet *DS; DominatorSet *DS;
bool AggregateArgs; bool AggregateArgs;
unsigned NumExitBlocks;
const Type *RetTy;
public: public:
CodeExtractor(DominatorSet *ds = 0, bool AggArgs = false) CodeExtractor(DominatorSet *ds = 0, bool AggArgs = false)
: DS(ds), AggregateArgs(AggregateArgsOpt) {} : DS(ds), AggregateArgs(AggregateArgsOpt), NumExitBlocks(~0U) {}
Function *ExtractCodeRegion(const std::vector<BasicBlock*> &code); Function *ExtractCodeRegion(const std::vector<BasicBlock*> &code);
@ -76,6 +78,7 @@ namespace {
void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs, void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs,
BasicBlock *newHeader, BasicBlock *newHeader,
BasicBlock *newRootNode) { BasicBlock *newRootNode) {
std::set<BasicBlock*> ExitBlocks;
for (std::set<BasicBlock*>::const_iterator ci = BlocksToExtract.begin(), for (std::set<BasicBlock*>::const_iterator ci = BlocksToExtract.begin(),
ce = BlocksToExtract.end(); ci != ce; ++ci) { ce = BlocksToExtract.end(); ci != ce; ++ci) {
BasicBlock *BB = *ci; BasicBlock *BB = *ci;
@ -116,7 +119,14 @@ void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs,
break; break;
} }
} // for: insts } // for: insts
TerminatorInst *TI = BB->getTerminator();
for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
if (!BlocksToExtract.count(TI->getSuccessor(i)))
ExitBlocks.insert(TI->getSuccessor(i));
} // for: basic blocks } // for: basic blocks
NumExitBlocks = ExitBlocks.size();
} }
/// constructFunction - make a function based on inputs and outputs, as follows: /// constructFunction - make a function based on inputs and outputs, as follows:
@ -133,7 +143,13 @@ Function *CodeExtractor::constructFunction(const Values &inputs,
DEBUG(std::cerr << "outputs: " << outputs.size() << "\n"); DEBUG(std::cerr << "outputs: " << outputs.size() << "\n");
// This function returns unsigned, outputs will go back by reference. // This function returns unsigned, outputs will go back by reference.
Type *retTy = Type::UShortTy; switch (NumExitBlocks) {
case 0:
case 1: RetTy = Type::VoidTy; break;
case 2: RetTy = Type::BoolTy; break;
default: RetTy = Type::UShortTy; break;
}
std::vector<const Type*> paramTy; std::vector<const Type*> paramTy;
// Add the types of the input values to the function's argument list // Add the types of the input values to the function's argument list
@ -154,7 +170,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs,
paramTy.push_back(PointerType::get((*I)->getType())); paramTy.push_back(PointerType::get((*I)->getType()));
} }
DEBUG(std::cerr << "Function type: " << retTy << " f("); DEBUG(std::cerr << "Function type: " << RetTy << " f(");
DEBUG(for (std::vector<const Type*>::iterator i = paramTy.begin(), DEBUG(for (std::vector<const Type*>::iterator i = paramTy.begin(),
e = paramTy.end(); i != e; ++i) std::cerr << *i << ", "); e = paramTy.end(); i != e; ++i) std::cerr << *i << ", ");
DEBUG(std::cerr << ")\n"); DEBUG(std::cerr << ")\n");
@ -164,7 +180,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs,
paramTy.clear(); paramTy.clear();
paramTy.push_back(StructPtr); paramTy.push_back(StructPtr);
} }
const FunctionType *funcType = FunctionType::get(retTy, paramTy, false); const FunctionType *funcType = FunctionType::get(RetTy, paramTy, false);
// Create the new function // Create the new function
Function *newFunction = new Function(funcType, Function *newFunction = new Function(funcType,
@ -296,7 +312,8 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
} }
// Emit the call to the function // Emit the call to the function
CallInst *call = new CallInst(newFunction, params, "targetBlock"); CallInst *call = new CallInst(newFunction, params,
NumExitBlocks > 1 ? "targetBlock": "");
codeReplacer->getInstList().push_back(call); codeReplacer->getInstList().push_back(call);
Function::aiterator OutputArgBegin = newFunction->abegin(); Function::aiterator OutputArgBegin = newFunction->abegin();
@ -330,7 +347,9 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
} }
// Now we can emit a switch statement using the call as a value. // Now we can emit a switch statement using the call as a value.
SwitchInst *TheSwitch = new SwitchInst(call, codeReplacer, codeReplacer); SwitchInst *TheSwitch =
new SwitchInst(ConstantUInt::getNullValue(Type::UShortTy),
codeReplacer, codeReplacer);
// Since there may be multiple exits from the original region, make the new // Since there may be multiple exits from the original region, make the new
// function return an unsigned, switch on that number. This loop iterates // function return an unsigned, switch on that number. This loop iterates
@ -353,12 +372,25 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
// destination, create one now! // destination, create one now!
NewTarget = new BasicBlock(OldTarget->getName() + ".exitStub", NewTarget = new BasicBlock(OldTarget->getName() + ".exitStub",
newFunction); newFunction);
unsigned SuccNum = switchVal++;
Value *brVal = 0;
switch (NumExitBlocks) {
case 0:
case 1: break; // No value needed.
case 2: // Conditional branch, return a bool
brVal = SuccNum ? ConstantBool::False : ConstantBool::True;
break;
default:
brVal = ConstantUInt::get(Type::UShortTy, SuccNum);
break;
}
ConstantUInt *brVal = ConstantUInt::get(Type::UShortTy, switchVal++);
ReturnInst *NTRet = new ReturnInst(brVal, NewTarget); ReturnInst *NTRet = new ReturnInst(brVal, NewTarget);
// Update the switch instruction. // Update the switch instruction.
TheSwitch->addCase(brVal, OldTarget); TheSwitch->addCase(ConstantUInt::get(Type::UShortTy, SuccNum),
OldTarget);
// Restore values just before we exit // Restore values just before we exit
Function::aiterator OAI = OutputArgBegin; Function::aiterator OAI = OutputArgBegin;
@ -391,20 +423,8 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
} }
// Now that we've done the deed, simplify the switch instruction. // Now that we've done the deed, simplify the switch instruction.
unsigned NumSuccs = TheSwitch->getNumSuccessors(); switch (NumExitBlocks) {
if (NumSuccs > 1) { case 0:
if (NumSuccs-1 == 1) {
// Only a single destination, change the switch into an unconditional
// branch.
new BranchInst(TheSwitch->getSuccessor(1), TheSwitch);
TheSwitch->getParent()->getInstList().erase(TheSwitch);
} else {
// Otherwise, make the default destination of the switch instruction be
// one of the other successors.
TheSwitch->setSuccessor(0, TheSwitch->getSuccessor(NumSuccs-1));
TheSwitch->removeCase(NumSuccs-1); // Remove redundant case
}
} else {
// There is only 1 successor (the block containing the switch itself), which // There is only 1 successor (the block containing the switch itself), which
// means that previously this was the last part of the function, and hence // means that previously this was the last part of the function, and hence
// this should be rewritten as a `ret' // this should be rewritten as a `ret'
@ -420,6 +440,25 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
new ReturnInst(0, TheSwitch); new ReturnInst(0, TheSwitch);
TheSwitch->getParent()->getInstList().erase(TheSwitch); TheSwitch->getParent()->getInstList().erase(TheSwitch);
break;
case 1:
// Only a single destination, change the switch into an unconditional
// branch.
new BranchInst(TheSwitch->getSuccessor(1), TheSwitch);
TheSwitch->getParent()->getInstList().erase(TheSwitch);
break;
case 2:
new BranchInst(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
call, TheSwitch);
TheSwitch->getParent()->getInstList().erase(TheSwitch);
break;
default:
// Otherwise, make the default destination of the switch instruction be one
// of the other successors.
TheSwitch->setOperand(0, call);
TheSwitch->setSuccessor(0, TheSwitch->getSuccessor(NumExitBlocks));
TheSwitch->removeCase(NumExitBlocks); // Remove redundant case
break;
} }
} }