allow non-device function calls in PTX when natively handling device-side printf

llvm-svn: 144388
This commit is contained in:
Dan Bailey 2011-11-11 14:45:12 +00:00
parent 6c29989135
commit ad6c209a79
5 changed files with 154 additions and 23 deletions

View File

@ -96,9 +96,23 @@ void PTXInstPrinter::printCall(const MCInst *MI, raw_ostream &O) {
O << "), ";
}
O << *(MI->getOperand(Index++).getExpr()) << ", (";
const MCExpr* Expr = MI->getOperand(Index++).getExpr();
unsigned NumArgs = MI->getOperand(Index++).getImm();
// if the function call is to printf or puts, change to vprintf
if (const MCSymbolRefExpr *SymRefExpr = dyn_cast<MCSymbolRefExpr>(Expr)) {
const MCSymbol &Sym = SymRefExpr->getSymbol();
if (Sym.getName() == "printf" || Sym.getName() == "puts") {
O << "vprintf";
} else {
O << Sym.getName();
}
} else {
O << *Expr;
}
O << ", (";
if (NumArgs > 0) {
printOperand(MI, Index++, O);
for (unsigned i = 1; i < NumArgs; ++i) {

View File

@ -165,6 +165,11 @@ void PTXAsmPrinter::EmitStartOfAsmFile(Module &M)
OutStreamer.AddBlankLine();
// declare external functions
for (Module::const_iterator i = M.begin(), e = M.end();
i != e; ++i)
EmitFunctionDeclaration(i);
// declare global variables
for (Module::const_global_iterator i = M.global_begin(), e = M.global_end();
i != e; ++i)
@ -454,6 +459,31 @@ void PTXAsmPrinter::EmitFunctionEntryLabel() {
OutStreamer.EmitRawText(os.str());
}
void PTXAsmPrinter::EmitFunctionDeclaration(const Function* func)
{
const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
std::string decl = "";
// hard-coded emission of extern vprintf function
if (func->getName() == "printf" || func->getName() == "puts") {
decl += ".extern .func (.param .b32 __param_1) vprintf (.param .b";
if (ST.is64Bit())
decl += "64";
else
decl += "32";
decl += " __param_2, .param .b";
if (ST.is64Bit())
decl += "64";
else
decl += "32";
decl += " __param_3)\n";
}
OutStreamer.EmitRawText(Twine(decl));
}
unsigned PTXAsmPrinter::GetOrCreateSourceID(StringRef FileName,
StringRef DirName) {
// If FE did not provide a file name, then assume stdin.

View File

@ -47,7 +47,7 @@ public:
private:
void EmitVariableDeclaration(const GlobalVariable *gv);
void EmitFunctionDeclaration();
void EmitFunctionDeclaration(const Function* func);
StringMap<unsigned> SourceIdMap;
}; // class PTXAsmPrinter

View File

@ -20,6 +20,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
@ -352,40 +353,101 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
SmallVectorImpl<SDValue> &InVals) const {
MachineFunction& MF = DAG.getMachineFunction();
PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
PTXParamManager &PM = MFI->getParamManager();
PTXMachineFunctionInfo *PTXMFI = MF.getInfo<PTXMachineFunctionInfo>();
PTXParamManager &PM = PTXMFI->getParamManager();
MachineFrameInfo *MFI = MF.getFrameInfo();
assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
"Calls are not handled for the target device");
// Identify the callee function
const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
const Function *function = cast<Function>(GV);
// allow non-device calls only for printf
bool isPrintf = function->getName() == "printf" || function->getName() == "puts";
assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) &&
"PTX function calls must be to PTX device functions");
unsigned outSize = isPrintf ? 2 : Outs.size();
std::vector<SDValue> Ops;
// The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
Ops.resize(Outs.size() + Ins.size() + 4);
Ops.resize(outSize + Ins.size() + 4);
Ops[0] = Chain;
// Identify the callee function
const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device &&
"PTX function calls must be to PTX device functions");
Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
Ops[Ins.size()+2] = Callee;
// Generate STORE_PARAM nodes for each function argument. In PTX, function
// arguments are explicitly stored into .param variables and passed as
// arguments. There is no register/stack-based calling convention in PTX.
Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32);
for (unsigned i = 0; i != OutVals.size(); ++i) {
unsigned Size = OutVals[i].getValueType().getSizeInBits();
unsigned Param = PM.addLocalParam(Size);
const std::string &ParamName = PM.getParamName(Param);
SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
MVT::Other);
// #Outs
Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32);
if (isPrintf) {
// first argument is the address of the global string variable in memory
unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits());
SDValue ParamValue0 = DAG.getTargetExternalSymbol(PM.getParamName(Param0).c_str(),
MVT::Other);
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
ParamValue, OutVals[i]);
Ops[i+Ins.size()+4] = ParamValue;
}
ParamValue0, OutVals[0]);
Ops[Ins.size()+4] = ParamValue0;
// alignment is the maximum size of all the arguments
unsigned alignment = 0;
for (unsigned i = 1; i < OutVals.size(); ++i) {
alignment = std::max(alignment,
OutVals[i].getValueType().getSizeInBits());
}
// size is the alignment multiplied by the number of arguments
unsigned size = alignment * (OutVals.size() - 1);
// second argument is the address of the stack object (unless no arguments)
unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits());
SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(),
MVT::Other);
Ops[Ins.size()+5] = ParamValue1;
if (size > 0)
{
// create a local stack object to store the arguments
unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false);
SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy());
// store each of the arguments to the stack in turn
for (unsigned int i = 1; i != OutVals.size(); i++) {
SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy()));
Chain = DAG.getStore(Chain, dl, OutVals[i], FrameAddr,
MachinePointerInfo(),
false, false, 0);
}
// copy the address of the local frame index to get the address in non-local space
SDValue genericAddr = DAG.getNode(PTXISD::COPY_ADDRESS, dl, getPointerTy(), FrameIndex);
// store this address in the second argument
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, ParamValue1, genericAddr);
}
}
else
{
// Generate STORE_PARAM nodes for each function argument. In PTX, function
// arguments are explicitly stored into .param variables and passed as
// arguments. There is no register/stack-based calling convention in PTX.
for (unsigned i = 0; i != OutVals.size(); ++i) {
unsigned Size = OutVals[i].getValueType().getSizeInBits();
unsigned Param = PM.addLocalParam(Size);
const std::string &ParamName = PM.getParamName(Param);
SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
MVT::Other);
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
ParamValue, OutVals[i]);
Ops[i+Ins.size()+4] = ParamValue;
}
}
std::vector<SDValue> InParams;
// Generate list of .param variables to hold the return value(s).

View File

@ -0,0 +1,25 @@
; RUN: llc < %s -march=ptx64 -mattr=+ptx20,+sm20 | FileCheck %s
declare i32 @printf(i8*, ...)
@str = private unnamed_addr constant [6 x i8] c"test\0A\00"
define ptx_device void @t1_printf() {
; CHECK: mov.u64 %rd{{[0-9]+}}, $L__str;
; CHECK: call.uni (__localparam_{{[0-9]+}}), vprintf, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}});
; CHECK: ret;
%1 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([6 x i8]* @str, i64 0, i64 0))
ret void
}
@str2 = private unnamed_addr constant [11 x i8] c"test = %f\0A\00"
define ptx_device void @t2_printf() {
; CHECK: .local .align 8 .b8 __local{{[0-9]+}}[{{[0-9]+}}];
; CHECK: mov.u64 %rd{{[0-9]+}}, $L__str2;
; CHECK: cvta.local.u64 %rd{{[0-9]+}}, __local{{[0-9+]}};
; CHECK: call.uni (__localparam_{{[0-9]+}}), vprintf, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}});
; CHECK: ret;
%1 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([11 x i8]* @str2, i64 0, i64 0), double 0x3FF3333340000000)
ret void
}