diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index 84b149cd02e..b132b563031 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -44,9 +44,11 @@ namespace { std::set BlocksToExtract; DominatorSet *DS; bool AggregateArgs; + unsigned NumExitBlocks; + const Type *RetTy; public: CodeExtractor(DominatorSet *ds = 0, bool AggArgs = false) - : DS(ds), AggregateArgs(AggregateArgsOpt) {} + : DS(ds), AggregateArgs(AggregateArgsOpt), NumExitBlocks(~0U) {} Function *ExtractCodeRegion(const std::vector &code); @@ -76,6 +78,7 @@ namespace { void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs, BasicBlock *newHeader, BasicBlock *newRootNode) { + std::set ExitBlocks; for (std::set::const_iterator ci = BlocksToExtract.begin(), ce = BlocksToExtract.end(); ci != ce; ++ci) { BasicBlock *BB = *ci; @@ -116,7 +119,14 @@ void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs, break; } } // for: insts + + TerminatorInst *TI = BB->getTerminator(); + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + if (!BlocksToExtract.count(TI->getSuccessor(i))) + ExitBlocks.insert(TI->getSuccessor(i)); } // for: basic blocks + + NumExitBlocks = ExitBlocks.size(); } /// constructFunction - make a function based on inputs and outputs, as follows: @@ -133,7 +143,13 @@ Function *CodeExtractor::constructFunction(const Values &inputs, DEBUG(std::cerr << "outputs: " << outputs.size() << "\n"); // This function returns unsigned, outputs will go back by reference. - Type *retTy = Type::UShortTy; + switch (NumExitBlocks) { + case 0: + case 1: RetTy = Type::VoidTy; break; + case 2: RetTy = Type::BoolTy; break; + default: RetTy = Type::UShortTy; break; + } + std::vector paramTy; // Add the types of the input values to the function's argument list @@ -154,7 +170,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs, paramTy.push_back(PointerType::get((*I)->getType())); } - DEBUG(std::cerr << "Function type: " << retTy << " f("); + DEBUG(std::cerr << "Function type: " << RetTy << " f("); DEBUG(for (std::vector::iterator i = paramTy.begin(), e = paramTy.end(); i != e; ++i) std::cerr << *i << ", "); DEBUG(std::cerr << ")\n"); @@ -164,7 +180,7 @@ Function *CodeExtractor::constructFunction(const Values &inputs, paramTy.clear(); paramTy.push_back(StructPtr); } - const FunctionType *funcType = FunctionType::get(retTy, paramTy, false); + const FunctionType *funcType = FunctionType::get(RetTy, paramTy, false); // Create the new function Function *newFunction = new Function(funcType, @@ -296,7 +312,8 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } // Emit the call to the function - CallInst *call = new CallInst(newFunction, params, "targetBlock"); + CallInst *call = new CallInst(newFunction, params, + NumExitBlocks > 1 ? "targetBlock": ""); codeReplacer->getInstList().push_back(call); Function::aiterator OutputArgBegin = newFunction->abegin(); @@ -330,7 +347,9 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } // Now we can emit a switch statement using the call as a value. - SwitchInst *TheSwitch = new SwitchInst(call, codeReplacer, codeReplacer); + SwitchInst *TheSwitch = + new SwitchInst(ConstantUInt::getNullValue(Type::UShortTy), + codeReplacer, codeReplacer); // Since there may be multiple exits from the original region, make the new // function return an unsigned, switch on that number. This loop iterates @@ -353,12 +372,25 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, // destination, create one now! NewTarget = new BasicBlock(OldTarget->getName() + ".exitStub", newFunction); + unsigned SuccNum = switchVal++; + + Value *brVal = 0; + switch (NumExitBlocks) { + case 0: + case 1: break; // No value needed. + case 2: // Conditional branch, return a bool + brVal = SuccNum ? ConstantBool::False : ConstantBool::True; + break; + default: + brVal = ConstantUInt::get(Type::UShortTy, SuccNum); + break; + } - ConstantUInt *brVal = ConstantUInt::get(Type::UShortTy, switchVal++); ReturnInst *NTRet = new ReturnInst(brVal, NewTarget); // Update the switch instruction. - TheSwitch->addCase(brVal, OldTarget); + TheSwitch->addCase(ConstantUInt::get(Type::UShortTy, SuccNum), + OldTarget); // Restore values just before we exit Function::aiterator OAI = OutputArgBegin; @@ -391,20 +423,8 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, } // Now that we've done the deed, simplify the switch instruction. - unsigned NumSuccs = TheSwitch->getNumSuccessors(); - if (NumSuccs > 1) { - if (NumSuccs-1 == 1) { - // Only a single destination, change the switch into an unconditional - // branch. - new BranchInst(TheSwitch->getSuccessor(1), TheSwitch); - TheSwitch->getParent()->getInstList().erase(TheSwitch); - } else { - // Otherwise, make the default destination of the switch instruction be - // one of the other successors. - TheSwitch->setSuccessor(0, TheSwitch->getSuccessor(NumSuccs-1)); - TheSwitch->removeCase(NumSuccs-1); // Remove redundant case - } - } else { + switch (NumExitBlocks) { + case 0: // There is only 1 successor (the block containing the switch itself), which // means that previously this was the last part of the function, and hence // this should be rewritten as a `ret' @@ -420,6 +440,25 @@ CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, new ReturnInst(0, TheSwitch); TheSwitch->getParent()->getInstList().erase(TheSwitch); + break; + case 1: + // Only a single destination, change the switch into an unconditional + // branch. + new BranchInst(TheSwitch->getSuccessor(1), TheSwitch); + TheSwitch->getParent()->getInstList().erase(TheSwitch); + break; + case 2: + new BranchInst(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), + call, TheSwitch); + TheSwitch->getParent()->getInstList().erase(TheSwitch); + break; + default: + // Otherwise, make the default destination of the switch instruction be one + // of the other successors. + TheSwitch->setOperand(0, call); + TheSwitch->setSuccessor(0, TheSwitch->getSuccessor(NumExitBlocks)); + TheSwitch->removeCase(NumExitBlocks); // Remove redundant case + break; } }