mirror of
https://github.com/RPCS3/llvm.git
synced 2024-12-23 12:40:17 +00:00
[NVPTX] Clean up argument lowering code and properly handle alignment for structs and vectors
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@211938 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
parent
3fb44103eb
commit
8992274412
@ -67,6 +67,17 @@ static bool IsPTXVectorType(MVT VT) {
|
||||
}
|
||||
}
|
||||
|
||||
static uint64_t GCD( int a, int b)
|
||||
{
|
||||
if (a < b) std::swap(a,b);
|
||||
while (b > 0) {
|
||||
uint64_t c = b;
|
||||
b = a % b;
|
||||
a = c;
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
|
||||
/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
|
||||
/// into their primitive components.
|
||||
@ -518,26 +529,12 @@ NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
|
||||
} else if (isa<PointerType>(retTy)) {
|
||||
O << ".param .b" << getPointerTy().getSizeInBits() << " _";
|
||||
} else {
|
||||
if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) {
|
||||
SmallVector<EVT, 16> vtparts;
|
||||
ComputeValueVTs(*this, retTy, vtparts);
|
||||
unsigned totalsz = 0;
|
||||
for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
|
||||
unsigned elems = 1;
|
||||
EVT elemtype = vtparts[i];
|
||||
if (vtparts[i].isVector()) {
|
||||
elems = vtparts[i].getVectorNumElements();
|
||||
elemtype = vtparts[i].getVectorElementType();
|
||||
}
|
||||
// TODO: no need to loop
|
||||
for (unsigned j = 0, je = elems; j != je; ++j) {
|
||||
unsigned sz = elemtype.getSizeInBits();
|
||||
if (elemtype.isInteger() && (sz < 8))
|
||||
sz = 8;
|
||||
totalsz += sz / 8;
|
||||
}
|
||||
}
|
||||
O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
|
||||
if((retTy->getTypeID() == Type::StructTyID) ||
|
||||
isa<VectorType>(retTy)) {
|
||||
O << ".param .align "
|
||||
<< retAlignment
|
||||
<< " .b8 _["
|
||||
<< getDataLayout()->getTypeAllocSize(retTy) << "]";
|
||||
} else {
|
||||
assert(false && "Unknown return type");
|
||||
}
|
||||
@ -706,7 +703,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
|
||||
if (Ty->isAggregateType()) {
|
||||
// aggregate
|
||||
SmallVector<EVT, 16> vtparts;
|
||||
ComputeValueVTs(*this, Ty, vtparts);
|
||||
SmallVector<uint64_t, 16> Offsets;
|
||||
ComputePTXValueVTs(*this, Ty, vtparts, &Offsets, 0);
|
||||
|
||||
unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
|
||||
// declare .param .align <align> .b8 .param<n>[<size>];
|
||||
@ -718,34 +716,26 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
|
||||
Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
|
||||
DeclareParamOps);
|
||||
InFlag = Chain.getValue(1);
|
||||
unsigned curOffset = 0;
|
||||
for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
|
||||
unsigned elems = 1;
|
||||
EVT elemtype = vtparts[j];
|
||||
if (vtparts[j].isVector()) {
|
||||
elems = vtparts[j].getVectorNumElements();
|
||||
elemtype = vtparts[j].getVectorElementType();
|
||||
}
|
||||
for (unsigned k = 0, ke = elems; k != ke; ++k) {
|
||||
unsigned sz = elemtype.getSizeInBits();
|
||||
if (elemtype.isInteger() && (sz < 8))
|
||||
sz = 8;
|
||||
SDValue StVal = OutVals[OIdx];
|
||||
if (elemtype.getSizeInBits() < 16) {
|
||||
StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
|
||||
}
|
||||
SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
|
||||
SDValue CopyParamOps[] = { Chain,
|
||||
DAG.getConstant(paramCount, MVT::i32),
|
||||
DAG.getConstant(curOffset, MVT::i32),
|
||||
StVal, InFlag };
|
||||
Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
|
||||
CopyParamVTs, CopyParamOps,
|
||||
elemtype, MachinePointerInfo());
|
||||
InFlag = Chain.getValue(1);
|
||||
curOffset += sz / 8;
|
||||
++OIdx;
|
||||
unsigned ArgAlign = GCD(align, Offsets[j]);
|
||||
if (elemtype.isInteger() && (sz < 8))
|
||||
sz = 8;
|
||||
SDValue StVal = OutVals[OIdx];
|
||||
if (elemtype.getSizeInBits() < 16) {
|
||||
StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
|
||||
}
|
||||
SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
|
||||
SDValue CopyParamOps[] = { Chain,
|
||||
DAG.getConstant(paramCount, MVT::i32),
|
||||
DAG.getConstant(Offsets[j], MVT::i32),
|
||||
StVal, InFlag };
|
||||
Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
|
||||
CopyParamVTs, CopyParamOps,
|
||||
elemtype, MachinePointerInfo(),
|
||||
ArgAlign);
|
||||
InFlag = Chain.getValue(1);
|
||||
++OIdx;
|
||||
}
|
||||
if (vtparts.size() > 0)
|
||||
--OIdx;
|
||||
@ -930,13 +920,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
|
||||
}
|
||||
// struct or vector
|
||||
SmallVector<EVT, 16> vtparts;
|
||||
SmallVector<uint64_t, 16> Offsets;
|
||||
const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
|
||||
assert(PTy && "Type of a byval parameter should be pointer");
|
||||
ComputeValueVTs(*this, PTy->getElementType(), vtparts);
|
||||
ComputePTXValueVTs(*this, PTy->getElementType(), vtparts, &Offsets, 0);
|
||||
|
||||
// declare .param .align <align> .b8 .param<n>[<size>];
|
||||
unsigned sz = Outs[OIdx].Flags.getByValSize();
|
||||
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
|
||||
unsigned ArgAlign = Outs[OIdx].Flags.getByValAlign();
|
||||
// The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
|
||||
// so we don't need to worry about natural alignment or not.
|
||||
// See TargetLowering::LowerCallTo().
|
||||
@ -948,38 +940,28 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
|
||||
Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
|
||||
DeclareParamOps);
|
||||
InFlag = Chain.getValue(1);
|
||||
unsigned curOffset = 0;
|
||||
for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
|
||||
unsigned elems = 1;
|
||||
EVT elemtype = vtparts[j];
|
||||
if (vtparts[j].isVector()) {
|
||||
elems = vtparts[j].getVectorNumElements();
|
||||
elemtype = vtparts[j].getVectorElementType();
|
||||
int curOffset = Offsets[j];
|
||||
unsigned PartAlign = GCD(ArgAlign, curOffset);
|
||||
SDValue srcAddr =
|
||||
DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
|
||||
DAG.getConstant(curOffset, getPointerTy()));
|
||||
SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
|
||||
MachinePointerInfo(), false, false, false,
|
||||
PartAlign);
|
||||
if (elemtype.getSizeInBits() < 16) {
|
||||
theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
|
||||
}
|
||||
for (unsigned k = 0, ke = elems; k != ke; ++k) {
|
||||
unsigned sz = elemtype.getSizeInBits();
|
||||
if (elemtype.isInteger() && (sz < 8))
|
||||
sz = 8;
|
||||
SDValue srcAddr =
|
||||
DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
|
||||
DAG.getConstant(curOffset, getPointerTy()));
|
||||
SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
|
||||
MachinePointerInfo(), false, false, false,
|
||||
0);
|
||||
if (elemtype.getSizeInBits() < 16) {
|
||||
theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
|
||||
}
|
||||
SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
|
||||
SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
|
||||
DAG.getConstant(curOffset, MVT::i32), theVal,
|
||||
InFlag };
|
||||
Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
|
||||
CopyParamOps, elemtype,
|
||||
MachinePointerInfo());
|
||||
SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
|
||||
SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
|
||||
DAG.getConstant(curOffset, MVT::i32), theVal,
|
||||
InFlag };
|
||||
Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
|
||||
CopyParamOps, elemtype,
|
||||
MachinePointerInfo());
|
||||
|
||||
InFlag = Chain.getValue(1);
|
||||
curOffset += sz / 8;
|
||||
}
|
||||
InFlag = Chain.getValue(1);
|
||||
}
|
||||
++paramCount;
|
||||
}
|
||||
@ -1088,7 +1070,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
|
||||
|
||||
// Generate loads from param memory/moves from registers for result
|
||||
if (Ins.size() > 0) {
|
||||
unsigned resoffset = 0;
|
||||
if (retTy && retTy->isVectorTy()) {
|
||||
EVT ObjectVT = getValueType(retTy);
|
||||
unsigned NumElts = ObjectVT.getVectorNumElements();
|
||||
@ -1097,14 +1078,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
|
||||
ObjectVT) == NumElts &&
|
||||
"Vector was not scalarized");
|
||||
unsigned sz = EltVT.getSizeInBits();
|
||||
bool needTruncate = sz < 16 ? true : false;
|
||||
bool needTruncate = sz < 8 ? true : false;
|
||||
|
||||
if (NumElts == 1) {
|
||||
// Just a simple load
|
||||
SmallVector<EVT, 4> LoadRetVTs;
|
||||
if (needTruncate) {
|
||||
// If loading i1 result, generate
|
||||
// load i16
|
||||
if (EltVT == MVT::i1 || EltVT == MVT::i8) {
|
||||
// If loading i1/i8 result, generate
|
||||
// load.b8 i16
|
||||
// if i1
|
||||
// trunc i16 to i1
|
||||
LoadRetVTs.push_back(MVT::i16);
|
||||
} else
|
||||
@ -1128,9 +1110,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
|
||||
} else if (NumElts == 2) {
|
||||
// LoadV2
|
||||
SmallVector<EVT, 4> LoadRetVTs;
|
||||
if (needTruncate) {
|
||||
// If loading i1 result, generate
|
||||
// load i16
|
||||
if (EltVT == MVT::i1 || EltVT == MVT::i8) {
|
||||
// If loading i1/i8 result, generate
|
||||
// load.b8 i16
|
||||
// if i1
|
||||
// trunc i16 to i1
|
||||
LoadRetVTs.push_back(MVT::i16);
|
||||
LoadRetVTs.push_back(MVT::i16);
|
||||
@ -1173,9 +1156,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
|
||||
EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
|
||||
for (unsigned i = 0; i < NumElts; i += VecSize) {
|
||||
SmallVector<EVT, 8> LoadRetVTs;
|
||||
if (needTruncate) {
|
||||
// If loading i1 result, generate
|
||||
// load i16
|
||||
if (EltVT == MVT::i1 || EltVT == MVT::i8) {
|
||||
// If loading i1/i8 result, generate
|
||||
// load.b8 i16
|
||||
// if i1
|
||||
// trunc i16 to i1
|
||||
for (unsigned j = 0; j < VecSize; ++j)
|
||||
LoadRetVTs.push_back(MVT::i16);
|
||||
@ -1214,10 +1198,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
|
||||
}
|
||||
} else {
|
||||
SmallVector<EVT, 16> VTs;
|
||||
ComputePTXValueVTs(*this, retTy, VTs);
|
||||
SmallVector<uint64_t, 16> Offsets;
|
||||
ComputePTXValueVTs(*this, retTy, VTs, &Offsets, 0);
|
||||
assert(VTs.size() == Ins.size() && "Bad value decomposition");
|
||||
unsigned RetAlign = getArgumentAlignment(Callee, CS, retTy, 0);
|
||||
for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
|
||||
unsigned sz = VTs[i].getSizeInBits();
|
||||
unsigned AlignI = GCD(RetAlign, Offsets[i]);
|
||||
bool needTruncate = sz < 8 ? true : false;
|
||||
if (VTs[i].isInteger() && (sz < 8))
|
||||
sz = 8;
|
||||
@ -1243,19 +1230,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
|
||||
SmallVector<SDValue, 4> LoadRetOps;
|
||||
LoadRetOps.push_back(Chain);
|
||||
LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
|
||||
LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
|
||||
LoadRetOps.push_back(DAG.getConstant(Offsets[i], MVT::i32));
|
||||
LoadRetOps.push_back(InFlag);
|
||||
SDValue retval = DAG.getMemIntrinsicNode(
|
||||
NVPTXISD::LoadParam, dl,
|
||||
DAG.getVTList(LoadRetVTs), LoadRetOps,
|
||||
TheLoadType, MachinePointerInfo());
|
||||
TheLoadType, MachinePointerInfo(), AlignI);
|
||||
Chain = retval.getValue(1);
|
||||
InFlag = retval.getValue(2);
|
||||
SDValue Ret0 = retval.getValue(0);
|
||||
if (needTruncate)
|
||||
Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
|
||||
InVals.push_back(Ret0);
|
||||
resoffset += sz / 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
13
test/CodeGen/NVPTX/arg-lowering.ll
Normal file
13
test/CodeGen/NVPTX/arg-lowering.ll
Normal file
@ -0,0 +1,13 @@
|
||||
; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
|
||||
|
||||
; CHECK: .visible .func (.param .align 16 .b8 func_retval0[16]) foo0(
|
||||
; CHECK: .param .align 4 .b8 foo0_param_0[8]
|
||||
define <4 x float> @foo0({float, float} %arg0) {
|
||||
ret <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>
|
||||
}
|
||||
|
||||
; CHECK: .visible .func (.param .align 8 .b8 func_retval0[8]) foo1(
|
||||
; CHECK: .param .align 8 .b8 foo1_param_0[16]
|
||||
define <2 x float> @foo1({float, float, i64} %arg0) {
|
||||
ret <2 x float> <float 1.0, float 1.0>
|
||||
}
|
Loading…
Reference in New Issue
Block a user