mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2025-01-21 03:05:15 +00:00
[NVPTX] Add support for vectorized function return values
llvm-svn: 185173
This commit is contained in:
parent
7332dc0027
commit
9ae87e685a
@ -1338,37 +1338,147 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
|
||||
}
|
||||
|
||||
|
||||
SDValue NVPTXTargetLowering::LowerReturn(
|
||||
SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
|
||||
const SmallVectorImpl<ISD::OutputArg> &Outs,
|
||||
const SmallVectorImpl<SDValue> &OutVals, SDLoc dl,
|
||||
SelectionDAG &DAG) const {
|
||||
SDValue
|
||||
NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
|
||||
bool isVarArg,
|
||||
const SmallVectorImpl<ISD::OutputArg> &Outs,
|
||||
const SmallVectorImpl<SDValue> &OutVals,
|
||||
SDLoc dl, SelectionDAG &DAG) const {
|
||||
MachineFunction &MF = DAG.getMachineFunction();
|
||||
const Function *F = MF.getFunction();
|
||||
const Type *RetTy = F->getReturnType();
|
||||
const DataLayout *TD = getDataLayout();
|
||||
|
||||
bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
|
||||
assert(isABI && "Non-ABI compilation is not supported");
|
||||
if (!isABI)
|
||||
return Chain;
|
||||
|
||||
unsigned sizesofar = 0;
|
||||
unsigned idx = 0;
|
||||
for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
|
||||
SDValue theVal = OutVals[i];
|
||||
EVT theValType = theVal.getValueType();
|
||||
unsigned numElems = 1;
|
||||
if (theValType.isVector())
|
||||
numElems = theValType.getVectorNumElements();
|
||||
for (unsigned j = 0, je = numElems; j != je; ++j) {
|
||||
SDValue tmpval = theVal;
|
||||
if (const VectorType *VTy = dyn_cast<const VectorType>(RetTy)) {
|
||||
// If we have a vector type, the OutVals array will be the scalarized
|
||||
// components and we have combine them into 1 or more vector stores.
|
||||
unsigned NumElts = VTy->getNumElements();
|
||||
assert(NumElts == Outs.size() && "Bad scalarization of return value");
|
||||
|
||||
// V1 store
|
||||
if (NumElts == 1) {
|
||||
SDValue StoreVal = OutVals[0];
|
||||
// We only have one element, so just directly store it
|
||||
if (StoreVal.getValueType().getSizeInBits() < 8)
|
||||
StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
|
||||
Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
|
||||
DAG.getConstant(0, MVT::i32), StoreVal);
|
||||
} else if (NumElts == 2) {
|
||||
// V2 store
|
||||
SDValue StoreVal0 = OutVals[0];
|
||||
SDValue StoreVal1 = OutVals[1];
|
||||
|
||||
if (StoreVal0.getValueType().getSizeInBits() < 8) {
|
||||
StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal0);
|
||||
StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal1);
|
||||
}
|
||||
|
||||
Chain = DAG.getNode(NVPTXISD::StoreRetvalV2, dl, MVT::Other, Chain,
|
||||
DAG.getConstant(0, MVT::i32), StoreVal0, StoreVal1);
|
||||
} else {
|
||||
// V4 stores
|
||||
// We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
|
||||
// vector will be expanded to a power of 2 elements, so we know we can
|
||||
// always round up to the next multiple of 4 when creating the vector
|
||||
// stores.
|
||||
// e.g. 4 elem => 1 st.v4
|
||||
// 6 elem => 2 st.v4
|
||||
// 8 elem => 2 st.v4
|
||||
// 11 elem => 3 st.v4
|
||||
|
||||
unsigned VecSize = 4;
|
||||
if (OutVals[0].getValueType().getSizeInBits() == 64)
|
||||
VecSize = 2;
|
||||
|
||||
unsigned Offset = 0;
|
||||
|
||||
EVT VecVT =
|
||||
EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
|
||||
unsigned PerStoreOffset =
|
||||
TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
|
||||
|
||||
bool Extend = false;
|
||||
if (OutVals[0].getValueType().getSizeInBits() < 8)
|
||||
Extend = true;
|
||||
|
||||
for (unsigned i = 0; i < NumElts; i += VecSize) {
|
||||
// Get values
|
||||
SDValue StoreVal;
|
||||
SmallVector<SDValue, 8> Ops;
|
||||
Ops.push_back(Chain);
|
||||
Ops.push_back(DAG.getConstant(Offset, MVT::i32));
|
||||
unsigned Opc = NVPTXISD::StoreRetvalV2;
|
||||
EVT ExtendedVT = (Extend) ? MVT::i8 : OutVals[0].getValueType();
|
||||
|
||||
StoreVal = OutVals[i];
|
||||
if (Extend)
|
||||
StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
|
||||
Ops.push_back(StoreVal);
|
||||
|
||||
if (i + 1 < NumElts) {
|
||||
StoreVal = OutVals[i + 1];
|
||||
if (Extend)
|
||||
StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
|
||||
} else {
|
||||
StoreVal = DAG.getUNDEF(ExtendedVT);
|
||||
}
|
||||
Ops.push_back(StoreVal);
|
||||
|
||||
if (VecSize == 4) {
|
||||
Opc = NVPTXISD::StoreRetvalV4;
|
||||
if (i + 2 < NumElts) {
|
||||
StoreVal = OutVals[i + 2];
|
||||
if (Extend)
|
||||
StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
|
||||
} else {
|
||||
StoreVal = DAG.getUNDEF(ExtendedVT);
|
||||
}
|
||||
Ops.push_back(StoreVal);
|
||||
|
||||
if (i + 3 < NumElts) {
|
||||
StoreVal = OutVals[i + 3];
|
||||
if (Extend)
|
||||
StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
|
||||
} else {
|
||||
StoreVal = DAG.getUNDEF(ExtendedVT);
|
||||
}
|
||||
Ops.push_back(StoreVal);
|
||||
}
|
||||
|
||||
Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
|
||||
Offset += PerStoreOffset;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
unsigned sizesofar = 0;
|
||||
for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
|
||||
SDValue theVal = OutVals[i];
|
||||
EVT theValType = theVal.getValueType();
|
||||
unsigned numElems = 1;
|
||||
if (theValType.isVector())
|
||||
tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
|
||||
theValType.getVectorElementType(), tmpval,
|
||||
DAG.getIntPtrConstant(j));
|
||||
Chain = DAG.getNode(
|
||||
isABI ? NVPTXISD::StoreRetval : NVPTXISD::MoveToRetval, dl,
|
||||
MVT::Other, Chain, DAG.getConstant(isABI ? sizesofar : idx, MVT::i32),
|
||||
tmpval);
|
||||
if (theValType.isVector())
|
||||
sizesofar += theValType.getVectorElementType().getStoreSizeInBits() / 8;
|
||||
else
|
||||
sizesofar += theValType.getStoreSizeInBits() / 8;
|
||||
++idx;
|
||||
numElems = theValType.getVectorNumElements();
|
||||
for (unsigned j = 0, je = numElems; j != je; ++j) {
|
||||
SDValue tmpval = theVal;
|
||||
if (theValType.isVector())
|
||||
tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
|
||||
theValType.getVectorElementType(), tmpval,
|
||||
DAG.getIntPtrConstant(j));
|
||||
EVT theStoreType = tmpval.getValueType();
|
||||
if (theStoreType.getSizeInBits() < 8)
|
||||
tmpval = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, tmpval);
|
||||
Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
|
||||
DAG.getConstant(sizesofar, MVT::i32), tmpval);
|
||||
if (theValType.isVector())
|
||||
sizesofar +=
|
||||
theValType.getVectorElementType().getStoreSizeInBits() / 8;
|
||||
else
|
||||
sizesofar += theValType.getStoreSizeInBits() / 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -23,3 +23,13 @@ define float @bar(<4 x float> %a) {
|
||||
%t4 = fadd float %t2, %t3
|
||||
ret float %t4
|
||||
}
|
||||
|
||||
|
||||
define <4 x float> @baz(<4 x float> %a) {
|
||||
; CHECK: .func (.param .align 16 .b8 func_retval0[16]) baz
|
||||
; CHECK: .param .align 16 .b8 baz_param_0[16]
|
||||
; CHECK: ld.param.v4.f32 {%f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}}
|
||||
; CHECK: st.param.v4.f32 [func_retval0+0], {%f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}, %f{{[0-9]+}}}
|
||||
%t1 = fmul <4 x float> %a, %a
|
||||
ret <4 x float> %t1
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user