mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-23 22:00:10 +00:00
[mlir][nvvm] Fix support for tf32 data type in mma.sync
The NVVM dialect test coverage for all possible type/shape combinations in the `nvvm.mma.sync` op is mostly complete. However, there were tests missing for TF32 datatype support. This change adds tests for the one relevant shape/type combination. This uncovered a small bug in the op verifier, which this change also fixes. Differential Revision: https://reviews.llvm.org/D124975
This commit is contained in:
parent
6385c039b8
commit
22c6e7b277
@ -81,8 +81,10 @@ Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
|
||||
return NVVM::MMATypes::f64;
|
||||
if (operandElType.isF16() || operandElType == half2Type)
|
||||
return NVVM::MMATypes::f16;
|
||||
if (operandElType.isF32())
|
||||
if (operandElType.isF32() && isAccumulator)
|
||||
return NVVM::MMATypes::f32;
|
||||
if (operandElType.isF32() && !isAccumulator)
|
||||
return NVVM::MMATypes::tf32;
|
||||
if (operandElType.isa<IntegerType>()) {
|
||||
if (isAccumulator)
|
||||
return NVVM::MMATypes::s32;
|
||||
@ -291,7 +293,7 @@ ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
parser.getNameLoc(),
|
||||
"expected one type for each operand segment but got " +
|
||||
Twine(operandTypes.size()) + " types");
|
||||
for (const auto& iter : llvm::enumerate(operandTypes)) {
|
||||
for (const auto &iter : llvm::enumerate(operandTypes)) {
|
||||
auto &frag = frags[iter.index()];
|
||||
frag.regTypes.resize(frag.regs.size(), iter.value());
|
||||
if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
|
||||
@ -376,8 +378,9 @@ LogicalResult MmaOp::verify() {
|
||||
switch (multiplicandAPtxType().getValue()) {
|
||||
case MMATypes::tf32:
|
||||
kFactor = 4;
|
||||
multiplicandFragType = i32Ty;
|
||||
expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
|
||||
context, {i32Ty, i32Ty, i32Ty, i32Ty}));
|
||||
context, {f32Ty, f32Ty, f32Ty, f32Ty}));
|
||||
break;
|
||||
case MMATypes::f16:
|
||||
case MMATypes::bf16:
|
||||
|
@ -152,6 +152,17 @@ func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
|
||||
%b0 : i32,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
|
||||
shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
|
||||
%c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {
|
||||
// CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
|
||||
|
@ -203,6 +203,17 @@ llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64,
|
||||
llvm.return %0 : !llvm.struct<(f64, f64)>
|
||||
}
|
||||
|
||||
llvm.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
|
||||
%b0 : i32,
|
||||
%c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
|
||||
// CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32
|
||||
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
|
||||
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
|
||||
multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>,
|
||||
shape = {m = 16 : i32, n = 8 : i32, k = 4 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
|
||||
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic
|
||||
// in the LLVM NVPTX backend.
|
||||
// CHECK-LABEL: @gpu_wmma_load_op
|
||||
|
Loading…
Reference in New Issue
Block a user