[mlir][NVGPU] Handle native mma.sync and ldmatrix(x4) sizes

This patch handles native `mma.sync` sizes and enables issuing `ldmatrix` on
largest possible tiles for matrixB. It requires handling
`vector.extract_strided_slice` from vector to ngpu lowering.

Differential Revision: https://reviews.llvm.org/D135749
This commit is contained in:
Manish Gupta 2022-10-12 05:17:32 +00:00
parent 97196a2d92
commit 114ba722c1
4 changed files with 202 additions and 30 deletions

View File

@ -13,25 +13,22 @@
#ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
#define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
namespace mlir {
namespace vector {
enum class IteratorType : uint32_t;
class ContractionOp;
} // namespace vector
namespace NVVM {
enum class MMALayout : uint32_t;
} // namespace NVVM
namespace nvgpu {
/// Represents the role of an operand in an MMA instruction:
/// `result := matmul(A, B) + C`
enum class MatMulOperandRole : int32_t { A = 0, B, C };
/// Returns the first user of the `op` that is vector.contract. If no
/// vector.contract user exists, return failure.
FailureOr<vector::ContractionOp> getUserContract(Operation *op);
/// Collects information about a warp-level matrix operand represented by a
/// VectorType.
struct WarpMatrixInfo {

View File

@ -192,6 +192,33 @@ static bool elementwiseSupportsMMAMatrixType(Operation *op) {
return convertElementwiseOpToMMA(op).has_value();
}
/// Returns true if the extract strided slice op is supported with `mma.sync`
/// path.
static bool
extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo))
return false;
FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
if (failed(contractOp))
return false;
// Handle vector.extract_strided_slice on registers containing
// matrixB and matrixC operands. vector.extract_strided_slice op
// is not supported on registers containing matrixA operands.
if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
return (op->getResult(0).getType().cast<VectorType>() ==
(*contractOp).getRhs().getType().cast<VectorType>());
else if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
return (op->getResult(0).getType().cast<VectorType>() ==
(*contractOp).getAcc().getType().cast<VectorType>());
return false;
}
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
if (isa<scf::ForOp, scf::YieldOp>(op))
return true;
@ -199,6 +226,9 @@ static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
return transferWriteSupportsMMAMatrixType(transferWrite);
if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
return useNvGpu &&
extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
if (auto contract = dyn_cast<vector::ContractionOp>(op))
return contractSupportsMMAMatrixType(contract, useNvGpu);
if (auto constant = dyn_cast<arith::ConstantOp>(op))
@ -338,8 +368,10 @@ struct PrepareContractToGPUMMA
}
};
// Merge transpose op into the transfer read op. Transpose are not supported on
// MMA types but MMA load can transpose the matrix when loading.
// Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports
// row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
// respectively. We can fold the transpose operation when loading the data from
// Shared Memory to registers.
struct CombineTransferReadOpTranspose final
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
@ -620,7 +652,7 @@ convertTransferReadToLoads(vector::TransferReadOp op,
int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
// When we are transposing the B operand, ldmatrix will only work if we have
// at least 8 rows to read and the width to read for the transpose is 128
// at least 8 rows to read and the width to read for the transpose is 128
// bits.
if (!op.getPermutationMap().isMinorIdentity() &&
(bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
@ -671,6 +703,83 @@ convertTransferWriteToStores(vector::TransferWriteOp op,
return success();
}
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
SmallVectorImpl<int64_t> &results) {
for (auto attr : arrayAttr)
results.push_back(attr.cast<IntegerAttr>().getInt());
}
static LogicalResult
convertExtractStridedSlice(vector::ExtractStridedSliceOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
OpBuilder b(op);
Location loc = op->getLoc();
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo))
return failure();
FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(mmaSyncFragmentInfo))
return failure();
// Find the vector.transer_read whose result vector is being sliced.
auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
if (!transferReadOp)
return failure();
warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
if (failed(warpMatrixInfo))
return failure();
FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(ldFragmentInfo))
return failure();
assert(
(mmaSyncFragmentInfo->elementsPerRegister ==
ldFragmentInfo->elementsPerRegister) &&
"Number of elements per register should be same for load and mma.sync");
// Create vector.extract_strided_slice op for thread-owned fragments.
std::array<int64_t, 2> strides = {1,
1}; // stride for extract slice is always 1.
std::array<int64_t, 2> sliceShape = {
mmaSyncFragmentInfo->numRegistersPerFragment,
mmaSyncFragmentInfo->elementsPerRegister};
auto sourceVector = valueMapping.find(transferReadOp)->second;
// offset and sizes at warp-level of onwership.
SmallVector<int64_t> offsets;
populateFromInt64AttrArray(op.getOffsets(), offsets);
SmallVector<int64_t> sizes;
populateFromInt64AttrArray(op.getSizes(), sizes);
ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
// Compute offset in vector registers. Note that the mma.sync vector registers
// are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
// registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
std::array<int64_t, 2> sliceOffset = {0, 0};
if (offsets[0] && offsets[1])
return op->emitError() << "Slicing fragments in 2D is not supported. ";
else if (offsets[0])
sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
else if (offsets[1])
sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
Value newOp = b.create<vector::ExtractStridedSliceOp>(
loc, sourceVector, sliceOffset, sliceShape, strides);
valueMapping[op] = newOp;
return success();
}
static void convertContractOp(vector::ContractionOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
OpBuilder b(op);
@ -858,6 +967,10 @@ LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) {
return convertTransferWriteToStores(transferWriteOp,
valueMapping);
})
.Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
return convertExtractStridedSlice(extractStridedSliceOp,
valueMapping);
})
.Case([&](vector::ContractionOp contractionOp) {
return convertContractOpToMmaSync(contractionOp, valueMapping);
})

