[NVPTX] Fix handling of vector arguments

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@177847 91177308-0d34-0410-b5e6-96231b3b80d8
This commit is contained in:
Justin Holewinski 2013-03-24 21:17:47 +00:00
parent d28e30fcf4
commit 1ce53cb526
3 changed files with 83 additions and 7 deletions

View File

@ -1481,7 +1481,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F,
O << "(\n";
for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
const Type *Ty = I->getType();
Type *Ty = I->getType();
if (!first)
O << ",\n";
@ -1504,6 +1504,22 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F,
}
if (PAL.hasAttribute(paramIndex+1, Attribute::ByVal) == false) {
if (Ty->isVectorTy()) {
// Just print .param .b8 .align <a> .param[size];
// <a> = PAL.getparamalignment
// size = typeallocsize of element type
unsigned align = PAL.getParamAlignment(paramIndex+1);
if (align == 0)
align = TD->getABITypeAlignment(Ty);
unsigned sz = TD->getTypeAllocSize(Ty);
O << "\t.param .align " << align
<< " .b8 ";
printParamName(I, paramIndex, O);
O << "[" << sz << "]";
continue;
}
// Just a scalar
const PointerType *PTy = dyn_cast<PointerType>(Ty);
if (isKernelFunc) {

View File

@ -1058,15 +1058,15 @@ NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
theArgs.push_back(I);
argTypes.push_back(I->getType());
}
assert(argTypes.size() == Ins.size() &&
"Ins types and function types did not match");
//assert(argTypes.size() == Ins.size() &&
// "Ins types and function types did not match");
int idx = 0;
for (unsigned i=0, e=Ins.size(); i!=e; ++i, ++idx) {
for (unsigned i=0, e=argTypes.size(); i!=e; ++i, ++idx) {
Type *Ty = argTypes[i];
EVT ObjectVT = getValueType(Ty);
assert(ObjectVT == Ins[i].VT &&
"Ins type did not match function type");
//assert(ObjectVT == Ins[i].VT &&
// "Ins type did not match function type");
// If the kernel argument is image*_t or sampler_t, convert it to
// a i32 constant holding the parameter position. This can later
@ -1081,7 +1081,15 @@ NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
if (theArgs[i]->use_empty()) {
// argument is dead
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
if (ObjectVT.isVector()) {
EVT EltVT = ObjectVT.getVectorElementType();
unsigned NumElts = ObjectVT.getVectorNumElements();
for (unsigned vi = 0; vi < NumElts; ++vi) {
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, EltVT));
}
} else {
InVals.push_back(DAG.getNode(ISD::UNDEF, dl, ObjectVT));
}
continue;
}
@ -1090,6 +1098,31 @@ NVPTXTargetLowering::LowerFormalArguments(SDValue Chain,
// appear in the same order as their order of appearance
// in the original function. "idx+1" holds that order.
if (PAL.hasAttribute(i+1, Attribute::ByVal) == false) {
if (ObjectVT.isVector()) {
unsigned NumElts = ObjectVT.getVectorNumElements();
EVT EltVT = ObjectVT.getVectorElementType();
unsigned Offset = 0;
for (unsigned vi = 0; vi < NumElts; ++vi) {
SDValue A = getParamSymbol(DAG, idx, getPointerTy());
SDValue B = DAG.getIntPtrConstant(Offset);
SDValue Addr = DAG.getNode(ISD::ADD, dl, getPointerTy(),
//getParamSymbol(DAG, idx, EltVT),
//DAG.getConstant(Offset, getPointerTy()));
A, B);
Value *SrcValue = Constant::getNullValue(PointerType::get(
EltVT.getTypeForEVT(F->getContext()),
llvm::ADDRESS_SPACE_PARAM));
SDValue Ld = DAG.getLoad(EltVT, dl, Root, Addr,
MachinePointerInfo(SrcValue),
false, false, false,
TD->getABITypeAlignment(EltVT.getTypeForEVT(
F->getContext())));
Offset += EltVT.getStoreSizeInBits()/8;
InVals.push_back(Ld);
}
continue;
}
// A plain scalar.
if (isABI || isKernel) {
// If ABI, load from the param symbol

View File

@ -0,0 +1,27 @@
; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
define float @foo(<2 x float> %a) {
; CHECK: .func (.param .b32 func_retval0) foo
; CHECK: .param .align 8 .b8 foo_param_0[8]
; CHECK: ld.param.f32 %f{{[0-9]+}}
; CHECK: ld.param.f32 %f{{[0-9]+}}
%t1 = fmul <2 x float> %a, %a
%t2 = extractelement <2 x float> %t1, i32 0
%t3 = extractelement <2 x float> %t1, i32 1
%t4 = fadd float %t2, %t3
ret float %t4
}
define float @bar(<4 x float> %a) {
; CHECK: .func (.param .b32 func_retval0) bar
; CHECK: .param .align 16 .b8 bar_param_0[16]
; CHECK: ld.param.f32 %f{{[0-9]+}}
; CHECK: ld.param.f32 %f{{[0-9]+}}
%t1 = fmul <4 x float> %a, %a
%t2 = extractelement <4 x float> %t1, i32 0
%t3 = extractelement <4 x float> %t1, i32 1
%t4 = fadd float %t2, %t3
ret float %t4
}