PTX: Use .param space for device function return values on SM 2.0+, and attempt

to fix up parameter passing on SM < 2.0

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@140309 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Justin Holewinski 2011-09-22 16:45:46 +00:00
parent 05591be5ed
commit 5422a0f166
8 changed files with 394 additions and 148 deletions

View File

@ -91,9 +91,13 @@ private:
static const char PARAM_PREFIX[] = "__param_";
static const char RETURN_PREFIX[] = "__ret_";
static const char *getRegisterTypeName(unsigned RegNo) {
#define TEST_REGCLS(cls, clsstr) \
if (PTX::cls ## RegisterClass->contains(RegNo)) return # clsstr;
static const char *getRegisterTypeName(unsigned RegNo,
const MachineRegisterInfo& MRI) {
const TargetRegisterClass *TRC = MRI.getRegClass(RegNo);
#define TEST_REGCLS(cls, clsstr) \
if (PTX::cls ## RegisterClass == TRC) return # clsstr;
TEST_REGCLS(RegPred, pred);
TEST_REGCLS(RegI16, b16);
TEST_REGCLS(RegI32, b32);
@ -288,18 +292,18 @@ void PTXAsmPrinter::EmitFunctionBodyStart() {
}
}
unsigned Index = 1;
//unsigned Index = 1;
// Print parameter passing params
for (PTXMachineFunctionInfo::param_iterator
i = MFI->paramBegin(), e = MFI->paramEnd(); i != e; ++i) {
std::string def = "\t.param .b";
def += utostr(*i);
def += " __ret_";
def += utostr(Index);
Index++;
def += ";";
OutStreamer.EmitRawText(Twine(def));
}
//for (PTXMachineFunctionInfo::param_iterator
// i = MFI->paramBegin(), e = MFI->paramEnd(); i != e; ++i) {
// std::string def = "\t.param .b";
// def += utostr(*i);
// def += " __ret_";
// def += utostr(Index);
// Index++;
// def += ";";
// OutStreamer.EmitRawText(Twine(def));
//}
}
void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
@ -436,7 +440,8 @@ void PTXAsmPrinter::printParamOperand(const MachineInstr *MI, int opNum,
void PTXAsmPrinter::printReturnOperand(const MachineInstr *MI, int opNum,
raw_ostream &OS, const char *Modifier) {
OS << RETURN_PREFIX << (int) MI->getOperand(opNum).getImm() + 1;
//OS << RETURN_PREFIX << (int) MI->getOperand(opNum).getImm() + 1;
OS << "__ret";
}
void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
@ -559,6 +564,7 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
const PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>();
const bool isKernel = MFI->isKernel();
const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
const MachineRegisterInfo& MRI = MF->getRegInfo();
std::string decl = isKernel ? ".entry" : ".func";
@ -566,16 +572,22 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
if (!isKernel) {
decl += " (";
for (PTXMachineFunctionInfo::ret_iterator
i = MFI->retRegBegin(), e = MFI->retRegEnd(), b = i;
i != e; ++i) {
if (i != b) {
decl += ", ";
if (ST.useParamSpaceForDeviceArgs() && MFI->getRetParamSize() != 0) {
decl += ".param .b";
decl += utostr(MFI->getRetParamSize());
decl += " __ret";
} else {
for (PTXMachineFunctionInfo::ret_iterator
i = MFI->retRegBegin(), e = MFI->retRegEnd(), b = i;
i != e; ++i) {
if (i != b) {
decl += ", ";
}
decl += ".reg .";
decl += getRegisterTypeName(*i, MRI);
decl += " ";
decl += MFI->getRegisterName(*i);
}
decl += ".reg .";
decl += getRegisterTypeName(*i);
decl += " ";
decl += getRegisterName(*i);
}
decl += ")";
}
@ -589,23 +601,32 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
cnt = 0;
// Print parameters
for (PTXMachineFunctionInfo::reg_iterator
i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i;
i != e; ++i) {
if (i != b) {
decl += ", ";
}
if (isKernel || ST.useParamSpaceForDeviceArgs()) {
if (isKernel || ST.useParamSpaceForDeviceArgs()) {
for (PTXMachineFunctionInfo::argparam_iterator
i = MFI->argParamBegin(), e = MFI->argParamEnd(), b = i;
i != e; ++i) {
if (i != b) {
decl += ", ";
}
decl += ".param .b";
decl += utostr(*i);
decl += " ";
decl += PARAM_PREFIX;
decl += utostr(++cnt);
} else {
}
} else {
for (PTXMachineFunctionInfo::reg_iterator
i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i;
i != e; ++i) {
if (i != b) {
decl += ", ";
}
decl += ".reg .";
decl += getRegisterTypeName(*i);
decl += getRegisterTypeName(*i, MRI);
decl += " ";
decl += getRegisterName(*i);
decl += MFI->getRegisterName(*i);
}
}
decl += ")";

View File

@ -46,6 +46,9 @@ class PTXDAGToDAGISel : public SelectionDAGISel {
// pattern (PTXbrcond bb:$d, ...) in PTXInstrInfo.td
SDNode *SelectBRCOND(SDNode *Node);
SDNode *SelectREADPARAM(SDNode *Node);
SDNode *SelectWRITEPARAM(SDNode *Node);
bool isImm(const SDValue &operand);
bool SelectImm(const SDValue &operand, SDValue &imm);
@ -68,6 +71,10 @@ SDNode *PTXDAGToDAGISel::Select(SDNode *Node) {
switch (Node->getOpcode()) {
case ISD::BRCOND:
return SelectBRCOND(Node);
case PTXISD::READ_PARAM:
return SelectREADPARAM(Node);
case PTXISD::WRITE_PARAM:
return SelectWRITEPARAM(Node);
default:
return SelectCode(Node);
}
@ -90,6 +97,82 @@ SDNode *PTXDAGToDAGISel::SelectBRCOND(SDNode *Node) {
return CurDAG->getMachineNode(PTX::BRAdp, dl, MVT::Other, Ops, 4);
}
SDNode *PTXDAGToDAGISel::SelectREADPARAM(SDNode *Node) {
SDValue Chain = Node->getOperand(0);
SDValue Index = Node->getOperand(1);
int OpCode;
// Get the type of parameter we are reading
EVT VT = Node->getValueType(0);
assert(VT.isSimple() && "READ_PARAM only implemented for MVT types");
MVT Type = VT.getSimpleVT();
if (Type == MVT::i1)
OpCode = PTX::READPARAMPRED;
else if (Type == MVT::i16)
OpCode = PTX::READPARAMI16;
else if (Type == MVT::i32)
OpCode = PTX::READPARAMI32;
else if (Type == MVT::i64)
OpCode = PTX::READPARAMI64;
else if (Type == MVT::f32)
OpCode = PTX::READPARAMF32;
else if (Type == MVT::f64)
OpCode = PTX::READPARAMF64;
SDValue Pred = CurDAG->getRegister(PTX::NoRegister, MVT::i1);
SDValue PredOp = CurDAG->getTargetConstant(PTX::PRED_NORMAL, MVT::i32);
DebugLoc dl = Node->getDebugLoc();
SDValue Ops[] = { Index, Pred, PredOp, Chain };
return CurDAG->getMachineNode(OpCode, dl, VT, Ops, 4);
}
SDNode *PTXDAGToDAGISel::SelectWRITEPARAM(SDNode *Node) {
SDValue Chain = Node->getOperand(0);
SDValue Value = Node->getOperand(1);
int OpCode;
//Node->dumpr(CurDAG);
// Get the type of parameter we are writing
EVT VT = Value->getValueType(0);
assert(VT.isSimple() && "WRITE_PARAM only implemented for MVT types");
MVT Type = VT.getSimpleVT();
if (Type == MVT::i1)
OpCode = PTX::WRITEPARAMPRED;
else if (Type == MVT::i16)
OpCode = PTX::WRITEPARAMI16;
else if (Type == MVT::i32)
OpCode = PTX::WRITEPARAMI32;
else if (Type == MVT::i64)
OpCode = PTX::WRITEPARAMI64;
else if (Type == MVT::f32)
OpCode = PTX::WRITEPARAMF32;
else if (Type == MVT::f64)
OpCode = PTX::WRITEPARAMF64;
else
llvm_unreachable("Invalid type in SelectWRITEPARAM");
SDValue Pred = CurDAG->getRegister(PTX::NoRegister, MVT::i1);
SDValue PredOp = CurDAG->getTargetConstant(PTX::PRED_NORMAL, MVT::i32);
DebugLoc dl = Node->getDebugLoc();
SDValue Ops[] = { Value, Pred, PredOp, Chain };
SDNode* Ret = CurDAG->getMachineNode(OpCode, dl, MVT::Other, Ops, 4);
//dbgs() << "SelectWRITEPARAM produced:\n\t";
//Ret->dumpr(CurDAG);
return Ret;
}
// Match memory operand of the form [reg+reg]
bool PTXDAGToDAGISel::SelectADDRrr(SDValue &Addr, SDValue &R1, SDValue &R2) {
if (Addr.getOpcode() != ISD::ADD || Addr.getNumOperands() < 2 ||

View File

@ -132,6 +132,10 @@ const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "PTXISD::LOAD_PARAM";
case PTXISD::STORE_PARAM:
return "PTXISD::STORE_PARAM";
case PTXISD::READ_PARAM:
return "PTXISD::READ_PARAM";
case PTXISD::WRITE_PARAM:
return "PTXISD::WRITE_PARAM";
case PTXISD::EXIT:
return "PTXISD::EXIT";
case PTXISD::RET:
@ -220,7 +224,6 @@ SDValue PTXTargetLowering::
if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) {
// We just need to emit the proper LOAD_PARAM ISDs
for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
"Kernels cannot take pred operands");
@ -231,26 +234,140 @@ SDValue PTXTargetLowering::
// Instead of storing a physical register in our argument list, we just
// store the total size of the parameter, in bits. The ASM printer
// knows how to process this.
MFI->addArgReg(Ins[i].VT.getStoreSizeInBits());
MFI->addArgParam(Ins[i].VT.getStoreSizeInBits());
}
}
else {
// For device functions, we use the PTX calling convention to do register
// assignments then create CopyFromReg ISDs for the allocated registers
SmallVector<CCValAssign, 16> ArgLocs;
CCState CCInfo(CallConv, isVarArg, MF, getTargetMachine(), ArgLocs,
*DAG.getContext());
//SmallVector<CCValAssign, 16> ArgLocs;
//CCState CCInfo(CallConv, isVarArg, MF, getTargetMachine(), ArgLocs,
// *DAG.getContext());
CCInfo.AnalyzeFormalArguments(Ins, CC_PTX);
//CCInfo.AnalyzeFormalArguments(Ins, CC_PTX);
for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
//for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
CCValAssign& VA = ArgLocs[i];
EVT RegVT = VA.getLocVT();
EVT RegVT = Ins[i].VT;
TargetRegisterClass* TRC = 0;
int OpCode;
assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
//assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
// Determine which register class we need
if (RegVT == MVT::i1) {
TRC = PTX::RegPredRegisterClass;
OpCode = PTX::READPARAMPRED;
}
else if (RegVT == MVT::i16) {
TRC = PTX::RegI16RegisterClass;
OpCode = PTX::READPARAMI16;
}
else if (RegVT == MVT::i32) {
TRC = PTX::RegI32RegisterClass;
OpCode = PTX::READPARAMI32;
}
else if (RegVT == MVT::i64) {
TRC = PTX::RegI64RegisterClass;
OpCode = PTX::READPARAMI64;
}
else if (RegVT == MVT::f32) {
TRC = PTX::RegF32RegisterClass;
OpCode = PTX::READPARAMF32;
}
else if (RegVT == MVT::f64) {
TRC = PTX::RegF64RegisterClass;
OpCode = PTX::READPARAMF64;
}
else {
llvm_unreachable("Unknown parameter type");
}
// Use a unique index in the instruction to prevent instruction folding.
// Yes, this is a hack.
SDValue Index = DAG.getTargetConstant(i, MVT::i32);
unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain,
Index);
SDValue Flag = ArgValue.getValue(1);
SDValue Copy = DAG.getCopyFromReg(Chain, dl, Reg, RegVT);
SDValue RegValue = DAG.getRegister(Reg, RegVT);
InVals.push_back(ArgValue);
MFI->addArgReg(Reg);
}
}
return Chain;
}
SDValue PTXTargetLowering::
LowerReturn(SDValue Chain,
CallingConv::ID CallConv,
bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals,
DebugLoc dl,
SelectionDAG &DAG) const {
if (isVarArg) llvm_unreachable("PTX does not support varargs");
switch (CallConv) {
default:
llvm_unreachable("Unsupported calling convention.");
case CallingConv::PTX_Kernel:
assert(Outs.size() == 0 && "Kernel must return void.");
return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
case CallingConv::PTX_Device:
assert(Outs.size() <= 1 && "Can at most return one value.");
break;
}
MachineFunction& MF = DAG.getMachineFunction();
PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
SDValue Flag;
// Even though we could use the .param space for return arguments for
// device functions if SM >= 2.0 and the number of return arguments is
// only 1, we just always use registers since this makes the codegen
// easier.
const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
if (ST.useParamSpaceForDeviceArgs()) {
assert(Outs.size() < 2 && "Device functions can return at most one value");
if (Outs.size() == 1) {
unsigned Size = OutVals[0].getValueType().getSizeInBits();
SDValue Index = DAG.getTargetConstant(MFI->getNextParam(Size), MVT::i32);
Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
Index, OutVals[0]);
//Flag = Chain.getValue(1);
MFI->setRetParamSize(Outs[0].VT.getStoreSizeInBits());
}
} else {
//SmallVector<CCValAssign, 16> RVLocs;
//CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(),
//getTargetMachine(), RVLocs, *DAG.getContext());
//CCInfo.AnalyzeReturn(Outs, RetCC_PTX);
//for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) {
//CCValAssign& VA = RVLocs[i];
for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
//assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
//unsigned Reg = VA.getLocReg();
EVT RegVT = Outs[i].VT;
TargetRegisterClass* TRC = 0;
// Determine which register class we need
if (RegVT == MVT::i1) {
@ -276,72 +393,28 @@ SDValue PTXTargetLowering::
}
unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
MF.getRegInfo().addLiveIn(VA.getLocReg(), Reg);
SDValue ArgValue = DAG.getCopyFromReg(Chain, dl, Reg, RegVT);
InVals.push_back(ArgValue);
//DAG.getMachineFunction().getRegInfo().addLiveOut(Reg);
MFI->addArgReg(VA.getLocReg());
//Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag);
//SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/);
// Guarantee that all emitted copies are stuck together,
// avoiding something bad
//Flag = Chain.getValue(1);
SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/);
SDValue OutReg = DAG.getRegister(Reg, RegVT);
Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg);
//Flag = Chain.getValue(1);
MFI->addRetReg(Reg);
//MFI->addRetReg(Reg);
}
}
return Chain;
}
SDValue PTXTargetLowering::
LowerReturn(SDValue Chain,
CallingConv::ID CallConv,
bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals,
DebugLoc dl,
SelectionDAG &DAG) const {
if (isVarArg) llvm_unreachable("PTX does not support varargs");
switch (CallConv) {
default:
llvm_unreachable("Unsupported calling convention.");
case CallingConv::PTX_Kernel:
assert(Outs.size() == 0 && "Kernel must return void.");
return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
case CallingConv::PTX_Device:
//assert(Outs.size() <= 1 && "Can at most return one value.");
break;
}
MachineFunction& MF = DAG.getMachineFunction();
PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
SDValue Flag;
// Even though we could use the .param space for return arguments for
// device functions if SM >= 2.0 and the number of return arguments is
// only 1, we just always use registers since this makes the codegen
// easier.
SmallVector<CCValAssign, 16> RVLocs;
CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(),
getTargetMachine(), RVLocs, *DAG.getContext());
CCInfo.AnalyzeReturn(Outs, RetCC_PTX);
for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) {
CCValAssign& VA = RVLocs[i];
assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
unsigned Reg = VA.getLocReg();
DAG.getMachineFunction().getRegInfo().addLiveOut(Reg);
Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag);
// Guarantee that all emitted copies are stuck together,
// avoiding something bad
Flag = Chain.getValue(1);
MFI->addRetReg(Reg);
}
if (Flag.getNode() == 0) {
return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
}

