diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index b0df2feae16b..cce525dcdcbe 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -39,7 +39,7 @@ namespace mlir { using namespace mlir; -/// Number of bits that needs to excluded when building matrix descriptor for +/// Number of bits that needs to be excluded when building matrix descriptor for /// wgmma operations. constexpr int exclude4LSB = 4; @@ -1160,137 +1160,276 @@ struct NVGPUWarpgroupMmaOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType, - int &wgmmaShapeM, int &wgmmaShapeN, - int &wgmmaShapeK) const { - wgmmaShapeM = 64; - wgmmaShapeN = sizeN; - if (inputElemType.isTF32()) { - wgmmaShapeK = 8; - } else if (inputElemType.isF16() || inputElemType.isBF16()) { - wgmmaShapeK = 16; - } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() || - inputElemType.isInteger(16)) { - wgmmaShapeK = 32; - } else if (inputElemType.isInteger(1)) { - wgmmaShapeK = 256; - } else { - llvm_unreachable("msg: not supported K shape"); - } - LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM - << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK - << "]\n"); - return success(); - } + /// This is a helper class to generate required NVVM Ops for warp-group level + /// matrix multiplication. + /// When the given GEMM shape is larger than the shape of + /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp + /// Op(s), group and execute them asynchronously. The class also handles + /// waiting for completion and iterates through WarpgroupMatrixDescriptor to + /// create descriptors for each instruction. + /// + /// For example this is the case when the shape of GEMM is 128x128x128 + /// + /// nvvm.wgmma.fence.aligned + /// + /// nvvm.wgmma.mma.async descA, descB + /// iterate(descA, descB) + /// nvvm.wgmma.mma.async descA, descB + /// [6x times more] + /// + /// nvvm.wgmma.group.sync.aligned + /// nvvm.wgmma.wait.group.sync [groupId] + /// + class WarpgroupGemm { + nvgpu::WarpgroupMmaOp op; + ImplicitLocOpBuilder b; + OpAdaptor adaptor; + const LLVMTypeConverter &typeConverter; - Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k, - Type resultStructType, Value inout, - Value descriptorA, Value descriptorB) const { - MLIRContext *ctx = b.getContext(); - auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k); - auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one); - auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one); - auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row); - auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col); - // todo: handle other input and output types - auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16); - auto overflow = - NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped); - Value res = b.create( - resultStructType, inout, descriptorA, descriptorB, shape, itype, itype, - scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); - return res; - } + // Entire shape of the given Op + int64_t totalM, totalN, totalK; - LogicalResult - matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - ImplicitLocOpBuilder b(op->getLoc(), rewriter); - int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0); - int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1); - int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1); + // Shape of one wgmma instruction + int wgmmaM = 0, wgmmaN = 0, wgmmaK = 0; - LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A[" - << sizeM << "][" << sizeK << "] * B[" << sizeK << "][" - << sizeN << "] ---===\n"); + // Iteration counts for GEMM + int iterationM = 0, iterationN = 0, iterationK = 0; - int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK; - if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM, - wgmmaShapeN, wgmmaShapeK))) { - return failure(); + /// The function returns the shape of wgmma instruction that is defined in + /// PTX programming guide. + /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape + void findWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType) { + wgmmaM = 64; + wgmmaN = sizeN; + if (inputElemType.isTF32()) { + wgmmaK = 8; + } else if (inputElemType.isF16() || inputElemType.isBF16()) { + wgmmaK = 16; + } else if (inputElemType.isFloat8E4M3FN() || + inputElemType.isFloat8E5M2() || inputElemType.isInteger(16)) { + wgmmaK = 32; + } else if (inputElemType.isInteger(1)) { + wgmmaK = 256; + } else { + llvm_unreachable("msg: not supported K shape"); + } + LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM + << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n"); } - Value descriptorA = adaptor.getDescriptorA(); - Value descriptorB = adaptor.getDescriptorB(); + /// Generates WGMMATypesAttr from MLIR Type + NVVM::WGMMATypesAttr generateWgmmaType(Type type) const { + auto getWgmmaType = [](Type elemType) { + if (elemType.isF32() || elemType.isTF32()) + return NVVM::WGMMATypes::tf32; + if (elemType.isF16()) + return NVVM::WGMMATypes::f16; + if (elemType.isBF16()) + return NVVM::WGMMATypes::bf16; + if (elemType.isFloat8E4M3FN()) + return NVVM::WGMMATypes::e4m3; + if (elemType.isFloat8E5M2()) + return NVVM::WGMMATypes::e5m2; + if (elemType.isInteger(1)) + return NVVM::WGMMATypes::b1; + if (elemType.isInteger(8)) + return NVVM::WGMMATypes::s8; + if (elemType.isUnsignedInteger(8)) + return NVVM::WGMMATypes::u8; + llvm_unreachable("unsupported type"); + }; + return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type)); + } - // Generate wgmma group - MemRefType typeTensorA = op.getDescriptorA().getType().getTensor(); - MemRefType typeTensorB = op.getDescriptorB().getType().getTensor(); + /// Generates layout attribute for the input matrix for wgmma instruction + NVVM::MMALayoutAttr + generateWgmmaLayout(std::optional transpose) const { + if (transpose.value_or(false)) + return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::col); + return NVVM::MMALayoutAttr::get(op->getContext(), NVVM::MMALayout::row); + } - auto makeAdd = [&](Value lhs, Value rhs) -> Value { + /// Generates shape attribute for wgmma instruction + NVVM::MMAShapeAttr generateWgmmaShape() const { + return NVVM::MMAShapeAttr::get(op->getContext(), wgmmaM, wgmmaN, wgmmaK); + } + + /// Generates scale attributes of output matrix for wgmma instruction + NVVM::WGMMAScaleOutAttr generateScaleOut() const { + return NVVM::WGMMAScaleOutAttr::get(op->getContext(), + NVVM::WGMMAScaleOut::one); + } + /// Generates scale attributes of input matrix for wgmma instruction + NVVM::WGMMAScaleInAttr generateScaleIn() const { + return NVVM::WGMMAScaleInAttr::get(op->getContext(), + NVVM::WGMMAScaleIn::one); + } + + /// Basic function to generate Add + Value makeAdd(Value lhs, Value rhs) { return b.create(lhs.getType(), lhs, rhs); }; - auto iterateDescA = [&](Value desc, int iterM, int iterN, - int iterK) -> Value { - // todo : Handle column major - int byte = typeTensorA.getElementTypeBitWidth() / 8; - int tileShapeA = typeTensorA.getDimSize(1); - int incrementVal = - ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte; + /// Moves the descriptor pointer of matrix-A for the next wgmma instruction. + /// Currently, it only handles row-major. + /// + /// It moves the pointer like below for [128][64] size: + /// +2 +4 +6 + /// ↓ ↓ ↓ + /// descA ---> +--+--+--+--+ + /// |->|->|->|->| + /// | | | | | + /// | | | | | + /// | | | | | + /// descA+512---> +-----------+ + /// | | | | | + /// | | | | | + /// | | | | | + /// | | | | | + /// +-----------+ + /// + Value iterateDescriptorA(Value desc, int i, int j, int k) { + MemRefType matrixTypeA = op.getDescriptorA().getType().getTensor(); + Type elemA = matrixTypeA.getElementType(); + int byte = elemA.getIntOrFloatBitWidth() / 8; + int tileShapeA = matrixTypeA.getDimSize(1); + int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: " - << iterK << "] [wgmma descriptors] Descriptor A + " + LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k + << "] [wgmma descriptors] Descriptor A + " << incrementVal << " | \t "); if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); - }; + } - auto iterateDescB = [&](Value desc, int iterM, int iterN, - int iterK) -> Value { - // todo : Handle row major - int byte = typeTensorB.getElementTypeBitWidth() / 8; - int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte; + /// Moves the descriptor pointer of matrix-B for the next wgmma instruction. + /// Currently, it only handles column-major. + /// + /// It moves the pointer like below for [128][64] size: + /// descB ---> +--+--+--+--+--+--+--+--+ + /// |↓ | | | | | | | | + /// |↓ | | | | | | | | + /// |↓ | | | | | | | | + /// |↓ | | | | | | | | + /// +--+--+--+--+--+--+--+--+ + /// + Value iterateDescriptorB(Value desc, int i, int j, int k) { + MemRefType matrixTypeB = op.getDescriptorB().getType().getTensor(); + Type elemB = matrixTypeB.getElementType(); + int byte = elemB.getIntOrFloatBitWidth() / 8; + int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; incrementVal = incrementVal >> exclude4LSB; LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); - }; - - b.create(); - - SmallVector wgmmaResults; - for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) { - Value matrixC = adaptor.getMatrixC()[iterM]; - Value matrixD = op.getMatrixD()[iterM]; - Type structType = getTypeConverter()->convertType(matrixD.getType()); - LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":" - << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0 - << ":" << wgmmaShapeN << "] += \n"); - for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) { - Value descA = iterateDescA(descriptorA, iterM, 0, iterK); - Value descB = iterateDescB(descriptorB, iterM, 0, iterK); - LLVM_DEBUG(DBGS() << "\t wgmma." - << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k" - << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM) - << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" - << (iterK * wgmmaShapeK) << ":" - << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * " - << " B[" << (iterK * wgmmaShapeK) << ":" - << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0 - << ":" << wgmmaShapeN << "])\n"); - matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK, - structType, matrixC, descA, descB); - } - wgmmaResults.push_back(matrixC); } - b.create(); - b.create(op.getWaitGroup()); - ValueRange myres(wgmmaResults); - rewriter.replaceOp(op, myres); + /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix + /// descriptors and arranges them based on induction variables: i, j, and k. + Value generateWgmma(int i, int j, int k, Value matrixC, Value matrixD) { + LLVM_DEBUG(DBGS() << "\t wgmma." + << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK + << "(A[" << (iterationM * wgmmaM) << ":" + << (iterationM * wgmmaM) + wgmmaM << "][" + << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "] * " + << " B[" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" + << wgmmaN << "])\n"); + + Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); + Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); + + Type elemA = op.getDescriptorA().getType().getTensor().getElementType(); + NVVM::WGMMATypesAttr itypeA = generateWgmmaType(elemA); + + Type elemB = op.getDescriptorB().getType().getTensor().getElementType(); + NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB); + + NVVM::MMAShapeAttr shape = generateWgmmaShape(); + NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut(); + NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn(); + NVVM::MMALayoutAttr layoutA = generateWgmmaLayout(op.getTransposeA()); + NVVM::MMALayoutAttr layoutB = generateWgmmaLayout(op.getTransposeB()); + + auto overflow = NVVM::MMAIntOverflowAttr::get( + op->getContext(), NVVM::MMAIntOverflow::wrapped); + + Type resultStructType = typeConverter.convertType(matrixD.getType()); + + return b.create( + resultStructType, matrixC, descriptorA, descriptorB, shape, itypeA, + itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); + } + + /// Generates multiple wgmma instructions to complete the given GEMM shape + SmallVector generateWgmmaGroup() { + SmallVector wgmmaResults; + + // Perform GEMM + for (int i = 0; i < iterationM; ++i) { + Value matrixC = adaptor.getMatrixC()[i]; + Value matrixD = op.getMatrixD()[i]; + for (int j = 0; j < iterationN; ++j) + for (int k = 0; k < iterationK; ++k) + matrixC = generateWgmma(i, j, k, matrixC, matrixD); + wgmmaResults.push_back(matrixC); + } + + return wgmmaResults; + } + + public: + WarpgroupGemm(nvgpu::WarpgroupMmaOp op, ImplicitLocOpBuilder &b, + OpAdaptor adaptor, const LLVMTypeConverter &typeConverter) + : op(op), b(b), adaptor(adaptor), typeConverter(typeConverter) { + // Find the entire GEMM Shape + totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); + totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); + totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); + LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN + << "] += A[" << totalM << "][" << totalK << "] * B[" + << totalK << "][" << totalN << "] ---===\n"); + + // Find the shape for one wgmma instruction + findWgmmaShape( + totalM, totalN, + op.getDescriptorA().getType().getTensor().getElementType()); + + // Iterations counts to complete the given shape with wgmma shape + iterationM = totalM / wgmmaM; + iterationN = totalN / wgmmaN; + iterationK = totalK / wgmmaK; + } + + /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It + /// includes generating a fence Op (WgmmaFenceAlignedOp) before the + /// instructions and group synchronization, as well as waiting + /// (WgmmaGroupSyncAlignedOp) for group synchronization + /// (WgmmaWaitGroupSyncOp) after the instructions. + SmallVector generateWarpgroupMma() { + b.create(); + SmallVector wgmmaResults = generateWgmmaGroup(); + b.create(); + b.create(op.getWaitGroup()); + return wgmmaResults; + } + }; + + LogicalResult + matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + // Step 1. Build a helper class + WarpgroupGemm warpgroupGemm(op, b, adaptor, *this->getTypeConverter()); + + // Step 2. Get the entire GEMM Shape + SmallVector wgmmaResults = warpgroupGemm.generateWarpgroupMma(); + + // Step 3. Replace fragmented result struct with the op results + rewriter.replaceOp(op, wgmmaResults); return success(); } };