mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-24 06:10:12 +00:00
[mlir][NVGPU] Handle native mma.sync and ldmatrix(x4) sizes
This patch handles native `mma.sync` sizes and enables issuing `ldmatrix` on largest possible tiles for matrixB. It requires handling `vector.extract_strided_slice` from vector to ngpu lowering. Differential Revision: https://reviews.llvm.org/D135749
This commit is contained in:
parent
97196a2d92
commit
114ba722c1
@ -13,25 +13,22 @@
|
||||
#ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
|
||||
#define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace vector {
|
||||
enum class IteratorType : uint32_t;
|
||||
class ContractionOp;
|
||||
} // namespace vector
|
||||
|
||||
namespace NVVM {
|
||||
enum class MMALayout : uint32_t;
|
||||
} // namespace NVVM
|
||||
|
||||
namespace nvgpu {
|
||||
|
||||
/// Represents the role of an operand in an MMA instruction:
|
||||
/// `result := matmul(A, B) + C`
|
||||
enum class MatMulOperandRole : int32_t { A = 0, B, C };
|
||||
|
||||
/// Returns the first user of the `op` that is vector.contract. If no
|
||||
/// vector.contract user exists, return failure.
|
||||
FailureOr<vector::ContractionOp> getUserContract(Operation *op);
|
||||
|
||||
/// Collects information about a warp-level matrix operand represented by a
|
||||
/// VectorType.
|
||||
struct WarpMatrixInfo {
|
||||
|
@ -192,6 +192,33 @@ static bool elementwiseSupportsMMAMatrixType(Operation *op) {
|
||||
return convertElementwiseOpToMMA(op).has_value();
|
||||
}
|
||||
|
||||
/// Returns true if the extract strided slice op is supported with `mma.sync`
|
||||
/// path.
|
||||
static bool
|
||||
extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
|
||||
|
||||
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
|
||||
nvgpu::getWarpMatrixInfo(op);
|
||||
if (failed(warpMatrixInfo))
|
||||
return false;
|
||||
|
||||
FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
|
||||
if (failed(contractOp))
|
||||
return false;
|
||||
|
||||
// Handle vector.extract_strided_slice on registers containing
|
||||
// matrixB and matrixC operands. vector.extract_strided_slice op
|
||||
// is not supported on registers containing matrixA operands.
|
||||
if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
|
||||
return (op->getResult(0).getType().cast<VectorType>() ==
|
||||
(*contractOp).getRhs().getType().cast<VectorType>());
|
||||
else if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
|
||||
return (op->getResult(0).getType().cast<VectorType>() ==
|
||||
(*contractOp).getAcc().getType().cast<VectorType>());
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
|
||||
if (isa<scf::ForOp, scf::YieldOp>(op))
|
||||
return true;
|
||||
@ -199,6 +226,9 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
|
||||
return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
|
||||
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
|
||||
return transferWriteSupportsMMAMatrixType(transferWrite);
|
||||
if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
|
||||
return useNvGpu &&
|
||||
extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
|
||||
if (auto contract = dyn_cast<vector::ContractionOp>(op))
|
||||
return contractSupportsMMAMatrixType(contract, useNvGpu);
|
||||
if (auto constant = dyn_cast<arith::ConstantOp>(op))
|
||||
@ -338,8 +368,10 @@ struct PrepareContractToGPUMMA
|
||||
}
|
||||
};
|
||||
|
||||
// Merge transpose op into the transfer read op. Transpose are not supported on
|
||||
// MMA types but MMA load can transpose the matrix when loading.
|
||||
// Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports
|
||||
// row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
|
||||
// respectively. We can fold the transpose operation when loading the data from
|
||||
// Shared Memory to registers.
|
||||
struct CombineTransferReadOpTranspose final
|
||||
: public OpRewritePattern<vector::TransposeOp> {
|
||||
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
|
||||
@ -620,7 +652,7 @@ convertTransferReadToLoads(vector::TransferReadOp op,
|
||||
int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
|
||||
|
||||
// When we are transposing the B operand, ldmatrix will only work if we have
|
||||
// at least 8 rows to read and the width to read for the transpose is 128
|
||||
// at least 8 rows to read and the width to read for the transpose is 128
|
||||
// bits.
|
||||
if (!op.getPermutationMap().isMinorIdentity() &&
|
||||
(bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
|
||||
@ -671,6 +703,83 @@ convertTransferWriteToStores(vector::TransferWriteOp op,
|
||||
return success();
|
||||
}
|
||||
|
||||
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
|
||||
SmallVectorImpl<int64_t> &results) {
|
||||
for (auto attr : arrayAttr)
|
||||
results.push_back(attr.cast<IntegerAttr>().getInt());
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
|
||||
llvm::DenseMap<Value, Value> &valueMapping) {
|
||||
|
||||
OpBuilder b(op);
|
||||
Location loc = op->getLoc();
|
||||
|
||||
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
|
||||
nvgpu::getWarpMatrixInfo(op);
|
||||
if (failed(warpMatrixInfo))
|
||||
return failure();
|
||||
|
||||
FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
|
||||
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
|
||||
if (failed(mmaSyncFragmentInfo))
|
||||
return failure();
|
||||
|
||||
// Find the vector.transer_read whose result vector is being sliced.
|
||||
auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
|
||||
if (!transferReadOp)
|
||||
return failure();
|
||||
|
||||
warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
|
||||
if (failed(warpMatrixInfo))
|
||||
return failure();
|
||||
|
||||
FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
|
||||
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
|
||||
if (failed(ldFragmentInfo))
|
||||
return failure();
|
||||
|
||||
assert(
|
||||
(mmaSyncFragmentInfo->elementsPerRegister ==
|
||||
ldFragmentInfo->elementsPerRegister) &&
|
||||
"Number of elements per register should be same for load and mma.sync");
|
||||
|
||||
// Create vector.extract_strided_slice op for thread-owned fragments.
|
||||
std::array<int64_t, 2> strides = {1,
|
||||
1}; // stride for extract slice is always 1.
|
||||
std::array<int64_t, 2> sliceShape = {
|
||||
mmaSyncFragmentInfo->numRegistersPerFragment,
|
||||
mmaSyncFragmentInfo->elementsPerRegister};
|
||||
auto sourceVector = valueMapping.find(transferReadOp)->second;
|
||||
|
||||
// offset and sizes at warp-level of onwership.
|
||||
SmallVector<int64_t> offsets;
|
||||
populateFromInt64AttrArray(op.getOffsets(), offsets);
|
||||
|
||||
SmallVector<int64_t> sizes;
|
||||
populateFromInt64AttrArray(op.getSizes(), sizes);
|
||||
ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
|
||||
|
||||
// Compute offset in vector registers. Note that the mma.sync vector registers
|
||||
// are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
|
||||
// registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
|
||||
std::array<int64_t, 2> sliceOffset = {0, 0};
|
||||
|
||||
if (offsets[0] && offsets[1])
|
||||
return op->emitError() << "Slicing fragments in 2D is not supported. ";
|
||||
else if (offsets[0])
|
||||
sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
|
||||
else if (offsets[1])
|
||||
sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
|
||||
|
||||
Value newOp = b.create<vector::ExtractStridedSliceOp>(
|
||||
loc, sourceVector, sliceOffset, sliceShape, strides);
|
||||
|
||||
valueMapping[op] = newOp;
|
||||
return success();
|
||||
}
|
||||
|
||||
static void convertContractOp(vector::ContractionOp op,
|
||||
llvm::DenseMap<Value, Value> &valueMapping) {
|
||||
OpBuilder b(op);
|
||||
@ -858,6 +967,10 @@ LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) {
|
||||
return convertTransferWriteToStores(transferWriteOp,
|
||||
valueMapping);
|
||||
})
|
||||
.Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
|
||||
return convertExtractStridedSlice(extractStridedSliceOp,
|
||||
valueMapping);
|
||||
})
|
||||
.Case([&](vector::ContractionOp contractionOp) {
|
||||
return convertContractOpToMmaSync(contractionOp, valueMapping);
|
||||
})
|
||||
|
@ -45,14 +45,24 @@ static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
|
||||
lineSizeBits};
|
||||
}
|
||||
|
||||
/// Returns the first user of the `op` that is vector.contract. If no
|
||||
/// vector.contract user exists, return failure.
|
||||
FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
|
||||
for (Operation *user : op->getUsers()) {
|
||||
if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
|
||||
return contractOp;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
|
||||
WarpMatrixInfo info;
|
||||
|
||||
// Determine the vector type.
|
||||
// Determine the vector type at warp-level.
|
||||
if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
|
||||
info.vectorType = writeOp.getVectorType();
|
||||
} else if (isa<vector::TransferReadOp, vector::ContractionOp,
|
||||
arith::ConstantOp>(op)) {
|
||||
vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
|
||||
info.vectorType = op->getResult(0).getType().cast<VectorType>();
|
||||
} else {
|
||||
return op->emitError()
|
||||
@ -62,19 +72,15 @@ FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
|
||||
// Determine the operand role. We assume it is an accumulator/result unless it
|
||||
// is directly consumed by a `vector.contract` op.
|
||||
info.operandRole = MatMulOperandRole::C;
|
||||
for (Operation *user : op->getUsers()) {
|
||||
auto contract = dyn_cast<vector::ContractionOp>(user);
|
||||
if (!contract)
|
||||
continue;
|
||||
if (contract.getLhs() == op->getResult(0)) {
|
||||
info.operandRole = MatMulOperandRole::A;
|
||||
break;
|
||||
}
|
||||
if (contract.getRhs() == op->getResult(0)) {
|
||||
info.operandRole = MatMulOperandRole::B;
|
||||
break;
|
||||
}
|
||||
}
|
||||
FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
|
||||
if (failed(contractOp))
|
||||
return info;
|
||||
|
||||
if ((*contractOp).getLhs() == op->getResult(0))
|
||||
info.operandRole = MatMulOperandRole::A;
|
||||
else if ((*contractOp).getRhs() == op->getResult(0))
|
||||
info.operandRole = MatMulOperandRole::B;
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
|
@ -164,9 +164,9 @@ func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x
|
||||
|
||||
// -----
|
||||
|
||||
//#########################################################
|
||||
// FP16 row-row-row
|
||||
//#########################################################
|
||||
//#########################################################################
|
||||
// FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x2 for matrixB)
|
||||
//#########################################################################
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d1, d0)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
@ -203,6 +203,62 @@ func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<
|
||||
|
||||
// -----
|
||||
|
||||
//#########################################################################
|
||||
// FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB)
|
||||
//#########################################################################
|
||||
|
||||
// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
|
||||
// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)>
|
||||
// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 - ((s0 floordiv 8) floordiv 2) * 16)>
|
||||
// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + ((s0 floordiv 8) floordiv 2) * 8)>
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d1, d0)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: func @m16n16k16_mmasync16816_fp16_f16_row_row_row
|
||||
func.func @m16n16k16_mmasync16816_fp16_f16_row_row_row(%arg0: memref<42x32xf16, 3>, %arg1: memref<32x64xf16, 3>, %arg2: memref<42x64xf16, 3>) {
|
||||
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
|
||||
%c0 = arith.constant 0 : index
|
||||
%c8 = arith.constant 8 : index
|
||||
%cst = arith.constant 0.000000e+00 : f16
|
||||
|
||||
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
|
||||
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
|
||||
// CHECK: [[fragmentA:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
|
||||
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x32xf16, 3>, vector<16x16xf16>
|
||||
|
||||
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
|
||||
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
|
||||
// CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[col]], [[row]]] {numTiles = 4 : i32, transpose = true}
|
||||
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, 3>, vector<16x16xf16>
|
||||
|
||||
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
|
||||
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
|
||||
// CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
|
||||
%C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x64xf16, 3>, vector<16x16xf16>
|
||||
|
||||
// CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
|
||||
// CHECK-DAG: [[fragmentC0:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
|
||||
// CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
|
||||
%B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
|
||||
%C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
|
||||
%D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
|
||||
vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, 3>
|
||||
|
||||
// CHECK-DAG: [[fragmentB1:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
|
||||
// CHECK-DAG: [[fragmentC1:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
|
||||
// CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB1]], [[fragmentC1]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
|
||||
%B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
|
||||
%C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
|
||||
%D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B1, %C1 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
|
||||
vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, 3>
|
||||
|
||||
return
|
||||
}
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
|
||||
// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
|
||||
// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)>
|
||||
|
Loading…
Reference in New Issue
Block a user