[RISCV][NFC] Make Reduction scheduler resources SEW aware

Create SchedWrites, WriteRes for reduction instructions that
are SEW specific. Future patches can use these resources
to customize the behavior of these resources depending on SEW.

Differential Revision: https://reviews.llvm.org/D151470
This commit is contained in:
Michael Maitland 2023-05-25 10:09:37 -07:00
parent 6042a1ac18
commit d70573b18e
3 changed files with 119 additions and 84 deletions

View File

@ -115,8 +115,14 @@ defvar MxListF = [V_MF4, V_MF2, V_M1, V_M2, V_M4, V_M8];
// Used for widening and narrowing instructions as it doesn't contain M8.
defvar MxListW = [V_MF8, V_MF4, V_MF2, V_M1, V_M2, V_M4];
// Used for widening reductions. It can contain M8 because wider operands are
// scalar operands.
defvar MxListWRed = MxList;
// For floating point which don't need MF8.
defvar MxListFW = [V_MF4, V_MF2, V_M1, V_M2, V_M4];
// For widening floating-point Reduction as it doesn't contain MF8. It can
// contain M8 because wider operands are scalar operands.
defvar MxListFWRed = [V_MF4, V_MF2, V_M1, V_M2, V_M4, V_M8];
// Use for zext/sext.vf2
defvar MxListVF2 = [V_MF4, V_MF2, V_M1, V_M2, V_M4, V_M8];
@ -3180,16 +3186,14 @@ multiclass VPseudoTernaryWithTailPolicy_E<VReg RetClass,
RegisterClass Op1Class,
DAGOperand Op2Class,
LMULInfo MInfo,
int sew,
string Constraint = "",
bit Commutable = 0> {
let VLMul = MInfo.value in {
defvar mx = MInfo.MX;
defvar sews = SchedSEWSet<mx>.val;
foreach e = sews in {
let isCommutable = Commutable in
def "_" # mx # "_E" # e : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
def "_" # mx # "_E" # e # "_MASK" : VPseudoBinaryTailPolicy<RetClass, Op1Class, Op2Class, Constraint>;
}
def "_" # mx # "_E" # sew : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
def "_" # mx # "_E" # sew # "_MASK" : VPseudoBinaryTailPolicy<RetClass, Op1Class, Op2Class, Constraint>;
}
}
@ -3448,50 +3452,60 @@ multiclass VPseudoVCMPM_VX_VI {
multiclass VPseudoVRED_VS {
foreach m = MxList in {
defvar mx = m.MX;
defvar WriteVIRedV_From_MX = !cast<SchedWrite>("WriteVIRedV_From_" # mx);
defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m>,
Sched<[WriteVIRedV_From_MX, ReadVIRedV, ReadVIRedV, ReadVIRedV,
ReadVMask]>;
foreach e = SchedSEWSet<mx>.val in {
defvar WriteVIRedV_From_MX_E = !cast<SchedWrite>("WriteVIRedV_From_" # mx # "_E" # e);
defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m, e>,
Sched<[WriteVIRedV_From_MX_E, ReadVIRedV, ReadVIRedV, ReadVIRedV,
ReadVMask]>;
}
}
}
multiclass VPseudoVWRED_VS {
foreach m = MxList in {
foreach m = MxListWRed in {
defvar mx = m.MX;
defvar WriteVIWRedV_From_MX = !cast<SchedWrite>("WriteVIWRedV_From_" # mx);
defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m>,
Sched<[WriteVIWRedV_From_MX, ReadVIWRedV, ReadVIWRedV,
ReadVIWRedV, ReadVMask]>;
foreach e = SchedSEWSet<mx, 1>.val in {
defvar WriteVIWRedV_From_MX_E = !cast<SchedWrite>("WriteVIWRedV_From_" # mx # "_E" # e);
defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m, e>,
Sched<[WriteVIWRedV_From_MX_E, ReadVIWRedV, ReadVIWRedV,
ReadVIWRedV, ReadVMask]>;
}
}
}
multiclass VPseudoVFRED_VS {
foreach m = MxListF in {
defvar mx = m.MX;
defvar WriteVFRedV_From_MX = !cast<SchedWrite>("WriteVFRedV_From_" # mx);
defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m>,
Sched<[WriteVFRedV_From_MX, ReadVFRedV, ReadVFRedV, ReadVFRedV,
ReadVMask]>;
foreach e = SchedSEWSetF<mx>.val in {
defvar WriteVFRedV_From_MX_E = !cast<SchedWrite>("WriteVFRedV_From_" # mx # "_E" # e);
defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m, e>,
Sched<[WriteVFRedV_From_MX_E, ReadVFRedV, ReadVFRedV, ReadVFRedV,
ReadVMask]>;
}
}
}
multiclass VPseudoVFREDO_VS {
foreach m = MxListF in {
defvar mx = m.MX;
defvar WriteVFRedOV_From_MX = !cast<SchedWrite>("WriteVFRedOV_From_" # mx);
defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m>,
Sched<[WriteVFRedOV_From_MX, ReadVFRedOV, ReadVFRedOV,
ReadVFRedOV, ReadVMask]>;
foreach e = SchedSEWSetF<mx>.val in {
defvar WriteVFRedOV_From_MX_E = !cast<SchedWrite>("WriteVFRedOV_From_" # mx # "_E" # e);
defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m, e>,
Sched<[WriteVFRedOV_From_MX_E, ReadVFRedOV, ReadVFRedOV,
ReadVFRedOV, ReadVMask]>;
}
}
}
multiclass VPseudoVFWRED_VS {
foreach m = MxListF in {
foreach m = MxListFWRed in {
defvar mx = m.MX;
defvar WriteVFWRedV_From_MX = !cast<SchedWrite>("WriteVFWRedV_From_" # mx);
defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m>,
Sched<[WriteVFWRedV_From_MX, ReadVFWRedV, ReadVFWRedV,
ReadVFWRedV, ReadVMask]>;
foreach e = SchedSEWSetF<mx, 1>.val in {
defvar WriteVFWRedV_From_MX_E = !cast<SchedWrite>("WriteVFWRedV_From_" # mx # "_E" # e);
defm _VS : VPseudoTernaryWithTailPolicy_E<V_M1.vrclass, m.vrclass, V_M1.vrclass, m, e>,
Sched<[WriteVFWRedV_From_MX_E, ReadVFWRedV, ReadVFWRedV,
ReadVFWRedV, ReadVMask]>;
}
}
}

View File

@ -620,12 +620,12 @@ foreach mx = SchedMxListFW in {
// 14. Vector Reduction Operations
let Latency = 32 in {
defm "" : LMULWriteRes<"WriteVIRedV_From", [SiFive7VA]>;
defm "" : LMULWriteRes<"WriteVIWRedV_From", [SiFive7VA]>;
defm "" : LMULWriteRes<"WriteVFRedV_From", [SiFive7VA]>;
defm "" : LMULWriteRes<"WriteVFRedOV_From", [SiFive7VA]>;
defm "" : LMULWriteResFWRed<"WriteVFWRedV_From", [SiFive7VA]>;
defm "" : LMULWriteResFWRed<"WriteVFWRedOV_From", [SiFive7VA]>;
defm "" : LMULSEWWriteRes<"WriteVIRedV_From", [SiFive7VA]>;
defm "" : LMULSEWWriteRes<"WriteVIWRedV_From", [SiFive7VA]>;
defm "" : LMULSEWWriteRes<"WriteVFRedV_From", [SiFive7VA]>;
defm "" : LMULSEWWriteRes<"WriteVFRedOV_From", [SiFive7VA]>;
defm "" : LMULSEWWriteResFWRed<"WriteVFWRedV_From", [SiFive7VA]>;
defm "" : LMULSEWWriteResFWRed<"WriteVFWRedOV_From", [SiFive7VA]>;
}
// 15. Vector Mask Instructions

View File

@ -12,30 +12,35 @@
defvar SchedMxList = ["MF8", "MF4", "MF2", "M1", "M2", "M4", "M8"];
// Used for widening and narrowing instructions as it doesn't contain M8.
defvar SchedMxListW = !listremove(SchedMxList, ["M8"]);
// Used for widening reductions, which does contain M8.
defvar SchedMxListWRed = SchedMxList;
defvar SchedMxListFW = !listremove(SchedMxList, ["M8", "MF8"]);
// Used for floating-point as it doesn't contain MF8.
defvar SchedMxListF = !listremove(SchedMxList, ["MF8"]);
// Used for widening floating-point Reduction as it doesn't contain MF8.
defvar SchedMxListFWRed = SchedMxListF;
class SchedSEWSet<string mx> {
list<int> val = !cond(!eq(mx, "M1"): [8, 16, 32, 64],
!eq(mx, "M2"): [8, 16, 32, 64],
!eq(mx, "M4"): [8, 16, 32, 64],
!eq(mx, "M8"): [8, 16, 32, 64],
!eq(mx, "MF2"): [8, 16, 32],
!eq(mx, "MF4"): [8, 16],
!eq(mx, "MF8"): [8]);
// For widening instructions, SEW will not be 64.
class SchedSEWSet<string mx, bit isWidening = 0> {
defvar t = !cond(!eq(mx, "M1"): [8, 16, 32, 64],
!eq(mx, "M2"): [8, 16, 32, 64],
!eq(mx, "M4"): [8, 16, 32, 64],
!eq(mx, "M8"): [8, 16, 32, 64],
!eq(mx, "MF2"): [8, 16, 32],
!eq(mx, "MF4"): [8, 16],
!eq(mx, "MF8"): [8]);
list<int> val = !if(isWidening, !listremove(t, [64]), t);
}
// For floating-point instructions, SEW won't be 8.
class SchedSEWSetF<string mx> {
list<int> val = !cond(!eq(mx, "M1"): [16, 32, 64],
!eq(mx, "M2"): [16, 32, 64],
!eq(mx, "M4"): [16, 32, 64],
!eq(mx, "M8"): [16, 32, 64],
!eq(mx, "MF2"): [16, 32],
!eq(mx, "MF4"): [16]);
class SchedSEWSetF<string mx, bit isWidening = 0> {
defvar t = !cond(!eq(mx, "M1"): [16, 32, 64],
!eq(mx, "M2"): [16, 32, 64],
!eq(mx, "M4"): [16, 32, 64],
!eq(mx, "M8"): [16, 32, 64],
!eq(mx, "MF2"): [16, 32],
!eq(mx, "MF4"): [16]);
list<int> val = !if(isWidening, !listremove(t, [64]), t);
}
// Helper function to get the largest LMUL from MxList
@ -102,34 +107,46 @@ multiclass LMULReadAdvanceImpl<string name, int val,
// ReadAdvance for each (name, LMUL, SEW) tuple for each LMUL in each of the
// SchedMxList variants above. Each multiclass is responsible for defining
// a record that represents the WorseCase behavior for name.
multiclass LMULSEWSchedWritesImpl<string name, list<string> MxList, bit isF = 0> {
multiclass LMULSEWSchedWritesImpl<string name, list<string> MxList, bit isF = 0,
bit isWidening = 0> {
def name # "_WorstCase" : SchedWrite;
foreach mx = MxList in {
foreach sew = !if(isF, SchedSEWSetF<mx>.val, SchedSEWSet<mx>.val) in
foreach sew = !if(isF, SchedSEWSetF<mx, isWidening>.val,
SchedSEWSet<mx, isWidening>.val) in
def name # "_" # mx # "_E" # sew : SchedWrite;
}
}
multiclass LMULSEWSchedReadsImpl<string name, list<string> MxList, bit isF = 0> {
multiclass LMULSEWSchedReadsImpl<string name, list<string> MxList, bit isF = 0,
bit isWidening = 0> {
def name # "_WorstCase" : SchedRead;
foreach mx = MxList in {
foreach sew = !if(isF, SchedSEWSetF<mx>.val, SchedSEWSet<mx>.val) in
foreach sew = !if(isF,SchedSEWSetF<mx, isWidening>.val,
SchedSEWSet<mx, isWidening>.val) in
def name # "_" # mx # "_E" # sew : SchedRead;
}
}
multiclass LMULSEWWriteResImpl<string name, list<ProcResourceKind> resources,
bit isF = 0> {
def : WriteRes<!cast<SchedWrite>(name # "_WorstCase"), resources>;
foreach mx = !if(isF, SchedMxListF, SchedMxList) in {
foreach sew = !if(isF, SchedSEWSetF<mx>.val, SchedSEWSet<mx>.val) in
def : WriteRes<!cast<SchedWrite>(name # "_" # mx # "_E" # sew), resources>;
list<string> MxList, bit isF = 0,
bit isWidening = 0> {
if !exists<SchedWrite>(name # "_WorstCase") then
def : WriteRes<!cast<SchedWrite>(name # "_WorstCase"), resources>;
foreach mx = MxList in {
foreach sew = !if(isF,SchedSEWSetF<mx, isWidening>.val,
SchedSEWSet<mx, isWidening>.val) in
if !exists<SchedWrite>(name # "_" # mx # "_E" # sew) then
def : WriteRes<!cast<SchedWrite>(name # "_" # mx # "_E" # sew), resources>;
}
}
multiclass LMULSEWReadAdvanceImpl<string name, int val, list<SchedWrite> writes = [],
bit isF = 0> {
def : ReadAdvance<!cast<SchedRead>(name # "_WorstCase"), val, writes>;
foreach mx = !if(isF, SchedMxListF, SchedMxList) in {
foreach sew = !if(isF, SchedSEWSetF<mx>.val, SchedSEWSet<mx>.val) in
def : ReadAdvance<!cast<SchedRead>(name # "_" # mx # "_E" # sew), val, writes>;
list<string> MxList, bit isF = 0,
bit isWidening = 0> {
if !exists<SchedRead>(name # "_WorstCase") then
def : ReadAdvance<!cast<SchedRead>(name # "_WorstCase"), val, writes>;
foreach mx = MxList in {
foreach sew = !if(isF,SchedSEWSetF<mx, isWidening>.val,
SchedSEWSet<mx, isWidening>.val) in
if !exists<SchedRead>(name # "_" # mx # "_E" # sew) then
def : ReadAdvance<!cast<SchedRead>(name # "_" # mx # "_E" # sew), val, writes>;
}
}
// Define classes to define list containing all SchedWrites for each (name, LMUL)
@ -159,16 +176,26 @@ class LMULSchedWriteList<list<string> names> : LMULSchedWriteListImpl<names, Sch
multiclass LMULSEWSchedWrites<string name> : LMULSEWSchedWritesImpl<name, SchedMxList>;
multiclass LMULSEWSchedReads<string name> : LMULSEWSchedReadsImpl<name, SchedMxList>;
multiclass LMULSEWWriteRes<string name, list<ProcResourceKind> resources>
: LMULSEWWriteResImpl<name, resources>;
: LMULSEWWriteResImpl<name, resources, SchedMxList>;
multiclass LMULSEWReadAdvance<string name, int val, list<SchedWrite> writes = []>
: LMULSEWReadAdvanceImpl<name, val, writes>;
: LMULSEWReadAdvanceImpl<name, val, writes, SchedMxList>;
multiclass LMULSEWSchedWritesWRed<string name>
: LMULSEWSchedWritesImpl<name, SchedMxListWRed, 0, 1>;
multiclass LMULSEWWriteResWRed<string name, list<ProcResourceKind> resources>
: LMULSEWWriteResImpl<name, resources, SchedMxListWRed, 0, 1>;
multiclass LMULSEWSchedWritesFWRed<string name>
: LMULSEWSchedWritesImpl<name, SchedMxListFWRed, 1, 1>;
multiclass LMULSEWWriteResFWRed<string name, list<ProcResourceKind> resources>
: LMULSEWWriteResImpl<name, resources, SchedMxListFWRed, 1, 1>;
multiclass LMULSEWSchedWritesF<string name> : LMULSEWSchedWritesImpl<name, SchedMxListF, 1>;
multiclass LMULSEWSchedReadsF<string name> : LMULSEWSchedReadsImpl<name, SchedMxListF, 1>;
multiclass LMULSEWWriteResF<string name, list<ProcResourceKind> resources>
: LMULSEWWriteResImpl<name, resources, 1>;
: LMULSEWWriteResImpl<name, resources, SchedMxListF, 1>;
multiclass LMULSEWReadAdvanceF<string name, int val, list<SchedWrite> writes = []>
: LMULSEWReadAdvanceImpl<name, val, writes, 1>;
: LMULSEWReadAdvanceImpl<name, val, writes, SchedMxListF, 1>;
multiclass LMULSchedWritesW<string name> : LMULSchedWritesImpl<name, SchedMxListW>;
multiclass LMULSchedReadsW<string name> : LMULSchedReadsImpl<name, SchedMxListW>;
@ -186,12 +213,6 @@ multiclass LMULReadAdvanceFW<string name, int val, list<SchedWrite> writes = []>
: LMULReadAdvanceImpl<name, val, writes>;
class LMULSchedWriteListFW<list<string> names> : LMULSchedWriteListImpl<names, SchedMxListFW>;
multiclass LMULSchedWritesFWRed<string name> : LMULSchedWritesImpl<name, SchedMxListFWRed>;
multiclass LMULWriteResFWRed<string name, list<ProcResourceKind> resources>
: LMULWriteResImpl<name, resources>;
class LMULSchedWriteListFWRed<list<string> names> : LMULSchedWriteListImpl<names, SchedMxListFWRed>;
// 3.6 Vector Byte Length vlenb
def WriteRdVLENB : SchedWrite;
@ -389,15 +410,15 @@ defm "" : LMULSchedWritesFW<"WriteVFNCvtFToFV">;
// MF8 and M8. Use the _From suffix to indicate the number of the
// LMUL from VS2.
// 14.1. Vector Single-Width Integer Reduction Instructions
defm "" : LMULSchedWrites<"WriteVIRedV_From">;
defm "" : LMULSEWSchedWrites<"WriteVIRedV_From">;
// 14.2. Vector Widening Integer Reduction Instructions
defm "" : LMULSchedWrites<"WriteVIWRedV_From">;
defm "" : LMULSEWSchedWritesWRed<"WriteVIWRedV_From">;
// 14.3. Vector Single-Width Floating-Point Reduction Instructions
defm "" : LMULSchedWrites<"WriteVFRedV_From">;
defm "" : LMULSchedWrites<"WriteVFRedOV_From">;
defm "" : LMULSEWSchedWritesF<"WriteVFRedV_From">;
defm "" : LMULSEWSchedWritesF<"WriteVFRedOV_From">;
// 14.4. Vector Widening Floating-Point Reduction Instructions
defm "" : LMULSchedWritesFWRed<"WriteVFWRedV_From">;
defm "" : LMULSchedWritesFWRed<"WriteVFWRedOV_From">;
defm "" : LMULSEWSchedWritesFWRed<"WriteVFWRedV_From">;
defm "" : LMULSEWSchedWritesFWRed<"WriteVFWRedOV_From">;
// 15. Vector Mask Instructions
// 15.1. Vector Mask-Register Logical Instructions
@ -821,12 +842,12 @@ defm "" : LMULWriteResW<"WriteVFNCvtFToIV", []>;
defm "" : LMULWriteResFW<"WriteVFNCvtFToFV", []>;
// 14. Vector Reduction Operations
defm "" : LMULWriteRes<"WriteVIRedV_From", []>;
defm "" : LMULWriteRes<"WriteVIWRedV_From", []>;
defm "" : LMULWriteRes<"WriteVFRedV_From", []>;
defm "" : LMULWriteRes<"WriteVFRedOV_From", []>;
defm "" : LMULWriteResFWRed<"WriteVFWRedV_From", []>;
defm "" : LMULWriteResFWRed<"WriteVFWRedOV_From", []>;
defm "" : LMULSEWWriteRes<"WriteVIRedV_From", []>;
defm "" : LMULSEWWriteResWRed<"WriteVIWRedV_From", []>;
defm "" : LMULSEWWriteResF<"WriteVFRedV_From", []>;
defm "" : LMULSEWWriteResF<"WriteVFRedOV_From", []>;
defm "" : LMULSEWWriteResFWRed<"WriteVFWRedV_From", []>;
defm "" : LMULSEWWriteResFWRed<"WriteVFWRedOV_From", []>;
// 15. Vector Mask Instructions
defm "" : LMULWriteRes<"WriteVMALUV", []>;