From 5422a0f166bc86c72fafaa547435e18578add3b9 Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Thu, 22 Sep 2011 16:45:46 +0000 Subject: [PATCH] PTX: Use .param space for device function return values on SM 2.0+, and attempt to fix up parameter passing on SM < 2.0 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@140309 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/PTX/PTXAsmPrinter.cpp | 89 ++++++---- lib/Target/PTX/PTXISelDAGToDAG.cpp | 83 +++++++++ lib/Target/PTX/PTXISelLowering.cpp | 215 ++++++++++++++++-------- lib/Target/PTX/PTXISelLowering.h | 2 + lib/Target/PTX/PTXInstrInfo.cpp | 6 +- lib/Target/PTX/PTXInstrInfo.td | 32 +++- lib/Target/PTX/PTXMFInfoExtract.cpp | 2 +- lib/Target/PTX/PTXMachineFunctionInfo.h | 113 +++++++++---- 8 files changed, 394 insertions(+), 148 deletions(-) diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp index f936d4bb3c4..6337ee99705 100644 --- a/lib/Target/PTX/PTXAsmPrinter.cpp +++ b/lib/Target/PTX/PTXAsmPrinter.cpp @@ -91,9 +91,13 @@ private: static const char PARAM_PREFIX[] = "__param_"; static const char RETURN_PREFIX[] = "__ret_"; -static const char *getRegisterTypeName(unsigned RegNo) { -#define TEST_REGCLS(cls, clsstr) \ - if (PTX::cls ## RegisterClass->contains(RegNo)) return # clsstr; +static const char *getRegisterTypeName(unsigned RegNo, + const MachineRegisterInfo& MRI) { + const TargetRegisterClass *TRC = MRI.getRegClass(RegNo); + +#define TEST_REGCLS(cls, clsstr) \ + if (PTX::cls ## RegisterClass == TRC) return # clsstr; + TEST_REGCLS(RegPred, pred); TEST_REGCLS(RegI16, b16); TEST_REGCLS(RegI32, b32); @@ -288,18 +292,18 @@ void PTXAsmPrinter::EmitFunctionBodyStart() { } } - unsigned Index = 1; + //unsigned Index = 1; // Print parameter passing params - for (PTXMachineFunctionInfo::param_iterator - i = MFI->paramBegin(), e = MFI->paramEnd(); i != e; ++i) { - std::string def = "\t.param .b"; - def += utostr(*i); - def += " __ret_"; - def += utostr(Index); - Index++; - def += ";"; - OutStreamer.EmitRawText(Twine(def)); - } + //for (PTXMachineFunctionInfo::param_iterator + // i = MFI->paramBegin(), e = MFI->paramEnd(); i != e; ++i) { + // std::string def = "\t.param .b"; + // def += utostr(*i); + // def += " __ret_"; + // def += utostr(Index); + // Index++; + // def += ";"; + // OutStreamer.EmitRawText(Twine(def)); + //} } void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) { @@ -436,7 +440,8 @@ void PTXAsmPrinter::printParamOperand(const MachineInstr *MI, int opNum, void PTXAsmPrinter::printReturnOperand(const MachineInstr *MI, int opNum, raw_ostream &OS, const char *Modifier) { - OS << RETURN_PREFIX << (int) MI->getOperand(opNum).getImm() + 1; + //OS << RETURN_PREFIX << (int) MI->getOperand(opNum).getImm() + 1; + OS << "__ret"; } void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) { @@ -559,6 +564,7 @@ void PTXAsmPrinter::EmitFunctionDeclaration() { const PTXMachineFunctionInfo *MFI = MF->getInfo(); const bool isKernel = MFI->isKernel(); const PTXSubtarget& ST = TM.getSubtarget(); + const MachineRegisterInfo& MRI = MF->getRegInfo(); std::string decl = isKernel ? ".entry" : ".func"; @@ -566,16 +572,22 @@ void PTXAsmPrinter::EmitFunctionDeclaration() { if (!isKernel) { decl += " ("; - for (PTXMachineFunctionInfo::ret_iterator - i = MFI->retRegBegin(), e = MFI->retRegEnd(), b = i; - i != e; ++i) { - if (i != b) { - decl += ", "; + if (ST.useParamSpaceForDeviceArgs() && MFI->getRetParamSize() != 0) { + decl += ".param .b"; + decl += utostr(MFI->getRetParamSize()); + decl += " __ret"; + } else { + for (PTXMachineFunctionInfo::ret_iterator + i = MFI->retRegBegin(), e = MFI->retRegEnd(), b = i; + i != e; ++i) { + if (i != b) { + decl += ", "; + } + decl += ".reg ."; + decl += getRegisterTypeName(*i, MRI); + decl += " "; + decl += MFI->getRegisterName(*i); } - decl += ".reg ."; - decl += getRegisterTypeName(*i); - decl += " "; - decl += getRegisterName(*i); } decl += ")"; } @@ -589,23 +601,32 @@ void PTXAsmPrinter::EmitFunctionDeclaration() { cnt = 0; // Print parameters - for (PTXMachineFunctionInfo::reg_iterator - i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i; - i != e; ++i) { - if (i != b) { - decl += ", "; - } - if (isKernel || ST.useParamSpaceForDeviceArgs()) { + if (isKernel || ST.useParamSpaceForDeviceArgs()) { + for (PTXMachineFunctionInfo::argparam_iterator + i = MFI->argParamBegin(), e = MFI->argParamEnd(), b = i; + i != e; ++i) { + if (i != b) { + decl += ", "; + } + decl += ".param .b"; decl += utostr(*i); decl += " "; decl += PARAM_PREFIX; decl += utostr(++cnt); - } else { + } + } else { + for (PTXMachineFunctionInfo::reg_iterator + i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i; + i != e; ++i) { + if (i != b) { + decl += ", "; + } + decl += ".reg ."; - decl += getRegisterTypeName(*i); + decl += getRegisterTypeName(*i, MRI); decl += " "; - decl += getRegisterName(*i); + decl += MFI->getRegisterName(*i); } } decl += ")"; diff --git a/lib/Target/PTX/PTXISelDAGToDAG.cpp b/lib/Target/PTX/PTXISelDAGToDAG.cpp index 9adfa624b29..685b24ecfcc 100644 --- a/lib/Target/PTX/PTXISelDAGToDAG.cpp +++ b/lib/Target/PTX/PTXISelDAGToDAG.cpp @@ -46,6 +46,9 @@ class PTXDAGToDAGISel : public SelectionDAGISel { // pattern (PTXbrcond bb:$d, ...) in PTXInstrInfo.td SDNode *SelectBRCOND(SDNode *Node); + SDNode *SelectREADPARAM(SDNode *Node); + SDNode *SelectWRITEPARAM(SDNode *Node); + bool isImm(const SDValue &operand); bool SelectImm(const SDValue &operand, SDValue &imm); @@ -68,6 +71,10 @@ SDNode *PTXDAGToDAGISel::Select(SDNode *Node) { switch (Node->getOpcode()) { case ISD::BRCOND: return SelectBRCOND(Node); + case PTXISD::READ_PARAM: + return SelectREADPARAM(Node); + case PTXISD::WRITE_PARAM: + return SelectWRITEPARAM(Node); default: return SelectCode(Node); } @@ -90,6 +97,82 @@ SDNode *PTXDAGToDAGISel::SelectBRCOND(SDNode *Node) { return CurDAG->getMachineNode(PTX::BRAdp, dl, MVT::Other, Ops, 4); } +SDNode *PTXDAGToDAGISel::SelectREADPARAM(SDNode *Node) { + SDValue Chain = Node->getOperand(0); + SDValue Index = Node->getOperand(1); + + int OpCode; + + // Get the type of parameter we are reading + EVT VT = Node->getValueType(0); + assert(VT.isSimple() && "READ_PARAM only implemented for MVT types"); + + MVT Type = VT.getSimpleVT(); + + if (Type == MVT::i1) + OpCode = PTX::READPARAMPRED; + else if (Type == MVT::i16) + OpCode = PTX::READPARAMI16; + else if (Type == MVT::i32) + OpCode = PTX::READPARAMI32; + else if (Type == MVT::i64) + OpCode = PTX::READPARAMI64; + else if (Type == MVT::f32) + OpCode = PTX::READPARAMF32; + else if (Type == MVT::f64) + OpCode = PTX::READPARAMF64; + + SDValue Pred = CurDAG->getRegister(PTX::NoRegister, MVT::i1); + SDValue PredOp = CurDAG->getTargetConstant(PTX::PRED_NORMAL, MVT::i32); + DebugLoc dl = Node->getDebugLoc(); + + SDValue Ops[] = { Index, Pred, PredOp, Chain }; + return CurDAG->getMachineNode(OpCode, dl, VT, Ops, 4); +} + +SDNode *PTXDAGToDAGISel::SelectWRITEPARAM(SDNode *Node) { + + SDValue Chain = Node->getOperand(0); + SDValue Value = Node->getOperand(1); + + int OpCode; + + //Node->dumpr(CurDAG); + + // Get the type of parameter we are writing + EVT VT = Value->getValueType(0); + assert(VT.isSimple() && "WRITE_PARAM only implemented for MVT types"); + + MVT Type = VT.getSimpleVT(); + + if (Type == MVT::i1) + OpCode = PTX::WRITEPARAMPRED; + else if (Type == MVT::i16) + OpCode = PTX::WRITEPARAMI16; + else if (Type == MVT::i32) + OpCode = PTX::WRITEPARAMI32; + else if (Type == MVT::i64) + OpCode = PTX::WRITEPARAMI64; + else if (Type == MVT::f32) + OpCode = PTX::WRITEPARAMF32; + else if (Type == MVT::f64) + OpCode = PTX::WRITEPARAMF64; + else + llvm_unreachable("Invalid type in SelectWRITEPARAM"); + + SDValue Pred = CurDAG->getRegister(PTX::NoRegister, MVT::i1); + SDValue PredOp = CurDAG->getTargetConstant(PTX::PRED_NORMAL, MVT::i32); + DebugLoc dl = Node->getDebugLoc(); + + SDValue Ops[] = { Value, Pred, PredOp, Chain }; + SDNode* Ret = CurDAG->getMachineNode(OpCode, dl, MVT::Other, Ops, 4); + + //dbgs() << "SelectWRITEPARAM produced:\n\t"; + //Ret->dumpr(CurDAG); + + return Ret; +} + // Match memory operand of the form [reg+reg] bool PTXDAGToDAGISel::SelectADDRrr(SDValue &Addr, SDValue &R1, SDValue &R2) { if (Addr.getOpcode() != ISD::ADD || Addr.getNumOperands() < 2 || diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp index a05a55b19c9..424c5a1e87d 100644 --- a/lib/Target/PTX/PTXISelLowering.cpp +++ b/lib/Target/PTX/PTXISelLowering.cpp @@ -132,6 +132,10 @@ const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "PTXISD::LOAD_PARAM"; case PTXISD::STORE_PARAM: return "PTXISD::STORE_PARAM"; + case PTXISD::READ_PARAM: + return "PTXISD::READ_PARAM"; + case PTXISD::WRITE_PARAM: + return "PTXISD::WRITE_PARAM"; case PTXISD::EXIT: return "PTXISD::EXIT"; case PTXISD::RET: @@ -220,7 +224,6 @@ SDValue PTXTargetLowering:: if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) { // We just need to emit the proper LOAD_PARAM ISDs for (unsigned i = 0, e = Ins.size(); i != e; ++i) { - assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) && "Kernels cannot take pred operands"); @@ -231,26 +234,140 @@ SDValue PTXTargetLowering:: // Instead of storing a physical register in our argument list, we just // store the total size of the parameter, in bits. The ASM printer // knows how to process this. - MFI->addArgReg(Ins[i].VT.getStoreSizeInBits()); + MFI->addArgParam(Ins[i].VT.getStoreSizeInBits()); } } else { // For device functions, we use the PTX calling convention to do register // assignments then create CopyFromReg ISDs for the allocated registers - SmallVector ArgLocs; - CCState CCInfo(CallConv, isVarArg, MF, getTargetMachine(), ArgLocs, - *DAG.getContext()); + //SmallVector ArgLocs; + //CCState CCInfo(CallConv, isVarArg, MF, getTargetMachine(), ArgLocs, + // *DAG.getContext()); - CCInfo.AnalyzeFormalArguments(Ins, CC_PTX); + //CCInfo.AnalyzeFormalArguments(Ins, CC_PTX); - for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { + //for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) { + for (unsigned i = 0, e = Ins.size(); i != e; ++i) { - CCValAssign& VA = ArgLocs[i]; - EVT RegVT = VA.getLocVT(); + EVT RegVT = Ins[i].VT; TargetRegisterClass* TRC = 0; + int OpCode; - assert(VA.isRegLoc() && "CCValAssign must be RegLoc"); + //assert(VA.isRegLoc() && "CCValAssign must be RegLoc"); + + // Determine which register class we need + if (RegVT == MVT::i1) { + TRC = PTX::RegPredRegisterClass; + OpCode = PTX::READPARAMPRED; + } + else if (RegVT == MVT::i16) { + TRC = PTX::RegI16RegisterClass; + OpCode = PTX::READPARAMI16; + } + else if (RegVT == MVT::i32) { + TRC = PTX::RegI32RegisterClass; + OpCode = PTX::READPARAMI32; + } + else if (RegVT == MVT::i64) { + TRC = PTX::RegI64RegisterClass; + OpCode = PTX::READPARAMI64; + } + else if (RegVT == MVT::f32) { + TRC = PTX::RegF32RegisterClass; + OpCode = PTX::READPARAMF32; + } + else if (RegVT == MVT::f64) { + TRC = PTX::RegF64RegisterClass; + OpCode = PTX::READPARAMF64; + } + else { + llvm_unreachable("Unknown parameter type"); + } + + // Use a unique index in the instruction to prevent instruction folding. + // Yes, this is a hack. + SDValue Index = DAG.getTargetConstant(i, MVT::i32); + unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC); + SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain, + Index); + + SDValue Flag = ArgValue.getValue(1); + + SDValue Copy = DAG.getCopyFromReg(Chain, dl, Reg, RegVT); + SDValue RegValue = DAG.getRegister(Reg, RegVT); + InVals.push_back(ArgValue); + + MFI->addArgReg(Reg); + } + } + + return Chain; +} + +SDValue PTXTargetLowering:: + LowerReturn(SDValue Chain, + CallingConv::ID CallConv, + bool isVarArg, + const SmallVectorImpl &Outs, + const SmallVectorImpl &OutVals, + DebugLoc dl, + SelectionDAG &DAG) const { + if (isVarArg) llvm_unreachable("PTX does not support varargs"); + + switch (CallConv) { + default: + llvm_unreachable("Unsupported calling convention."); + case CallingConv::PTX_Kernel: + assert(Outs.size() == 0 && "Kernel must return void."); + return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain); + case CallingConv::PTX_Device: + assert(Outs.size() <= 1 && "Can at most return one value."); + break; + } + + MachineFunction& MF = DAG.getMachineFunction(); + PTXMachineFunctionInfo *MFI = MF.getInfo(); + + SDValue Flag; + + // Even though we could use the .param space for return arguments for + // device functions if SM >= 2.0 and the number of return arguments is + // only 1, we just always use registers since this makes the codegen + // easier. + + const PTXSubtarget& ST = getTargetMachine().getSubtarget(); + + if (ST.useParamSpaceForDeviceArgs()) { + assert(Outs.size() < 2 && "Device functions can return at most one value"); + + if (Outs.size() == 1) { + unsigned Size = OutVals[0].getValueType().getSizeInBits(); + SDValue Index = DAG.getTargetConstant(MFI->getNextParam(Size), MVT::i32); + Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, + Index, OutVals[0]); + + //Flag = Chain.getValue(1); + MFI->setRetParamSize(Outs[0].VT.getStoreSizeInBits()); + } + } else { + //SmallVector RVLocs; + //CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), + //getTargetMachine(), RVLocs, *DAG.getContext()); + + //CCInfo.AnalyzeReturn(Outs, RetCC_PTX); + + //for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) { + //CCValAssign& VA = RVLocs[i]; + + for (unsigned i = 0, e = Outs.size(); i != e; ++i) { + + //assert(VA.isRegLoc() && "CCValAssign must be RegLoc"); + + //unsigned Reg = VA.getLocReg(); + + EVT RegVT = Outs[i].VT; + TargetRegisterClass* TRC = 0; // Determine which register class we need if (RegVT == MVT::i1) { @@ -276,72 +393,28 @@ SDValue PTXTargetLowering:: } unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC); - MF.getRegInfo().addLiveIn(VA.getLocReg(), Reg); - SDValue ArgValue = DAG.getCopyFromReg(Chain, dl, Reg, RegVT); - InVals.push_back(ArgValue); + //DAG.getMachineFunction().getRegInfo().addLiveOut(Reg); - MFI->addArgReg(VA.getLocReg()); + //Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag); + //SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/); + + // Guarantee that all emitted copies are stuck together, + // avoiding something bad + //Flag = Chain.getValue(1); + + SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/); + SDValue OutReg = DAG.getRegister(Reg, RegVT); + + Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg); + //Flag = Chain.getValue(1); + + MFI->addRetReg(Reg); + + //MFI->addRetReg(Reg); } } - return Chain; -} - -SDValue PTXTargetLowering:: - LowerReturn(SDValue Chain, - CallingConv::ID CallConv, - bool isVarArg, - const SmallVectorImpl &Outs, - const SmallVectorImpl &OutVals, - DebugLoc dl, - SelectionDAG &DAG) const { - if (isVarArg) llvm_unreachable("PTX does not support varargs"); - - switch (CallConv) { - default: - llvm_unreachable("Unsupported calling convention."); - case CallingConv::PTX_Kernel: - assert(Outs.size() == 0 && "Kernel must return void."); - return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain); - case CallingConv::PTX_Device: - //assert(Outs.size() <= 1 && "Can at most return one value."); - break; - } - - MachineFunction& MF = DAG.getMachineFunction(); - PTXMachineFunctionInfo *MFI = MF.getInfo(); - - SDValue Flag; - - // Even though we could use the .param space for return arguments for - // device functions if SM >= 2.0 and the number of return arguments is - // only 1, we just always use registers since this makes the codegen - // easier. - SmallVector RVLocs; - CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), - getTargetMachine(), RVLocs, *DAG.getContext()); - - CCInfo.AnalyzeReturn(Outs, RetCC_PTX); - - for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) { - CCValAssign& VA = RVLocs[i]; - - assert(VA.isRegLoc() && "CCValAssign must be RegLoc"); - - unsigned Reg = VA.getLocReg(); - - DAG.getMachineFunction().getRegInfo().addLiveOut(Reg); - - Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag); - - // Guarantee that all emitted copies are stuck together, - // avoiding something bad - Flag = Chain.getValue(1); - - MFI->addRetReg(Reg); - } - if (Flag.getNode() == 0) { return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain); } diff --git a/lib/Target/PTX/PTXISelLowering.h b/lib/Target/PTX/PTXISelLowering.h index 3112b03d4b1..f88349f865c 100644 --- a/lib/Target/PTX/PTXISelLowering.h +++ b/lib/Target/PTX/PTXISelLowering.h @@ -26,6 +26,8 @@ namespace PTXISD { FIRST_NUMBER = ISD::BUILTIN_OP_END, LOAD_PARAM, STORE_PARAM, + READ_PARAM, + WRITE_PARAM, EXIT, RET, COPY_ADDRESS, diff --git a/lib/Target/PTX/PTXInstrInfo.cpp b/lib/Target/PTX/PTXInstrInfo.cpp index 4d4bde40881..cf6a89973e1 100644 --- a/lib/Target/PTX/PTXInstrInfo.cpp +++ b/lib/Target/PTX/PTXInstrInfo.cpp @@ -50,11 +50,11 @@ void PTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB, bool KillSrc) const { const MachineRegisterInfo& MRI = MBB.getParent()->getRegInfo(); - assert(MRI.getRegClass(SrcReg) == MRI.getRegClass(DstReg) && - "Invalid register copy between two register classes"); + //assert(MRI.getRegClass(SrcReg) == MRI.getRegClass(DstReg) && + // "Invalid register copy between two register classes"); for (int i = 0, e = sizeof(map)/sizeof(map[0]); i != e; ++i) { - if (map[i].cls == MRI.getRegClass(SrcReg)) { + if (map[i].cls == MRI.getRegClass(DstReg)) { const MCInstrDesc &MCID = get(map[i].opcode); MachineInstr *MI = BuildMI(MBB, I, DL, MCID, DstReg). addReg(SrcReg, getKillRegState(KillSrc)); diff --git a/lib/Target/PTX/PTXInstrInfo.td b/lib/Target/PTX/PTXInstrInfo.td index 11caa7f1f9d..088142b2724 100644 --- a/lib/Target/PTX/PTXInstrInfo.td +++ b/lib/Target/PTX/PTXInstrInfo.td @@ -209,6 +209,13 @@ def PTXstoreparam : SDNode<"PTXISD::STORE_PARAM", SDTypeProfile<0, 2, [SDTCisVT<0, i32>]>, [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue]>; +def PTXreadparam + : SDNode<"PTXISD::READ_PARAM", SDTypeProfile<1, 1, [SDTCisVT<1, i32>]>, + [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue]>; +def PTXwriteparam + : SDNode<"PTXISD::WRITE_PARAM", SDTypeProfile<0, 1, []>, + [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue]>; + //===----------------------------------------------------------------------===// // Instruction Class Templates //===----------------------------------------------------------------------===// @@ -617,7 +624,7 @@ defm FMUL : PTX_FLOAT_3OP<"mul.rn", fmul>; // SM_13+ defaults to .rn for f32 and f64, // SM10 must *not* provide a rounding -// TODO: +// TODO: // - Allow user selection of rounding modes for fdiv // - Add support for -prec-div=false (.approx) @@ -1045,7 +1052,7 @@ def CVT_f32_f64 // Conversion to f64 def CVT_f64_pred - : InstPTX<(outs RegF64:$d), (ins RegPred:$a), + : InstPTX<(outs RegF64:$d), (ins RegPred:$a), "selp.f64\t$d, 0D3F80000000000000, 0D0000000000000000, $a", // 1.0 [(set RegF64:$d, (uint_to_fp RegPred:$a))]>; @@ -1114,6 +1121,27 @@ def STACKLOADF32 : InstPTX<(outs), (ins RegF32:$d, i32imm:$a), def STACKLOADF64 : InstPTX<(outs), (ins RegF64:$d, i32imm:$a), "mov.f64\t$d, s$a", []>; +///===- Parameter Passing Pseudo-Instructions -----------------------------===// + +def READPARAMPRED : InstPTX<(outs RegPred:$a), (ins i32imm:$b), + "mov.pred\t$a, %param$b", []>; +def READPARAMI16 : InstPTX<(outs RegI16:$a), (ins i32imm:$b), + "mov.b16\t$a, %param$b", []>; +def READPARAMI32 : InstPTX<(outs RegI32:$a), (ins i32imm:$b), + "mov.b32\t$a, %param$b", []>; +def READPARAMI64 : InstPTX<(outs RegI64:$a), (ins i32imm:$b), + "mov.b64\t$a, %param$b", []>; +def READPARAMF32 : InstPTX<(outs RegF32:$a), (ins i32imm:$b), + "mov.f32\t$a, %param$b", []>; +def READPARAMF64 : InstPTX<(outs RegF64:$a), (ins i32imm:$b), + "mov.f64\t$a, %param$b", []>; + +def WRITEPARAMPRED : InstPTX<(outs), (ins RegPred:$a), "//w", []>; +def WRITEPARAMI16 : InstPTX<(outs), (ins RegI16:$a), "//w", []>; +def WRITEPARAMI32 : InstPTX<(outs), (ins RegI32:$a), "//w", []>; +def WRITEPARAMI64 : InstPTX<(outs), (ins RegI64:$a), "//w", []>; +def WRITEPARAMF32 : InstPTX<(outs), (ins RegF32:$a), "//w", []>; +def WRITEPARAMF64 : InstPTX<(outs), (ins RegF64:$a), "//w", []>; // Call handling // def ADJCALLSTACKUP : diff --git a/lib/Target/PTX/PTXMFInfoExtract.cpp b/lib/Target/PTX/PTXMFInfoExtract.cpp index 0a41520fcc2..f27505803ff 100644 --- a/lib/Target/PTX/PTXMFInfoExtract.cpp +++ b/lib/Target/PTX/PTXMFInfoExtract.cpp @@ -66,7 +66,7 @@ bool PTXMFInfoExtract::runOnMachineFunction(MachineFunction &MF) { // FIXME: This is a slow linear scanning for (unsigned reg = PTX::NoRegister + 1; reg < PTX::NUM_TARGET_REGS; ++reg) if (MRI.isPhysRegUsed(reg) && - !MFI->isRetReg(reg) && + //!MFI->isRetReg(reg) && (MFI->isKernel() || !MFI->isArgReg(reg))) MFI->addLocalVarReg(reg); diff --git a/lib/Target/PTX/PTXMachineFunctionInfo.h b/lib/Target/PTX/PTXMachineFunctionInfo.h index 16e5e7ba7fa..93189bbf62c 100644 --- a/lib/Target/PTX/PTXMachineFunctionInfo.h +++ b/lib/Target/PTX/PTXMachineFunctionInfo.h @@ -20,16 +20,20 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" namespace llvm { + /// PTXMachineFunctionInfo - This class is derived from MachineFunction and /// contains private PTX target-specific information for each MachineFunction. /// class PTXMachineFunctionInfo : public MachineFunctionInfo { private: bool is_kernel; - std::vector reg_arg, reg_local_var; - std::vector reg_ret; + DenseSet reg_local_var; + DenseSet reg_arg; + DenseSet reg_ret; std::vector call_params; bool _isDoneAddArg; @@ -40,29 +44,28 @@ private: RegisterMap usedRegs; RegisterNameMap regNames; + SmallVector argParams; + + unsigned retParamSize; + public: PTXMachineFunctionInfo(MachineFunction &MF) : is_kernel(false), reg_ret(PTX::NoRegister), _isDoneAddArg(false) { - reg_arg.reserve(8); - reg_local_var.reserve(32); - usedRegs[PTX::RegPredRegisterClass] = RegisterList(); usedRegs[PTX::RegI16RegisterClass] = RegisterList(); usedRegs[PTX::RegI32RegisterClass] = RegisterList(); usedRegs[PTX::RegI64RegisterClass] = RegisterList(); usedRegs[PTX::RegF32RegisterClass] = RegisterList(); usedRegs[PTX::RegF64RegisterClass] = RegisterList(); + + retParamSize = 0; } void setKernel(bool _is_kernel=true) { is_kernel = _is_kernel; } - void addArgReg(unsigned reg) { reg_arg.push_back(reg); } - void addLocalVarReg(unsigned reg) { reg_local_var.push_back(reg); } - void addRetReg(unsigned reg) { - if (!isRetReg(reg)) { - reg_ret.push_back(reg); - } - } + + void addLocalVarReg(unsigned reg) { reg_local_var.insert(reg); } + void doneAddArg(void) { _isDoneAddArg = true; @@ -71,17 +74,20 @@ public: bool isKernel() const { return is_kernel; } - typedef std::vector::const_iterator reg_iterator; - typedef std::vector::const_reverse_iterator reg_reverse_iterator; - typedef std::vector::const_iterator ret_iterator; + typedef DenseSet::const_iterator reg_iterator; + //typedef DenseSet::const_reverse_iterator reg_reverse_iterator; + typedef DenseSet::const_iterator ret_iterator; typedef std::vector::const_iterator param_iterator; + typedef SmallVector::const_iterator argparam_iterator; bool argRegEmpty() const { return reg_arg.empty(); } int getNumArg() const { return reg_arg.size(); } reg_iterator argRegBegin() const { return reg_arg.begin(); } reg_iterator argRegEnd() const { return reg_arg.end(); } - reg_reverse_iterator argRegReverseBegin() const { return reg_arg.rbegin(); } - reg_reverse_iterator argRegReverseEnd() const { return reg_arg.rend(); } + argparam_iterator argParamBegin() const { return argParams.begin(); } + argparam_iterator argParamEnd() const { return argParams.end(); } + //reg_reverse_iterator argRegReverseBegin() const { return reg_arg.rbegin(); } + //reg_reverse_iterator argRegReverseEnd() const { return reg_arg.rend(); } bool localVarRegEmpty() const { return reg_local_var.empty(); } reg_iterator localVarRegBegin() const { return reg_local_var.begin(); } @@ -103,42 +109,75 @@ public: return std::find(reg_arg.begin(), reg_arg.end(), reg) != reg_arg.end(); } - bool isRetReg(unsigned reg) const { + /*bool isRetReg(unsigned reg) const { return std::find(reg_ret.begin(), reg_ret.end(), reg) != reg_ret.end(); - } + }*/ bool isLocalVarReg(unsigned reg) const { return std::find(reg_local_var.begin(), reg_local_var.end(), reg) != reg_local_var.end(); } - void addVirtualRegister(const TargetRegisterClass *TRC, unsigned Reg) { - usedRegs[TRC].push_back(Reg); + void addRetReg(unsigned Reg) { + if (!reg_ret.count(Reg)) { + reg_ret.insert(Reg); + std::string name; + name = "%ret"; + name += utostr(reg_ret.size() - 1); + regNames[Reg] = name; + } + } + void setRetParamSize(unsigned SizeInBits) { + retParamSize = SizeInBits; + } + + unsigned getRetParamSize() const { + return retParamSize; + } + + void addArgReg(unsigned Reg) { + reg_arg.insert(Reg); + std::string name; + name = "%param"; + name += utostr(reg_arg.size() - 1); + regNames[Reg] = name; + } + + void addArgParam(unsigned SizeInBits) { + argParams.push_back(SizeInBits); + } + + void addVirtualRegister(const TargetRegisterClass *TRC, unsigned Reg) { std::string name; - if (TRC == PTX::RegPredRegisterClass) - name = "%p"; - else if (TRC == PTX::RegI16RegisterClass) - name = "%rh"; - else if (TRC == PTX::RegI32RegisterClass) - name = "%r"; - else if (TRC == PTX::RegI64RegisterClass) - name = "%rd"; - else if (TRC == PTX::RegF32RegisterClass) - name = "%f"; - else if (TRC == PTX::RegF64RegisterClass) - name = "%fd"; - else - llvm_unreachable("Invalid register class"); + if (!reg_ret.count(Reg) && !reg_arg.count(Reg)) { + usedRegs[TRC].push_back(Reg); + if (TRC == PTX::RegPredRegisterClass) + name = "%p"; + else if (TRC == PTX::RegI16RegisterClass) + name = "%rh"; + else if (TRC == PTX::RegI32RegisterClass) + name = "%r"; + else if (TRC == PTX::RegI64RegisterClass) + name = "%rd"; + else if (TRC == PTX::RegF32RegisterClass) + name = "%f"; + else if (TRC == PTX::RegF64RegisterClass) + name = "%fd"; + else + llvm_unreachable("Invalid register class"); - name += utostr(usedRegs[TRC].size() - 1); - regNames[Reg] = name; + name += utostr(usedRegs[TRC].size() - 1); + regNames[Reg] = name; + } } std::string getRegisterName(unsigned Reg) const { if (regNames.count(Reg)) return regNames.lookup(Reg); + else if (Reg == PTX::NoRegister) + return "%noreg"; else llvm_unreachable("Register not in register name map"); }