mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-30 09:01:19 +00:00
[mlir][nvgpu] Improve nvgpu->nvvm transformation of warpgroup.mma
Op (NFC) (#67325)
This PR introduces substantial improvements to the readability and maintainability of the `nvgpu.warpgroup.mma` Op transformation from nvgpu->nvvm. This transformation plays a crucial role in GEMM and manages complex operations such as generating multiple wgmma ops and iterating their descriptors. The prior code lacked clarity, but this PR addresses that issue effectively. **PR does followings:** **Introduces a helper class:** `WarpgroupGemm` class encapsulates the necessary functionality, making the code cleaner and more understandable. **Detailed Documentation:** Each function within the helper class is thoroughly documented to provide clear insights into its purpose and functionality.
This commit is contained in:
parent
7eb2b99f16
commit
b74cfc139a
@ -39,7 +39,7 @@ namespace mlir {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Number of bits that needs to excluded when building matrix descriptor for
|
||||
/// Number of bits that needs to be excluded when building matrix descriptor for
|
||||
/// wgmma operations.
|
||||
constexpr int exclude4LSB = 4;
|
||||
|
||||
@ -1160,137 +1160,276 @@ struct NVGPUWarpgroupMmaOpLowering
|
||||
: public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp> {
|
||||
using ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType,
|
||||
int &wgmmaShapeM, int &wgmmaShapeN,
|
||||
int &wgmmaShapeK) const {
|
||||
wgmmaShapeM = 64;
|
||||
wgmmaShapeN = sizeN;
|
||||
if (inputElemType.isTF32()) {
|
||||
wgmmaShapeK = 8;
|
||||
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
|
||||
wgmmaShapeK = 16;
|
||||
} else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() ||
|
||||
inputElemType.isInteger(16)) {
|
||||
wgmmaShapeK = 32;
|
||||
} else if (inputElemType.isInteger(1)) {
|
||||
wgmmaShapeK = 256;
|
||||
} else {
|
||||
llvm_unreachable("msg: not supported K shape");
|
||||
}
|
||||
LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM
|
||||
<< ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK
|
||||
<< "]\n");
|
||||
return success();
|
||||
}
|
||||
/// This is a helper class to generate required NVVM Ops for warp-group level
|
||||
/// matrix multiplication.
|
||||
/// When the given GEMM shape is larger than the shape of
|
||||
/// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
|
||||
/// Op(s), group and execute them asynchronously. The class also handles
|
||||
/// waiting for completion and iterates through WarpgroupMatrixDescriptor to
|
||||
/// create descriptors for each instruction.
|
||||
///
|
||||
/// For example this is the case when the shape of GEMM is 128x128x128
|
||||
///
|
||||
/// nvvm.wgmma.fence.aligned
|
||||
///
|
||||
/// nvvm.wgmma.mma.async descA, descB
|
||||
/// iterate(descA, descB)
|
||||
/// nvvm.wgmma.mma.async descA, descB
|
||||
/// [6x times more]
|
||||
///
|
||||
/// nvvm.wgmma.group.sync.aligned
|
||||
/// nvvm.wgmma.wait.group.sync [groupId]
|
||||
///
|
||||
class WarpgroupGemm {
|
||||
nvgpu::WarpgroupMmaOp op;
|
||||
ImplicitLocOpBuilder b;
|
||||
OpAdaptor adaptor;
|
||||
const LLVMTypeConverter &typeConverter;
|
||||
|
||||
Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
|
||||
Type resultStructType, Value inout,
|
||||
Value descriptorA, Value descriptorB) const {
|
||||
MLIRContext *ctx = b.getContext();
|
||||
auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
|
||||
auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
|
||||
auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
|
||||
auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row);
|
||||
auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col);
|
||||
// todo: handle other input and output types
|
||||
auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
|
||||
auto overflow =
|
||||
NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
|
||||
Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
|
||||
resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
|
||||
scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
|
||||
return res;
|
||||
}
|
||||
// Entire shape of the given Op
|
||||
int64_t totalM, totalN, totalK;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
|
||||
int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
|
||||
int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
|
||||
// Shape of one wgmma instruction
|
||||
int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0;
|
||||
|
||||
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A["
|
||||
<< sizeM << "][" << sizeK << "] * B[" << sizeK << "]["
|
||||
<< sizeN << "] ---===\n");
|
||||
// Iteration counts for GEMM
|
||||
int iterationM = 0, iterationN = 0, iterationK = 0;
|
||||
|
||||
int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK;
|
||||
if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM,
|
||||
wgmmaShapeN, wgmmaShapeK))) {
|
||||
return failure();
|
||||
/// The function returns the shape of wgmma instruction that is defined in
|
||||
/// PTX programming guide.
|
||||
/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
|
||||
void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) {
|
||||
wgmmaM = 64;
|
||||
wgmmaN = sizeN;
|
||||
if (inputElemType.isTF32()) {
|
||||
wgmmaK = 8;
|
||||
} else if (inputElemType.isF16() || inputElemType.isBF16()) {
|
||||
wgmmaK = 16;
|
||||
} else if (inputElemType.isFloat8E4M3FN() ||
|
||||
inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) {
|
||||
wgmmaK = 32;
|
||||
} else if (inputElemType.isInteger(1)) {
|
||||
wgmmaK = 256;
|
||||
} else {
|
||||
llvm_unreachable("msg: not supported K shape");
|
||||
}
|
||||
LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
|
||||
<< ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
|
||||
}
|
||||
|
||||
Value descriptorA = adaptor.getDescriptorA();
|
||||
Value descriptorB = adaptor.getDescriptorB();
|
||||
/// Generates WGMMATypesAttr from MLIR Type
|
||||
NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
|
||||
auto getWgmmaType = [](Type elemType) {
|
||||
if (elemType.isF32() || elemType.isTF32())
|
||||
return NVVM::WGMMATypes::tf32;
|
||||
if (elemType.isF16())
|
||||
return NVVM::WGMMATypes::f16;
|
||||
if (elemType.isBF16())
|
||||
return NVVM::WGMMATypes::bf16;
|
||||
if (elemType.isFloat8E4M3FN())
|
||||
return NVVM::WGMMATypes::e4m3;
|
||||
if (elemType.isFloat8E5M2())
|
||||
return NVVM::WGMMATypes::e5m2;
|
||||
if (elemType.isInteger(1))
|
||||
return NVVM::WGMMATypes::b1;
|
||||
if (elemType.isInteger(8))
|
||||
return NVVM::WGMMATypes::s8;
|
||||
if (elemType.isUnsignedInteger(8))
|
||||
return NVVM::WGMMATypes::u8;
|
||||
llvm_unreachable("unsupported type");
|
||||
};
|
||||
return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
|
||||
}
|
||||
|
||||
// Generate wgmma group
|
||||
MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
|
||||
MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
|
||||
/// Generates layout attribute for the input matrix for wgmma instruction
|
||||
NVVM::MMALayoutAttr
|
||||
generateWgmmaLayout(std::optional<bool> transpose) const {
|
||||
if (transpose.value_or(false))
|
||||
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col);
|
||||
return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row);
|
||||
}
|
||||
|
||||
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
|
||||
/// Generates shape attribute for wgmma instruction
|
||||
NVVM::MMAShapeAttr generateWgmmaShape() const {
|
||||
return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK);
|
||||
}
|
||||
|
||||
/// Generates scale attributes of output matrix for wgmma instruction
|
||||
NVVM::WGMMAScaleOutAttr generateScaleOut() const {
|
||||
return NVVM::WGMMAScaleOutAttr::get(op->getContext(),
|
||||
NVVM::WGMMAScaleOut::one);
|
||||
}
|
||||
/// Generates scale attributes of input matrix for wgmma instruction
|
||||
NVVM::WGMMAScaleInAttr generateScaleIn() const {
|
||||
return NVVM::WGMMAScaleInAttr::get(op->getContext(),
|
||||
NVVM::WGMMAScaleIn::one);
|
||||
}
|
||||
|
||||
/// Basic function to generate Add
|
||||
Value makeAdd(Value lhs, Value rhs) {
|
||||
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
|
||||
};
|
||||
|
||||
auto iterateDescA = [&](Value desc, int iterM, int iterN,
|
||||
int iterK) -> Value {
|
||||
// todo : Handle column major
|
||||
int byte = typeTensorA.getElementTypeBitWidth() / 8;
|
||||
int tileShapeA = typeTensorA.getDimSize(1);
|
||||
int incrementVal =
|
||||
((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte;
|
||||
/// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
|
||||
/// Currently, it only handles row-major.
|
||||
///
|
||||
/// It moves the pointer like below for [128][64] size:
|
||||
/// +2 +4 +6
|
||||
/// ↓ ↓ ↓
|
||||
/// descA ---> +--+--+--+--+
|
||||
/// |->|->|->|->|
|
||||
/// | | | | |
|
||||
/// | | | | |
|
||||
/// | | | | |
|
||||
/// descA+512---> +-----------+
|
||||
/// | | | | |
|
||||
/// | | | | |
|
||||
/// | | | | |
|
||||
/// | | | | |
|
||||
/// +-----------+
|
||||
///
|
||||
Value iterateDescriptorA(Value desc, int i, int j, int k) {
|
||||
MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor();
|
||||
Type elemA = matrixTypeA.getElementType();
|
||||
int byte = elemA.getIntOrFloatBitWidth() / 8;
|
||||
int tileShapeA = matrixTypeA.getDimSize(1);
|
||||
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
|
||||
incrementVal = incrementVal >> exclude4LSB;
|
||||
LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: "
|
||||
<< iterK << "] [wgmma descriptors] Descriptor A + "
|
||||
LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
|
||||
<< "] [wgmma descriptors] Descriptor A + "
|
||||
<< incrementVal << " | \t ");
|
||||
if (!incrementVal)
|
||||
return desc;
|
||||
return makeAdd(desc, makeI64Const(b, incrementVal));
|
||||
};
|
||||
}
|
||||
|
||||
auto iterateDescB = [&](Value desc, int iterM, int iterN,
|
||||
int iterK) -> Value {
|
||||
// todo : Handle row major
|
||||
int byte = typeTensorB.getElementTypeBitWidth() / 8;
|
||||
int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte;
|
||||
/// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
|
||||
/// Currently, it only handles column-major.
|
||||
///
|
||||
/// It moves the pointer like below for [128][64] size:
|
||||
/// descB ---> +--+--+--+--+--+--+--+--+
|
||||
/// |↓ | | | | | | | |
|
||||
/// |↓ | | | | | | | |
|
||||
/// |↓ | | | | | | | |
|
||||
/// |↓ | | | | | | | |
|
||||
/// +--+--+--+--+--+--+--+--+
|
||||
///
|
||||
Value iterateDescriptorB(Value desc, int i, int j, int k) {
|
||||
MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor();
|
||||
Type elemB = matrixTypeB.getElementType();
|
||||
int byte = elemB.getIntOrFloatBitWidth() / 8;
|
||||
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
|
||||
incrementVal = incrementVal >> exclude4LSB;
|
||||
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
|
||||
if (!incrementVal)
|
||||
return desc;
|
||||
return makeAdd(desc, makeI64Const(b, incrementVal));
|
||||
};
|
||||
|
||||
b.create<NVVM::WgmmaFenceAlignedOp>();
|
||||
|
||||
SmallVector<Value> wgmmaResults;
|
||||
for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
|
||||
Value matrixC = adaptor.getMatrixC()[iterM];
|
||||
Value matrixD = op.getMatrixD()[iterM];
|
||||
Type structType = getTypeConverter()->convertType(matrixD.getType());
|
||||
LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":"
|
||||
<< (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0
|
||||
<< ":" << wgmmaShapeN << "] += \n");
|
||||
for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) {
|
||||
Value descA = iterateDescA(descriptorA, iterM, 0, iterK);
|
||||
Value descB = iterateDescB(descriptorB, iterM, 0, iterK);
|
||||
LLVM_DEBUG(DBGS() << "\t wgmma."
|
||||
<< "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k"
|
||||
<< wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM)
|
||||
<< ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "]["
|
||||
<< (iterK * wgmmaShapeK) << ":"
|
||||
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "] * "
|
||||
<< " B[" << (iterK * wgmmaShapeK) << ":"
|
||||
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
|
||||
<< ":" << wgmmaShapeN << "])\n");
|
||||
matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
|
||||
structType, matrixC, descA, descB);
|
||||
}
|
||||
wgmmaResults.push_back(matrixC);
|
||||
}
|
||||
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
|
||||
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
|
||||
|
||||
ValueRange myres(wgmmaResults);
|
||||
rewriter.replaceOp(op, myres);
|
||||
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
|
||||
/// descriptors and arranges them based on induction variables: i, j, and k.
|
||||
Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) {
|
||||
LLVM_DEBUG(DBGS() << "\t wgmma."
|
||||
<< "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
|
||||
<< "(A[" << (iterationM * wgmmaM) << ":"
|
||||
<< (iterationM * wgmmaM) + wgmmaM << "]["
|
||||
<< (iterationK * wgmmaK) << ":"
|
||||
<< (iterationK * wgmmaK + wgmmaK) << "] * "
|
||||
<< " B[" << (iterationK * wgmmaK) << ":"
|
||||
<< (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
|
||||
<< wgmmaN << "])\n");
|
||||
|
||||
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
|
||||
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
|
||||
|
||||
Type elemA = op.getDescriptorA().getType().getTensor().getElementType();
|
||||
NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA);
|
||||
|
||||
Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
|
||||
NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
|
||||
|
||||
NVVM::MMAShapeAttr shape = generateWgmmaShape();
|
||||
NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
|
||||
NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
|
||||
NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA());
|
||||
NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB());
|
||||
|
||||
auto overflow = NVVM::MMAIntOverflowAttr::get(
|
||||
op->getContext(), NVVM::MMAIntOverflow::wrapped);
|
||||
|
||||
Type resultStructType = typeConverter.convertType(matrixD.getType());
|
||||
|
||||
return b.create<NVVM::WgmmaMmaAsyncOp>(
|
||||
resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA,
|
||||
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
|
||||
}
|
||||
|
||||
/// Generates multiple wgmma instructions to complete the given GEMM shape
|
||||
SmallVector<Value> generateWgmmaGroup() {
|
||||
SmallVector<Value> wgmmaResults;
|
||||
|
||||
// Perform GEMM
|
||||
for (int i = 0; i < iterationM; ++i) {
|
||||
Value matrixC = adaptor.getMatrixC()[i];
|
||||
Value matrixD = op.getMatrixD()[i];
|
||||
for (int j = 0; j < iterationN; ++j)
|
||||
for (int k = 0; k < iterationK; ++k)
|
||||
matrixC = generateWgmma(i, j, k, matrixC, matrixD);
|
||||
wgmmaResults.push_back(matrixC);
|
||||
}
|
||||
|
||||
return wgmmaResults;
|
||||
}
|
||||
|
||||
public:
|
||||
WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b,
|
||||
OpAdaptor adaptor, const LLVMTypeConverter &typeConverter)
|
||||
: op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) {
|
||||
// Find the entire GEMM Shape
|
||||
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
|
||||
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
|
||||
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
|
||||
LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
|
||||
<< "] += A[" << totalM << "][" << totalK << "] * B["
|
||||
<< totalK << "][" << totalN << "] ---===\n");
|
||||
|
||||
// Find the shape for one wgmma instruction
|
||||
findWgmmaShape(
|
||||
totalM, totalN,
|
||||
op.getDescriptorA().getType().getTensor().getElementType());
|
||||
|
||||
// Iterations counts to complete the given shape with wgmma shape
|
||||
iterationM = totalM / wgmmaM;
|
||||
iterationN = totalN / wgmmaN;
|
||||
iterationK = totalK / wgmmaK;
|
||||
}
|
||||
|
||||
/// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
|
||||
/// includes generating a fence Op (WgmmaFenceAlignedOp) before the
|
||||
/// instructions and group synchronization, as well as waiting
|
||||
/// (WgmmaGroupSyncAlignedOp) for group synchronization
|
||||
/// (WgmmaWaitGroupSyncOp) after the instructions.
|
||||
SmallVector<Value> generateWarpgroupMma() {
|
||||
b.create<NVVM::WgmmaFenceAlignedOp>();
|
||||
SmallVector<Value> wgmmaResults = generateWgmmaGroup();
|
||||
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
|
||||
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
|
||||
return wgmmaResults;
|
||||
}
|
||||
};
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
// Step 1. Build a helper class
|
||||
WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter());
|
||||
|
||||
// Step 2. Get the entire GEMM Shape
|
||||
SmallVector<Value> wgmmaResults = warpgroupGemm.generateWarpgroupMma();
|
||||
|
||||
// Step 3. Replace fragmented result struct with the op results
|
||||
rewriter.replaceOp(op, wgmmaResults);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user