ptx: add passing parameter to kernel functions

llvm-svn: 125279
This commit is contained in:
Che-Liang Chiou 2011-02-10 12:01:24 +00:00
parent f35d4d8c53
commit 762ff2a943
9 changed files with 104 additions and 62 deletions

View File

@ -38,12 +38,11 @@
using namespace llvm; using namespace llvm;
static cl::opt<std::string> static cl::opt<std::string>
OptPTXVersion("ptx-version", cl::desc("Set PTX version"), OptPTXVersion("ptx-version", cl::desc("Set PTX version"), cl::init("1.4"));
cl::init("1.4"));
static cl::opt<std::string> static cl::opt<std::string>
OptPTXTarget("ptx-target", cl::desc("Set GPU target (comma-separated list)"), OptPTXTarget("ptx-target", cl::desc("Set GPU target (comma-separated list)"),
cl::init("sm_10")); cl::init("sm_10"));
namespace { namespace {
class PTXAsmPrinter : public AsmPrinter { class PTXAsmPrinter : public AsmPrinter {
@ -67,6 +66,8 @@ public:
void printOperand(const MachineInstr *MI, int opNum, raw_ostream &OS); void printOperand(const MachineInstr *MI, int opNum, raw_ostream &OS);
void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &OS, void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &OS,
const char *Modifier = 0); const char *Modifier = 0);
void printParamOperand(const MachineInstr *MI, int opNum, raw_ostream &OS,
const char *Modifier = 0);
// autogen'd. // autogen'd.
void printInstruction(const MachineInstr *MI, raw_ostream &OS); void printInstruction(const MachineInstr *MI, raw_ostream &OS);
@ -231,6 +232,11 @@ void PTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
printOperand(MI, opNum+1, OS); printOperand(MI, opNum+1, OS);
} }
void PTXAsmPrinter::printParamOperand(const MachineInstr *MI, int opNum,
raw_ostream &OS, const char *Modifier) {
OS << PARAM_PREFIX << (int) MI->getOperand(opNum).getImm() + 1;
}
void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) { void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
// Check to see if this is a special global used by LLVM, if so, emit it. // Check to see if this is a special global used by LLVM, if so, emit it.
if (EmitSpecialLLVMGlobal(gv)) if (EmitSpecialLLVMGlobal(gv))

View File

@ -40,6 +40,8 @@ class PTXDAGToDAGISel : public SelectionDAGISel {
#include "PTXGenDAGISel.inc" #include "PTXGenDAGISel.inc"
private: private:
SDNode *SelectREAD_PARAM(SDNode *Node);
bool isImm(const SDValue &operand); bool isImm(const SDValue &operand);
bool SelectImm(const SDValue &operand, SDValue &imm); bool SelectImm(const SDValue &operand, SDValue &imm);
}; // class PTXDAGToDAGISel }; // class PTXDAGToDAGISel
@ -57,8 +59,21 @@ PTXDAGToDAGISel::PTXDAGToDAGISel(PTXTargetMachine &TM,
: SelectionDAGISel(TM, OptLevel) {} : SelectionDAGISel(TM, OptLevel) {}
SDNode *PTXDAGToDAGISel::Select(SDNode *Node) { SDNode *PTXDAGToDAGISel::Select(SDNode *Node) {
// SelectCode() is auto'gened if (Node->getOpcode() == PTXISD::READ_PARAM)
return SelectCode(Node); return SelectREAD_PARAM(Node);
else
return SelectCode(Node);
}
SDNode *PTXDAGToDAGISel::SelectREAD_PARAM(SDNode *Node) {
SDValue index = Node->getOperand(1);
DebugLoc dl = Node->getDebugLoc();
if (index.getOpcode() != ISD::TargetConstant)
llvm_unreachable("READ_PARAM: index is not ISD::TargetConstant");
return PTXInstrInfo::
GetPTXMachineNode(CurDAG, PTX::LDpi, dl, MVT::i32, index);
} }
// Match memory operand of the form [reg+reg] // Match memory operand of the form [reg+reg]

View File

@ -47,9 +47,14 @@ SDValue PTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const { const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
switch (Opcode) { switch (Opcode) {
default: llvm_unreachable("Unknown opcode"); default:
case PTXISD::EXIT: return "PTXISD::EXIT"; llvm_unreachable("Unknown opcode");
case PTXISD::RET: return "PTXISD::RET"; case PTXISD::READ_PARAM:
return "PTXISD::READ_PARAM";
case PTXISD::EXIT:
return "PTXISD::EXIT";
case PTXISD::RET:
return "PTXISD::RET";
} }
} }
@ -86,42 +91,6 @@ struct argmap_entry {
}; };
} // end anonymous namespace } // end anonymous namespace
static SDValue lower_kernel_argument(int i,
SDValue Chain,
DebugLoc dl,
MVT::SimpleValueType VT,
argmap_entry *entry,
SelectionDAG &DAG,
unsigned *argreg) {
// TODO
llvm_unreachable("Not implemented yet");
}
static SDValue lower_device_argument(int i,
SDValue Chain,
DebugLoc dl,
MVT::SimpleValueType VT,
argmap_entry *entry,
SelectionDAG &DAG,
unsigned *argreg) {
MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
unsigned preg = *++(entry->loc); // allocate start from register 1
unsigned vreg = RegInfo.createVirtualRegister(entry->RC);
RegInfo.addLiveIn(preg, vreg);
*argreg = preg;
return DAG.getCopyFromReg(Chain, dl, vreg, VT);
}
typedef SDValue (*lower_argument_func)(int i,
SDValue Chain,
DebugLoc dl,
MVT::SimpleValueType VT,
argmap_entry *entry,
SelectionDAG &DAG,
unsigned *argreg);
SDValue PTXTargetLowering:: SDValue PTXTargetLowering::
LowerFormalArguments(SDValue Chain, LowerFormalArguments(SDValue Chain,
CallingConv::ID CallConv, CallingConv::ID CallConv,
@ -135,22 +104,22 @@ SDValue PTXTargetLowering::
MachineFunction &MF = DAG.getMachineFunction(); MachineFunction &MF = DAG.getMachineFunction();
PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>(); PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
lower_argument_func lower_argument;
switch (CallConv) { switch (CallConv) {
default: default:
llvm_unreachable("Unsupported calling convention"); llvm_unreachable("Unsupported calling convention");
break; break;
case CallingConv::PTX_Kernel: case CallingConv::PTX_Kernel:
MFI->setKernel(); MFI->setKernel(true);
lower_argument = lower_kernel_argument;
break; break;
case CallingConv::PTX_Device: case CallingConv::PTX_Device:
MFI->setKernel(false); MFI->setKernel(false);
lower_argument = lower_device_argument;
break; break;
} }
// Make sure we don't add argument registers twice
if (MFI->isDoneAddArg())
llvm_unreachable("cannot add argument registers twice");
// Reset argmap before allocation // Reset argmap before allocation
for (struct argmap_entry *i = argmap, *e = argmap + array_lengthof(argmap); for (struct argmap_entry *i = argmap, *e = argmap + array_lengthof(argmap);
i != e; ++ i) i != e; ++ i)
@ -164,17 +133,27 @@ SDValue PTXTargetLowering::
if (entry == argmap + array_lengthof(argmap)) if (entry == argmap + array_lengthof(argmap))
llvm_unreachable("Type of argument is not supported"); llvm_unreachable("Type of argument is not supported");
unsigned reg; if (MFI->isKernel() && entry->RC == PTX::PredsRegisterClass)
SDValue arg = lower_argument(i, Chain, dl, VT, entry, DAG, &reg); llvm_unreachable("cannot pass preds to kernel");
InVals.push_back(arg);
if (!MFI->isDoneAddArg()) MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
MFI->addArgReg(reg);
unsigned preg = *++(entry->loc); // allocate start from register 1
unsigned vreg = RegInfo.createVirtualRegister(entry->RC);
RegInfo.addLiveIn(preg, vreg);
MFI->addArgReg(preg);
SDValue inval;
if (MFI->isKernel())
inval = DAG.getNode(PTXISD::READ_PARAM, dl, VT, Chain,
DAG.getTargetConstant(i, MVT::i32));
else
inval = DAG.getCopyFromReg(Chain, dl, vreg, VT);
InVals.push_back(inval);
} }
// Make sure we don't add argument registers twice MFI->doneAddArg();
if (!MFI->isDoneAddArg())
MFI->doneAddArg();
return Chain; return Chain;
} }

