mirror of
synced 2025-02-21 12:51:20 +00:00
[NVPTX] Refactor generation of MMA intrinsics and instructions. NFC.
Generalized constructions of 'fragments' of MMA operations to provide common primitives for construction of the ops. This will make it easier to add new variants of the instructions that operate on integer types. Use nested foreach loops which makes it possible to better control naming of the intrinsics. This patch does not affect LLVM's output, so there are no test changes. Differential Revision: https://reviews.llvm.org/D59389 llvm-svn: 359245
This commit is contained in:
@ -37,6 +37,69 @@ def llvm_anyi64ptr_ty : LLVMAnyPointerType<llvm_i64_ty>; // (space)i64*
// Helper class for construction of n-element list<LLVMtype> [t,t,...,t]
class RepLLVMType<int N, LLVMType T> {
list<LLVMType> ret = !if(N, !listconcat(RepLLVMType<!add(N,-1), T>.ret, [T]), []);
// Helper class that represents a 'fragment' of an NVPTX *MMA instruction.
// Geom: m<M>n<N>k<K>. E.g. m8n32k16
// Frag: [abcd]
// PtxEltType: PTX type for the element.
class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
string geom = Geom;
string frag = Frag;
string ptx_elt_type = PtxEltType;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !cond(
// fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
// All currently supported geometries use the same fragment format,
// so we only need to consider {fragment, type}.
!eq(ft,"a:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret,
!eq(ft,"b:f16") : RepLLVMType<8, llvm_v2f16_ty>.ret,
!eq(ft,"c:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
!eq(ft,"d:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
!eq(ft,"c:f32") : RepLLVMType<8, llvm_float_ty>.ret,
!eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret);
class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
string intr = "llvm.nvvm.wmma."
# Frag.geom
# "." # Op
# "." # Frag.frag
# "." # Layout
# !if(WithStride, ".stride", "")
# "." # Frag.ptx_elt_type
// TODO(tra): record name should ideally use the same field order as the intrinsic.
// E.g. string record = !subst("llvm", "int",
// !subst(".", "_", llvm));
string record = "int_nvvm_wmma_"
# Frag.geom
# "_" # Op
# "_" # Frag.frag
# "_" # Frag.ptx_elt_type
# "_" # Layout
# !if(WithStride, "_stride", "");
class WMMA_NAME_MMA<string ALayout, string BLayout,
int Satfinite> {
string llvm = "llvm.nvvm.wmma."
# C.geom
# ".mma"
# "." # ALayout
# "." # BLayout
# "." # D.ptx_elt_type // Intrinsic encodes 'd' first.
# "." # C.ptx_elt_type
# !if(Satfinite, ".satfinite", "");
string record = !subst(".", "_",
!subst("llvm.", "int_", llvm));
let TargetPrefix = "nvvm" in {
def int_nvvm_prmt : GCCBuiltin<"__nvvm_prmt">,
Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
@ -3889,166 +3952,69 @@ def int_nvvm_match_all_sync_i64p :
// WMMA instructions
class NVVM_WMMA_LD_GALSTS<string Geometry, string Abc, string Layout,
string Type, LLVMType regty, int WithStride>
: Intrinsic<!if(!eq(Abc#Type,"cf16"),
[regty, regty, regty, regty],
[regty, regty, regty, regty,
regty, regty, regty, regty]),
class NVVM_WMMA_LD<WMMA_REGS Frag, string Layout, int WithStride>
: Intrinsic<Frag.regs,
!if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
[IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>],
# Geometry
# ".load"
# "." # Abc
# "." # Layout
# !if(WithStride, ".stride", "")
# "." # Type>;
multiclass NVVM_WMMA_LD_GALT<string Geometry, string Abc, string Layout,
string Type, LLVMType regty> {
def _stride: NVVM_WMMA_LD_GALSTS<Geometry, Abc, Layout, Type, regty, 1>;
def NAME : NVVM_WMMA_LD_GALSTS<Geometry, Abc, Layout, Type, regty, 0>;
multiclass NVVM_WMMA_LD_GAT<string Geometry, string Abc,
string Type, LLVMType regty> {
defm _row: NVVM_WMMA_LD_GALT<Geometry, Abc, "row", Type, regty>;
defm _col: NVVM_WMMA_LD_GALT<Geometry, Abc, "col", Type, regty>;
multiclass NVVM_WMMA_LD_G<string Geometry> {
defm _a_f16: NVVM_WMMA_LD_GAT<Geometry, "a", "f16", llvm_v2f16_ty>;
defm _b_f16: NVVM_WMMA_LD_GAT<Geometry, "b", "f16", llvm_v2f16_ty>;
defm _c_f16: NVVM_WMMA_LD_GAT<Geometry, "c", "f16", llvm_v2f16_ty>;
defm _c_f32: NVVM_WMMA_LD_GAT<Geometry, "c", "f32", llvm_float_ty>;
multiclass NVVM_WMMA_LD {
defm _m32n8k16_load: NVVM_WMMA_LD_G<"m32n8k16">;
defm _m16n16k16_load: NVVM_WMMA_LD_G<"m16n16k16">;
defm _m8n32k16_load: NVVM_WMMA_LD_G<"m8n32k16">;
defm int_nvvm_wmma: NVVM_WMMA_LD;
WMMA_NAME_LDST<"load", Frag, Layout, WithStride>.intr>;
class NVVM_WMMA_STD_GLSTS<string Geometry, string Layout,
string Type, LLVMType regty, int WithStride,
// This is only used to create a typed empty array we
// need to pass to !if below.
class NVVM_WMMA_ST<WMMA_REGS Frag, string Layout, int WithStride>
: Intrinsic<[],
[regty, regty, regty, regty],
[regty, regty, regty, regty,
regty, regty, regty, regty]),
!if(WithStride, [llvm_i32_ty], Empty)),
!if(WithStride, [llvm_i32_ty], [])),
[IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>],
# Geometry
# ".store.d"
# "." # Layout
# !if(WithStride, ".stride", "")
# "." # Type>;
WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr>;
multiclass NVVM_WMMA_STD_GLT<string Geometry, string Layout,
string Type, LLVMType regty> {
def _stride: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 1>;
def NAME: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 0>;
// Create all load/store variants
foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
foreach layout = ["row", "col"] in {
foreach stride = [0, 1] in {
foreach frag = [WMMA_REGS<geom, "a", "f16">,
WMMA_REGS<geom, "b", "f16">,
WMMA_REGS<geom, "c", "f16">,
WMMA_REGS<geom, "c", "f32">] in {
def WMMA_NAME_LDST<"load", frag, layout, stride>.record
: NVVM_WMMA_LD<frag, layout, stride>;
foreach frag = [WMMA_REGS<geom, "d", "f16">,
WMMA_REGS<geom, "d", "f32">] in {
def WMMA_NAME_LDST<"store", frag, layout, stride>.record
: NVVM_WMMA_ST<frag, layout, stride>;
multiclass NVVM_WMMA_STD_GT<string Geometry, string Type, LLVMType regty> {
defm _row: NVVM_WMMA_STD_GLT<Geometry, "row", Type, regty>;
defm _col: NVVM_WMMA_STD_GLT<Geometry, "col", Type, regty>;
multiclass NVVM_WMMA_STD_G<string Geometry> {
defm _d_f16: NVVM_WMMA_STD_GT<Geometry, "f16", llvm_v2f16_ty>;
defm _d_f32: NVVM_WMMA_STD_GT<Geometry, "f32", llvm_float_ty>;
multiclass NVVM_WMMA_STD {
defm _m32n8k16_store: NVVM_WMMA_STD_G<"m32n8k16">;
defm _m16n16k16_store: NVVM_WMMA_STD_G<"m16n16k16">;
defm _m8n32k16_store: NVVM_WMMA_STD_G<"m8n32k16">;
defm int_nvvm_wmma: NVVM_WMMA_STD;
class NVVM_WMMA_MMA_GABDCS<string Geometry,
string ALayout, string BLayout,
string DType, LLVMType d_regty,
string CType, LLVMType c_regty,
string Satfinite = "">
: Intrinsic<!if(!eq(DType,"f16"),
[d_regty, d_regty, d_regty, d_regty],
[d_regty, d_regty, d_regty, d_regty,
d_regty, d_regty, d_regty, d_regty]),
class NVVM_WMMA_MMA<string ALayout, string BLayout,
WMMA_REGS C, WMMA_REGS D, int Satfinite>
: Intrinsic<D.regs,
[// A
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
// B
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty],
[c_regty, c_regty, c_regty, c_regty],
[c_regty, c_regty, c_regty, c_regty,
c_regty, c_regty, c_regty, c_regty])),
WMMA_REGS<C.geom, "a", "f16">.regs,
WMMA_REGS<C.geom, "b", "f16">.regs,
# Geometry
# ".mma"
# "." # ALayout
# "." # BLayout
# "." # DType
# "." # CType
# Satfinite> {
WMMA_NAME_MMA<ALayout, BLayout, C, D, Satfinite>.llvm>;
multiclass NVVM_WMMA_MMA_GABDC<string Geometry, string ALayout, string BLayout,
string DType, LLVMType d_regty,
string CType, LLVMType c_regty> {
def NAME : NVVM_WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
DType, d_regty, CType, c_regty>;
def _satfinite: NVVM_WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
DType, d_regty, CType, c_regty,".satfinite">;
foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
foreach frag_c = [WMMA_REGS<geom, "c", "f16">,
WMMA_REGS<geom, "c", "f32">] in {
foreach frag_d = [WMMA_REGS<geom, "d", "f16">,
WMMA_REGS<geom, "d", "f32">] in {
foreach satf = [0, 1] in {
def WMMA_NAME_MMA<layout_a, layout_b, frag_c, frag_d, satf>.record
: NVVM_WMMA_MMA<layout_a, layout_b, frag_c, frag_d, satf>;
multiclass NVVM_WMMA_MMA_GABD<string Geometry, string ALayout, string BLayout,
string DType, LLVMType d_regty> {
defm _f16: NVVM_WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_regty,
"f16", llvm_v2f16_ty>;
defm _f32: NVVM_WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_regty,
"f32", llvm_float_ty>;
multiclass NVVM_WMMA_MMA_GAB<string Geometry, string ALayout, string BLayout> {
defm _f16: NVVM_WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f16", llvm_v2f16_ty>;
defm _f32: NVVM_WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f32", llvm_float_ty>;
multiclass NVVM_WMMA_MMA_GA<string Geometry, string ALayout> {
defm _col: NVVM_WMMA_MMA_GAB<Geometry, ALayout, "col">;
defm _row: NVVM_WMMA_MMA_GAB<Geometry, ALayout, "row">;
multiclass NVVM_WMMA_MMA_G<string Geometry> {
defm _col: NVVM_WMMA_MMA_GA<Geometry, "col">;
defm _row: NVVM_WMMA_MMA_GA<Geometry, "row">;
multiclass NVVM_WMMA_MMA {
defm _m32n8k16_mma : NVVM_WMMA_MMA_G<"m32n8k16">;
defm _m16n16k16_mma : NVVM_WMMA_MMA_G<"m16n16k16">;
defm _m8n32k16_mma : NVVM_WMMA_MMA_G<"m8n32k16">;
defm int_nvvm_wmma : NVVM_WMMA_MMA;
} // let TargetPrefix = "nvvm"
@ -26,7 +26,17 @@ def immDouble1 : PatLeaf<(fpimm), [{
return (d==1.0);
def AS_match {
code generic = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
code shared = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
code global = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
// Synchronization and shuffle functions
@ -1006,17 +1016,11 @@ def INT_FNS_iii : INT_FNS_MBO<(ins i32imm:$mask, i32imm:$base, i32imm:$
class ATOMIC_GLOBAL_CHK <dag ops, dag frag>
: PatFrag<ops, frag, [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
: PatFrag<ops, frag, AS_match.global>;
class ATOMIC_SHARED_CHK <dag ops, dag frag>
: PatFrag<ops, frag, [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
: PatFrag<ops, frag, AS_match.shared>;
class ATOMIC_GENERIC_CHK <dag ops, dag frag>
: PatFrag<ops, frag, [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
: PatFrag<ops, frag, AS_match.generic>;
multiclass F_ATOMIC_2_imp<NVPTXRegClass ptrclass, NVPTXRegClass regclass,
string SpaceStr, string TypeStr, string OpcStr, PatFrag IntOp,
@ -7380,36 +7384,60 @@ def INT_PTX_SREG_WARPSIZE :
NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;",
[(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>;
// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>;
// Generates list of n sequential register names.
class RegSeq<int n, string prefix> {
list<string> ret = !if(n, !listconcat(RegSeq<!add(n,-1), prefix>.ret,
[prefix # !add(n, -1)]),
class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout,
string Space, string Type, NVPTXRegClass regclass,
DAGOperand SrcOp, bit WithStride>
: EmptyNVPTXInst,
Requires<[!if(!eq(Geometry, "m16n16k16"),
hasSM70]> {
// Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
// for this function.
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_"
# Geometry # "_load_"
# !subst("c", "c_" # Type, Abc)
# "_" # Layout
# !subst(".", "_", Space)
# !if(WithStride,"_stride", "")
# "_Intr");
dag OutsR03 = (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
dag OutsR47 = (outs regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
dag Outs = !if(!eq(Abc#Type,"cf16"), OutsR03, !con(OutsR03, OutsR47));
// Helper class that represents a 'fragment' of an NVPTX *MMA instruction.
// In addition to target-independent fields provided by WMMA_REGS, it adds
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
class WMMA_REGINFO<string Geom, string Frag, string PtxEltType>
: WMMA_REGS<Geom, Frag, PtxEltType> {
// NVPTX register types used to carry fragment data.
NVPTXRegClass regclass = !cond(
!eq(PtxEltType, "f16") : Float16x2Regs,
!eq(PtxEltType, "f32") : Float32Regs);
dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins));
dag Ins = !con((ins SrcOp:$src), StrideArg);
// Instruction input/output arguments for the fragment.
list<NVPTXRegClass> ptx_regs = !foreach(tmp, regs, regclass);
// List of register names for the fragment -- ["ra0", "ra1",...]
list<string> reg_names = RegSeq<!size(ptx_regs), "r"#frag>.ret;
// Generates "{{$r0, $r1,.... $rN-1}}" for use in asm string construction.
string regstring = "{{$" # !head(reg_names)
# !foldl("", !tail(reg_names), a, b,
!strconcat(a, ", $", b))
# "}}";
// Predicates for particular fragment variant. Technically those are
// per-instruction predicates, but currently all fragments that can be used in
// a given instruction are subject to the same constraints, so an instruction
// can use predicates from any of its fragments. If/when this is no
// longer the case, we can concat all per-fragment predicates to enforce that
// all fragments of the instruction are viable.
list<Predicate> Predicates = !cond(
// fp16 -> fp16/fp32 @ m16n16k16
!and(!eq(Geom, "m16n16k16"),
!or(!eq(PtxEltType, "f16"),
!eq(PtxEltType, "f32"))) : [hasSM70, hasPTX60],
// fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
!and(!or(!eq(Geom, "m8n32k16"),
!eq(Geom, "m32n8k16")),
!or(!eq(PtxEltType, "f16"),
!eq(PtxEltType, "f32"))) : [hasSM70, hasPTX61]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
dag Ins = !dag(ins, ptx_regs, reg_names);
class BuildPattern<dag Outs, PatFrag IntrMatcher, dag Ins> {
// Build a dag pattern that matches the intrinsic call.
// We want a dag that looks like this:
// (set <output args>, (intrinsic <input arguments>)) where input and
@ -7430,277 +7458,127 @@ class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout,
!subst(ins, IntrMatcher, tmp)))));
// Finally, consatenate both parts together. !con() requires both dags to have
// the same operator, so we wrap PatArgs in a (set ...) dag.
let Pattern = [!con(PatOuts, (set PatArgs))];
let OutOperandList = Outs;
dag ret = !con(PatOuts, (set PatArgs));
// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
class WMMA_LOAD_INTR_HELPER<WMMA_REGINFO Frag, string Layout, string Space,
bit WithStride>
: PatFrag <(ops),(ops)> {
// Intrinsic that matches this instruction.
Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_LDST<"load", Frag, Layout,
let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src));
let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared,
!eq(Space, ".global"): AS_match.global,
1: AS_match.generic);
class WMMA_LOAD<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride,
DAGOperand SrcOp>
: EmptyNVPTXInst,
Requires<Frag.Predicates> {
// Pattern that matches the intrinsic for this instruction variant.
PatFrag IntrMatcher = WMMA_LOAD_INTR_HELPER<Frag, Layout, Space, WithStride>;
dag Ins = !con((ins SrcOp:$src), !if(WithStride, (ins Int32Regs:$ldm), (ins)));
let Pattern = [BuildPattern<Frag.Outs, IntrMatcher, Ins>.ret];
let OutOperandList = Frag.Outs;
let InOperandList = Ins;
let AsmString = "wmma.load."
# Abc
# Frag.frag
# ".sync"
# "." # Layout
# "." # Geometry
# "." # Frag.geom
# Space
# "." # Type # " \t"
# !if(!eq(Abc#Type, "cf16"),
"{{$r0, $r1, $r2, $r3}}",
"{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
# "." # Frag.ptx_elt_type # " \t"
# Frag.regstring
# ", [$src]"
# !if(WithStride, ", $ldm", "")
# ";";
class WMMA_LOAD_INTR_HELPER<string Geometry, string Abc, string Layout,
string Space, string Type, bit WithStride>
: PatFrag <(ops),(ops)> {
// Intrinsic that matches this instruction.
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma"
# "_" # Geometry # "_load_"
# Abc # "_" # Type # "_" # Layout
# !if(WithStride,"_stride", ""));
code match_generic = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
code match_shared = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
code match_global = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
let Operands = !if(WithStride, (ops node:$src, node:$ldm), (ops node:$src));
let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
!if(!eq(Space, ".global"), match_global, match_generic));
multiclass WMMA_LOAD_GALSTS<string Geometry, string Abc, string Layout,
string Space, string Type, NVPTXRegClass regclass,
bit WithStride> {
def _avar: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
imem, WithStride>;
def _areg: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
Int32Regs, WithStride>;
def _areg64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
Int64Regs, WithStride>;
def _ari: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
MEMri, WithStride>;
def _ari64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
MEMri64, WithStride>;
multiclass WMMA_LOAD_GALSTSh<string Geometry, string Abc, string Layout,
string Space, string Type, NVPTXRegClass regclass,
bit WithStride> {
// Define a PatFrag that matches appropriate intrinsic that loads from the
// given address space.
def _Intr: WMMA_LOAD_INTR_HELPER<Geometry, Abc, Layout, Space, Type,
defm NAME: WMMA_LOAD_GALSTS<Geometry, Abc, Layout, Space, Type, regclass,
multiclass WMMA_LOAD_GALST<string Geometry, string Abc, string Layout,
string Space, string Type, NVPTXRegClass regclass> {
defm _stride: WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 1>;
defm NAME: WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 0>;
multiclass WMMA_LOAD_GALT<string Geometry, string Abc, string Layout,
string Type, NVPTXRegClass regclass> {
defm _global: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".global",
Type, regclass>;
defm _shared: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".shared",
Type, regclass>;
defm NAME: WMMA_LOAD_GALST<Geometry, Abc, Layout, "",
Type, regclass>;
multiclass WMMA_LOAD_GAT<string Geometry, string Abc,
string Type, NVPTXRegClass regclass> {
defm _row: WMMA_LOAD_GALT<Geometry, Abc, "row", Type, regclass>;
defm _col: WMMA_LOAD_GALT<Geometry, Abc, "col", Type, regclass>;
multiclass WMMA_LOAD_G<string Geometry> {
defm _load_a: WMMA_LOAD_GAT<Geometry, "a", "f16", Float16x2Regs>;
defm _load_b: WMMA_LOAD_GAT<Geometry, "b", "f16", Float16x2Regs>;
defm _load_c_f16: WMMA_LOAD_GAT<Geometry, "c", "f16", Float16x2Regs>;
defm _load_c_f32: WMMA_LOAD_GAT<Geometry, "c", "f32", Float32Regs>;
defm INT_WMMA_m32n8k16: WMMA_LOAD_G<"m32n8k16">;
defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">;
defm INT_WMMA_m8n32k16: WMMA_LOAD_G<"m8n32k16">;
// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
class WMMA_STORE_D_GLSTSO<string Geometry, string Layout, string Space,
string Type, NVPTXRegClass regclass,
bit WithStride, DAGOperand DstOp>
: EmptyNVPTXInst,
Requires<[!if(!eq(Geometry, "m16n16k16"),
hasSM70]> {
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA"
# "_" # Geometry # "_store_d"
# "_" # Type
# "_" # Layout
# !subst(".", "_", Space)
# !if(WithStride,"_stride", "")
# "_Intr");
dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1,
regclass:$r2, regclass:$r3);
dag InsR47 = (ins regclass:$r4, regclass:$r5,
regclass:$r6, regclass:$r7);
dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47));
dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins));
dag Ins = !con(InsR, StrideArg);
class WMMA_STORE_INTR_HELPER<WMMA_REGINFO Frag, string Layout, string Space,
bit WithStride>
: PatFrag <(ops),(ops)> {
// Intrinsic that matches this instruction.
Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_LDST<"store", Frag, Layout,
let Operands = !con((ops node:$dst),
!dag(ops, !foreach(tmp, Frag.regs, node), Frag.reg_names),
!if(WithStride, (ops node:$ldm), (ops)));
let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
let PredicateCode = !cond(!eq(Space, ".shared"): AS_match.shared,
!eq(Space, ".global"): AS_match.global,
1: AS_match.generic);
// Construct the pattern to match corresponding intrinsic call. See the
// details in the comments in WMMA_LOAD_ALSTOS.
dag PatArgs = !foreach(tmp, Ins,
!subst(imem, ADDRvar,
!subst(MEMri64, ADDRri64,
!subst(MEMri, ADDRri,
!subst(ins, IntrMatcher, tmp)))));
let Pattern = [PatArgs];
class WMMA_STORE<WMMA_REGINFO Frag, string Layout, string Space, bit WithStride,
DAGOperand DstOp>
: EmptyNVPTXInst,
Requires<Frag.Predicates> {
PatFrag IntrMatcher = WMMA_STORE_INTR_HELPER<Frag, Layout, Space, WithStride>;
dag Ins = !con((ins DstOp:$src),
!if(WithStride, (ins Int32Regs:$ldm), (ins)));
let Pattern = [BuildPattern<(set), IntrMatcher, Ins>.ret];
let OutOperandList = (outs);
let InOperandList = Ins;
let AsmString = "wmma.store.d.sync."
# Layout
# "." # Geometry
# "." # Frag.geom
# Space
# "." # Type
# "." # Frag.ptx_elt_type
# " \t[$src],"
# !if(!eq(Type,"f16"),
"{{$r0, $r1, $r2, $r3}}",
"{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
# Frag.regstring
# !if(WithStride, ", $ldm", "")
# ";";
class WMMA_STORE_INTR_HELPER<string Geometry, string Layout, string Space,
string Type, bit WithStride>
: PatFrag <(ops),(ops)> {
// Intrinsic that matches this instruction.
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_"
# Geometry
# "_store_d"
# "_" # Type
# "_" # Layout
# !if(WithStride, "_stride", ""));
code match_generic = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
code match_shared = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_SHARED);
code match_global = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
dag Args = !if(!eq(Type,"f16"),
(ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3),
(ops node:$dst, node:$r0, node:$r1, node:$r2, node:$r3,
node:$r4, node:$r5, node:$r6, node:$r7));
dag StrideArg = !if(WithStride, (ops node:$ldm), (ops));
let Operands = !con(Args, StrideArg);
let Fragments = [!foreach(tmp, Operands, !subst(ops, Intr, tmp))];
let PredicateCode = !if(!eq(Space, ".shared"), match_shared,
!if(!eq(Space, ".global"), match_global, match_generic));
multiclass WMMA_STORE_D_GLSTS<string Geometry, string Layout, string Space,
string Type, NVPTXRegClass regclass,
bit WithStride> {
def _avar: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
WithStride, imem>;
def _areg: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
WithStride, Int32Regs>;
def _areg64: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
WithStride, Int64Regs>;
def _ari: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
WithStride, MEMri>;
def _ari64: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
WithStride, MEMri64>;
multiclass WMMA_STORE_D_GLSTSh<string Geometry, string Layout, string Space,
string Type, NVPTXRegClass regclass,
bit WithStride> {
// Define a PatFrag that matches appropriate intrinsic that loads from the
// given address space.
def _Intr: WMMA_STORE_INTR_HELPER<Geometry, Layout, Space, Type,
defm NAME: WMMA_STORE_D_GLSTS<Geometry, Layout, Space, Type, regclass,
multiclass WMMA_STORE_D_GLST<string Geometry, string Layout, string Space,
string Type, NVPTXRegClass regclass > {
defm _stride: WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 1>;
defm NAME: WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 0>;
multiclass WMMA_STORE_D_GLT<string Geometry, string Layout,
string Type, NVPTXRegClass regclass> {
defm _global: WMMA_STORE_D_GLST<Geometry, Layout, ".global", Type, regclass>;
defm _shared: WMMA_STORE_D_GLST<Geometry, Layout, ".shared", Type, regclass>;
defm NAME: WMMA_STORE_D_GLST<Geometry, Layout, "", Type, regclass>;
multiclass WMMA_STORE_D_GT<string Geometry, string Type,
NVPTXRegClass regclass> {
defm _row: WMMA_STORE_D_GLT<Geometry, "row", Type, regclass>;
defm _col: WMMA_STORE_D_GLT<Geometry, "col", Type, regclass>;
multiclass WMMA_STORE_D_G<string Geometry> {
defm _store_d_f16: WMMA_STORE_D_GT<Geometry, "f16", Float16x2Regs>;
defm _store_d_f32: WMMA_STORE_D_GT<Geometry, "f32", Float32Regs>;
defm INT_WMMA_m32n8k16: WMMA_STORE_D_G<"m32n8k16">;
defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
defm INT_WMMA_m8n32k16: WMMA_STORE_D_G<"m8n32k16">;
// Create all load/store variants
foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
foreach layout = ["row", "col"] in {
foreach stride = [0, 1] in {
foreach space = [".global", ".shared", ""] in {
foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
foreach frag = [WMMA_REGINFO<geom, "a", "f16">,
WMMA_REGINFO<geom, "b", "f16">,
WMMA_REGINFO<geom, "c", "f16">,
WMMA_REGINFO<geom, "c", "f32">] in {
def : WMMA_LOAD<frag, layout, space, stride, addr>;
foreach frag = [WMMA_REGINFO<geom, "d", "f16">,
WMMA_REGINFO<geom, "d", "f32">] in {
def : WMMA_STORE<frag, layout, space, stride, addr>;
} // addr
} // space
} // stride
} // layout
} // geom
class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout,
string DType, NVPTXRegClass d_reg,
string CType, NVPTXRegClass c_reg,
NVPTXRegClass ab_reg,
string Satfinite = "">
string ALayout, string BLayout, int Satfinite>
: EmptyNVPTXInst,
Requires<[!if(!eq(Geometry, "m16n16k16"),
hasSM70]> {
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_"
# Geometry
# "_mma"
# "_" # ALayout
# "_" # BLayout
# "_" # DType
# "_" # CType
# !subst(".", "_", Satfinite));
dag Outs = !if(!eq(DType,"f16"),
(outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
(outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7));
dag InsExtraCArgs = !if(!eq(CType,"f16"),
(ins c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7));
dag Ins = !con((ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3),
Requires<FragC.Predicates> {
//Intrinsic Intr = int_nvvm_suld_1d_v4i32_zero;
Intrinsic Intr = !cast<Intrinsic>(WMMA_NAME_MMA<ALayout, BLayout, FragC, FragD, Satfinite>.record);
dag Outs = FragD.Outs;
dag Ins = !con(FragA.Ins,
// Construct the pattern to match corresponding intrinsic call. See the
// details in the comments in WMMA_LOAD_ALSTOS.
// Construct the pattern to match corresponding intrinsic call.
// mma does not load/store anything, so we don't need complex operand matching here.
dag PatOuts = !foreach(tmp, Outs, !subst(outs, set, tmp));
dag PatArgs = !foreach(tmp, Ins, !subst(ins, Intr, tmp));
let Pattern = [!con(PatOuts, (set PatArgs))];
@ -7709,54 +7587,30 @@ class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout,
let AsmString = "wmma.mma.sync."
# ALayout
# "." # BLayout
# "." # Geometry
# "." # DType
# "." # CType
# Satfinite # "\n\t\t"
# !if(!eq(DType,"f16"),
"{{$d0, $d1, $d2, $d3}}, \n\t\t",
"{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t")
# "{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t"
# "{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t"
# !if(!eq(CType,"f16"),
"{{$c0, $c1, $c2, $c3}};",
"{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};");
# "." # FragA.geom
# "." # FragD.ptx_elt_type
# "." # FragC.ptx_elt_type
# !if(Satfinite, ".satfinite", "") # "\n\t\t"
# FragD.regstring # ",\n\t\t"
# FragA.regstring # ",\n\t\t"
# FragB.regstring # ",\n\t\t"
# FragC.regstring # ";";
multiclass WMMA_MMA_GABDC<string Geometry, string ALayout, string BLayout,
string DType, NVPTXRegClass d_reg,
string CType, NVPTXRegClass c_reg> {
def _satfinite: WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
DType, d_reg, CType, c_reg,
Float16x2Regs, ".satfinite">;
def NAME: WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
DType, d_reg, CType, c_reg,
multiclass WMMA_MMA_GABD<string Geometry, string ALayout, string BLayout,
string DType, NVPTXRegClass d_reg> {
defm _f16: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg,
"f16", Float16x2Regs>;
defm _f32: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg,
"f32", Float32Regs>;
multiclass WMMA_MMA_GAB<string Geometry, string ALayout, string BLayout> {
defm _f16: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f16", Float16x2Regs>;
defm _f32: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f32", Float32Regs>;
multiclass WMMA_MMA_GA<string Geometry, string ALayout> {
defm _col: WMMA_MMA_GAB<Geometry, ALayout, "col">;
defm _row: WMMA_MMA_GAB<Geometry, ALayout, "row">;
multiclass WMMA_MMA_G<string Geometry> {
defm _col: WMMA_MMA_GA<Geometry, "col">;
defm _row: WMMA_MMA_GA<Geometry, "row">;
defm INT_WMMA_MMA_m32n8k16 : WMMA_MMA_G<"m32n8k16">;
defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">;
defm INT_WMMA_MMA_m8n32k16 : WMMA_MMA_G<"m8n32k16">;
foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
foreach frag_c = [WMMA_REGINFO<geom, "c", "f16">,
WMMA_REGINFO<geom, "c", "f32">] in {
foreach frag_d = [WMMA_REGINFO<geom, "d", "f16">,
WMMA_REGINFO<geom, "d", "f32">] in {
foreach satf = [0, 1] in {
def : WMMA_MMA<WMMA_REGINFO<geom, "a", "f16">,
WMMA_REGINFO<geom, "b", "f16">,
frag_c, frag_d, layout_a, layout_b, satf>;
} // satf
} // frag_d
} // frag_c
} // layout_b
} // layout_a
} // geom
Reference in New Issue
Block a user