mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-10 11:23:52 +00:00
[mlir][NVVM] Add ldmatrix op to NVVM dialect
Differential Revision: https://reviews.llvm.org/D121347
This commit is contained in:
parent
c7f25b6fd4
commit
2f33f11428
@ -634,4 +634,48 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
|
||||
Results<(outs AnyType:$res)>,
|
||||
Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
|
||||
|
||||
let summary = "cooperative matrix load";
|
||||
|
||||
string llvmBuilder = [{
|
||||
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
|
||||
auto intId = getLdMatrixIntrinsicId($layout, $num);
|
||||
$res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
|
||||
}];
|
||||
|
||||
string baseDescription = [{
|
||||
The `nvvm.ldmatrix` operation collectively loads one or more matrices across
|
||||
all threads in a warp from the location indicated by the address operand
|
||||
`ptr` from shared memory.
|
||||
|
||||
The attribute `num` indicates how many 8x8 16-bit matrices are to be loaded.
|
||||
|
||||
All the threads in the warp must execute the same ldmatrix operations.
|
||||
|
||||
Each row of 8 elements needs to be consecutive in memory. Each lane of the
|
||||
warp contains the start address of a row of 8 elements laid out as below:
|
||||
|
||||
```
|
||||
num | lane 0--7 | Threads 8--15 | Threads 16--31
|
||||
1 | addr0--addr7 | |
|
||||
2 | addr0--addr7 | addr8--addr15 |
|
||||
4 | addr0--addr7 | addr8--addr15 | addr16--addr31
|
||||
```
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%l1 = nvvm.ldmatrix %ptr {num = 1 : i32, layout = #nvvm.mma_layout<row>} :
|
||||
(!llvm.ptr<i32, 3>) -> i32
|
||||
%l2 = nvvm.ldmatrix %ptr {num = 4 : i32, layout = #nvvm.mma_layout<row>} :
|
||||
(!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
|
||||
```
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$ptr attr-dict `:` functional-type($ptr, $res)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // NVVMIR_OPS
|
||||
|
@ -219,6 +219,28 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult NVVM::LdMatrixOp::verify() {
|
||||
unsigned addressSpace =
|
||||
ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
|
||||
if (addressSpace != 3)
|
||||
return emitOpError("expected source pointer in memory space 3");
|
||||
|
||||
if (num() != 1 && num() != 2 && num() != 4)
|
||||
return emitOpError("expected num attribute to be 1, 2 or 4");
|
||||
|
||||
Type i32 = IntegerType::get(getContext(), 32);
|
||||
if (num() == 1 && getType() != i32)
|
||||
return emitOpError("expected destination type is i32");
|
||||
if (num() == 2 || num() == 4) {
|
||||
Type dstType = LLVM::LLVMStructType::getLiteral(
|
||||
getContext(), SmallVector<Type>(num(), i32));
|
||||
if (getType() != dstType)
|
||||
return emitOpError("expected destination type is a structure of ")
|
||||
<< num() << " elements of type i32";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NVVMDialect initialization, type parsing, and registration.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -64,6 +64,35 @@ static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
|
||||
llvm_unreachable("unknown shuffle kind");
|
||||
}
|
||||
|
||||
/// Return the intrinsic ID associated with ldmatrix for the given paramters.
|
||||
static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
|
||||
int32_t num) {
|
||||
if (layout == NVVM::MMALayout::col) {
|
||||
switch (num) {
|
||||
case 1:
|
||||
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
|
||||
case 2:
|
||||
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
|
||||
case 4:
|
||||
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
|
||||
default:
|
||||
llvm_unreachable("unsupported number of matrix");
|
||||
}
|
||||
|
||||
} else {
|
||||
switch (num) {
|
||||
case 1:
|
||||
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
|
||||
case 2:
|
||||
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
|
||||
case 4:
|
||||
return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
|
||||
default:
|
||||
llvm_unreachable("unsupported number of matrix");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Implementation of the dialect interface that converts operations belonging
|
||||
/// to the NVVM dialect to LLVM IR.
|
||||
|
@ -1191,6 +1191,38 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32>) {
|
||||
// expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
|
||||
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32>) -> i32
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
|
||||
// expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
|
||||
%l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
|
||||
// expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}}
|
||||
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32)>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
|
||||
// expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
|
||||
%l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @caller() {
|
||||
// expected-error @below {{expected function call to produce a value}}
|
||||
llvm.call @callee() : () -> ()
|
||||
|
@ -105,6 +105,16 @@ llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @ld_matrix
|
||||
llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
|
||||
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<i32, 3>) -> i32
|
||||
%l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
|
||||
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
|
||||
%l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
|
||||
// CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
|
||||
%l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
|
||||
llvm.return
|
||||
}
|
||||
// -----
|
||||
|
||||
// expected-error@below {{attribute attached to unexpected op}}
|
||||
|
@ -176,6 +176,17 @@ llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ld_matrix(
|
||||
llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
|
||||
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
|
||||
%l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
|
||||
// CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
|
||||
%l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
|
||||
// CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
|
||||
%l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// This function has the "kernel" attribute attached and should appear in the
|
||||
// NVVM annotations after conversion.
|
||||
llvm.func @kernel_func() attributes {nvvm.kernel} {
|
||||
|
Loading…
x
Reference in New Issue
Block a user