[AMDGPU] Move V_FMA_MIX pattern matching into tablegen. NFC

The matching for V_FMA_MIX was partially implemented with a C++
matcher (for fmas with 32 bit results and 16 bit inputs) and partially
in tablegen (for fmas with 16 bit results). Move the C++ matcher logic
into tablegen to make this more consistent and so we can remove the
duplication between SDAG and GISel.

Differential Revision: https://reviews.llvm.org/D144612
This commit is contained in:
Justin Bogner 2023-02-22 17:45:58 -08:00
parent a628ca4925
commit c083c89744
7 changed files with 61 additions and 112 deletions

View File

@ -153,6 +153,10 @@ def gi_vop3_mad_mix_mods :
GIComplexOperandMatcher<s64, "selectVOP3PMadMixMods">,
GIComplexPatternEquiv<VOP3PMadMixMods>;
def gi_vop3_mad_mix_mods_ext :
GIComplexOperandMatcher<s64, "selectVOP3PMadMixModsExt">,
GIComplexPatternEquiv<VOP3PMadMixModsExt>;
// Separate load nodes are defined to glue m0 initialization in
// SelectionDAG. The GISel selector can just insert m0 initialization
// directly before selecting a glue-less load, so hide this

View File

@ -663,10 +663,6 @@ void AMDGPUDAGToDAGISel::Select(SDNode *N) {
case ISD::BRCOND:
SelectBRCOND(N);
return;
case ISD::FMAD:
case ISD::FMA:
SelectFMAD_FMA(N);
return;
case AMDGPUISD::CVT_PKRTZ_F16_F32:
case AMDGPUISD::CVT_PKNORM_I16_F32:
case AMDGPUISD::CVT_PKNORM_U16_F32:
@ -2283,52 +2279,6 @@ void AMDGPUDAGToDAGISel::SelectBRCOND(SDNode *N) {
VCC.getValue(0));
}
void AMDGPUDAGToDAGISel::SelectFMAD_FMA(SDNode *N) {
MVT VT = N->getSimpleValueType(0);
bool IsFMA = N->getOpcode() == ISD::FMA;
if (VT != MVT::f32 || (!Subtarget->hasMadMixInsts() &&
!Subtarget->hasFmaMixInsts()) ||
((IsFMA && Subtarget->hasMadMixInsts()) ||
(!IsFMA && Subtarget->hasFmaMixInsts()))) {
SelectCode(N);
return;
}
SDValue Src0 = N->getOperand(0);
SDValue Src1 = N->getOperand(1);
SDValue Src2 = N->getOperand(2);
unsigned Src0Mods, Src1Mods, Src2Mods;
// Avoid using v_mad_mix_f32/v_fma_mix_f32 unless there is actually an operand
// using the conversion from f16.
bool Sel0 = SelectVOP3PMadMixModsImpl(Src0, Src0, Src0Mods);
bool Sel1 = SelectVOP3PMadMixModsImpl(Src1, Src1, Src1Mods);
bool Sel2 = SelectVOP3PMadMixModsImpl(Src2, Src2, Src2Mods);
assert((IsFMA || !Mode.allFP32Denormals()) &&
"fmad selected with denormals enabled");
// TODO: We can select this with f32 denormals enabled if all the sources are
// converted from f16 (in which case fmad isn't legal).
if (Sel0 || Sel1 || Sel2) {
// For dummy operands.
SDValue Zero = CurDAG->getTargetConstant(0, SDLoc(), MVT::i32);
SDValue Ops[] = {
CurDAG->getTargetConstant(Src0Mods, SDLoc(), MVT::i32), Src0,
CurDAG->getTargetConstant(Src1Mods, SDLoc(), MVT::i32), Src1,
CurDAG->getTargetConstant(Src2Mods, SDLoc(), MVT::i32), Src2,
CurDAG->getTargetConstant(0, SDLoc(), MVT::i1),
Zero, Zero
};
CurDAG->SelectNodeTo(N,
IsFMA ? AMDGPU::V_FMA_MIX_F32 : AMDGPU::V_MAD_MIX_F32,
MVT::f32, Ops);
} else {
SelectCode(N);
}
}
void AMDGPUDAGToDAGISel::SelectDSAppendConsume(SDNode *N, unsigned IntrID) {
// The address is assumed to be uniform, so if it ends up in a VGPR, it will
// be copied to an SGPR with readfirstlane.
@ -2883,6 +2833,15 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src,
return false;
}
bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
SDValue &SrcMods) const {
unsigned Mods = 0;
if (!SelectVOP3PMadMixModsImpl(In, Src, Mods))
return false;
SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
return true;
}
bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixMods(SDValue In, SDValue &Src,
SDValue &SrcMods) const {
unsigned Mods = 0;

View File

@ -248,6 +248,8 @@ private:
bool SelectVOP3OpSelMods(SDValue In, SDValue &Src, SDValue &SrcMods) const;
bool SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src,
unsigned &Mods) const;
bool SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
SDValue &SrcMods) const;
bool SelectVOP3PMadMixMods(SDValue In, SDValue &Src, SDValue &SrcMods) const;
SDValue getHi16Elt(SDValue In) const;

