[mlir][nvgpu] Improve nvgpu->nvvm transformation of warpgroup.mma Op (NFC) (#67325)

This PR introduces substantial improvements to the readability and
maintainability of the `nvgpu.warpgroup.mma` Op transformation from
nvgpu->nvvm. This transformation plays a crucial role in GEMM and
manages complex operations such as generating multiple wgmma ops and
iterating their descriptors. The prior code lacked clarity, but this PR
addresses that issue effectively.

**PR does followings:**
**Introduces a helper class:** `WarpgroupGemm` class encapsulates the
necessary functionality, making the code cleaner and more
understandable.

**Detailed Documentation:** Each function within the helper class is
thoroughly documented to provide clear insights into its purpose and
functionality.
This commit is contained in:
Guray Ozen 2023-10-05 10:16:59 +02:00 committed by GitHub
parent 7eb2b99f16
commit b74cfc139a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -39,7 +39,7 @@ namespace mlir {
using namespace mlir;
/// Number of bits that needs to excluded when building matrix descriptor for
/// Number of bits that needs to be excluded when building matrix descriptor for
/// wgmma operations.
constexpr int exclude4LSB = 4;
@ -1160,137 +1160,276 @@ struct NVGPUWarpgroupMmaOpLowering
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
int &wgmmaShapeM, int &wgmmaShapeN,
int &wgmmaShapeK) const {
wgmmaShapeM = 64;
wgmmaShapeN = sizeN;
if (inputElemType.isTF32()) {
wgmmaShapeK = 8;
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
wgmmaShapeK = 16;
} else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
inputElemType.isInteger(16)) {
wgmmaShapeK = 32;
} else if (inputElemType.isInteger(1)) {
wgmmaShapeK = 256;
} else {
llvm_unreachable("msg: not supported K shape");
}
LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
<< ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
<< "]\n");
return success();
}
/// This is a helper class to generate required NVVM Ops for warp-group level
/// matrix multiplication.
/// When the given GEMM shape is larger than the shape of
/// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
/// Op(s), group and execute them asynchronously. The class also handles
/// waiting for completion and iterates through WarpgroupMatrixDescriptor to
/// create descriptors for each instruction.
///
/// For example this is the case when the shape of GEMM is 128x128x128
///
/// nvvm.wgmma.fence.aligned
///
/// nvvm.wgmma.mma.async descA, descB
/// iterate(descA, descB)
/// nvvm.wgmma.mma.async descA, descB
/// [6x times more]
///
/// nvvm.wgmma.group.sync.aligned
/// nvvm.wgmma.wait.group.sync [groupId]
///
class WarpgroupGemm {
nvgpu::WarpgroupMmaOp op;
ImplicitLocOpBuilder b;
OpAdaptor adaptor;
const LLVMTypeConverter &typeConverter;
Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
Type resultStructType, Value inout,
Value descriptorA, Value descriptorB) const {
MLIRContext *ctx = b.getContext();
auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
// todo: handle other input and output types
auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
auto overflow =
NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
return res;
}
// Entire shape of the given Op
int64_t totalM, totalN, totalK;
LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
// Shape of one wgmma instruction
int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
<< sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
<< sizeN << "] ---===\n");
// Iteration counts for GEMM
int iterationM = 0, iterationN = 0, iterationK = 0;
int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
wgmmaShapeN, wgmmaShapeK))) {
return failure();
/// The function returns the shape of wgmma instruction that is defined in
/// PTX programming guide.
/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
wgmmaM = 64;
wgmmaN = sizeN;
if (inputElemType.isTF32()) {
wgmmaK = 8;
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
wgmmaK = 16;
} else if (inputElemType.isFloat8E4M3FN() ||
inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
wgmmaK = 32;
} else if (inputElemType.isInteger(1)) {
wgmmaK = 256;
} else {
llvm_unreachable("msg: not supported K shape");
}
LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
<< ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
}
Value descriptorA = adaptor.getDescriptorA();
Value descriptorB = adaptor.getDescriptorB();
/// Generates WGMMATypesAttr from MLIR Type
NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
auto getWgmmaType = [](Type elemType) {
if (elemType.isF32() || elemType.isTF32())
return NVVM::WGMMATypes::tf32;
if (elemType.isF16())
return NVVM::WGMMATypes::f16;
if (elemType.isBF16())
return NVVM::WGMMATypes::bf16;
if (elemType.isFloat8E4M3FN())
return NVVM::WGMMATypes::e4m3;
if (elemType.isFloat8E5M2())
return NVVM::WGMMATypes::e5m2;
if (elemType.isInteger(1))
return NVVM::WGMMATypes::b1;
if (elemType.isInteger(8))
return NVVM::WGMMATypes::s8;
if (elemType.isUnsignedInteger(8))
return NVVM::WGMMATypes::u8;
llvm_unreachable("unsupported type");
};
return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
}
// Generate wgmma group
MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
/// Generates layout attribute for the input matrix for wgmma instruction
NVVM::MMALayoutAttr
generateWgmmaLayout(std::optional<bool> transpose) const {
if (transpose.value_or(false))
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
}
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
/// Generates shape attribute for wgmma instruction
NVVM::MMAShapeAttr generateWgmmaShape() const {
return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
}
/// Generates scale attributes of output matrix for wgmma instruction
NVVM::WGMMAScaleOutAttr generateScaleOut() const {
return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
NVVM::WGMMAScaleOut::one);
}
/// Generates scale attributes of input matrix for wgmma instruction
NVVM::WGMMAScaleInAttr generateScaleIn() const {
return NVVM::WGMMAScaleInAttr::get(op->getContext(),
NVVM::WGMMAScaleIn::one);
}
/// Basic function to generate Add
Value makeAdd(Value lhs, Value rhs) {
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
};
auto iterateDescA = [&](Value desc, int iterM, int iterN,
int iterK) -> Value {
// todo : Handle column major
int byte = typeTensorA.getElementTypeBitWidth() / 8;
int tileShapeA = typeTensorA.getDimSize(1);
int incrementVal =
((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
/// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
/// Currently, it only handles row-major.
///
/// It moves the pointer like below for [128][64] size:
/// +2 +4 +6
/// ↓ ↓ ↓
/// descA ---> +--+--+--+--+
/// |->|->|->|->|
/// | | | | |
/// | | | | |
/// | | | | |
/// descA+512---> +-----------+
/// | | | | |
/// | | | | |
/// | | | | |
/// | | | | |
/// +-----------+
///
Value iterateDescriptorA(Value desc, int i, int j, int k) {
MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
Type elemA = matrixTypeA.getElementType();
int byte = elemA.getIntOrFloatBitWidth() / 8;
int tileShapeA = matrixTypeA.getDimSize(1);
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
incrementVal = incrementVal >> exclude4LSB;
LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
<< iterK << "] [wgmma descriptors] Descriptor A + "
LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
<< "] [wgmma descriptors] Descriptor A + "
<< incrementVal << " | \t ");
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
};
}
auto iterateDescB = [&](Value desc, int iterM, int iterN,
int iterK) -> Value {
// todo : Handle row major
int byte = typeTensorB.getElementTypeBitWidth() / 8;
int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
/// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
/// Currently, it only handles column-major.
///
/// It moves the pointer like below for [128][64] size:
/// descB ---> +--+--+--+--+--+--+--+--+
/// |↓ | | | | | | | |
/// |↓ | | | | | | | |
/// |↓ | | | | | | | |
/// |↓ | | | | | | | |
/// +--+--+--+--+--+--+--+--+
///
Value iterateDescriptorB(Value desc, int i, int j, int k) {
MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
Type elemB = matrixTypeB.getElementType();
int byte = elemB.getIntOrFloatBitWidth() / 8;
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
incrementVal = incrementVal >> exclude4LSB;
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
};
b.create<NVVM::WgmmaFenceAlignedOp>();
SmallVector<Value> wgmmaResults;
for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
Value matrixC = adaptor.getMatrixC()[iterM];
Value matrixD = op.getMatrixD()[iterM];
Type structType = getTypeConverter()->convertType(matrixD.getType());
LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
<< (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
<< ":" << wgmmaShapeN << "] += \n");
for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
LLVM_DEBUG(DBGS() << "\t wgmma."
<< "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
<< wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
<< ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
<< (iterK * wgmmaShapeK) << ":"
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
<< " B[" << (iterK * wgmmaShapeK) << ":"
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
<< ":" << wgmmaShapeN << "])\n");
matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
structType, matrixC, descA, descB);
}
wgmmaResults.push_back(matrixC);
}
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
ValueRange myres(wgmmaResults);
rewriter.replaceOp(op, myres);
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
LLVM_DEBUG(DBGS() << "\t wgmma."
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
<< "(A[" << (iterationM * wgmmaM) << ":"
<< (iterationM * wgmmaM) + wgmmaM << "]["
<< (iterationK * wgmmaK) << ":"
<< (iterationK * wgmmaK + wgmmaK) << "] * "
<< " B[" << (iterationK * wgmmaK) << ":"
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
<< wgmmaN << "])\n");
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
NVVM::MMAShapeAttr shape = generateWgmmaShape();
NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
auto overflow = NVVM::MMAIntOverflowAttr::get(
op->getContext(), NVVM::MMAIntOverflow::wrapped);
Type resultStructType = typeConverter.convertType(matrixD.getType());
return b.create<NVVM::WgmmaMmaAsyncOp>(
resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
}
/// Generates multiple wgmma instructions to complete the given GEMM shape
SmallVector<Value> generateWgmmaGroup() {
SmallVector<Value> wgmmaResults;
// Perform GEMM
for (int i = 0; i < iterationM; ++i) {
Value matrixC = adaptor.getMatrixC()[i];
Value matrixD = op.getMatrixD()[i];
for (int j = 0; j < iterationN; ++j)
for (int k = 0; k < iterationK; ++k)
matrixC = generateWgmma(i, j, k, matrixC, matrixD);
wgmmaResults.push_back(matrixC);
}
return wgmmaResults;
}
public:
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
: op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
// Find the entire GEMM Shape
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
<< "] += A[" << totalM << "][" << totalK << "] * B["
<< totalK << "][" << totalN << "] ---===\n");
// Find the shape for one wgmma instruction
findWgmmaShape(
totalM, totalN,
op.getDescriptorA().getType().getTensor().getElementType());
// Iterations counts to complete the given shape with wgmma shape
iterationM = totalM / wgmmaM;
iterationN = totalN / wgmmaN;
iterationK = totalK / wgmmaK;
}
/// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
/// includes generating a fence Op (WgmmaFenceAlignedOp) before the
/// instructions and group synchronization, as well as waiting
/// (WgmmaGroupSyncAlignedOp) for group synchronization
/// (WgmmaWaitGroupSyncOp) after the instructions.
SmallVector<Value> generateWarpgroupMma() {
b.create<NVVM::WgmmaFenceAlignedOp>();
SmallVector<Value> wgmmaResults = generateWgmmaGroup();
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
return wgmmaResults;
}
};
LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
// Step 1. Build a helper class
WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
// Step 2. Get the entire GEMM Shape
SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
// Step 3. Replace fragmented result struct with the op results
rewriter.replaceOp(op, wgmmaResults);
return success();
}
};