View File

@ -45,14 +45,24 @@ static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
lineSizeBits};
}
/// Returns the first user of the `op` that is vector.contract. If no
/// vector.contract user exists, return failure.
FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
for (Operation *user : op->getUsers()) {
if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
return contractOp;
}
return failure();
}
FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
WarpMatrixInfo info;
// Determine the vector type.
// Determine the vector type at warp-level.
if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
info.vectorType = writeOp.getVectorType();
} else if (isa<vector::TransferReadOp, vector::ContractionOp,
arith::ConstantOp>(op)) {
vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
info.vectorType = op->getResult(0).getType().cast<VectorType>();
} else {
return op->emitError()
@ -62,19 +72,15 @@ FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
// Determine the operand role. We assume it is an accumulator/result unless it
// is directly consumed by a `vector.contract` op.
info.operandRole = MatMulOperandRole::C;
for (Operation *user : op->getUsers()) {
auto contract = dyn_cast<vector::ContractionOp>(user);
if (!contract)
continue;
if (contract.getLhs() == op->getResult(0)) {
info.operandRole = MatMulOperandRole::A;
break;
}
if (contract.getRhs() == op->getResult(0)) {
info.operandRole = MatMulOperandRole::B;
break;
}
}
FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
if (failed(contractOp))
return info;
if ((*contractOp).getLhs() == op->getResult(0))
info.operandRole = MatMulOperandRole::A;
else if ((*contractOp).getRhs() == op->getResult(0))
info.operandRole = MatMulOperandRole::B;
return info;
}

View File

@ -164,9 +164,9 @@ func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x
// -----
//#########################################################
// FP16 row-row-row
//#########################################################
//#########################################################################
// FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x2 for matrixB)
//#########################################################################
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
@ -203,6 +203,62 @@ func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<
// -----
//#########################################################################
// FP16 row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB)
//#########################################################################
// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16)>
// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8)>
// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 - ((s0 floordiv 8) floordiv 2) * 16)>
// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + ((s0 floordiv 8) floordiv 2) * 8)>
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @m16n16k16_mmasync16816_fp16_f16_row_row_row
func.func @m16n16k16_mmasync16816_fp16_f16_row_row_row(%arg0: memref<42x32xf16, 3>, %arg1: memref<32x64xf16, 3>, %arg2: memref<42x64xf16, 3>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%cst = arith.constant 0.000000e+00 : f16
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
// CHECK: [[fragmentA:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x32xf16, 3>, vector<16x16xf16>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
// CHECK-DAG: [[fragmentB:%.+]] = nvgpu.ldmatrix %arg1[[[col]], [[row]]] {numTiles = 4 : i32, transpose = true}
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, 3>, vector<16x16xf16>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
// CHECK-DAG: [[fragmentC:%.*]] = nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
%C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<42x64xf16, 3>, vector<16x16xf16>
// CHECK-DAG: [[fragmentB0:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
// CHECK-DAG: [[fragmentC0:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
// CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB0]], [[fragmentC0]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
%B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
%C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
%D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B0, %C0 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, 3>
// CHECK-DAG: [[fragmentB1:%.+]] = vector.extract_strided_slice [[fragmentB]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
// CHECK-DAG: [[fragmentC1:%.+]] = vector.extract_strided_slice [[fragmentC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf16> to vector<2x2xf16>
// CHECK: nvgpu.mma.sync([[fragmentA]], [[fragmentB1]], [[fragmentC1]]) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
%B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
%C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf16> to vector<16x8xf16>
%D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B1, %C1 : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<42x64xf16, 3>
return
}
// -----
// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)>