View File

@ -24,6 +24,7 @@ class PTXTargetMachine;
namespace PTXISD { namespace PTXISD {
enum NodeType { enum NodeType {
FIRST_NUMBER = ISD::BUILTIN_OP_END, FIRST_NUMBER = ISD::BUILTIN_OP_END,
READ_PARAM,
EXIT, EXIT,
RET RET
}; };

View File

@ -15,6 +15,8 @@
#define PTX_INSTR_INFO_H #define PTX_INSTR_INFO_H
#include "PTXRegisterInfo.h" #include "PTXRegisterInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/Target/TargetInstrInfo.h" #include "llvm/Target/TargetInstrInfo.h"
namespace llvm { namespace llvm {
@ -45,6 +47,28 @@ class PTXInstrInfo : public TargetInstrInfoImpl {
virtual bool isMoveInstr(const MachineInstr& MI, virtual bool isMoveInstr(const MachineInstr& MI,
unsigned &SrcReg, unsigned &DstReg, unsigned &SrcReg, unsigned &DstReg,
unsigned &SrcSubIdx, unsigned &DstSubIdx) const; unsigned &SrcSubIdx, unsigned &DstSubIdx) const;
// static helper routines
static MachineSDNode *GetPTXMachineNode(SelectionDAG *DAG, unsigned Opcode,
DebugLoc dl, EVT VT,
SDValue Op1) {
SDValue pred_reg = DAG->getRegister(0, MVT::i1);
SDValue pred_imm = DAG->getTargetConstant(0, MVT::i32);
SDValue ops[] = { Op1, pred_reg, pred_imm };
return DAG->getMachineNode(Opcode, dl, VT, ops, array_lengthof(ops));
}
static MachineSDNode *GetPTXMachineNode(SelectionDAG *DAG, unsigned Opcode,
DebugLoc dl, EVT VT,
SDValue Op1,
SDValue Op2) {
SDValue pred_reg = DAG->getRegister(0, MVT::i1);
SDValue pred_imm = DAG->getTargetConstant(0, MVT::i32);
SDValue ops[] = { Op1, Op2, pred_reg, pred_imm };
return DAG->getMachineNode(Opcode, dl, VT, ops, array_lengthof(ops));
}
}; // class PTXInstrInfo }; // class PTXInstrInfo
} // namespace llvm } // namespace llvm

View File

@ -120,6 +120,10 @@ def MEMii : Operand<i32> {
let PrintMethod = "printMemOperand"; let PrintMethod = "printMemOperand";
let MIOperandInfo = (ops i32imm, i32imm); let MIOperandInfo = (ops i32imm, i32imm);
} }
def MEMpi : Operand<i32> {
let PrintMethod = "printParamOperand";
let MIOperandInfo = (ops i32imm);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PTX Specific Node Definitions // PTX Specific Node Definitions
@ -236,9 +240,13 @@ defm LDl : PTX_LD<"ld.local", RRegs32, load_local>;
defm LDp : PTX_LD<"ld.param", RRegs32, load_parameter>; defm LDp : PTX_LD<"ld.param", RRegs32, load_parameter>;
defm LDs : PTX_LD<"ld.shared", RRegs32, load_shared>; defm LDs : PTX_LD<"ld.shared", RRegs32, load_shared>;
def LDpi : InstPTX<(outs RRegs32:$d), (ins MEMpi:$a),
"ld.param.%type\t$d, [$a]", []>;
defm STg : PTX_ST<"st.global", RRegs32, store_global>; defm STg : PTX_ST<"st.global", RRegs32, store_global>;
defm STl : PTX_ST<"st.local", RRegs32, store_local>; defm STl : PTX_ST<"st.local", RRegs32, store_local>;
defm STp : PTX_ST<"st.param", RRegs32, store_parameter>; // Store to parameter state space requires PTX 2.0 or higher?
// defm STp : PTX_ST<"st.param", RRegs32, store_parameter>;
defm STs : PTX_ST<"st.shared", RRegs32, store_shared>; defm STs : PTX_ST<"st.shared", RRegs32, store_shared>;
///===- Control Flow Instructions -----------------------------------------===// ///===- Control Flow Instructions -----------------------------------------===//

View File

@ -67,7 +67,9 @@ bool PTXMFInfoExtract::runOnMachineFunction(MachineFunction &MF) {
// FIXME: This is a slow linear scanning // FIXME: This is a slow linear scanning
for (unsigned reg = PTX::NoRegister + 1; reg < PTX::NUM_TARGET_REGS; ++reg) for (unsigned reg = PTX::NoRegister + 1; reg < PTX::NUM_TARGET_REGS; ++reg)
if (MRI.isPhysRegUsed(reg) && reg != retreg && !MFI->isArgReg(reg)) if (MRI.isPhysRegUsed(reg) &&
reg != retreg &&
(MFI->isKernel() || !MFI->isArgReg(reg)))
MFI->addLocalVarReg(reg); MFI->addLocalVarReg(reg);
// Notify MachineFunctionInfo that I've done adding local var reg // Notify MachineFunctionInfo that I've done adding local var reg

View File

@ -31,8 +31,8 @@ private:
public: public:
PTXMachineFunctionInfo(MachineFunction &MF) PTXMachineFunctionInfo(MachineFunction &MF)
: is_kernel(false), reg_ret(PTX::NoRegister), _isDoneAddArg(false) { : is_kernel(false), reg_ret(PTX::NoRegister), _isDoneAddArg(false) {
reg_arg.reserve(32); reg_arg.reserve(8);
reg_local_var.reserve(64); reg_local_var.reserve(32);
} }
void setKernel(bool _is_kernel=true) { is_kernel = _is_kernel; } void setKernel(bool _is_kernel=true) { is_kernel = _is_kernel; }

View File

@ -3,5 +3,12 @@
define ptx_kernel void @t1() { define ptx_kernel void @t1() {
; CHECK: exit; ; CHECK: exit;
; CHECK-NOT: ret; ; CHECK-NOT: ret;
ret void ret void
}
define ptx_kernel void @t2(i32* %p, i32 %x) {
store i32 %x, i32* %p
; CHECK: exit;
; CHECK-NOT: ret;
ret void
} }