View File

@ -523,60 +523,6 @@ bool AMDGPUInstructionSelector::selectG_EXTRACT(MachineInstr &I) const {
return true;
}
bool AMDGPUInstructionSelector::selectG_FMA_FMAD(MachineInstr &I) const {
assert(I.getOpcode() == AMDGPU::G_FMA || I.getOpcode() == AMDGPU::G_FMAD);
// Try to manually select MAD_MIX/FMA_MIX.
Register Dst = I.getOperand(0).getReg();
LLT ResultTy = MRI->getType(Dst);
bool IsFMA = I.getOpcode() == AMDGPU::G_FMA;
if (ResultTy != LLT::scalar(32) ||
(IsFMA ? !Subtarget->hasFmaMixInsts() : !Subtarget->hasMadMixInsts()))
return false;
// Avoid using v_mad_mix_f32/v_fma_mix_f32 unless there is actually an operand
// using the conversion from f16.
bool MatchedSrc0, MatchedSrc1, MatchedSrc2;
auto [Src0, Src0Mods] =
selectVOP3PMadMixModsImpl(I.getOperand(1), MatchedSrc0);
auto [Src1, Src1Mods] =
selectVOP3PMadMixModsImpl(I.getOperand(2), MatchedSrc1);
auto [Src2, Src2Mods] =
selectVOP3PMadMixModsImpl(I.getOperand(3), MatchedSrc2);
#ifndef NDEBUG
const SIMachineFunctionInfo *MFI =
I.getMF()->getInfo<SIMachineFunctionInfo>();
SIModeRegisterDefaults Mode = MFI->getMode();
assert((IsFMA || !Mode.allFP32Denormals()) &&
"fmad selected with denormals enabled");
#endif
// TODO: We can select this with f32 denormals enabled if all the sources are
// converted from f16 (in which case fmad isn't legal).
if (!MatchedSrc0 && !MatchedSrc1 && !MatchedSrc2)
return false;
const unsigned OpC = IsFMA ? AMDGPU::V_FMA_MIX_F32 : AMDGPU::V_MAD_MIX_F32;
MachineInstr *MixInst =
BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(OpC), Dst)
.addImm(Src0Mods)
.addReg(copyToVGPRIfSrcFolded(Src0, Src0Mods, I.getOperand(1), &I))
.addImm(Src1Mods)
.addReg(copyToVGPRIfSrcFolded(Src1, Src1Mods, I.getOperand(2), &I))
.addImm(Src2Mods)
.addReg(copyToVGPRIfSrcFolded(Src2, Src2Mods, I.getOperand(3), &I))
.addImm(0)
.addImm(0)
.addImm(0);
if (!constrainSelectedInstRegOperands(*MixInst, TII, TRI, RBI))
return false;
I.eraseFromParent();
return true;
}
bool AMDGPUInstructionSelector::selectG_MERGE_VALUES(MachineInstr &MI) const {
MachineBasicBlock *BB = MI.getParent();
Register DstReg = MI.getOperand(0).getReg();
@ -3405,11 +3351,6 @@ bool AMDGPUInstructionSelector::select(MachineInstr &I) {
return selectG_FABS(I);
case TargetOpcode::G_EXTRACT:
return selectG_EXTRACT(I);
case TargetOpcode::G_FMA:
case TargetOpcode::G_FMAD:
if (selectG_FMA_FMAD(I))
return true;
return selectImpl(I, *CoverageInfo);
case TargetOpcode::G_MERGE_VALUES:
case TargetOpcode::G_CONCAT_VECTORS:
return selectG_MERGE_VALUES(I);
@ -4987,6 +4928,22 @@ AMDGPUInstructionSelector::selectVOP3PMadMixModsImpl(MachineOperand &Root,
return {Src, Mods};
}
InstructionSelector::ComplexRendererFns
AMDGPUInstructionSelector::selectVOP3PMadMixModsExt(
MachineOperand &Root) const {
Register Src;
unsigned Mods;
bool Matched;
std::tie(Src, Mods) = selectVOP3PMadMixModsImpl(Root, Matched);
if (!Matched)
return {};
return {{
[=](MachineInstrBuilder &MIB) { MIB.addReg(Src); },
[=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods
}};
}
InstructionSelector::ComplexRendererFns
AMDGPUInstructionSelector::selectVOP3PMadMixMods(MachineOperand &Root) const {
Register Src;

View File

@ -297,6 +297,7 @@ private:
std::pair<Register, unsigned> selectVOP3PMadMixModsImpl(MachineOperand &Root,
bool &Matched) const;
ComplexRendererFns selectVOP3PMadMixModsExt(MachineOperand &Root) const;
ComplexRendererFns selectVOP3PMadMixMods(MachineOperand &Root) const;
void renderTruncImm32(MachineInstrBuilder &MIB, const MachineInstr &MI,

View File

@ -1511,7 +1511,8 @@ def VOP3OpSel : ComplexPattern<untyped, 2, "SelectVOP3OpSel">;
def VOP3OpSelMods : ComplexPattern<untyped, 2, "SelectVOP3OpSelMods">;
def VOP3PMadMixMods : ComplexPattern<untyped, 2, "SelectVOP3PMadMixMods">;
def VOP3PMadMixModsExt : ComplexPattern<untyped, 2, "SelectVOP3PMadMixModsExt">;
def VOP3PMadMixMods : ComplexPattern<untyped, 2, "SelectVOP3PMadMixMods">;
def VINTERPMods : ComplexPattern<untyped, 2, "SelectVINTERPMods">;
def VINTERPModsHi : ComplexPattern<untyped, 2, "SelectVINTERPModsHi">;

View File

@ -142,9 +142,34 @@ def : VOP3PSatPat<usubsat, V_PK_SUB_U16>;
def : VOP3PSatPat<ssubsat, V_PK_SUB_I16>;
} // End SubtargetPredicate = HasVOP3PInsts
// TODO: Make sure we're doing the right thing with denormals. Note
// that FMA and MAD will differ.
multiclass MadFmaMixPats<SDPatternOperator fma_like,
Instruction mix_inst,
Instruction mixlo_inst,
Instruction mixhi_inst> {
// At least one of the operands needs to be an fpextend of an f16
// for this to be worthwhile, so we need three patterns here.
// TODO: Could we use a predicate to inspect src1/2/3 instead?
def : GCNPat <
(f32 (fma_like (f32 (VOP3PMadMixModsExt f16:$src0, i32:$src0_mods)),
(f32 (VOP3PMadMixMods f16:$src1, i32:$src1_mods)),
(f32 (VOP3PMadMixMods f16:$src2, i32:$src2_mods)))),
(mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
DSTCLAMP.NONE)>;
def : GCNPat <
(f32 (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_mods)),
(f32 (VOP3PMadMixModsExt f16:$src1, i32:$src1_mods)),
(f32 (VOP3PMadMixMods f32:$src2, i32:$src2_mods)))),
(mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
DSTCLAMP.NONE)>;
def : GCNPat <
(f32 (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_mods)),
(f32 (VOP3PMadMixMods f32:$src1, i32:$src1_mods)),
(f32 (VOP3PMadMixModsExt f16:$src2, i32:$src2_mods)))),
(mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
DSTCLAMP.NONE)>;
def : GCNPat <
(f16 (fpround (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_modifiers)),
(f32 (VOP3PMadMixMods f16:$src1, i32:$src1_modifiers)),
@ -222,7 +247,7 @@ defm V_MAD_MIXHI_F16 : VOP3_VOP3PInst<"v_mad_mixhi_f16", VOP3P_Mix_Profile<VOP_F
} // End FPDPRounding = 1
}
defm : MadFmaMixPats<fmad, V_MAD_MIXLO_F16, V_MAD_MIXHI_F16>;
defm : MadFmaMixPats<fmad, V_MAD_MIX_F32, V_MAD_MIXLO_F16, V_MAD_MIXHI_F16>;
} // End SubtargetPredicate = HasMadMixInsts
@ -243,7 +268,7 @@ defm V_FMA_MIXHI_F16 : VOP3_VOP3PInst<"v_fma_mixhi_f16", VOP3P_Mix_Profile<VOP_F
} // End FPDPRounding = 1
}
defm : MadFmaMixPats<fma, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>;
defm : MadFmaMixPats<fma, V_FMA_MIX_F32, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>;
}
// Defines patterns that extract signed 4bit from each Idx[0].