[NVPTX] Add support for vectorized function return values

llvm-svn: 185173
This commit is contained in:
Justin Holewinski 2013-06-28 17:57:55 +00:00
parent 7332dc0027
commit 9ae87e685a
2 changed files with 147 additions and 27 deletions

View File

@ -1338,37 +1338,147 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
}
SDValue NVPTXTargetLowering::LowerReturn(
SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals, SDLoc dl,
SelectionDAG &DAG) const {
SDValue
NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals,
SDLoc dl, SelectionDAG &DAG) const {
MachineFunction &MF = DAG.getMachineFunction();
const Function *F = MF.getFunction();
const Type *RetTy = F->getReturnType();
const DataLayout *TD = getDataLayout();
bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
assert(isABI && "Non-ABI compilation is not supported");
if (!isABI)
return Chain;
unsigned sizesofar = 0;
unsigned idx = 0;
for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
SDValue theVal = OutVals[i];
EVT theValType = theVal.getValueType();
unsigned numElems = 1;
if (theValType.isVector())
numElems = theValType.getVectorNumElements();
for (unsigned j = 0, je = numElems; j != je; ++j) {
SDValue tmpval = theVal;
if (const VectorType *VTy = dyn_cast<const VectorType>(RetTy)) {
// If we have a vector type, the OutVals array will be the scalarized
// components and we have combine them into 1 or more vector stores.
unsigned NumElts = VTy->getNumElements();
assert(NumElts == Outs.size() && "Bad scalarization of return value");
// V1 store
if (NumElts == 1) {
SDValue StoreVal = OutVals[0];
// We only have one element, so just directly store it
if (StoreVal.getValueType().getSizeInBits() < 8)
StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
DAG.getConstant(0, MVT::i32), StoreVal);
} else if (NumElts == 2) {
// V2 store
SDValue StoreVal0 = OutVals[0];
SDValue StoreVal1 = OutVals[1];
if (StoreVal0.getValueType().getSizeInBits() < 8) {
StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal0);
StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal1);
}
Chain = DAG.getNode(NVPTXISD::StoreRetvalV2, dl, MVT::Other, Chain,
DAG.getConstant(0, MVT::i32), StoreVal0, StoreVal1);
} else {
// V4 stores
// We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
// vector will be expanded to a power of 2 elements, so we know we can
// always round up to the next multiple of 4 when creating the vector
// stores.
// e.g. 4 elem => 1 st.v4
// 6 elem => 2 st.v4
// 8 elem => 2 st.v4
// 11 elem => 3 st.v4
unsigned VecSize = 4;
if (OutVals[0].getValueType().getSizeInBits() == 64)
VecSize = 2;
unsigned Offset = 0;
EVT VecVT =
EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
unsigned PerStoreOffset =
TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
bool Extend = false;
if (OutVals[0].getValueType().getSizeInBits() < 8)
Extend = true;
for (unsigned i = 0; i < NumElts; i += VecSize) {
// Get values
SDValue StoreVal;
SmallVector<SDValue, 8> Ops;
Ops.push_back(Chain);
Ops.push_back(DAG.getConstant(Offset, MVT::i32));
unsigned Opc = NVPTXISD::StoreRetvalV2;
EVT ExtendedVT = (Extend) ? MVT::i8 : OutVals[0].getValueType();
StoreVal = OutVals[i];
if (Extend)
StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
Ops.push_back(StoreVal);
if (i + 1 < NumElts) {
StoreVal = OutVals[i + 1];
if (Extend)
StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
} else {
StoreVal = DAG.getUNDEF(ExtendedVT);
}
Ops.push_back(StoreVal);
if (VecSize == 4) {
Opc = NVPTXISD::StoreRetvalV4;
if (i + 2 < NumElts) {
StoreVal = OutVals[i + 2];
if (Extend)
StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
} else {
StoreVal = DAG.getUNDEF(ExtendedVT);
}
Ops.push_back(StoreVal);
if (i + 3 < NumElts) {
StoreVal = OutVals[i + 3];
if (Extend)
StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
} else {
StoreVal = DAG.getUNDEF(ExtendedVT);
}
Ops.push_back(StoreVal);
}
Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
Offset += PerStoreOffset;
}
}
} else {
unsigned sizesofar = 0;
for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
SDValue theVal = OutVals[i];
EVT theValType = theVal.getValueType();
unsigned numElems = 1;
if (theValType.isVector())
tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
theValType.getVectorElementType(), tmpval,
DAG.getIntPtrConstant(j));
Chain = DAG.getNode(
isABI ? NVPTXISD::StoreRetval : NVPTXISD::MoveToRetval, dl,
MVT::Other, Chain, DAG.getConstant(isABI ? sizesofar : idx, MVT::i32),
tmpval);
if (theValType.isVector())
sizesofar += theValType.getVectorElementType().getStoreSizeInBits() / 8;
else
sizesofar += theValType.getStoreSizeInBits() / 8;
++idx;
numElems = theValType.getVectorNumElements();
for (unsigned j = 0, je = numElems; j != je; ++j) {
SDValue tmpval = theVal;
if (theValType.isVector())
tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
theValType.getVectorElementType(), tmpval,
DAG.getIntPtrConstant(j));
EVT theStoreType = tmpval.getValueType();
if (theStoreType.getSizeInBits() < 8)
tmpval = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, tmpval);
Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
DAG.getConstant(sizesofar, MVT::i32), tmpval);
if (theValType.isVector())
sizesofar +=
theValType.getVectorElementType().getStoreSizeInBits() / 8;
else
sizesofar += theValType.getStoreSizeInBits() / 8;
}
}
}

View File

@ -23,3 +23,13 @@ define float @bar(<4 x float> %a) {
%t4 = fadd float %t2, %t3
ret float %t4
}
define <4 x float> @baz(<4 x float> %a) {
; CHECK: .func (.param .align 16 .b8 func_retval0[16]) baz
; CHECK: .param .align 16 .b8 baz_param_0[16]
; CHECK: ld.param.v4.f32 {%f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}}
; CHECK: st.param.v4.f32 [func_retval0+0], {%f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}}
%t1 = fmul <4 x float> %a, %a
ret <4 x float> %t1
}