From f7f88095a32d1ac5bc7778204fd9a37a9fb8082c Mon Sep 17 00:00:00 2001 From: Tim Northover Date: Mon, 1 Dec 2014 17:46:39 +0000 Subject: [PATCH] ARM: lower tail calls correctly when using GHC calling convention. Patch by Ben Gamari. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@223055 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/ARM/ARMFrameLowering.cpp | 91 ++++++++++++++---------- lib/Target/ARM/ARMFrameLowering.h | 2 + test/CodeGen/ARM/ghc-tcreturn-lowered.ll | 21 ++++++ 3 files changed, 76 insertions(+), 38 deletions(-) create mode 100644 test/CodeGen/ARM/ghc-tcreturn-lowered.ll diff --git a/lib/Target/ARM/ARMFrameLowering.cpp b/lib/Target/ARM/ARMFrameLowering.cpp index 80add7a62d3..4f3012e9aaf 100644 --- a/lib/Target/ARM/ARMFrameLowering.cpp +++ b/lib/Target/ARM/ARMFrameLowering.cpp @@ -612,11 +612,59 @@ void ARMFrameLowering::emitPrologue(MachineFunction &MF) const { AFI->setShouldRestoreSPFromFP(true); } +// Resolve TCReturn pseudo-instruction +void ARMFrameLowering::fixTCReturn(MachineFunction &MF, + MachineBasicBlock &MBB) const { + MachineBasicBlock::iterator MBBI = MBB.getLastNonDebugInstr(); + assert(MBBI->isReturn() && "Can only insert epilog into returning blocks"); + unsigned RetOpcode = MBBI->getOpcode(); + DebugLoc dl = MBBI->getDebugLoc(); + const ARMBaseInstrInfo &TII = + *static_cast(MF.getSubtarget().getInstrInfo()); + + if (!(RetOpcode == ARM::TCRETURNdi || RetOpcode == ARM::TCRETURNri)) + return; + + // Tail call return: adjust the stack pointer and jump to callee. + MBBI = MBB.getLastNonDebugInstr(); + MachineOperand &JumpTarget = MBBI->getOperand(0); + + // Jump to label or value in register. + if (RetOpcode == ARM::TCRETURNdi) { + unsigned TCOpcode = STI.isThumb() ? + (STI.isTargetMachO() ? ARM::tTAILJMPd : ARM::tTAILJMPdND) : + ARM::TAILJMPd; + MachineInstrBuilder MIB = BuildMI(MBB, MBBI, dl, TII.get(TCOpcode)); + if (JumpTarget.isGlobal()) + MIB.addGlobalAddress(JumpTarget.getGlobal(), JumpTarget.getOffset(), + JumpTarget.getTargetFlags()); + else { + assert(JumpTarget.isSymbol()); + MIB.addExternalSymbol(JumpTarget.getSymbolName(), + JumpTarget.getTargetFlags()); + } + + // Add the default predicate in Thumb mode. + if (STI.isThumb()) MIB.addImm(ARMCC::AL).addReg(0); + } else if (RetOpcode == ARM::TCRETURNri) { + BuildMI(MBB, MBBI, dl, + TII.get(STI.isThumb() ? ARM::tTAILJMPr : ARM::TAILJMPr)). + addReg(JumpTarget.getReg(), RegState::Kill); + } + + MachineInstr *NewMI = std::prev(MBBI); + for (unsigned i = 1, e = MBBI->getNumOperands(); i != e; ++i) + NewMI->addOperand(MBBI->getOperand(i)); + + // Delete the pseudo instruction TCRETURN. + MBB.erase(MBBI); + MBBI = NewMI; +} + void ARMFrameLowering::emitEpilogue(MachineFunction &MF, MachineBasicBlock &MBB) const { MachineBasicBlock::iterator MBBI = MBB.getLastNonDebugInstr(); assert(MBBI->isReturn() && "Can only insert epilog into returning blocks"); - unsigned RetOpcode = MBBI->getOpcode(); DebugLoc dl = MBBI->getDebugLoc(); MachineFrameInfo *MFI = MF.getFrameInfo(); ARMFunctionInfo *AFI = MF.getInfo(); @@ -637,8 +685,10 @@ void ARMFrameLowering::emitEpilogue(MachineFunction &MF, // All calls are tail calls in GHC calling conv, and functions have no // prologue/epilogue. - if (MF.getFunction()->getCallingConv() == CallingConv::GHC) + if (MF.getFunction()->getCallingConv() == CallingConv::GHC) { + fixTCReturn(MF, MBB); return; + } if (!AFI->hasStackFrame()) { if (NumBytes - ArgRegsSaveSize != 0) @@ -717,42 +767,7 @@ void ARMFrameLowering::emitEpilogue(MachineFunction &MF, if (AFI->getGPRCalleeSavedArea1Size()) MBBI++; } - if (RetOpcode == ARM::TCRETURNdi || RetOpcode == ARM::TCRETURNri) { - // Tail call return: adjust the stack pointer and jump to callee. - MBBI = MBB.getLastNonDebugInstr(); - MachineOperand &JumpTarget = MBBI->getOperand(0); - - // Jump to label or value in register. - if (RetOpcode == ARM::TCRETURNdi) { - unsigned TCOpcode = STI.isThumb() ? - (STI.isTargetMachO() ? ARM::tTAILJMPd : ARM::tTAILJMPdND) : - ARM::TAILJMPd; - MachineInstrBuilder MIB = BuildMI(MBB, MBBI, dl, TII.get(TCOpcode)); - if (JumpTarget.isGlobal()) - MIB.addGlobalAddress(JumpTarget.getGlobal(), JumpTarget.getOffset(), - JumpTarget.getTargetFlags()); - else { - assert(JumpTarget.isSymbol()); - MIB.addExternalSymbol(JumpTarget.getSymbolName(), - JumpTarget.getTargetFlags()); - } - - // Add the default predicate in Thumb mode. - if (STI.isThumb()) MIB.addImm(ARMCC::AL).addReg(0); - } else if (RetOpcode == ARM::TCRETURNri) { - BuildMI(MBB, MBBI, dl, - TII.get(STI.isThumb() ? ARM::tTAILJMPr : ARM::TAILJMPr)). - addReg(JumpTarget.getReg(), RegState::Kill); - } - - MachineInstr *NewMI = std::prev(MBBI); - for (unsigned i = 1, e = MBBI->getNumOperands(); i != e; ++i) - NewMI->addOperand(MBBI->getOperand(i)); - - // Delete the pseudo instruction TCRETURN. - MBB.erase(MBBI); - MBBI = NewMI; - } + fixTCReturn(MF, MBB); if (ArgRegsSaveSize) emitSPUpdate(isARM, MBB, MBBI, dl, TII, ArgRegsSaveSize); diff --git a/lib/Target/ARM/ARMFrameLowering.h b/lib/Target/ARM/ARMFrameLowering.h index a83b7730ff9..b7be43642ad 100644 --- a/lib/Target/ARM/ARMFrameLowering.h +++ b/lib/Target/ARM/ARMFrameLowering.h @@ -31,6 +31,8 @@ public: void emitPrologue(MachineFunction &MF) const override; void emitEpilogue(MachineFunction &MF, MachineBasicBlock &MBB) const override; + void fixTCReturn(MachineFunction &MF, MachineBasicBlock &MBB) const; + bool spillCalleeSavedRegisters(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, const std::vector &CSI, diff --git a/test/CodeGen/ARM/ghc-tcreturn-lowered.ll b/test/CodeGen/ARM/ghc-tcreturn-lowered.ll new file mode 100644 index 00000000000..6d2564ba1ab --- /dev/null +++ b/test/CodeGen/ARM/ghc-tcreturn-lowered.ll @@ -0,0 +1,21 @@ +; RUN: llc -mtriple=thumbv7-eabi -o - %s | FileCheck %s + +declare cc 10 void @g() + +define cc 10 void @test_direct_tail() { +; CHECK-LABEL: test_direct_tail: +; CHECK: b g + + tail call cc10 void @g() + ret void +} + +@ind_func = global void()* zeroinitializer + +define cc 10 void @test_indirect_tail() { +; CHECK-LABEL: test_indirect_tail: +; CHECK: bx {{r[0-9]+}} + %func = load void()** @ind_func + tail call cc10 void()* %func() + ret void +}