View File

@ -26,6 +26,8 @@ namespace PTXISD {
FIRST_NUMBER = ISD::BUILTIN_OP_END,
LOAD_PARAM,
STORE_PARAM,
READ_PARAM,
WRITE_PARAM,
EXIT,
RET,
COPY_ADDRESS,

View File

@ -50,11 +50,11 @@ void PTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
bool KillSrc) const {
const MachineRegisterInfo& MRI = MBB.getParent()->getRegInfo();
assert(MRI.getRegClass(SrcReg) == MRI.getRegClass(DstReg) &&
"Invalid register copy between two register classes");
//assert(MRI.getRegClass(SrcReg) == MRI.getRegClass(DstReg) &&
// "Invalid register copy between two register classes");
for (int i = 0, e = sizeof(map)/sizeof(map[0]); i != e; ++i) {
if (map[i].cls == MRI.getRegClass(SrcReg)) {
if (map[i].cls == MRI.getRegClass(DstReg)) {
const MCInstrDesc &MCID = get(map[i].opcode);
MachineInstr *MI = BuildMI(MBB, I, DL, MCID, DstReg).
addReg(SrcReg, getKillRegState(KillSrc));

View File

@ -209,6 +209,13 @@ def PTXstoreparam
: SDNode<"PTXISD::STORE_PARAM", SDTypeProfile<0, 2, [SDTCisVT<0, i32>]>,
[SDNPHasChain, SDNPOutGlue, SDNPOptInGlue]>;
def PTXreadparam
: SDNode<"PTXISD::READ_PARAM", SDTypeProfile<1, 1, [SDTCisVT<1, i32>]>,
[SDNPHasChain, SDNPOutGlue, SDNPOptInGlue]>;
def PTXwriteparam
: SDNode<"PTXISD::WRITE_PARAM", SDTypeProfile<0, 1, []>,
[SDNPHasChain, SDNPOutGlue, SDNPOptInGlue]>;
//===----------------------------------------------------------------------===//
// Instruction Class Templates
//===----------------------------------------------------------------------===//
@ -617,7 +624,7 @@ defm FMUL : PTX_FLOAT_3OP<"mul.rn", fmul>;
// SM_13+ defaults to .rn for f32 and f64,
// SM10 must *not* provide a rounding
// TODO:
// TODO:
// - Allow user selection of rounding modes for fdiv
// - Add support for -prec-div=false (.approx)
@ -1045,7 +1052,7 @@ def CVT_f32_f64
// Conversion to f64
def CVT_f64_pred
: InstPTX<(outs RegF64:$d), (ins RegPred:$a),
: InstPTX<(outs RegF64:$d), (ins RegPred:$a),
"selp.f64\t$d, 0D3F80000000000000, 0D0000000000000000, $a", // 1.0
[(set RegF64:$d, (uint_to_fp RegPred:$a))]>;
@ -1114,6 +1121,27 @@ def STACKLOADF32 : InstPTX<(outs), (ins RegF32:$d, i32imm:$a),
def STACKLOADF64 : InstPTX<(outs), (ins RegF64:$d, i32imm:$a),
"mov.f64\t$d, s$a", []>;
///===- Parameter Passing Pseudo-Instructions -----------------------------===//
def READPARAMPRED : InstPTX<(outs RegPred:$a), (ins i32imm:$b),
"mov.pred\t$a, %param$b", []>;
def READPARAMI16 : InstPTX<(outs RegI16:$a), (ins i32imm:$b),
"mov.b16\t$a, %param$b", []>;
def READPARAMI32 : InstPTX<(outs RegI32:$a), (ins i32imm:$b),
"mov.b32\t$a, %param$b", []>;
def READPARAMI64 : InstPTX<(outs RegI64:$a), (ins i32imm:$b),
"mov.b64\t$a, %param$b", []>;
def READPARAMF32 : InstPTX<(outs RegF32:$a), (ins i32imm:$b),
"mov.f32\t$a, %param$b", []>;
def READPARAMF64 : InstPTX<(outs RegF64:$a), (ins i32imm:$b),
"mov.f64\t$a, %param$b", []>;
def WRITEPARAMPRED : InstPTX<(outs), (ins RegPred:$a), "//w", []>;
def WRITEPARAMI16 : InstPTX<(outs), (ins RegI16:$a), "//w", []>;
def WRITEPARAMI32 : InstPTX<(outs), (ins RegI32:$a), "//w", []>;
def WRITEPARAMI64 : InstPTX<(outs), (ins RegI64:$a), "//w", []>;
def WRITEPARAMF32 : InstPTX<(outs), (ins RegF32:$a), "//w", []>;
def WRITEPARAMF64 : InstPTX<(outs), (ins RegF64:$a), "//w", []>;
// Call handling
// def ADJCALLSTACKUP :

View File

@ -66,7 +66,7 @@ bool PTXMFInfoExtract::runOnMachineFunction(MachineFunction &MF) {
// FIXME: This is a slow linear scanning
for (unsigned reg = PTX::NoRegister + 1; reg < PTX::NUM_TARGET_REGS; ++reg)
if (MRI.isPhysRegUsed(reg) &&
!MFI->isRetReg(reg) &&
//!MFI->isRetReg(reg) &&
(MFI->isKernel() || !MFI->isArgReg(reg)))
MFI->addLocalVarReg(reg);

View File

@ -20,16 +20,20 @@
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
namespace llvm {
/// PTXMachineFunctionInfo - This class is derived from MachineFunction and
/// contains private PTX target-specific information for each MachineFunction.
///
class PTXMachineFunctionInfo : public MachineFunctionInfo {
private:
bool is_kernel;
std::vector<unsigned> reg_arg, reg_local_var;
std::vector<unsigned> reg_ret;
DenseSet<unsigned> reg_local_var;
DenseSet<unsigned> reg_arg;
DenseSet<unsigned> reg_ret;
std::vector<unsigned> call_params;
bool _isDoneAddArg;
@ -40,29 +44,28 @@ private:
RegisterMap usedRegs;
RegisterNameMap regNames;
SmallVector<unsigned, 8> argParams;
unsigned retParamSize;
public:
PTXMachineFunctionInfo(MachineFunction &MF)
: is_kernel(false), reg_ret(PTX::NoRegister), _isDoneAddArg(false) {
reg_arg.reserve(8);
reg_local_var.reserve(32);
usedRegs[PTX::RegPredRegisterClass] = RegisterList();
usedRegs[PTX::RegI16RegisterClass] = RegisterList();
usedRegs[PTX::RegI32RegisterClass] = RegisterList();
usedRegs[PTX::RegI64RegisterClass] = RegisterList();
usedRegs[PTX::RegF32RegisterClass] = RegisterList();
usedRegs[PTX::RegF64RegisterClass] = RegisterList();
retParamSize = 0;
}
void setKernel(bool _is_kernel=true) { is_kernel = _is_kernel; }
void addArgReg(unsigned reg) { reg_arg.push_back(reg); }
void addLocalVarReg(unsigned reg) { reg_local_var.push_back(reg); }
void addRetReg(unsigned reg) {
if (!isRetReg(reg)) {
reg_ret.push_back(reg);
}
}
void addLocalVarReg(unsigned reg) { reg_local_var.insert(reg); }
void doneAddArg(void) {
_isDoneAddArg = true;
@ -71,17 +74,20 @@ public:
bool isKernel() const { return is_kernel; }
typedef std::vector<unsigned>::const_iterator reg_iterator;
typedef std::vector<unsigned>::const_reverse_iterator reg_reverse_iterator;
typedef std::vector<unsigned>::const_iterator ret_iterator;
typedef DenseSet<unsigned>::const_iterator reg_iterator;
//typedef DenseSet<unsigned>::const_reverse_iterator reg_reverse_iterator;
typedef DenseSet<unsigned>::const_iterator ret_iterator;
typedef std::vector<unsigned>::const_iterator param_iterator;
typedef SmallVector<unsigned, 8>::const_iterator argparam_iterator;
bool argRegEmpty() const { return reg_arg.empty(); }
int getNumArg() const { return reg_arg.size(); }
reg_iterator argRegBegin() const { return reg_arg.begin(); }
reg_iterator argRegEnd() const { return reg_arg.end(); }
reg_reverse_iterator argRegReverseBegin() const { return reg_arg.rbegin(); }
reg_reverse_iterator argRegReverseEnd() const { return reg_arg.rend(); }
argparam_iterator argParamBegin() const { return argParams.begin(); }
argparam_iterator argParamEnd() const { return argParams.end(); }
//reg_reverse_iterator argRegReverseBegin() const { return reg_arg.rbegin(); }
//reg_reverse_iterator argRegReverseEnd() const { return reg_arg.rend(); }
bool localVarRegEmpty() const { return reg_local_var.empty(); }
reg_iterator localVarRegBegin() const { return reg_local_var.begin(); }
@ -103,42 +109,75 @@ public:
return std::find(reg_arg.begin(), reg_arg.end(), reg) != reg_arg.end();
}
bool isRetReg(unsigned reg) const {
/*bool isRetReg(unsigned reg) const {
return std::find(reg_ret.begin(), reg_ret.end(), reg) != reg_ret.end();
}
}*/
bool isLocalVarReg(unsigned reg) const {
return std::find(reg_local_var.begin(), reg_local_var.end(), reg)
!= reg_local_var.end();
}
void addVirtualRegister(const TargetRegisterClass *TRC, unsigned Reg) {
usedRegs[TRC].push_back(Reg);
void addRetReg(unsigned Reg) {
if (!reg_ret.count(Reg)) {
reg_ret.insert(Reg);
std::string name;
name = "%ret";
name += utostr(reg_ret.size() - 1);
regNames[Reg] = name;
}
}
void setRetParamSize(unsigned SizeInBits) {
retParamSize = SizeInBits;
}
unsigned getRetParamSize() const {
return retParamSize;
}
void addArgReg(unsigned Reg) {
reg_arg.insert(Reg);
std::string name;
name = "%param";
name += utostr(reg_arg.size() - 1);
regNames[Reg] = name;
}
void addArgParam(unsigned SizeInBits) {
argParams.push_back(SizeInBits);
}
void addVirtualRegister(const TargetRegisterClass *TRC, unsigned Reg) {
std::string name;
if (TRC == PTX::RegPredRegisterClass)
name = "%p";
else if (TRC == PTX::RegI16RegisterClass)
name = "%rh";
else if (TRC == PTX::RegI32RegisterClass)
name = "%r";
else if (TRC == PTX::RegI64RegisterClass)
name = "%rd";
else if (TRC == PTX::RegF32RegisterClass)
name = "%f";
else if (TRC == PTX::RegF64RegisterClass)
name = "%fd";
else
llvm_unreachable("Invalid register class");
if (!reg_ret.count(Reg) && !reg_arg.count(Reg)) {
usedRegs[TRC].push_back(Reg);
if (TRC == PTX::RegPredRegisterClass)
name = "%p";
else if (TRC == PTX::RegI16RegisterClass)
name = "%rh";
else if (TRC == PTX::RegI32RegisterClass)
name = "%r";
else if (TRC == PTX::RegI64RegisterClass)
name = "%rd";
else if (TRC == PTX::RegF32RegisterClass)
name = "%f";
else if (TRC == PTX::RegF64RegisterClass)
name = "%fd";
else
llvm_unreachable("Invalid register class");
name += utostr(usedRegs[TRC].size() - 1);
regNames[Reg] = name;
name += utostr(usedRegs[TRC].size() - 1);
regNames[Reg] = name;
}
}
std::string getRegisterName(unsigned Reg) const {
if (regNames.count(Reg))
return regNames.lookup(Reg);
else if (Reg == PTX::NoRegister)
return "%noreg";
else
llvm_unreachable("Register not in register name map");
}