Add 'musttail' marker to call instructions

This is similar to the 'tail' marker, except that it guarantees that
tail call optimization will occur.  It also comes with convervative IR
verification rules that ensure that tail call optimization is possible.

Reviewers: nicholas

Differential Revision: http://llvm-reviews.chandlerc.com/D3240

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@207143 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Reid Kleckner 2014-04-24 20:14:34 +00:00
parent 870200a833
commit 710c1a449d
21 changed files with 334 additions and 30 deletions

View File

@ -6161,7 +6161,7 @@ Syntax:
::
<result> = [tail] call [cconv] [ret attrs] <ty> [<fnty>*] <fnptrval>(<function args>) [fn attrs]
<result> = [tail | musttail] call [cconv] [ret attrs] <ty> [<fnty>*] <fnptrval>(<function args>) [fn attrs]
Overview:
"""""""""
@ -6173,17 +6173,34 @@ Arguments:
This instruction requires several arguments:
#. The optional "tail" marker indicates that the callee function does
not access any allocas or varargs in the caller. Note that calls may
be marked "tail" even if they do not occur before a
:ref:`ret <i_ret>` instruction. If the "tail" marker is present, the
function call is eligible for tail call optimization, but `might not
in fact be optimized into a jump <CodeGenerator.html#tailcallopt>`_.
The code generator may optimize calls marked "tail" with either 1)
automatic `sibling call
optimization <CodeGenerator.html#sibcallopt>`_ when the caller and
callee have matching signatures, or 2) forced tail call optimization
when the following extra requirements are met:
#. The optional ``tail`` and ``musttail`` markers indicate that the optimizers
should perform tail call optimization. The ``tail`` marker is a hint that
`can be ignored <CodeGenerator.html#sibcallopt>`_. The ``musttail`` marker
means that the call must be tail call optimized in order for the program to
be correct. The ``musttail`` marker provides these guarantees:
#. The call will not cause unbounded stack growth if it is part of a
recursive cycle in the call graph.
#. Arguments with the :ref:`inalloca <attr_inalloca>` attribute are
forwarded in place.
Both markers imply that the callee does not access allocas or varargs from
the caller. Calls marked ``musttail`` must obey the following additional
rules:
- The call must immediately precede a :ref:`ret <i_ret>` instruction,
or a pointer bitcast followed by a ret instruction.
- The ret instruction must return the (possibly bitcasted) value
produced by the call or void.
- The caller and callee prototypes must match. Pointer types of
parameters or return types may differ in pointee type, but not
in address space.
- The calling conventions of the caller and callee must match.
- All ABI-impacting function attributes, such as sret, byval, inreg,
returned, and inalloca, must match.
Tail call optimization for calls marked ``tail`` is guaranteed to occur if
the following conditions are met:
- Caller and callee both have the calling convention ``fastcc``.
- The call is in tail position (ret immediately follows call and ret

View File

@ -160,6 +160,15 @@ public:
///
FunTy *getCaller() const { return (*this)->getParent()->getParent(); }
/// \brief Tests if this is a tail call. Only a CallInst can be a tail call.
bool isTailCall() const { return isCall() && cast<CallInst>->isTailCall(); }
/// \brief Tests if this call site must be tail call optimized. Only a
/// CallInst can be tail call optimized.
bool isMustTailCall() const {
return isCall() && cast<CallInst>(getInstruction())->isMustTailCall();
}
#define CALLSITE_DELEGATE_GETTER(METHOD) \
InstrTy *II = getInstruction(); \
return isCall() \

View File

@ -1279,10 +1279,24 @@ public:
~CallInst();
bool isTailCall() const { return getSubclassDataFromInstruction() & 1; }
// Note that 'musttail' implies 'tail'.
enum TailCallKind { TCK_None = 0, TCK_Tail = 1, TCK_MustTail = 2 };
TailCallKind getTailCallKind() const {
return TailCallKind(getSubclassDataFromInstruction() & 3);
}
bool isTailCall() const {
return (getSubclassDataFromInstruction() & 3) != TCK_None;
}
bool isMustTailCall() const {
return (getSubclassDataFromInstruction() & 3) == TCK_MustTail;
}
void setTailCall(bool isTC = true) {
setInstructionSubclassData((getSubclassDataFromInstruction() & ~1) |
unsigned(isTC));
setInstructionSubclassData((getSubclassDataFromInstruction() & ~3) |
unsigned(isTC ? TCK_Tail : TCK_None));
}
void setTailCallKind(TailCallKind TCK) {
setInstructionSubclassData((getSubclassDataFromInstruction() & ~3) |
unsigned(TCK));
}
/// Provide fast operand accessors
@ -1316,11 +1330,11 @@ public:
/// getCallingConv/setCallingConv - Get or set the calling convention of this
/// function call.
CallingConv::ID getCallingConv() const {
return static_cast<CallingConv::ID>(getSubclassDataFromInstruction() >> 1);
return static_cast<CallingConv::ID>(getSubclassDataFromInstruction() >> 2);
}
void setCallingConv(CallingConv::ID CC) {
setInstructionSubclassData((getSubclassDataFromInstruction() & 1) |
(static_cast<unsigned>(CC) << 1));
setInstructionSubclassData((getSubclassDataFromInstruction() & 3) |
(static_cast<unsigned>(CC) << 2));
}
/// getAttributes - Return the parameter attributes for this call.

View File

@ -512,6 +512,7 @@ lltok::Kind LLLexer::LexIdentifier() {
KEYWORD(null);
KEYWORD(to);
KEYWORD(tail);
KEYWORD(musttail);
KEYWORD(target);
KEYWORD(triple);
KEYWORD(unwind);

View File

@ -3367,8 +3367,10 @@ int LLParser::ParseInstruction(Instruction *&Inst, BasicBlock *BB,
case lltok::kw_shufflevector: return ParseShuffleVector(Inst, PFS);
case lltok::kw_phi: return ParsePHI(Inst, PFS);
case lltok::kw_landingpad: return ParseLandingPad(Inst, PFS);
case lltok::kw_call: return ParseCall(Inst, PFS, false);
case lltok::kw_tail: return ParseCall(Inst, PFS, true);
// Call.
case lltok::kw_call: return ParseCall(Inst, PFS, CallInst::TCK_None);
case lltok::kw_tail: return ParseCall(Inst, PFS, CallInst::TCK_Tail);
case lltok::kw_musttail: return ParseCall(Inst, PFS, CallInst::TCK_MustTail);
// Memory.
case lltok::kw_alloca: return ParseAlloc(Inst, PFS);
case lltok::kw_load: return ParseLoad(Inst, PFS);
@ -3984,10 +3986,14 @@ bool LLParser::ParseLandingPad(Instruction *&Inst, PerFunctionState &PFS) {
}
/// ParseCall
/// ::= 'tail'? 'call' OptionalCallingConv OptionalAttrs Type Value
/// ::= 'call' OptionalCallingConv OptionalAttrs Type Value
/// ParameterList OptionalAttrs
/// ::= 'tail' 'call' OptionalCallingConv OptionalAttrs Type Value
/// ParameterList OptionalAttrs
/// ::= 'musttail' 'call' OptionalCallingConv OptionalAttrs Type Value
/// ParameterList OptionalAttrs
bool LLParser::ParseCall(Instruction *&Inst, PerFunctionState &PFS,
bool isTail) {
CallInst::TailCallKind TCK) {
AttrBuilder RetAttrs, FnAttrs;
std::vector<unsigned> FwdRefAttrGrps;
LocTy BuiltinLoc;
@ -3998,7 +4004,8 @@ bool LLParser::ParseCall(Instruction *&Inst, PerFunctionState &PFS,
SmallVector<ParamInfo, 16> ArgList;
LocTy CallLoc = Lex.getLoc();
if ((isTail && ParseToken(lltok::kw_call, "expected 'tail call'")) ||
if ((TCK != CallInst::TCK_None &&
ParseToken(lltok::kw_call, "expected 'tail call'")) ||
ParseOptionalCallingConv(CC) ||
ParseOptionalReturnAttrs(RetAttrs) ||
ParseType(RetType, RetTypeLoc, true /*void allowed*/) ||
@ -4074,7 +4081,7 @@ bool LLParser::ParseCall(Instruction *&Inst, PerFunctionState &PFS,
AttributeSet PAL = AttributeSet::get(Context, Attrs);
CallInst *CI = CallInst::Create(Callee, Args);
CI->setTailCall(isTail);
CI->setTailCallKind(TCK);
CI->setCallingConv(CC);
CI->setAttributes(PAL);
ForwardRefAttrGroups[CI] = FwdRefAttrGrps;

View File

@ -372,6 +372,8 @@ namespace llvm {
bool ParseFunctionBody(Function &Fn);
bool ParseBasicBlock(PerFunctionState &PFS);
enum TailCallType { TCT_None, TCT_Tail, TCT_MustTail };
// Instruction Parsing. Each instruction parsing routine can return with a
// normal result, an error result, or return having eaten an extra comma.
enum InstResult { InstNormal = 0, InstError = 1, InstExtraComma = 2 };
@ -398,7 +400,8 @@ namespace llvm {
bool ParseShuffleVector(Instruction *&I, PerFunctionState &PFS);
int ParsePHI(Instruction *&I, PerFunctionState &PFS);
bool ParseLandingPad(Instruction *&I, PerFunctionState &PFS);
bool ParseCall(Instruction *&I, PerFunctionState &PFS, bool isTail);
bool ParseCall(Instruction *&I, PerFunctionState &PFS,
CallInst::TailCallKind IsTail);
int ParseAlloc(Instruction *&I, PerFunctionState &PFS);
int ParseLoad(Instruction *&I, PerFunctionState &PFS);
int ParseStore(Instruction *&I, PerFunctionState &PFS);

View File

@ -54,6 +54,7 @@ namespace lltok {
kw_undef, kw_null,
kw_to,
kw_tail,
kw_musttail,
kw_target,
kw_triple,
kw_unwind,

View File

@ -2994,8 +2994,13 @@ error_code BitcodeReader::ParseFunctionBody(Function *F) {
I = CallInst::Create(Callee, Args);
InstructionList.push_back(I);
cast<CallInst>(I)->setCallingConv(
static_cast<CallingConv::ID>(CCInfo>>1));
cast<CallInst>(I)->setTailCall(CCInfo & 1);
static_cast<CallingConv::ID>((~(1U << 14) & CCInfo) >> 1));
CallInst::TailCallKind TCK = CallInst::TCK_None;
if (CCInfo & 1)
TCK = CallInst::TCK_Tail;
if (CCInfo & (1 << 14))
TCK = CallInst::TCK_MustTail;
cast<CallInst>(I)->setTailCallKind(TCK);
cast<CallInst>(I)->setAttributes(PAL);
break;
}

View File

@ -1469,7 +1469,8 @@ static void WriteInstruction(const Instruction &I, unsigned InstID,
Code = bitc::FUNC_CODE_INST_CALL;
Vals.push_back(VE.getAttributeID(CI.getAttributes()));
Vals.push_back((CI.getCallingConv() << 1) | unsigned(CI.isTailCall()));
Vals.push_back((CI.getCallingConv() << 1) | unsigned(CI.isTailCall()) |
unsigned(CI.isMustTailCall()) << 14);
PushValueAndType(CI.getCalledValue(), InstID, Vals, VE); // Callee
// Emit value #'s for the fixed parameters.

View File

@ -1768,8 +1768,12 @@ void AssemblyWriter::printInstruction(const Instruction &I) {
Out << '%' << SlotNum << " = ";
}
if (isa<CallInst>(I) && cast<CallInst>(I).isTailCall())
Out << "tail ";
if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
if (CI->isMustTailCall())
Out << "musttail ";
else if (CI->isTailCall())
Out << "tail ";
}
// Print out the opcode...
Out << I.getOpcodeName();

View File

@ -301,6 +301,7 @@ private:
void visitLandingPadInst(LandingPadInst &LPI);
void VerifyCallSite(CallSite CS);
void verifyMustTailCall(CallInst &CI);
bool PerformTypeCheck(Intrinsic::ID ID, Function *F, Type *Ty, int VT,
unsigned ArgNo, std::string &Suffix);
bool VerifyIntrinsicType(Type *Ty, ArrayRef<Intrinsic::IITDescriptor> &Infos,
@ -1545,9 +1546,97 @@ void Verifier::VerifyCallSite(CallSite CS) {
visitInstruction(*I);
}
/// Two types are "congruent" if they are identical, or if they are both pointer
/// types with different pointee types and the same address space.
static bool isTypeCongruent(Type *L, Type *R) {
if (L == R)
return true;
PointerType *PL = dyn_cast<PointerType>(L);
PointerType *PR = dyn_cast<PointerType>(R);
if (!PL || !PR)
return false;
return PL->getAddressSpace() == PR->getAddressSpace();
}
void Verifier::verifyMustTailCall(CallInst &CI) {
Assert1(!CI.isInlineAsm(), "cannot use musttail call with inline asm", &CI);
// - The caller and callee prototypes must match. Pointer types of
// parameters or return types may differ in pointee type, but not
// address space.
Function *F = CI.getParent()->getParent();
auto GetFnTy = [](Value *V) {
return cast<FunctionType>(
cast<PointerType>(V->getType())->getElementType());
};
FunctionType *CallerTy = GetFnTy(F);
FunctionType *CalleeTy = GetFnTy(CI.getCalledValue());
Assert1(CallerTy->getNumParams() == CalleeTy->getNumParams(),
"cannot guarantee tail call due to mismatched parameter counts", &CI);
Assert1(CallerTy->isVarArg() == CalleeTy->isVarArg(),
"cannot guarantee tail call due to mismatched varargs", &CI);
Assert1(isTypeCongruent(CallerTy->getReturnType(), CalleeTy->getReturnType()),
"cannot guarantee tail call due to mismatched return types", &CI);
for (int I = 0, E = CallerTy->getNumParams(); I != E; ++I) {
Assert1(
isTypeCongruent(CallerTy->getParamType(I), CalleeTy->getParamType(I)),
"cannot guarantee tail call due to mismatched parameter types", &CI);
}
// - The calling conventions of the caller and callee must match.
Assert1(F->getCallingConv() == CI.getCallingConv(),
"cannot guarantee tail call due to mismatched calling conv", &CI);
// - All ABI-impacting function attributes, such as sret, byval, inreg,
// returned, and inalloca, must match.
static const Attribute::AttrKind ABIAttrs[] = {
Attribute::Alignment, Attribute::StructRet, Attribute::ByVal,
Attribute::InAlloca, Attribute::InReg, Attribute::Returned};
AttributeSet CallerAttrs = F->getAttributes();
AttributeSet CalleeAttrs = CI.getAttributes();
for (int I = 0, E = CallerTy->getNumParams(); I != E; ++I) {
AttrBuilder CallerABIAttrs;
AttrBuilder CalleeABIAttrs;
for (auto AK : ABIAttrs) {
if (CallerAttrs.hasAttribute(I + 1, AK))
CallerABIAttrs.addAttribute(AK);
if (CalleeAttrs.hasAttribute(I + 1, AK))
CalleeABIAttrs.addAttribute(AK);
}
Assert2(CallerABIAttrs == CalleeABIAttrs,
"cannot guarantee tail call due to mismatched ABI impacting "
"function attributes", &CI, CI.getOperand(I));
}
// - The call must immediately precede a :ref:`ret <i_ret>` instruction,
// or a pointer bitcast followed by a ret instruction.
// - The ret instruction must return the (possibly bitcasted) value
// produced by the call or void.
Value *RetVal = &CI;
Instruction *Next = CI.getNextNode();
// Handle the optional bitcast.
if (BitCastInst *BI = dyn_cast_or_null<BitCastInst>(Next)) {
Assert1(BI->getOperand(0) == RetVal,
"bitcast following musttail call must use the call", BI);
RetVal = BI;
Next = BI->getNextNode();
}
// Check the return.
ReturnInst *Ret = dyn_cast_or_null<ReturnInst>(Next);
Assert1(Ret, "musttail call must be precede a ret with an optional bitcast",
&CI);
Assert1(!Ret->getReturnValue() || Ret->getReturnValue() == RetVal,
"musttail call result must be returned", Ret);
}
void Verifier::visitCallInst(CallInst &CI) {
VerifyCallSite(&CI);
if (CI.isMustTailCall())
verifyMustTailCall(CI);
if (Function *F = CI.getCalledFunction())
if (Intrinsic::ID ID = (Intrinsic::ID)F->getIntrinsicID())
visitIntrinsicFunctionCall(ID, CI);

View File

@ -1524,6 +1524,10 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
IsVarArg, IsStructRet, MF.getFunction()->hasStructRetAttr(),
Outs, OutVals, Ins, DAG);
if (!IsTailCall && CLI.CS && CLI.CS->isMustTailCall())
report_fatal_error("failed to perform tail call elimination on a call "
"site marked musttail");
// A sibling call is one where we're under the usual C ABI and not planning
// to change that but can still do a tail call:
if (!TailCallOpt && IsTailCall)

View File

@ -1400,6 +1400,9 @@ ARMTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
isTailCall = IsEligibleForTailCallOptimization(Callee, CallConv,
isVarArg, isStructRet, MF.getFunction()->hasStructRetAttr(),
Outs, OutVals, Ins, DAG);
if (!isTailCall && CLI.CS && CLI.CS->isMustTailCall())
report_fatal_error("failed to perform tail call elimination on a call "
"site marked musttail");
// We don't support GuaranteedTailCallOpt for ARM, only automatically
// detected sibcalls.
if (isTailCall) {

View File

@ -1957,6 +1957,9 @@ SDValue ARM64TargetLowering::LowerCall(CallLoweringInfo &CLI,
IsTailCall = isEligibleForTailCallOptimization(
Callee, CallConv, IsVarArg, IsStructRet,
MF.getFunction()->hasStructRetAttr(), Outs, OutVals, Ins, DAG);
if (!IsTailCall && CLI.CS && CLI.CS->isMustTailCall())
report_fatal_error("failed to perform tail call elimination on a call "
"site marked musttail");
// We don't support GuaranteedTailCallOpt, only automatically
// detected sibcalls.
// FIXME: Re-evaluate. Is this true? Should it be true?

View File

@ -2339,6 +2339,10 @@ MipsTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
isEligibleForTailCallOptimization(MipsCCInfo, NextStackOffset,
*MF.getInfo<MipsFunctionInfo>());
if (!IsTailCall && CLI.CS && CLI.CS->isMustTailCall())
report_fatal_error("failed to perform tail call elimination on a call "
"site marked musttail");
if (IsTailCall)
++NumTailCalls;

View File

@ -3720,6 +3720,10 @@ PPCTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
isTailCall = IsEligibleForTailCallOptimization(Callee, CallConv, isVarArg,
Ins, DAG);
if (!isTailCall && CLI.CS && CLI.CS->isMustTailCall())
report_fatal_error("failed to perform tail call elimination on a call "
"site marked musttail");
if (PPCSubTarget.isSVR4ABI()) {
if (PPCSubTarget.isPPC64())
return LowerCall_64SVR4(Chain, Callee, CallConv, isVarArg,

View File

@ -2544,6 +2544,10 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
MF.getFunction()->hasStructRetAttr(), CLI.RetTy,
Outs, OutVals, Ins, DAG);
if (!isTailCall && CLI.CS && CLI.CS->isMustTailCall())
report_fatal_error("failed to perform tail call elimination on a call "
"site marked musttail");
// Sibcalls are automatically detected tailcalls which do not require
// ABI changes.
if (!MF.getTarget().Options.GuaranteedTailCallOpt && isTailCall)

17
test/Bitcode/tailcall.ll Normal file
View File

@ -0,0 +1,17 @@
; RUN: llvm-as < %s | llvm-dis | FileCheck %s
; Check that musttail and tail roundtrip.
declare cc8191 void @t1_callee()
define cc8191 void @t1() {
; CHECK: tail call cc8191 void @t1_callee()
tail call cc8191 void @t1_callee()
ret void
}
declare cc8191 void @t2_callee()
define cc8191 void @t2() {
; CHECK: musttail call cc8191 void @t2_callee()
musttail call cc8191 void @t2_callee()
ret void
}

View File

@ -0,0 +1,23 @@
; RUN: llc -march=x86 < %s | FileCheck %s
; FIXME: Eliminate this tail call at -O0, since musttail is a correctness
; requirement.
; RUN: not llc -march=x86 -O0 < %s
declare void @t1_callee(i8*)
define void @t1(i32* %a) {
; CHECK-LABEL: t1:
; CHECK: jmp {{_?}}t1_callee
%b = bitcast i32* %a to i8*
musttail call void @t1_callee(i8* %b)
ret void
}
declare i8* @t2_callee()
define i32* @t2() {
; CHECK-LABEL: t2:
; CHECK: jmp {{_?}}t2_callee
%v = musttail call i8* @t2_callee()
%w = bitcast i8* %v to i32*
ret i32* %w
}

View File

@ -0,0 +1,75 @@
; RUN: not llvm-as %s -o /dev/null 2>&1 | FileCheck %s
; Each musttail call should fail to validate.
declare x86_stdcallcc void @cc_mismatch_callee()
define void @cc_mismatch() {
; CHECK: mismatched calling conv
musttail call x86_stdcallcc void @cc_mismatch_callee()
ret void
}
declare void @more_parms_callee(i32)
define void @more_parms() {
; CHECK: mismatched parameter counts
musttail call void @more_parms_callee(i32 0)
ret void
}
declare void @mismatched_intty_callee(i8)
define void @mismatched_intty(i32) {
; CHECK: mismatched parameter types
musttail call void @mismatched_intty_callee(i8 0)
ret void
}
declare void @mismatched_vararg_callee(i8*, ...)
define void @mismatched_vararg(i8*) {
; CHECK: mismatched varargs
musttail call void (i8*, ...)* @mismatched_vararg_callee(i8* null)
ret void
}
; We would make this an implicit sret parameter, which would disturb the
; tail call.
declare { i32, i32, i32 } @mismatched_retty_callee(i32)
define void @mismatched_retty(i32) {
; CHECK: mismatched return types
musttail call { i32, i32, i32 } @mismatched_retty_callee(i32 0)
ret void
}
declare void @mismatched_byval_callee({ i32 }*)
define void @mismatched_byval({ i32 }* byval %a) {
; CHECK: mismatched ABI impacting function attributes
musttail call void @mismatched_byval_callee({ i32 }* %a)
ret void
}
declare void @mismatched_inreg_callee(i32 inreg)
define void @mismatched_inreg(i32 %a) {
; CHECK: mismatched ABI impacting function attributes
musttail call void @mismatched_inreg_callee(i32 inreg %a)
ret void
}
declare void @mismatched_sret_callee(i32* sret)
define void @mismatched_sret(i32* %a) {
; CHECK: mismatched ABI impacting function attributes
musttail call void @mismatched_sret_callee(i32* sret %a)
ret void
}
declare i32 @not_tail_pos_callee()
define i32 @not_tail_pos() {
; CHECK: musttail call must be precede a ret with an optional bitcast
%v = musttail call i32 @not_tail_pos_callee()
%w = add i32 %v, 1
ret i32 %w
}
define void @inline_asm() {
; CHECK: cannot use musttail call with inline asm
musttail call void asm "ret", ""()
ret void
}

View File

@ -0,0 +1,16 @@
; RUN: llvm-as %s -o /dev/null
; Should assemble without error.
declare void @similar_param_ptrty_callee(i8*)
define void @similar_param_ptrty(i32*) {
musttail call void @similar_param_ptrty_callee(i8* null)
ret void
}
declare i8* @similar_ret_ptrty_callee()
define i32* @similar_ret_ptrty() {
%v = musttail call i8* @similar_ret_ptrty_callee()
%w = bitcast i8* %v to i32*
ret i32* %w
}