mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-12-28 22:43:29 +00:00
allow non-device function calls in PTX when natively handling device-side printf
llvm-svn: 144388
This commit is contained in:
parent
6c29989135
commit
ad6c209a79
@ -96,9 +96,23 @@ void PTXInstPrinter::printCall(const MCInst *MI, raw_ostream &O) {
|
||||
O << "), ";
|
||||
}
|
||||
|
||||
O << *(MI->getOperand(Index++).getExpr()) << ", (";
|
||||
|
||||
const MCExpr* Expr = MI->getOperand(Index++).getExpr();
|
||||
unsigned NumArgs = MI->getOperand(Index++).getImm();
|
||||
|
||||
// if the function call is to printf or puts, change to vprintf
|
||||
if (const MCSymbolRefExpr *SymRefExpr = dyn_cast<MCSymbolRefExpr>(Expr)) {
|
||||
const MCSymbol &Sym = SymRefExpr->getSymbol();
|
||||
if (Sym.getName() == "printf" || Sym.getName() == "puts") {
|
||||
O << "vprintf";
|
||||
} else {
|
||||
O << Sym.getName();
|
||||
}
|
||||
} else {
|
||||
O << *Expr;
|
||||
}
|
||||
|
||||
O << ", (";
|
||||
|
||||
if (NumArgs > 0) {
|
||||
printOperand(MI, Index++, O);
|
||||
for (unsigned i = 1; i < NumArgs; ++i) {
|
||||
|
@ -165,6 +165,11 @@ void PTXAsmPrinter::EmitStartOfAsmFile(Module &M)
|
||||
|
||||
OutStreamer.AddBlankLine();
|
||||
|
||||
// declare external functions
|
||||
for (Module::const_iterator i = M.begin(), e = M.end();
|
||||
i != e; ++i)
|
||||
EmitFunctionDeclaration(i);
|
||||
|
||||
// declare global variables
|
||||
for (Module::const_global_iterator i = M.global_begin(), e = M.global_end();
|
||||
i != e; ++i)
|
||||
@ -454,6 +459,31 @@ void PTXAsmPrinter::EmitFunctionEntryLabel() {
|
||||
OutStreamer.EmitRawText(os.str());
|
||||
}
|
||||
|
||||
void PTXAsmPrinter::EmitFunctionDeclaration(const Function* func)
|
||||
{
|
||||
const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
|
||||
|
||||
std::string decl = "";
|
||||
|
||||
// hard-coded emission of extern vprintf function
|
||||
|
||||
if (func->getName() == "printf" || func->getName() == "puts") {
|
||||
decl += ".extern .func (.param .b32 __param_1) vprintf (.param .b";
|
||||
if (ST.is64Bit())
|
||||
decl += "64";
|
||||
else
|
||||
decl += "32";
|
||||
decl += " __param_2, .param .b";
|
||||
if (ST.is64Bit())
|
||||
decl += "64";
|
||||
else
|
||||
decl += "32";
|
||||
decl += " __param_3)\n";
|
||||
}
|
||||
|
||||
OutStreamer.EmitRawText(Twine(decl));
|
||||
}
|
||||
|
||||
unsigned PTXAsmPrinter::GetOrCreateSourceID(StringRef FileName,
|
||||
StringRef DirName) {
|
||||
// If FE did not provide a file name, then assume stdin.
|
||||
|
@ -47,7 +47,7 @@ public:
|
||||
|
||||
private:
|
||||
void EmitVariableDeclaration(const GlobalVariable *gv);
|
||||
void EmitFunctionDeclaration();
|
||||
void EmitFunctionDeclaration(const Function* func);
|
||||
|
||||
StringMap<unsigned> SourceIdMap;
|
||||
}; // class PTXAsmPrinter
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/CodeGen/CallingConvLower.h"
|
||||
#include "llvm/CodeGen/MachineFunction.h"
|
||||
#include "llvm/CodeGen/MachineFrameInfo.h"
|
||||
#include "llvm/CodeGen/MachineRegisterInfo.h"
|
||||
#include "llvm/CodeGen/SelectionDAG.h"
|
||||
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
|
||||
@ -352,40 +353,101 @@ PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
|
||||
SmallVectorImpl<SDValue> &InVals) const {
|
||||
|
||||
MachineFunction& MF = DAG.getMachineFunction();
|
||||
PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
|
||||
PTXParamManager &PM = MFI->getParamManager();
|
||||
|
||||
PTXMachineFunctionInfo *PTXMFI = MF.getInfo<PTXMachineFunctionInfo>();
|
||||
PTXParamManager &PM = PTXMFI->getParamManager();
|
||||
MachineFrameInfo *MFI = MF.getFrameInfo();
|
||||
|
||||
assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
|
||||
"Calls are not handled for the target device");
|
||||
|
||||
// Identify the callee function
|
||||
const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
|
||||
const Function *function = cast<Function>(GV);
|
||||
|
||||
// allow non-device calls only for printf
|
||||
bool isPrintf = function->getName() == "printf" || function->getName() == "puts";
|
||||
|
||||
assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) &&
|
||||
"PTX function calls must be to PTX device functions");
|
||||
|
||||
unsigned outSize = isPrintf ? 2 : Outs.size();
|
||||
|
||||
std::vector<SDValue> Ops;
|
||||
// The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
|
||||
Ops.resize(Outs.size() + Ins.size() + 4);
|
||||
Ops.resize(outSize + Ins.size() + 4);
|
||||
|
||||
Ops[0] = Chain;
|
||||
|
||||
// Identify the callee function
|
||||
const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
|
||||
assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device &&
|
||||
"PTX function calls must be to PTX device functions");
|
||||
Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
|
||||
Ops[Ins.size()+2] = Callee;
|
||||
|
||||
// Generate STORE_PARAM nodes for each function argument. In PTX, function
|
||||
// arguments are explicitly stored into .param variables and passed as
|
||||
// arguments. There is no register/stack-based calling convention in PTX.
|
||||
Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32);
|
||||
for (unsigned i = 0; i != OutVals.size(); ++i) {
|
||||
unsigned Size = OutVals[i].getValueType().getSizeInBits();
|
||||
unsigned Param = PM.addLocalParam(Size);
|
||||
const std::string &ParamName = PM.getParamName(Param);
|
||||
SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
|
||||
MVT::Other);
|
||||
// #Outs
|
||||
Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32);
|
||||
|
||||
if (isPrintf) {
|
||||
// first argument is the address of the global string variable in memory
|
||||
unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits());
|
||||
SDValue ParamValue0 = DAG.getTargetExternalSymbol(PM.getParamName(Param0).c_str(),
|
||||
MVT::Other);
|
||||
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
|
||||
ParamValue, OutVals[i]);
|
||||
Ops[i+Ins.size()+4] = ParamValue;
|
||||
}
|
||||
ParamValue0, OutVals[0]);
|
||||
Ops[Ins.size()+4] = ParamValue0;
|
||||
|
||||
// alignment is the maximum size of all the arguments
|
||||
unsigned alignment = 0;
|
||||
for (unsigned i = 1; i < OutVals.size(); ++i) {
|
||||
alignment = std::max(alignment,
|
||||
OutVals[i].getValueType().getSizeInBits());
|
||||
}
|
||||
|
||||
// size is the alignment multiplied by the number of arguments
|
||||
unsigned size = alignment * (OutVals.size() - 1);
|
||||
|
||||
// second argument is the address of the stack object (unless no arguments)
|
||||
unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits());
|
||||
SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(),
|
||||
MVT::Other);
|
||||
Ops[Ins.size()+5] = ParamValue1;
|
||||
|
||||
if (size > 0)
|
||||
{
|
||||
// create a local stack object to store the arguments
|
||||
unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false);
|
||||
SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy());
|
||||
|
||||
// store each of the arguments to the stack in turn
|
||||
for (unsigned int i = 1; i != OutVals.size(); i++) {
|
||||
SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy()));
|
||||
Chain = DAG.getStore(Chain, dl, OutVals[i], FrameAddr,
|
||||
MachinePointerInfo(),
|
||||
false, false, 0);
|
||||
}
|
||||
|
||||
// copy the address of the local frame index to get the address in non-local space
|
||||
SDValue genericAddr = DAG.getNode(PTXISD::COPY_ADDRESS, dl, getPointerTy(), FrameIndex);
|
||||
|
||||
// store this address in the second argument
|
||||
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, ParamValue1, genericAddr);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Generate STORE_PARAM nodes for each function argument. In PTX, function
|
||||
// arguments are explicitly stored into .param variables and passed as
|
||||
// arguments. There is no register/stack-based calling convention in PTX.
|
||||
for (unsigned i = 0; i != OutVals.size(); ++i) {
|
||||
unsigned Size = OutVals[i].getValueType().getSizeInBits();
|
||||
unsigned Param = PM.addLocalParam(Size);
|
||||
const std::string &ParamName = PM.getParamName(Param);
|
||||
SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
|
||||
MVT::Other);
|
||||
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
|
||||
ParamValue, OutVals[i]);
|
||||
Ops[i+Ins.size()+4] = ParamValue;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<SDValue> InParams;
|
||||
|
||||
// Generate list of .param variables to hold the return value(s).
|
||||
|
25
test/CodeGen/PTX/printf.ll
Normal file
25
test/CodeGen/PTX/printf.ll
Normal file
@ -0,0 +1,25 @@
|
||||
; RUN: llc < %s -march=ptx64 -mattr=+ptx20,+sm20 | FileCheck %s
|
||||
|
||||
declare i32 @printf(i8*, ...)
|
||||
|
||||
@str = private unnamed_addr constant [6 x i8] c"test\0A\00"
|
||||
|
||||
define ptx_device void @t1_printf() {
|
||||
; CHECK: mov.u64 %rd{{[0-9]+}}, $L__str;
|
||||
; CHECK: call.uni (__localparam_{{[0-9]+}}), vprintf, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}});
|
||||
; CHECK: ret;
|
||||
%1 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([6 x i8]* @str, i64 0, i64 0))
|
||||
ret void
|
||||
}
|
||||
|
||||
@str2 = private unnamed_addr constant [11 x i8] c"test = %f\0A\00"
|
||||
|
||||
define ptx_device void @t2_printf() {
|
||||
; CHECK: .local .align 8 .b8 __local{{[0-9]+}}[{{[0-9]+}}];
|
||||
; CHECK: mov.u64 %rd{{[0-9]+}}, $L__str2;
|
||||
; CHECK: cvta.local.u64 %rd{{[0-9]+}}, __local{{[0-9+]}};
|
||||
; CHECK: call.uni (__localparam_{{[0-9]+}}), vprintf, (__localparam_{{[0-9]+}}, __localparam_{{[0-9]+}});
|
||||
; CHECK: ret;
|
||||
%1 = call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([11 x i8]* @str2, i64 0, i64 0), double 0x3FF3333340000000)
|
||||
ret void
|
||||
}
|
Loading…
Reference in New Issue
Block a user