[NVPTX] Add isel patterns for bit-field extract (bfe)

llvm-svn: 211932
This commit is contained in:
Justin Holewinski 2014-06-27 18:35:27 +00:00
parent 7c9cd5f566
commit 2ffa2d24b0
4 changed files with 270 additions and 0 deletions

View File

@ -253,6 +253,12 @@ SDNode *NVPTXDAGToDAGISel::Select(SDNode *N) {
case NVPTXISD::Suld3DV4I32Trap:
ResNode = SelectSurfaceIntrinsic(N);
break;
case ISD::AND:
case ISD::SRA:
case ISD::SRL:
// Try to select BFE
ResNode = SelectBFE(N);
break;
case ISD::ADDRSPACECAST:
ResNode = SelectAddrSpaceCast(N);
break;
@ -2959,6 +2965,214 @@ SDNode *NVPTXDAGToDAGISel::SelectSurfaceIntrinsic(SDNode *N) {
return Ret;
}
/// SelectBFE - Look for instruction sequences that can be made more efficient
/// by using the 'bfe' (bit-field extract) PTX instruction
SDNode *NVPTXDAGToDAGISel::SelectBFE(SDNode *N) {
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
SDValue Len;
SDValue Start;
SDValue Val;
bool IsSigned = false;
if (N->getOpcode() == ISD::AND) {
// Canonicalize the operands
// We want 'and %val, %mask'
if (isa<ConstantSDNode>(LHS) && !isa<ConstantSDNode>(RHS)) {
std::swap(LHS, RHS);
}
ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(RHS);
if (!Mask) {
// We need a constant mask on the RHS of the AND
return NULL;
}
// Extract the mask bits
uint64_t MaskVal = Mask->getZExtValue();
if (!isMask_64(MaskVal)) {
// We *could* handle shifted masks here, but doing so would require an
// 'and' operation to fix up the low-order bits so we would trade
// shr+and for bfe+and, which has the same throughput
return NULL;
}
// How many bits are in our mask?
uint64_t NumBits = CountTrailingOnes_64(MaskVal);
Len = CurDAG->getTargetConstant(NumBits, MVT::i32);
if (LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SRA) {
// We have a 'srl/and' pair, extract the effective start bit and length
Val = LHS.getNode()->getOperand(0);
Start = LHS.getNode()->getOperand(1);
ConstantSDNode *StartConst = dyn_cast<ConstantSDNode>(Start);
if (StartConst) {
uint64_t StartVal = StartConst->getZExtValue();
// How many "good" bits do we have left? "good" is defined here as bits
// that exist in the original value, not shifted in.
uint64_t GoodBits = Start.getValueType().getSizeInBits() - StartVal;
if (NumBits > GoodBits) {
// Do not handle the case where bits have been shifted in. In theory
// we could handle this, but the cost is likely higher than just
// emitting the srl/and pair.
return NULL;
}
Start = CurDAG->getTargetConstant(StartVal, MVT::i32);
} else {
// Do not handle the case where the shift amount (can be zero if no srl
// was found) is not constant. We could handle this case, but it would
// require run-time logic that would be more expensive than just
// emitting the srl/and pair.
return NULL;
}
} else {
// Do not handle the case where the LHS of the and is not a shift. While
// it would be trivial to handle this case, it would just transform
// 'and' -> 'bfe', but 'and' has higher-throughput.
return NULL;
}
} else if (N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) {
if (LHS->getOpcode() == ISD::AND) {
ConstantSDNode *ShiftCnst = dyn_cast<ConstantSDNode>(RHS);
if (!ShiftCnst) {
// Shift amount must be constant
return NULL;
}
uint64_t ShiftAmt = ShiftCnst->getZExtValue();
SDValue AndLHS = LHS->getOperand(0);
SDValue AndRHS = LHS->getOperand(1);
// Canonicalize the AND to have the mask on the RHS
if (isa<ConstantSDNode>(AndLHS)) {
std::swap(AndLHS, AndRHS);
}
ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(AndRHS);
if (!MaskCnst) {
// Mask must be constant
return NULL;
}
uint64_t MaskVal = MaskCnst->getZExtValue();
uint64_t NumZeros;
uint64_t NumBits;
if (isMask_64(MaskVal)) {
NumZeros = 0;
// The number of bits in the result bitfield will be the number of
// trailing ones (the AND) minus the number of bits we shift off
NumBits = CountTrailingOnes_64(MaskVal) - ShiftAmt;
} else if (isShiftedMask_64(MaskVal)) {
NumZeros = countTrailingZeros(MaskVal);
unsigned NumOnes = CountTrailingOnes_64(MaskVal >> NumZeros);
// The number of bits in the result bitfield will be the number of
// trailing zeros plus the number of set bits in the mask minus the
// number of bits we shift off
NumBits = NumZeros + NumOnes - ShiftAmt;
} else {
// This is not a mask we can handle
return NULL;
}
if (ShiftAmt < NumZeros) {
// Handling this case would require extra logic that would make this
// transformation non-profitable
return NULL;
}
Val = AndLHS;
Start = CurDAG->getTargetConstant(ShiftAmt, MVT::i32);
Len = CurDAG->getTargetConstant(NumBits, MVT::i32);
} else if (LHS->getOpcode() == ISD::SHL) {
// Here, we have a pattern like:
//
// (sra (shl val, NN), MM)
// or
// (srl (shl val, NN), MM)
//
// If MM >= NN, we can efficiently optimize this with bfe
Val = LHS->getOperand(0);
SDValue ShlRHS = LHS->getOperand(1);
ConstantSDNode *ShlCnst = dyn_cast<ConstantSDNode>(ShlRHS);
if (!ShlCnst) {
// Shift amount must be constant
return NULL;
}
uint64_t InnerShiftAmt = ShlCnst->getZExtValue();
SDValue ShrRHS = RHS;
ConstantSDNode *ShrCnst = dyn_cast<ConstantSDNode>(ShrRHS);
if (!ShrCnst) {
// Shift amount must be constant
return NULL;
}
uint64_t OuterShiftAmt = ShrCnst->getZExtValue();
// To avoid extra codegen and be profitable, we need Outer >= Inner
if (OuterShiftAmt < InnerShiftAmt) {
return NULL;
}
// If the outer shift is more than the type size, we have no bitfield to
// extract (since we also check that the inner shift is <= the outer shift
// then this also implies that the inner shift is < the type size)
if (OuterShiftAmt >= Val.getValueType().getSizeInBits()) {
return NULL;
}
Start =
CurDAG->getTargetConstant(OuterShiftAmt - InnerShiftAmt, MVT::i32);
Len =
CurDAG->getTargetConstant(Val.getValueType().getSizeInBits() -
OuterShiftAmt, MVT::i32);
if (N->getOpcode() == ISD::SRA) {
// If we have a arithmetic right shift, we need to use the signed bfe
// variant
IsSigned = true;
}
} else {
// No can do...
return NULL;
}
} else {
// No can do...
return NULL;
}
unsigned Opc;
// For the BFE operations we form here from "and" and "srl", always use the
// unsigned variants.
if (Val.getValueType() == MVT::i32) {
if (IsSigned) {
Opc = NVPTX::BFE_S32rii;
} else {
Opc = NVPTX::BFE_U32rii;
}
} else if (Val.getValueType() == MVT::i64) {
if (IsSigned) {
Opc = NVPTX::BFE_S64rii;
} else {
Opc = NVPTX::BFE_U64rii;
}
} else {
// We cannot handle this type
return NULL;
}
SDValue Ops[] = {
Val, Start, Len
};
SDNode *Ret =
CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops);
return Ret;
}
// SelectDirectAddr - Match a direct address for DAG.
// A direct address could be a globaladdress or externalsymbol.
bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) {

View File

@ -71,6 +71,7 @@ private:
SDNode *SelectAddrSpaceCast(SDNode *N);
SDNode *SelectTextureIntrinsic(SDNode *N);
SDNode *SelectSurfaceIntrinsic(SDNode *N);
SDNode *SelectBFE(SDNode *N);
inline SDValue getI32Imm(unsigned Imm) {
return CurDAG->getTargetConstant(Imm, MVT::i32);

View File

@ -1179,6 +1179,29 @@ def ROTR64reg_sw : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src,
!strconcat("}}", ""))))))))),
[(set Int64Regs:$dst, (rotr Int64Regs:$src, Int32Regs:$amt))]>;
// BFE - bit-field extract
multiclass BFE<string TyStr, RegisterClass RC> {
// BFE supports both 32-bit and 64-bit values, but the start and length
// operands are always 32-bit
def rrr
: NVPTXInst<(outs RC:$d),
(ins RC:$a, Int32Regs:$b, Int32Regs:$c),
!strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
def rri
: NVPTXInst<(outs RC:$d),
(ins RC:$a, Int32Regs:$b, i32imm:$c),
!strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
def rii
: NVPTXInst<(outs RC:$d),
(ins RC:$a, i32imm:$b, i32imm:$c),
!strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
}
defm BFE_S32 : BFE<"s32", Int32Regs>;
defm BFE_U32 : BFE<"u32", Int32Regs>;
defm BFE_S64 : BFE<"s64", Int64Regs>;
defm BFE_U64 : BFE<"u64", Int64Regs>;
//-----------------------------------
// General Comparison

32
test/CodeGen/NVPTX/bfe.ll Normal file
View File

@ -0,0 +1,32 @@
; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
; CHECK: bfe0
define i32 @bfe0(i32 %a) {
; CHECK: bfe.u32 %r{{[0-9]+}}, %r{{[0-9]+}}, 4, 4
; CHECK-NOT: shr
; CHECK-NOT: and
%val0 = ashr i32 %a, 4
%val1 = and i32 %val0, 15
ret i32 %val1
}
; CHECK: bfe1
define i32 @bfe1(i32 %a) {
; CHECK: bfe.u32 %r{{[0-9]+}}, %r{{[0-9]+}}, 3, 3
; CHECK-NOT: shr
; CHECK-NOT: and
%val0 = ashr i32 %a, 3
%val1 = and i32 %val0, 7
ret i32 %val1
}
; CHECK: bfe2
define i32 @bfe2(i32 %a) {
; CHECK: bfe.u32 %r{{[0-9]+}}, %r{{[0-9]+}}, 5, 3
; CHECK-NOT: shr
; CHECK-NOT: and
%val0 = ashr i32 %a, 5
%val1 = and i32 %val0, 7
ret i32 %val1
}