diff --git a/lib/Target/NVPTX/NVPTXISelLowering.cpp b/lib/Target/NVPTX/NVPTXISelLowering.cpp index b324cdb7d66..292e8e173b0 100644 --- a/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -67,6 +67,17 @@ static bool IsPTXVectorType(MVT VT) { } } +static uint64_t GCD( int a, int b) +{ + if (a < b) std::swap(a,b); + while (b > 0) { + uint64_t c = b; + b = a % b; + a = c; + } + return a; +} + /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive /// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors /// into their primitive components. @@ -518,26 +529,12 @@ NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args, } else if (isa(retTy)) { O << ".param .b" << getPointerTy().getSizeInBits() << " _"; } else { - if ((retTy->getTypeID() == Type::StructTyID) || isa(retTy)) { - SmallVector vtparts; - ComputeValueVTs(*this, retTy, vtparts); - unsigned totalsz = 0; - for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { - unsigned elems = 1; - EVT elemtype = vtparts[i]; - if (vtparts[i].isVector()) { - elems = vtparts[i].getVectorNumElements(); - elemtype = vtparts[i].getVectorElementType(); - } - // TODO: no need to loop - for (unsigned j = 0, je = elems; j != je; ++j) { - unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 8)) - sz = 8; - totalsz += sz / 8; - } - } - O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]"; + if((retTy->getTypeID() == Type::StructTyID) || + isa(retTy)) { + O << ".param .align " + << retAlignment + << " .b8 _[" + << getDataLayout()->getTypeAllocSize(retTy) << "]"; } else { assert(false && "Unknown return type"); } @@ -706,7 +703,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, if (Ty->isAggregateType()) { // aggregate SmallVector vtparts; - ComputeValueVTs(*this, Ty, vtparts); + SmallVector Offsets; + ComputePTXValueVTs(*this, Ty, vtparts, &Offsets, 0); unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1); // declare .param .align .b8 .param[]; @@ -718,34 +716,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, DeclareParamOps); InFlag = Chain.getValue(1); - unsigned curOffset = 0; for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { - unsigned elems = 1; EVT elemtype = vtparts[j]; - if (vtparts[j].isVector()) { - elems = vtparts[j].getVectorNumElements(); - elemtype = vtparts[j].getVectorElementType(); - } - for (unsigned k = 0, ke = elems; k != ke; ++k) { - unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 8)) - sz = 8; - SDValue StVal = OutVals[OIdx]; - if (elemtype.getSizeInBits() < 16) { - StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal); - } - SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue CopyParamOps[] = { Chain, - DAG.getConstant(paramCount, MVT::i32), - DAG.getConstant(curOffset, MVT::i32), - StVal, InFlag }; - Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, - CopyParamVTs, CopyParamOps, - elemtype, MachinePointerInfo()); - InFlag = Chain.getValue(1); - curOffset += sz / 8; - ++OIdx; + unsigned ArgAlign = GCD(align, Offsets[j]); + if (elemtype.isInteger() && (sz < 8)) + sz = 8; + SDValue StVal = OutVals[OIdx]; + if (elemtype.getSizeInBits() < 16) { + StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal); } + SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue CopyParamOps[] = { Chain, + DAG.getConstant(paramCount, MVT::i32), + DAG.getConstant(Offsets[j], MVT::i32), + StVal, InFlag }; + Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, + CopyParamVTs, CopyParamOps, + elemtype, MachinePointerInfo(), + ArgAlign); + InFlag = Chain.getValue(1); + ++OIdx; } if (vtparts.size() > 0) --OIdx; @@ -930,13 +920,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, } // struct or vector SmallVector vtparts; + SmallVector Offsets; const PointerType *PTy = dyn_cast(Args[i].Ty); assert(PTy && "Type of a byval parameter should be pointer"); - ComputeValueVTs(*this, PTy->getElementType(), vtparts); + ComputePTXValueVTs(*this, PTy->getElementType(), vtparts, &Offsets, 0); // declare .param .align .b8 .param[]; unsigned sz = Outs[OIdx].Flags.getByValSize(); SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + unsigned ArgAlign = Outs[OIdx].Flags.getByValAlign(); // The ByValAlign in the Outs[OIdx].Flags is alway set at this point, // so we don't need to worry about natural alignment or not. // See TargetLowering::LowerCallTo(). @@ -948,38 +940,28 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, DeclareParamOps); InFlag = Chain.getValue(1); - unsigned curOffset = 0; for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { - unsigned elems = 1; EVT elemtype = vtparts[j]; - if (vtparts[j].isVector()) { - elems = vtparts[j].getVectorNumElements(); - elemtype = vtparts[j].getVectorElementType(); + int curOffset = Offsets[j]; + unsigned PartAlign = GCD(ArgAlign, curOffset); + SDValue srcAddr = + DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx], + DAG.getConstant(curOffset, getPointerTy())); + SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr, + MachinePointerInfo(), false, false, false, + PartAlign); + if (elemtype.getSizeInBits() < 16) { + theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal); } - for (unsigned k = 0, ke = elems; k != ke; ++k) { - unsigned sz = elemtype.getSizeInBits(); - if (elemtype.isInteger() && (sz < 8)) - sz = 8; - SDValue srcAddr = - DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx], - DAG.getConstant(curOffset, getPointerTy())); - SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr, - MachinePointerInfo(), false, false, false, - 0); - if (elemtype.getSizeInBits() < 16) { - theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal); - } - SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); - SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32), - DAG.getConstant(curOffset, MVT::i32), theVal, - InFlag }; - Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs, - CopyParamOps, elemtype, - MachinePointerInfo()); + SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32), + DAG.getConstant(curOffset, MVT::i32), theVal, + InFlag }; + Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs, + CopyParamOps, elemtype, + MachinePointerInfo()); - InFlag = Chain.getValue(1); - curOffset += sz / 8; - } + InFlag = Chain.getValue(1); } ++paramCount; } @@ -1088,7 +1070,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, // Generate loads from param memory/moves from registers for result if (Ins.size() > 0) { - unsigned resoffset = 0; if (retTy && retTy->isVectorTy()) { EVT ObjectVT = getValueType(retTy); unsigned NumElts = ObjectVT.getVectorNumElements(); @@ -1097,14 +1078,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, ObjectVT) == NumElts && "Vector was not scalarized"); unsigned sz = EltVT.getSizeInBits(); - bool needTruncate = sz < 16 ? true : false; + bool needTruncate = sz < 8 ? true : false; if (NumElts == 1) { // Just a simple load SmallVector LoadRetVTs; - if (needTruncate) { - // If loading i1 result, generate - // load i16 + if (EltVT == MVT::i1 || EltVT == MVT::i8) { + // If loading i1/i8 result, generate + // load.b8 i16 + // if i1 // trunc i16 to i1 LoadRetVTs.push_back(MVT::i16); } else @@ -1128,9 +1110,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, } else if (NumElts == 2) { // LoadV2 SmallVector LoadRetVTs; - if (needTruncate) { - // If loading i1 result, generate - // load i16 + if (EltVT == MVT::i1 || EltVT == MVT::i8) { + // If loading i1/i8 result, generate + // load.b8 i16 + // if i1 // trunc i16 to i1 LoadRetVTs.push_back(MVT::i16); LoadRetVTs.push_back(MVT::i16); @@ -1173,9 +1156,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize); for (unsigned i = 0; i < NumElts; i += VecSize) { SmallVector LoadRetVTs; - if (needTruncate) { - // If loading i1 result, generate - // load i16 + if (EltVT == MVT::i1 || EltVT == MVT::i8) { + // If loading i1/i8 result, generate + // load.b8 i16 + // if i1 // trunc i16 to i1 for (unsigned j = 0; j < VecSize; ++j) LoadRetVTs.push_back(MVT::i16); @@ -1214,10 +1198,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, } } else { SmallVector VTs; - ComputePTXValueVTs(*this, retTy, VTs); + SmallVector Offsets; + ComputePTXValueVTs(*this, retTy, VTs, &Offsets, 0); assert(VTs.size() == Ins.size() && "Bad value decomposition"); + unsigned RetAlign = getArgumentAlignment(Callee, CS, retTy, 0); for (unsigned i = 0, e = Ins.size(); i != e; ++i) { unsigned sz = VTs[i].getSizeInBits(); + unsigned AlignI = GCD(RetAlign, Offsets[i]); bool needTruncate = sz < 8 ? true : false; if (VTs[i].isInteger() && (sz < 8)) sz = 8; @@ -1243,19 +1230,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SmallVector LoadRetOps; LoadRetOps.push_back(Chain); LoadRetOps.push_back(DAG.getConstant(1, MVT::i32)); - LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32)); + LoadRetOps.push_back(DAG.getConstant(Offsets[i], MVT::i32)); LoadRetOps.push_back(InFlag); SDValue retval = DAG.getMemIntrinsicNode( NVPTXISD::LoadParam, dl, DAG.getVTList(LoadRetVTs), LoadRetOps, - TheLoadType, MachinePointerInfo()); + TheLoadType, MachinePointerInfo(), AlignI); Chain = retval.getValue(1); InFlag = retval.getValue(2); SDValue Ret0 = retval.getValue(0); if (needTruncate) Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0); InVals.push_back(Ret0); - resoffset += sz / 8; } } } diff --git a/test/CodeGen/NVPTX/arg-lowering.ll b/test/CodeGen/NVPTX/arg-lowering.ll new file mode 100644 index 00000000000..f7b8a1491d3 --- /dev/null +++ b/test/CodeGen/NVPTX/arg-lowering.ll @@ -0,0 +1,13 @@ +; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s + +; CHECK: .visible .func (.param .align 16 .b8 func_retval0[16]) foo0( +; CHECK: .param .align 4 .b8 foo0_param_0[8] +define <4 x float> @foo0({float, float} %arg0) { + ret <4 x float> +} + +; CHECK: .visible .func (.param .align 8 .b8 func_retval0[8]) foo1( +; CHECK: .param .align 8 .b8 foo1_param_0[16] +define <2 x float> @foo1({float, float, i64} %arg0) { + ret <2 x float> +}