[NVPTX] Make tensor shape part of WMMA intrinsic's name.

This is needed for the upcoming implementation of the
new 8x32x16 and 32x8x16 variants of WMMA instructions
introduced in CUDA 9.1.

Differential Revision: https://reviews.llvm.org/D44719

llvm-svn: 328158
This commit is contained in:
Artem Belevich 2018-03-21 21:55:02 +00:00
parent a9a301a7ba
commit f8d3ed33cf
4 changed files with 313 additions and 220 deletions

View File

@ -3884,39 +3884,53 @@ def int_nvvm_match_all_sync_i64p :
//
// WMMA.LOAD
class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Type,
LLVMType regty, int WithStride>
class NVVM_WMMA_LD_GALSTS<string Geometry, string Abc, string Layout,
string Type, LLVMType regty, int WithStride>
: Intrinsic<!if(!eq(Abc#Type,"cf16"),
[regty, regty, regty, regty],
[regty, regty, regty, regty,
regty, regty, regty, regty]),
!if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
[IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>],
"llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
#!if(WithStride,".stride","")
#"."#Type>;
"llvm.nvvm.wmma."
# Geometry
# ".load"
# "." # Abc
# "." # Layout
# !if(WithStride, ".stride", "")
# "." # Type>;
multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout, string Type,
LLVMType regty> {
def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 1>;
def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 0>;
multiclass NVVM_WMMA_LD_GALT<string Geometry, string Abc, string Layout,
string Type, LLVMType regty> {
def _stride: NVVM_WMMA_LD_GALSTS<Geometry, Abc, Layout, Type, regty, 1>;
def NAME : NVVM_WMMA_LD_GALSTS<Geometry, Abc, Layout, Type, regty, 0>;
}
multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
defm _row: NVVM_WMMA_LD_ALT<Abc, "row", Type, regty>;
defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
multiclass NVVM_WMMA_LD_GAT<string Geometry, string Abc,
string Type, LLVMType regty> {
defm _row: NVVM_WMMA_LD_GALT<Geometry, Abc, "row", Type, regty>;
defm _col: NVVM_WMMA_LD_GALT<Geometry, Abc, "col", Type, regty>;
}
defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
multiclass NVVM_WMMA_LD_G<string Geometry> {
defm _a_f16: NVVM_WMMA_LD_GAT<Geometry, "a", "f16", llvm_v2f16_ty>;
defm _b_f16: NVVM_WMMA_LD_GAT<Geometry, "b", "f16", llvm_v2f16_ty>;
defm _c_f16: NVVM_WMMA_LD_GAT<Geometry, "c", "f16", llvm_v2f16_ty>;
defm _c_f32: NVVM_WMMA_LD_GAT<Geometry, "c", "f32", llvm_float_ty>;
}
multiclass NVVM_WMMA_LD {
defm _m16n16k16_load: NVVM_WMMA_LD_G<"m16n16k16">;
}
defm int_nvvm_wmma: NVVM_WMMA_LD;
// WMMA.STORE.D
class NVVM_WMMA_STD_LSTS<string Layout, string Type, LLVMType regty, int WithStride,
// This is only used to create a typed empty array we
// need to pass to !if below.
list<LLVMType>Empty=[]>
class NVVM_WMMA_STD_GLSTS<string Geometry, string Layout,
string Type, LLVMType regty, int WithStride,
// This is only used to create a typed empty array we
// need to pass to !if below.
list<LLVMType>Empty=[]>
: Intrinsic<[],
!listconcat(
[llvm_anyptr_ty],
@ -3926,29 +3940,40 @@ class NVVM_WMMA_STD_LSTS<string Layout, string Type, LLVMType regty, int WithStr
regty, regty, regty, regty]),
!if(WithStride, [llvm_i32_ty], Empty)),
[IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>],
"llvm.nvvm.wmma.store.d.sync."#Layout
#".m16n16k16"
#!if(WithStride,".stride","")
#"."#Type>;
"llvm.nvvm.wmma."
# Geometry
# ".store.d"
# "." # Layout
# !if(WithStride, ".stride", "")
# "." # Type>;
multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
def _stride: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 1>;
def NAME: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 0>;
multiclass NVVM_WMMA_STD_GLT<string Geometry, string Layout,
string Type, LLVMType regty> {
def _stride: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 1>;
def NAME: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 0>;
}
multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>;
defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
multiclass NVVM_WMMA_STD_GT<string Geometry, string Type, LLVMType regty> {
defm _row: NVVM_WMMA_STD_GLT<Geometry, "row", Type, regty>;
defm _col: NVVM_WMMA_STD_GLT<Geometry, "col", Type, regty>;
}
multiclass NVVM_WMMA_STD_G<string Geometry> {
defm _d_f16: NVVM_WMMA_STD_GT<Geometry, "f16", llvm_v2f16_ty>;
defm _d_f32: NVVM_WMMA_STD_GT<Geometry, "f32", llvm_float_ty>;
}
defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
multiclass NVVM_WMMA_STD {
defm _m16n16k16_store: NVVM_WMMA_STD_G<"m16n16k16">;
}
defm int_nvvm_wmma: NVVM_WMMA_STD;
// WMMA.MMA
class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
string DType, LLVMType d_regty,
string CType, LLVMType c_regty,
string Satfinite = "">
class NVVM_WMMA_MMA_GABDCS<string Geometry,
string ALayout, string BLayout,
string DType, LLVMType d_regty,
string CType, LLVMType c_regty,
string Satfinite = "">
: Intrinsic<!if(!eq(DType,"f16"),
[d_regty, d_regty, d_regty, d_regty],
[d_regty, d_regty, d_regty, d_regty,
@ -3965,39 +3990,52 @@ class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
[c_regty, c_regty, c_regty, c_regty,
c_regty, c_regty, c_regty, c_regty])),
[IntrNoMem],
"llvm.nvvm.wmma.mma.sync."#ALayout#"."#BLayout
#".m16n16k16."#DType#"."#CType#Satfinite>;
multiclass NVVM_WMMA_MMA_ABDC<string ALayout, string BLayout,
string DType, LLVMType d_regty,
string CType, LLVMType c_regty> {
def NAME : NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
DType, d_regty,
CType, c_regty>;
def _satfinite: NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
DType, d_regty,
CType, c_regty,".satfinite">;
"llvm.nvvm.wmma."
# Geometry
# ".mma"
# "." # ALayout
# "." # BLayout
# "." # DType
# "." # CType
# Satfinite> {
}
multiclass NVVM_WMMA_MMA_ABD<string ALayout, string BLayout,
multiclass NVVM_WMMA_MMA_GABDC<string Geometry, string ALayout, string BLayout,
string DType, LLVMType d_regty,
string CType, LLVMType c_regty> {
def NAME : NVVM_WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
DType, d_regty, CType, c_regty>;
def _satfinite: NVVM_WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
DType, d_regty, CType, c_regty,".satfinite">;
}
multiclass NVVM_WMMA_MMA_GABD<string Geometry, string ALayout, string BLayout,
string DType, LLVMType d_regty> {
defm _f16: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
defm _f16: NVVM_WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_regty,
"f16", llvm_v2f16_ty>;
defm _f32: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
defm _f32: NVVM_WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_regty,
"f32", llvm_float_ty>;
}
multiclass NVVM_WMMA_MMA_AB<string ALayout, string BLayout> {
defm _f16: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f16", llvm_v2f16_ty>;
defm _f32: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f32", llvm_float_ty>;
multiclass NVVM_WMMA_MMA_GAB<string Geometry, string ALayout, string BLayout> {
defm _f16: NVVM_WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f16", llvm_v2f16_ty>;
defm _f32: NVVM_WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f32", llvm_float_ty>;
}
multiclass NVVM_WMMA_MMA_A<string ALayout> {
defm _col: NVVM_WMMA_MMA_AB<ALayout, "col">;
defm _row: NVVM_WMMA_MMA_AB<ALayout, "row">;
multiclass NVVM_WMMA_MMA_GA<string Geometry, string ALayout> {
defm _col: NVVM_WMMA_MMA_GAB<Geometry, ALayout, "col">;
defm _row: NVVM_WMMA_MMA_GAB<Geometry, ALayout, "row">;
}
defm int_nvvm_wmma_mma_sync_col: NVVM_WMMA_MMA_A<"col">;
defm int_nvvm_wmma_mma_sync_row: NVVM_WMMA_MMA_A<"row">;
multiclass NVVM_WMMA_MMA_G<string Geometry> {
defm _col: NVVM_WMMA_MMA_GA<Geometry, "col">;
defm _row: NVVM_WMMA_MMA_GA<Geometry, "row">;
}
multiclass NVVM_WMMA_MMA {
defm _m16n16k16_mma : NVVM_WMMA_MMA_G<"m16n16k16">;
}
defm int_nvvm_wmma : NVVM_WMMA_MMA;
} // let TargetPrefix = "nvvm"

View File

@ -3323,14 +3323,14 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
// Our result depends on both our and other thread's arguments.
Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
return true;
case Intrinsic::nvvm_wmma_load_a_f16_col:
case Intrinsic::nvvm_wmma_load_a_f16_row:
case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
case Intrinsic::nvvm_wmma_load_b_f16_col:
case Intrinsic::nvvm_wmma_load_b_f16_row:
case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
case Intrinsic::nvvm_wmma_load_b_f16_row_stride: {
case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f16;
Info.ptrVal = I.getArgOperand(0);
@ -3340,10 +3340,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
case Intrinsic::nvvm_wmma_load_c_f16_col:
case Intrinsic::nvvm_wmma_load_c_f16_row:
case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
case Intrinsic::nvvm_wmma_load_c_f16_row_stride: {
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
@ -3353,10 +3353,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
case Intrinsic::nvvm_wmma_load_c_f32_col:
case Intrinsic::nvvm_wmma_load_c_f32_row:
case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
case Intrinsic::nvvm_wmma_load_c_f32_row_stride: {
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
@ -3366,10 +3366,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
case Intrinsic::nvvm_wmma_store_d_f16_col:
case Intrinsic::nvvm_wmma_store_d_f16_row:
case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
case Intrinsic::nvvm_wmma_store_d_f16_row_stride: {
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v4f16;
Info.ptrVal = I.getArgOperand(0);
@ -3379,10 +3379,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
case Intrinsic::nvvm_wmma_store_d_f32_col:
case Intrinsic::nvvm_wmma_store_d_f32_row:
case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
case Intrinsic::nvvm_wmma_store_d_f32_row_stride: {
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);

View File

@ -7375,16 +7375,15 @@ def INT_PTX_SREG_WARPSIZE :
class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>;
class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass,
DAGOperand SrcOp, bit WithStride>
class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout,
string Space, string Type, NVPTXRegClass regclass,
DAGOperand SrcOp, bit WithStride>
: EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
// Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
// for this function.
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_LOAD_"
# !subst("a", "A",
!subst("b", "B",
!subst("c", "C_" # Type, Abc)))
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_"
# Geometry # "_load_"
# !subst("c", "c_" # Type, Abc)
# "_" # Layout
# !subst(".", "_", Space)
# !if(WithStride,"_stride", "")
@ -7419,23 +7418,28 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
let Pattern = [!con(PatOuts, (set PatArgs))];
let OutOperandList = Outs;
let InOperandList = Ins;
let AsmString = "wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
#!if(!eq(Abc#Type,"cf16"),
"{{$r0, $r1, $r2, $r3}}",
"{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
#", [$src]"
#!if(WithStride, ", $ldm", "")
#";";
let AsmString = "wmma.load."
# Abc
# ".sync."
# Layout
# ".m16n16k16"
# Space
# "." # Type # " \t"
# !if(!eq(Abc#Type, "cf16"),
"{{$r0, $r1, $r2, $r3}}",
"{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
# ", [$src]"
# !if(WithStride, ", $ldm", "")
# ";";
}
class WMMA_LOAD_INTR_HELPER<string Abc, string Layout, string Space,
string Type, bit WithStride>
class WMMA_LOAD_INTR_HELPER<string Geometry, string Abc, string Layout,
string Space, string Type, bit WithStride>
: PatFrag <(ops),(ops)> {
// Intrinsic that matches this instruction.
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_"
# Abc
# "_" # Type
# "_" # Layout
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma"
# "_" # Geometry # "_load_"
# Abc # "_" # Type # "_" # Layout
# !if(WithStride,"_stride", ""));
code match_generic = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
@ -7453,62 +7457,81 @@ class WMMA_LOAD_INTR_HELPER<string Abc, string Layout, string Space,
!if(!eq(Space, ".global"), match_global, match_generic));
}
multiclass WMMA_LOAD_ALSTS<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass, bit WithStride> {
def _avar: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, imem, WithStride>;
def _areg: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int32Regs, WithStride>;
def _areg64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int64Regs, WithStride>;
def _ari: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri, WithStride>;
def _ari64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri64, WithStride>;
multiclass WMMA_LOAD_GALSTS<string Geometry, string Abc, string Layout,
string Space, string Type, NVPTXRegClass regclass,
bit WithStride> {
def _avar: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
imem, WithStride>;
def _areg: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
Int32Regs, WithStride>;
def _areg64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
Int64Regs, WithStride>;
def _ari: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
MEMri, WithStride>;
def _ari64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
MEMri64, WithStride>;
}
multiclass WMMA_LOAD_ALSTSh<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass, bit WithStride> {
multiclass WMMA_LOAD_GALSTSh<string Geometry, string Abc, string Layout,
string Space, string Type, NVPTXRegClass regclass,
bit WithStride> {
// Define a PatFrag that matches appropriate intrinsic that loads from the
// given address space.
def _Intr : WMMA_LOAD_INTR_HELPER<Abc, Layout, Space, Type, WithStride>;
defm NAME: WMMA_LOAD_ALSTS<Abc, Layout, Space, Type, regclass, WithStride>;
def _Intr: WMMA_LOAD_INTR_HELPER<Geometry, Abc, Layout, Space, Type,
WithStride>;
defm NAME: WMMA_LOAD_GALSTS<Geometry, Abc, Layout, Space, Type, regclass,
WithStride>;
}
multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
string Type, NVPTXRegClass regclass> {
defm _stride: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 1>;
defm NAME: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 0>;
multiclass WMMA_LOAD_GALST<string Geometry, string Abc, string Layout,
string Space, string Type, NVPTXRegClass regclass> {
defm _stride: WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 1>;
defm NAME: WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 0>;
}
multiclass WMMA_LOAD_ALT<string Abc, string Layout,
string Type, NVPTXRegClass regclass> {
defm _global: WMMA_LOAD_ALST<Abc, Layout, ".global", Type, regclass>;
defm _shared: WMMA_LOAD_ALST<Abc, Layout, ".shared", Type, regclass>;
defm NAME: WMMA_LOAD_ALST<Abc, Layout, "", Type, regclass>;
multiclass WMMA_LOAD_GALT<string Geometry, string Abc, string Layout,
string Type, NVPTXRegClass regclass> {
defm _global: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".global",
Type, regclass>;
defm _shared: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".shared",
Type, regclass>;
defm NAME: WMMA_LOAD_GALST<Geometry, Abc, Layout, "",
Type, regclass>;
}
multiclass WMMA_LOAD_AT<string Abc, string Type, NVPTXRegClass regclass> {
defm _row: WMMA_LOAD_ALT<Abc, "row", Type, regclass>;
defm _col: WMMA_LOAD_ALT<Abc, "col", Type, regclass>;
multiclass WMMA_LOAD_GAT<string Geometry, string Abc,
string Type, NVPTXRegClass regclass> {
defm _row: WMMA_LOAD_GALT<Geometry, Abc, "row", Type, regclass>;
defm _col: WMMA_LOAD_GALT<Geometry, Abc, "col", Type, regclass>;
}
defm INT_WMMA_LOAD_A: WMMA_LOAD_AT<"a", "f16", Float16x2Regs>;
defm INT_WMMA_LOAD_B: WMMA_LOAD_AT<"b", "f16", Float16x2Regs>;
defm INT_WMMA_LOAD_C_f16: WMMA_LOAD_AT<"c", "f16", Float16x2Regs>;
defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>;
multiclass WMMA_LOAD_G<string Geometry> {
defm _load_a: WMMA_LOAD_GAT<Geometry, "a", "f16", Float16x2Regs>;
defm _load_b: WMMA_LOAD_GAT<Geometry, "b", "f16", Float16x2Regs>;
defm _load_c_f16: WMMA_LOAD_GAT<Geometry, "c", "f16", Float16x2Regs>;
defm _load_c_f32: WMMA_LOAD_GAT<Geometry, "c", "f32", Float32Regs>;
}
defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">;
//
// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
//
class WMMA_STORE_D_LSTSO<string Layout, string Space,
string Type, NVPTXRegClass regclass,
bit WithStride, DAGOperand DstOp>
class WMMA_STORE_D_GLSTSO<string Geometry, string Layout, string Space,
string Type, NVPTXRegClass regclass,
bit WithStride, DAGOperand DstOp>
: EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_STORE_D"
PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA"
# "_" # Geometry # "_store_d"
# "_" # Type
# "_" # Layout
# !subst(".", "_", Space)
# !if(WithStride,"_stride", "")
# "_Intr");
dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
dag InsR47 = (ins regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1,
regclass:$r2, regclass:$r3);
dag InsR47 = (ins regclass:$r4, regclass:$r5,
regclass:$r6, regclass:$r7);
dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47));
dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins));
dag Ins = !con(InsR, StrideArg);
@ -7525,7 +7548,7 @@ class WMMA_STORE_D_LSTSO<string Layout, string Space,
let InOperandList = Ins;
let AsmString = "wmma.store.d.sync."
# Layout
# ".m16n16k16"
# "." # Geometry
# Space
# "." # Type
# " \t[$src],"
@ -7537,11 +7560,13 @@ class WMMA_STORE_D_LSTSO<string Layout, string Space,
}
class WMMA_STORE_INTR_HELPER<string Layout, string Space,
class WMMA_STORE_INTR_HELPER<string Geometry, string Layout, string Space,
string Type, bit WithStride>
: PatFrag <(ops),(ops)> {
// Intrinsic that matches this instruction.
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d"
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_"
# Geometry
# "_store_d"
# "_" # Type
# "_" # Layout
# !if(WithStride, "_stride", ""));
@ -7566,57 +7591,77 @@ class WMMA_STORE_INTR_HELPER<string Layout, string Space,
!if(!eq(Space, ".global"), match_global, match_generic));
}
multiclass WMMA_STORE_D_LSTS<string Layout, string Space,
string Type, NVPTXRegClass regclass, bit WithStride> {
def _avar: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, imem>;
def _areg: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int32Regs>;
def _areg64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int64Regs>;
def _ari: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri>;
def _ari64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri64>;
multiclass WMMA_STORE_D_GLSTS<string Geometry, string Layout, string Space,
string Type, NVPTXRegClass regclass,
bit WithStride> {
def _avar: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
WithStride, imem>;
def _areg: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
WithStride, Int32Regs>;
def _areg64: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
WithStride, Int64Regs>;
def _ari: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
WithStride, MEMri>;
def _ari64: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
WithStride, MEMri64>;
}
multiclass WMMA_STORE_D_LSTSh<string Layout, string Space,
string Type, NVPTXRegClass regclass, bit WithStride> {
multiclass WMMA_STORE_D_GLSTSh<string Geometry, string Layout, string Space,
string Type, NVPTXRegClass regclass,
bit WithStride> {
// Define a PatFrag that matches appropriate intrinsic that loads from the
// given address space.
def _Intr: WMMA_STORE_INTR_HELPER<Layout, Space, Type, WithStride>;
defm NAME: WMMA_STORE_D_LSTS<Layout, Space, Type, regclass, WithStride>;
def _Intr: WMMA_STORE_INTR_HELPER<Geometry, Layout, Space, Type,
WithStride>;
defm NAME: WMMA_STORE_D_GLSTS<Geometry, Layout, Space, Type, regclass,
WithStride>;
}
multiclass WMMA_STORE_D_LST<string Layout, string Space,
multiclass WMMA_STORE_D_GLST<string Geometry, string Layout, string Space,
string Type, NVPTXRegClass regclass > {
defm _stride: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 1>;
defm NAME: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 0>;
defm _stride: WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 1>;
defm NAME: WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 0>;
}
multiclass WMMA_STORE_D_LT<string Layout,
multiclass WMMA_STORE_D_GLT<string Geometry, string Layout,
string Type, NVPTXRegClass regclass> {
defm _global: WMMA_STORE_D_LST<Layout, ".global", Type, regclass>;
defm _shared: WMMA_STORE_D_LST<Layout, ".shared", Type, regclass>;
defm NAME: WMMA_STORE_D_LST<Layout, "", Type, regclass>;
defm _global: WMMA_STORE_D_GLST<Geometry, Layout, ".global", Type, regclass>;
defm _shared: WMMA_STORE_D_GLST<Geometry, Layout, ".shared", Type, regclass>;
defm NAME: WMMA_STORE_D_GLST<Geometry, Layout, "", Type, regclass>;
}
multiclass WMMA_STORE_D_T<string Type, NVPTXRegClass regclass> {
defm _row: WMMA_STORE_D_LT<"row", Type, regclass>;
defm _col: WMMA_STORE_D_LT<"col", Type, regclass>;
multiclass WMMA_STORE_D_GT<string Geometry, string Type,
NVPTXRegClass regclass> {
defm _row: WMMA_STORE_D_GLT<Geometry, "row", Type, regclass>;
defm _col: WMMA_STORE_D_GLT<Geometry, "col", Type, regclass>;
}
defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>;
defm INT_WMMA_STORE_D_f32: WMMA_STORE_D_T<"f32", Float32Regs>;
multiclass WMMA_STORE_D_G<string Geometry> {
defm _store_d_f16: WMMA_STORE_D_GT<Geometry, "f16", Float16x2Regs>;
defm _store_d_f32: WMMA_STORE_D_GT<Geometry, "f32", Float32Regs>;
}
// multiclass WMMA_STORE_D {
// defm _m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
// }
defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
// WMMA.MMA
class WMMA_MMA_ABDCS<string ALayout, string BLayout,
class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout,
string DType, NVPTXRegClass d_reg,
string CType, NVPTXRegClass c_reg,
NVPTXRegClass ab_reg,
string Satfinite = "">
: EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_mma_sync_"
# ALayout
Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_"
# Geometry
# "_mma"
# "_" # ALayout
# "_" # BLayout
# "_" # DType
# "_" # CType
# !subst(".","_",Satfinite));
# !subst(".", "_", Satfinite));
dag Outs = !if(!eq(DType,"f16"),
(outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
(outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
@ -7655,33 +7700,38 @@ class WMMA_MMA_ABDCS<string ALayout, string BLayout,
"{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};");
}
multiclass WMMA_MMA_ABDC<string ALayout, string BLayout,
multiclass WMMA_MMA_GABDC<string Geometry, string ALayout, string BLayout,
string DType, NVPTXRegClass d_reg,
string CType, NVPTXRegClass c_reg> {
def _satfinite: WMMA_MMA_ABDCS<ALayout, BLayout,
def _satfinite: WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
DType, d_reg, CType, c_reg,
Float16x2Regs, ".satfinite">;
def NAME: WMMA_MMA_ABDCS<ALayout, BLayout,
def NAME: WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
DType, d_reg, CType, c_reg,
Float16x2Regs>;
}
multiclass WMMA_MMA_ABD<string ALayout, string BLayout,
multiclass WMMA_MMA_GABD<string Geometry, string ALayout, string BLayout,
string DType, NVPTXRegClass d_reg> {
defm _f16: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f16", Float16x2Regs>;
defm _f32: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f32", Float32Regs>;
defm _f16: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg,
"f16", Float16x2Regs>;
defm _f32: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg,
"f32", Float32Regs>;
}
multiclass WMMA_MMA_AB<string ALayout, string BLayout> {
defm _f16: WMMA_MMA_ABD<ALayout, BLayout, "f16", Float16x2Regs>;
defm _f32: WMMA_MMA_ABD<ALayout, BLayout, "f32", Float32Regs>;
multiclass WMMA_MMA_GAB<string Geometry, string ALayout, string BLayout> {
defm _f16: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f16", Float16x2Regs>;
defm _f32: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f32", Float32Regs>;
}
multiclass WMMA_MMA_A<string ALayout> {
defm _col: WMMA_MMA_AB<ALayout, "col">;
defm _row: WMMA_MMA_AB<ALayout, "row">;
multiclass WMMA_MMA_GA<string Geometry, string ALayout> {
defm _col: WMMA_MMA_GAB<Geometry, ALayout, "col">;
defm _row: WMMA_MMA_GAB<Geometry, ALayout, "row">;
}
defm INT_WMMA_MMA_col: WMMA_MMA_A<"col">;
defm INT_WMMA_MMA_row: WMMA_MMA_A<"row">;
multiclass WMMA_MMA_G<string Geometry> {
defm _col: WMMA_MMA_GA<Geometry, "col">;
defm _row: WMMA_MMA_GA<Geometry, "row">;
}
defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">;

View File

@ -38,29 +38,29 @@ check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
def gen_wmma_load_tests():
load_template = """
declare ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}(
define ${ret_ty} @test_wmma_load_${function_suffix}(i8 ${as}* %src ${extra_args}) {
; CHECK wmma.load.${intrinsic_suffix}
; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
; CHECK ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
%v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
ret ${ret_ty} %v0;
}
; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o(
define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8 ${as}* %src ${extra_args}) {
; CHECK wmma.load.${intrinsic_suffix}
; CHECK-LABEL: .func{{.*}}test_${function}_o(
define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
; CHECK ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
%v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src1 ${extra_args});
%v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
ret ${ret_ty} %v0;
}
"""
suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
instruction_template = "wmma.load.${abc}.sync.${geom}.${layout}${space}.${itype}"
for abc, layout, space, stride, itype in product(
"abc",
@ -76,16 +76,17 @@ define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8 ${as}* %src ${extra_arg
"stride" : stride,
"itype" : itype,
"pspace" : get_pspace(space),
"as" : "addrspace(%d)" % get_aspace(space)
"as" : "addrspace(%d)" % get_aspace(space),
"geom" : "m16n16k16",
}
if itype == "f32" and abc != "c":
continue
test_params = params
test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
test_params["function"] = test_params["intrinsic"].replace(".","_")
test_params["instruction"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
if abc == "c" :
test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
@ -107,29 +108,29 @@ def make_wmma_slice_args(itype, abcd, prefix="v"):
def gen_wmma_store_tests():
store_template = """
declare void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args}${extra_args});
declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}(
define void @test_wmma_store_${function_suffix}(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
; CHECK-LABEL: .func {{.*}}test_${function}(
define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}
; CHECK: {${check_args}}
; CHECK: ${stride_pattern}
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args} ${extra_args});
call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
ret void
}
; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o(
define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128]
; CHECK-LABEL: .func{{.*}}test_${function}_o(
define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
; CHECK: ${check_args}
; CHECK: ${stride_pattern}
%src1 = getelementptr i8, i8 ${as}* %src, i32 128;
call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src1, ${args}${extra_args});
call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
ret void
}
"""
suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
instruction_template = "wmma.store.${abc}.sync.${geom}.${layout}${space}.${itype}"
for abc, layout, space, stride, itype in product(
"d",
@ -145,13 +146,14 @@ define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra
"stride" : stride,
"itype" : itype,
"pspace" : get_pspace(space),
"as" : "addrspace(%d)" % get_aspace(space)
"as" : "addrspace(%d)" % get_aspace(space),
"geom" : "m16n16k16",
}
test_params = params
test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
test_params["function"] = test_params["intrinsic"].replace(".","_")
test_params["instruction"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
if stride:
@ -166,23 +168,24 @@ define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra
def gen_wmma_mma_tests():
mma_template = """
declare ${ret_ty} @llvm.nvvm.wmma.mma.sync.$intrinsic_suffix(
declare ${ret_ty} @${intrinsic}(
${args});
; CHECK-LABEL: .func {{.*}}test_wmma_mma_${function_suffix}(
define ${ret_ty} @test_wmma_mma_${function_suffix}(
; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(
${args}) {
; CHECK wmma.mma.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}
; CHECK ${check_d}
; CHECK ${check_ab}
; CHECK ${check_ab}
; CHECK ${check_c}
%r = call ${ret_ty} @llvm.nvvm.wmma.mma.sync.${intrinsic_suffix}(
%r = call ${ret_ty} @${intrinsic}(
${args});
ret ${ret_ty} %r;
}
"""
suffix_template = "${alayout}.${blayout}.m16n16k16.${dtype}.${ctype}${satf}"
intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
for alayout, blayout, ctype, dtype, satf in product(
["row","col"],
@ -196,12 +199,14 @@ define ${ret_ty} @test_wmma_mma_${function_suffix}(
"blayout" : blayout,
"ctype" : ctype,
"dtype" : dtype,
"satf" : satf
"satf" : satf,
"geom" : "m16n16k16",
}
test_params = params
test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".", "_")
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
test_params["function"] = test_params["intrinsic"].replace(".", "_")
test_params["instruction"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
test_params["check_ab"] = check_f16_8
test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8