diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td index 9ef59a4e9c49..0f3e3c0c2c7e 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUGISel.td +++ b/llvm/lib/Target/AMDGPU/AMDGPUGISel.td @@ -153,6 +153,10 @@ def gi_vop3_mad_mix_mods : GIComplexOperandMatcher, GIComplexPatternEquiv; +def gi_vop3_mad_mix_mods_ext : + GIComplexOperandMatcher, + GIComplexPatternEquiv; + // Separate load nodes are defined to glue m0 initialization in // SelectionDAG. The GISel selector can just insert m0 initialization // directly before selecting a glue-less load, so hide this diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp index 382eeeb9bde1..28c26b2998f0 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp @@ -663,10 +663,6 @@ void AMDGPUDAGToDAGISel::Select(SDNode *N) { case ISD::BRCOND: SelectBRCOND(N); return; - case ISD::FMAD: - case ISD::FMA: - SelectFMAD_FMA(N); - return; case AMDGPUISD::CVT_PKRTZ_F16_F32: case AMDGPUISD::CVT_PKNORM_I16_F32: case AMDGPUISD::CVT_PKNORM_U16_F32: @@ -2283,52 +2279,6 @@ void AMDGPUDAGToDAGISel::SelectBRCOND(SDNode *N) { VCC.getValue(0)); } -void AMDGPUDAGToDAGISel::SelectFMAD_FMA(SDNode *N) { - MVT VT = N->getSimpleValueType(0); - bool IsFMA = N->getOpcode() == ISD::FMA; - if (VT != MVT::f32 || (!Subtarget->hasMadMixInsts() && - !Subtarget->hasFmaMixInsts()) || - ((IsFMA && Subtarget->hasMadMixInsts()) || - (!IsFMA && Subtarget->hasFmaMixInsts()))) { - SelectCode(N); - return; - } - - SDValue Src0 = N->getOperand(0); - SDValue Src1 = N->getOperand(1); - SDValue Src2 = N->getOperand(2); - unsigned Src0Mods, Src1Mods, Src2Mods; - - // Avoid using v_mad_mix_f32/v_fma_mix_f32 unless there is actually an operand - // using the conversion from f16. - bool Sel0 = SelectVOP3PMadMixModsImpl(Src0, Src0, Src0Mods); - bool Sel1 = SelectVOP3PMadMixModsImpl(Src1, Src1, Src1Mods); - bool Sel2 = SelectVOP3PMadMixModsImpl(Src2, Src2, Src2Mods); - - assert((IsFMA || !Mode.allFP32Denormals()) && - "fmad selected with denormals enabled"); - // TODO: We can select this with f32 denormals enabled if all the sources are - // converted from f16 (in which case fmad isn't legal). - - if (Sel0 || Sel1 || Sel2) { - // For dummy operands. - SDValue Zero = CurDAG->getTargetConstant(0, SDLoc(), MVT::i32); - SDValue Ops[] = { - CurDAG->getTargetConstant(Src0Mods, SDLoc(), MVT::i32), Src0, - CurDAG->getTargetConstant(Src1Mods, SDLoc(), MVT::i32), Src1, - CurDAG->getTargetConstant(Src2Mods, SDLoc(), MVT::i32), Src2, - CurDAG->getTargetConstant(0, SDLoc(), MVT::i1), - Zero, Zero - }; - - CurDAG->SelectNodeTo(N, - IsFMA ? AMDGPU::V_FMA_MIX_F32 : AMDGPU::V_MAD_MIX_F32, - MVT::f32, Ops); - } else { - SelectCode(N); - } -} - void AMDGPUDAGToDAGISel::SelectDSAppendConsume(SDNode *N, unsigned IntrID) { // The address is assumed to be uniform, so if it ends up in a VGPR, it will // be copied to an SGPR with readfirstlane. @@ -2883,6 +2833,15 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src, return false; } +bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src, + SDValue &SrcMods) const { + unsigned Mods = 0; + if (!SelectVOP3PMadMixModsImpl(In, Src, Mods)) + return false; + SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32); + return true; +} + bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixMods(SDValue In, SDValue &Src, SDValue &SrcMods) const { unsigned Mods = 0; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h index 8c4e378e72e2..12912b77edaf 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h @@ -248,6 +248,8 @@ private: bool SelectVOP3OpSelMods(SDValue In, SDValue &Src, SDValue &SrcMods) const; bool SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src, unsigned &Mods) const; + bool SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src, + SDValue &SrcMods) const; bool SelectVOP3PMadMixMods(SDValue In, SDValue &Src, SDValue &SrcMods) const; SDValue getHi16Elt(SDValue In) const; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp index 17490707e3d2..7d3536df7a0a 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp @@ -523,60 +523,6 @@ bool AMDGPUInstructionSelector::selectG_EXTRACT(MachineInstr &I) const { return true; } -bool AMDGPUInstructionSelector::selectG_FMA_FMAD(MachineInstr &I) const { - assert(I.getOpcode() == AMDGPU::G_FMA || I.getOpcode() == AMDGPU::G_FMAD); - - // Try to manually select MAD_MIX/FMA_MIX. - Register Dst = I.getOperand(0).getReg(); - LLT ResultTy = MRI->getType(Dst); - bool IsFMA = I.getOpcode() == AMDGPU::G_FMA; - if (ResultTy != LLT::scalar(32) || - (IsFMA ? !Subtarget->hasFmaMixInsts() : !Subtarget->hasMadMixInsts())) - return false; - - // Avoid using v_mad_mix_f32/v_fma_mix_f32 unless there is actually an operand - // using the conversion from f16. - bool MatchedSrc0, MatchedSrc1, MatchedSrc2; - auto [Src0, Src0Mods] = - selectVOP3PMadMixModsImpl(I.getOperand(1), MatchedSrc0); - auto [Src1, Src1Mods] = - selectVOP3PMadMixModsImpl(I.getOperand(2), MatchedSrc1); - auto [Src2, Src2Mods] = - selectVOP3PMadMixModsImpl(I.getOperand(3), MatchedSrc2); - -#ifndef NDEBUG - const SIMachineFunctionInfo *MFI = - I.getMF()->getInfo(); - SIModeRegisterDefaults Mode = MFI->getMode(); - assert((IsFMA || !Mode.allFP32Denormals()) && - "fmad selected with denormals enabled"); -#endif - - // TODO: We can select this with f32 denormals enabled if all the sources are - // converted from f16 (in which case fmad isn't legal). - if (!MatchedSrc0 && !MatchedSrc1 && !MatchedSrc2) - return false; - - const unsigned OpC = IsFMA ? AMDGPU::V_FMA_MIX_F32 : AMDGPU::V_MAD_MIX_F32; - MachineInstr *MixInst = - BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(OpC), Dst) - .addImm(Src0Mods) - .addReg(copyToVGPRIfSrcFolded(Src0, Src0Mods, I.getOperand(1), &I)) - .addImm(Src1Mods) - .addReg(copyToVGPRIfSrcFolded(Src1, Src1Mods, I.getOperand(2), &I)) - .addImm(Src2Mods) - .addReg(copyToVGPRIfSrcFolded(Src2, Src2Mods, I.getOperand(3), &I)) - .addImm(0) - .addImm(0) - .addImm(0); - - if (!constrainSelectedInstRegOperands(*MixInst, TII, TRI, RBI)) - return false; - - I.eraseFromParent(); - return true; -} - bool AMDGPUInstructionSelector::selectG_MERGE_VALUES(MachineInstr &MI) const { MachineBasicBlock *BB = MI.getParent(); Register DstReg = MI.getOperand(0).getReg(); @@ -3405,11 +3351,6 @@ bool AMDGPUInstructionSelector::select(MachineInstr &I) { return selectG_FABS(I); case TargetOpcode::G_EXTRACT: return selectG_EXTRACT(I); - case TargetOpcode::G_FMA: - case TargetOpcode::G_FMAD: - if (selectG_FMA_FMAD(I)) - return true; - return selectImpl(I, *CoverageInfo); case TargetOpcode::G_MERGE_VALUES: case TargetOpcode::G_CONCAT_VECTORS: return selectG_MERGE_VALUES(I); @@ -4987,6 +4928,22 @@ AMDGPUInstructionSelector::selectVOP3PMadMixModsImpl(MachineOperand &Root, return {Src, Mods}; } +InstructionSelector::ComplexRendererFns +AMDGPUInstructionSelector::selectVOP3PMadMixModsExt( + MachineOperand &Root) const { + Register Src; + unsigned Mods; + bool Matched; + std::tie(Src, Mods) = selectVOP3PMadMixModsImpl(Root, Matched); + if (!Matched) + return {}; + + return {{ + [=](MachineInstrBuilder &MIB) { MIB.addReg(Src); }, + [=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods + }}; +} + InstructionSelector::ComplexRendererFns AMDGPUInstructionSelector::selectVOP3PMadMixMods(MachineOperand &Root) const { Register Src; diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h index 99af9ddb048d..0ccf02ba41cf 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h @@ -297,6 +297,7 @@ private: std::pair selectVOP3PMadMixModsImpl(MachineOperand &Root, bool &Matched) const; + ComplexRendererFns selectVOP3PMadMixModsExt(MachineOperand &Root) const; ComplexRendererFns selectVOP3PMadMixMods(MachineOperand &Root) const; void renderTruncImm32(MachineInstrBuilder &MIB, const MachineInstr &MI, diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td index 8253641957a8..e0fea7d300ad 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td @@ -1511,7 +1511,8 @@ def VOP3OpSel : ComplexPattern; def VOP3OpSelMods : ComplexPattern; -def VOP3PMadMixMods : ComplexPattern; +def VOP3PMadMixModsExt : ComplexPattern; +def VOP3PMadMixMods : ComplexPattern; def VINTERPMods : ComplexPattern; def VINTERPModsHi : ComplexPattern; diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td index 2c7888e58761..8f8c4489454d 100644 --- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td @@ -142,9 +142,34 @@ def : VOP3PSatPat; def : VOP3PSatPat; } // End SubtargetPredicate = HasVOP3PInsts +// TODO: Make sure we're doing the right thing with denormals. Note +// that FMA and MAD will differ. multiclass MadFmaMixPats { + // At least one of the operands needs to be an fpextend of an f16 + // for this to be worthwhile, so we need three patterns here. + // TODO: Could we use a predicate to inspect src1/2/3 instead? + def : GCNPat < + (f32 (fma_like (f32 (VOP3PMadMixModsExt f16:$src0, i32:$src0_mods)), + (f32 (VOP3PMadMixMods f16:$src1, i32:$src1_mods)), + (f32 (VOP3PMadMixMods f16:$src2, i32:$src2_mods)))), + (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2, + DSTCLAMP.NONE)>; + def : GCNPat < + (f32 (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_mods)), + (f32 (VOP3PMadMixModsExt f16:$src1, i32:$src1_mods)), + (f32 (VOP3PMadMixMods f32:$src2, i32:$src2_mods)))), + (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2, + DSTCLAMP.NONE)>; + def : GCNPat < + (f32 (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_mods)), + (f32 (VOP3PMadMixMods f32:$src1, i32:$src1_mods)), + (f32 (VOP3PMadMixModsExt f16:$src2, i32:$src2_mods)))), + (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2, + DSTCLAMP.NONE)>; + def : GCNPat < (f16 (fpround (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_modifiers)), (f32 (VOP3PMadMixMods f16:$src1, i32:$src1_modifiers)), @@ -222,7 +247,7 @@ defm V_MAD_MIXHI_F16 : VOP3_VOP3PInst<"v_mad_mixhi_f16", VOP3P_Mix_Profile; +defm : MadFmaMixPats; } // End SubtargetPredicate = HasMadMixInsts @@ -243,7 +268,7 @@ defm V_FMA_MIXHI_F16 : VOP3_VOP3PInst<"v_fma_mixhi_f16", VOP3P_Mix_Profile; +defm : MadFmaMixPats; } // Defines patterns that extract signed 4bit from each Idx[0].