mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-14 11:39:35 +00:00
[mlir][nvgpu] Use ImplicitLocOpBuilder in nvgpu-to-nvvm pass (NFC) (#67993)
For the sake of better readability, this PR uses `ImplicitLocOpBuilder` instead of rewriter+loc
This commit is contained in:
parent
2b5a6d774c
commit
ee49cda7d4
@ -19,6 +19,7 @@
|
||||
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
@ -44,13 +45,12 @@ constexpr int exclude4LSB = 4;
|
||||
|
||||
/// GPU has 32 bit registers, this function truncates values when larger width
|
||||
/// is not needed.
|
||||
static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value value) {
|
||||
static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
|
||||
Type type = value.getType();
|
||||
assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
|
||||
if (type.getIntOrFloatBitWidth() <= 32)
|
||||
return value;
|
||||
return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), value);
|
||||
return b.create<LLVM::TruncOp>(b.getI32Type(), value);
|
||||
}
|
||||
|
||||
/// Returns the type for the intrinsic given the vectorResultType of the
|
||||
@ -170,22 +170,21 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
|
||||
/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
|
||||
/// scalars of certain types. This function helps unpack the `vector` arguments
|
||||
/// and cast them to the types expected by `nvvm.mma.sync`.
|
||||
static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
|
||||
Location loc, Value operand,
|
||||
static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
|
||||
Value operand,
|
||||
NVVM::MMATypes operandPtxType) {
|
||||
SmallVector<Value> result;
|
||||
Type i32Ty = rewriter.getI32Type();
|
||||
Type f64Ty = rewriter.getF64Type();
|
||||
Type f32Ty = rewriter.getF32Type();
|
||||
Type i8Ty = rewriter.getI8Type();
|
||||
Type i4Ty = rewriter.getIntegerType(4);
|
||||
Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
|
||||
Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
|
||||
Type i32Ty = b.getI32Type();
|
||||
Type f64Ty = b.getF64Type();
|
||||
Type f32Ty = b.getF32Type();
|
||||
Type i64Ty = b.getI64Type();
|
||||
Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
|
||||
Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
|
||||
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
|
||||
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
|
||||
|
||||
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
|
||||
Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
|
||||
Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
|
||||
|
||||
// For 4xi8 vectors, the intrinsic expects these to be provided as i32
|
||||
// scalar types.
|
||||
@ -193,8 +192,7 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
|
||||
arrayTy.getElementType() == i4x8Ty ||
|
||||
(arrayTy.getElementType() == f32x1Ty &&
|
||||
operandPtxType == NVVM::MMATypes::tf32)) {
|
||||
result.push_back(
|
||||
rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
|
||||
result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -207,10 +205,9 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
|
||||
innerArrayTy.getElementType() == f32Ty)) {
|
||||
for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
|
||||
idx < innerSize; idx++) {
|
||||
result.push_back(rewriter.create<LLVM::ExtractElementOp>(
|
||||
loc, toUse,
|
||||
rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
|
||||
result.push_back(b.create<LLVM::ExtractElementOp>(
|
||||
toUse,
|
||||
b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@ -256,7 +253,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
|
||||
matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MLIRContext *ctx = getContext();
|
||||
Location loc = op->getLoc();
|
||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
|
||||
// The result type of ldmatrix will always be a struct of 32bit integer
|
||||
// registers if more than one 32bit value is returned. Otherwise, the result
|
||||
@ -283,10 +280,10 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
|
||||
|
||||
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
|
||||
Value srcPtr =
|
||||
getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
|
||||
getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
|
||||
adaptor.getIndices(), rewriter);
|
||||
Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
|
||||
loc, ldMatrixResultType, srcPtr,
|
||||
Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
|
||||
ldMatrixResultType, srcPtr,
|
||||
/*num=*/op.getNumTiles(),
|
||||
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
|
||||
: NVVM::MMALayout::row);
|
||||
@ -296,15 +293,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
|
||||
// actual vector type (still of width 32b) and repack them into a result
|
||||
// struct.
|
||||
Type finalResultType = typeConverter->convertType(vectorResultType);
|
||||
Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
|
||||
Value result = b.create<LLVM::UndefOp>(finalResultType);
|
||||
for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
|
||||
Value i32Register =
|
||||
num32BitRegs > 1
|
||||
? rewriter.create<LLVM::ExtractValueOp>(loc, ldMatrixResult, i)
|
||||
: ldMatrixResult;
|
||||
Value casted =
|
||||
rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
|
||||
result = rewriter.create<LLVM::InsertValueOp>(loc, result, casted, i);
|
||||
num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
|
||||
: ldMatrixResult;
|
||||
Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
|
||||
result = b.create<LLVM::InsertValueOp>(result, casted, i);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
@ -335,7 +330,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
// Get the shapes of the MMAMatrix type being used. The shapes will
|
||||
// choose which intrinsic this op will be lowered to.
|
||||
VectorType aType = op.getMatrixA().getType();
|
||||
@ -368,17 +363,17 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
|
||||
overflow = NVVM::MMAIntOverflow::satfinite;
|
||||
|
||||
SmallVector<Value> matA =
|
||||
unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
|
||||
unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
|
||||
SmallVector<Value> matB =
|
||||
unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
|
||||
unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
|
||||
SmallVector<Value> matC =
|
||||
unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
|
||||
unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
|
||||
|
||||
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
|
||||
Type intrinsicResTy = inferIntrinsicResultType(
|
||||
typeConverter->convertType(op->getResultTypes()[0]));
|
||||
Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
|
||||
op.getLoc(), intrinsicResTy, matA, matB, matC,
|
||||
Value intrinsicResult = b.create<NVVM::MmaOp>(
|
||||
intrinsicResTy, matA, matB, matC,
|
||||
/*shape=*/gemmShape,
|
||||
/*b1Op=*/std::nullopt,
|
||||
/*intOverflow=*/overflow,
|
||||
@ -511,14 +506,14 @@ static std::string buildMmaSparseAsmString(
|
||||
/// Builds an inline assembly operation corresponding to the specified MMA
|
||||
/// sparse sync operation.
|
||||
static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
|
||||
Location loc, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
|
||||
ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
|
||||
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
|
||||
std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
|
||||
ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
|
||||
int64_t metadataSelector, const std::array<int64_t, 3> &shape,
|
||||
Type intrinsicResultType, ConversionPatternRewriter &rewriter) {
|
||||
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
|
||||
LLVM::AsmDialect::AD_ATT);
|
||||
Type intrinsicResultType) {
|
||||
auto asmDialectAttr =
|
||||
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
|
||||
|
||||
const unsigned matASize = unpackedAData.size();
|
||||
const unsigned matBSize = unpackedB.size();
|
||||
@ -536,15 +531,15 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
|
||||
llvm::append_range(asmVals, args);
|
||||
asmVals.push_back(indexData);
|
||||
|
||||
return rewriter.create<LLVM::InlineAsmOp>(loc,
|
||||
/*resultTypes=*/intrinsicResultType,
|
||||
/*operands=*/asmVals,
|
||||
/*asm_string=*/asmStr,
|
||||
/*constraints=*/constraintStr,
|
||||
/*has_side_effects=*/true,
|
||||
/*is_align_stack=*/false,
|
||||
/*asm_dialect=*/asmDialectAttr,
|
||||
/*operand_attrs=*/ArrayAttr());
|
||||
return b.create<LLVM::InlineAsmOp>(
|
||||
/*resultTypes=*/intrinsicResultType,
|
||||
/*operands=*/asmVals,
|
||||
/*asm_string=*/asmStr,
|
||||
/*constraints=*/constraintStr,
|
||||
/*has_side_effects=*/true,
|
||||
/*is_align_stack=*/false,
|
||||
/*asm_dialect=*/asmDialectAttr,
|
||||
/*operand_attrs=*/ArrayAttr());
|
||||
}
|
||||
|
||||
/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
|
||||
@ -555,7 +550,7 @@ struct NVGPUMmaSparseSyncLowering
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
// Get the shapes of the MMAMatrix type being used. The shapes will
|
||||
// choose which intrinsic this op will be lowered to.
|
||||
VectorType aType = op.getMatrixA().getType();
|
||||
@ -586,11 +581,11 @@ struct NVGPUMmaSparseSyncLowering
|
||||
overflow = NVVM::MMAIntOverflow::satfinite;
|
||||
|
||||
SmallVector<Value> matA =
|
||||
unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
|
||||
unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
|
||||
SmallVector<Value> matB =
|
||||
unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
|
||||
unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
|
||||
SmallVector<Value> matC =
|
||||
unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
|
||||
unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
|
||||
|
||||
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
|
||||
Type intrinsicResTy = inferIntrinsicResultType(
|
||||
@ -602,13 +597,13 @@ struct NVGPUMmaSparseSyncLowering
|
||||
LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
|
||||
return op->emitOpError() << "Expected metadata type to be LLVM "
|
||||
"VectorType of 2 i16 elements";
|
||||
sparseMetadata = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, rewriter.getI32Type(), sparseMetadata);
|
||||
sparseMetadata =
|
||||
b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
|
||||
|
||||
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
|
||||
loc, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
|
||||
b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
|
||||
matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
|
||||
intrinsicResTy, rewriter);
|
||||
intrinsicResTy);
|
||||
if (failed(intrinsicResult))
|
||||
return failure();
|
||||
|
||||
@ -629,10 +624,12 @@ struct NVGPUAsyncCopyLowering
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
||||
Location loc = op.getLoc();
|
||||
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
|
||||
Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
|
||||
adaptor.getDstIndices(), rewriter);
|
||||
Value dstPtr =
|
||||
getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
|
||||
adaptor.getDstIndices(), rewriter);
|
||||
auto i8Ty = IntegerType::get(op.getContext(), 8);
|
||||
FailureOr<unsigned> dstAddressSpace =
|
||||
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
|
||||
@ -642,7 +639,7 @@ struct NVGPUAsyncCopyLowering
|
||||
auto dstPointerType =
|
||||
getTypeConverter()->getPointerType(i8Ty, *dstAddressSpace);
|
||||
if (!getTypeConverter()->useOpaquePointers())
|
||||
dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
|
||||
dstPtr = b.create<LLVM::BitcastOp>(dstPointerType, dstPtr);
|
||||
|
||||
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
|
||||
FailureOr<unsigned> srcAddressSpace =
|
||||
@ -656,12 +653,11 @@ struct NVGPUAsyncCopyLowering
|
||||
auto srcPointerType =
|
||||
getTypeConverter()->getPointerType(i8Ty, *srcAddressSpace);
|
||||
if (!getTypeConverter()->useOpaquePointers())
|
||||
scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
|
||||
scrPtr = b.create<LLVM::BitcastOp>(srcPointerType, scrPtr);
|
||||
// Intrinsics takes a global pointer so we need an address space cast.
|
||||
auto srcPointerGlobalType = getTypeConverter()->getPointerType(
|
||||
i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
|
||||
scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
|
||||
scrPtr);
|
||||
scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
|
||||
int64_t dstElements = adaptor.getDstElements().getZExtValue();
|
||||
int64_t sizeInBytes =
|
||||
(dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
|
||||
@ -675,16 +671,14 @@ struct NVGPUAsyncCopyLowering
|
||||
// memory) of CpAsyncOp is read only for SrcElements number of elements.
|
||||
// The rest of the DstElements in the destination (shared memory) are
|
||||
// filled with zeros.
|
||||
Value c3I32 = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
|
||||
Value bitwidth = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, rewriter.getI32Type(),
|
||||
rewriter.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
|
||||
Value srcElementsI32 =
|
||||
rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcBytes);
|
||||
srcBytes = rewriter.create<LLVM::LShrOp>(
|
||||
loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32),
|
||||
c3I32);
|
||||
Value c3I32 =
|
||||
b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
|
||||
Value bitwidth = b.create<LLVM::ConstantOp>(
|
||||
b.getI32Type(),
|
||||
b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
|
||||
Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
|
||||
srcBytes = b.create<LLVM::LShrOp>(
|
||||
b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
|
||||
}
|
||||
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
|
||||
// 16 dst bytes.
|
||||
@ -693,15 +687,14 @@ struct NVGPUAsyncCopyLowering
|
||||
? NVVM::LoadCacheModifierKind::CG
|
||||
: NVVM::LoadCacheModifierKind::CA;
|
||||
|
||||
rewriter.create<NVVM::CpAsyncOp>(
|
||||
loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
|
||||
b.create<NVVM::CpAsyncOp>(
|
||||
dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
|
||||
NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
|
||||
srcBytes);
|
||||
|
||||
// Drop the result token.
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
op->getLoc(), IntegerType::get(op.getContext(), 32),
|
||||
rewriter.getI32IntegerAttr(0));
|
||||
Value zero = b.create<LLVM::ConstantOp>(
|
||||
IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
|
||||
rewriter.replaceOp(op, zero);
|
||||
return success();
|
||||
}
|
||||
@ -790,14 +783,14 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
|
||||
public:
|
||||
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
|
||||
/// Returns the base pointer of the mbarrier object.
|
||||
Value getMbarrierPtr(Operation *op, nvgpu::MBarrierGroupType mbarType,
|
||||
Value memrefDesc, Value mbarId,
|
||||
Value getMbarrierPtr(ImplicitLocOpBuilder &b,
|
||||
nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
|
||||
Value mbarId,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
MemRefType mbarrierMemrefType =
|
||||
nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
|
||||
return ConvertToLLVMPattern::getStridedElementPtr(
|
||||
op->getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
|
||||
return memrefDesc;
|
||||
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
@ -809,11 +802,12 @@ struct NVGPUMBarrierInitLowering
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
|
||||
rewriter.setInsertionPoint(op);
|
||||
Value barrier = getMbarrierPtr(op, mbarrierType, adaptor.getBarriers(),
|
||||
Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
|
||||
adaptor.getMbarId(), rewriter);
|
||||
Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
|
||||
Value count = truncToI32(b, adaptor.getCount());
|
||||
if (isMbarrierShared(mbarrierType)) {
|
||||
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
|
||||
count);
|
||||
@ -831,8 +825,9 @@ struct NVGPUMBarrierArriveLowering
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
Value barrier =
|
||||
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
adaptor.getMbarId(), rewriter);
|
||||
Type tokenType = getTypeConverter()->convertType(
|
||||
nvgpu::MBarrierTokenType::get(op->getContext()));
|
||||
@ -856,12 +851,13 @@ struct NVGPUMBarrierArriveNoCompleteLowering
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
Value barrier =
|
||||
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
adaptor.getMbarId(), rewriter);
|
||||
Type tokenType = getTypeConverter()->convertType(
|
||||
nvgpu::MBarrierTokenType::get(op->getContext()));
|
||||
Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
|
||||
Value count = truncToI32(b, adaptor.getCount());
|
||||
if (isMbarrierShared(op.getBarriers().getType())) {
|
||||
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
|
||||
op, tokenType, barrier, count);
|
||||
@ -880,8 +876,9 @@ struct NVGPUMBarrierTestWaitLowering
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
Value barrier =
|
||||
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
adaptor.getMbarId(), rewriter);
|
||||
Type retType = rewriter.getI1Type();
|
||||
if (isMbarrierShared(op.getBarriers().getType())) {
|
||||
@ -902,10 +899,11 @@ struct NVGPUMBarrierArriveExpectTxLowering
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
Value barrier =
|
||||
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
adaptor.getMbarId(), rewriter);
|
||||
Value txcount = truncToI32(rewriter, op->getLoc(), adaptor.getTxcount());
|
||||
Value txcount = truncToI32(b, adaptor.getTxcount());
|
||||
|
||||
if (isMbarrierShared(op.getBarriers().getType())) {
|
||||
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
|
||||
@ -926,11 +924,12 @@ struct NVGPUMBarrierTryWaitParityLowering
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
Value barrier =
|
||||
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
adaptor.getMbarId(), rewriter);
|
||||
Value ticks = truncToI32(rewriter, op->getLoc(), adaptor.getTicks());
|
||||
Value phase = truncToI32(rewriter, op->getLoc(), adaptor.getPhase());
|
||||
Value ticks = truncToI32(b, adaptor.getTicks());
|
||||
Value phase = truncToI32(b, adaptor.getPhase());
|
||||
|
||||
if (isMbarrierShared(op.getBarriers().getType())) {
|
||||
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
|
||||
@ -950,16 +949,17 @@ struct NVGPUTmaAsyncLoadOpLowering
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
|
||||
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
|
||||
adaptor.getDst(), {}, rewriter);
|
||||
Value barrier =
|
||||
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
|
||||
adaptor.getMbarId(), rewriter);
|
||||
|
||||
SmallVector<Value> coords = adaptor.getCoordinates();
|
||||
for (auto [index, value] : llvm::enumerate(coords)) {
|
||||
coords[index] = truncToI32(rewriter, op->getLoc(), value);
|
||||
coords[index] = truncToI32(b, value);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
|
||||
@ -976,7 +976,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
|
||||
matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op->getLoc();
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
|
||||
nvgpu::TensorMapSwizzleKind swizzleKind =
|
||||
op.getTensorMap().getType().getSwizzle();
|
||||
@ -992,20 +992,18 @@ struct NVGPUGenerateGmmaDescriptorLowering
|
||||
: (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
|
||||
: 0;
|
||||
|
||||
auto ti64 = rewriter.getIntegerType(64);
|
||||
auto ti64 = b.getIntegerType(64);
|
||||
auto makeConst = [&](uint64_t index) -> Value {
|
||||
return rewriter.create<LLVM::ConstantOp>(
|
||||
loc, ti64, rewriter.getI64IntegerAttr(index));
|
||||
return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
|
||||
};
|
||||
auto shiftLeft = [&](Value value, unsigned shift) -> Value {
|
||||
return rewriter.create<LLVM::ShlOp>(loc, ti64, value, makeConst(shift));
|
||||
return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
|
||||
};
|
||||
auto shiftRight = [&](Value value, unsigned shift) -> Value {
|
||||
return rewriter.create<LLVM::LShrOp>(loc, ti64, value, makeConst(shift));
|
||||
return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
|
||||
};
|
||||
auto insertBit = [&](Value desc, Value val, int startBit) {
|
||||
return rewriter.create<LLVM::OrOp>(loc, ti64, desc,
|
||||
shiftLeft(val, startBit));
|
||||
return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
|
||||
};
|
||||
|
||||
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
|
||||
@ -1019,7 +1017,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
|
||||
Value baseAddr = getStridedElementPtr(
|
||||
op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
|
||||
adaptor.getTensor(), {}, rewriter);
|
||||
Value basePtr = rewriter.create<LLVM::PtrToIntOp>(loc, ti64, baseAddr);
|
||||
Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
|
||||
// Just use 14 bits for base address
|
||||
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
|
||||
|
||||
@ -1050,16 +1048,13 @@ struct NVGPUGenerateGmmaDescriptorLowering
|
||||
}
|
||||
};
|
||||
|
||||
static Value makeI64Const(RewriterBase &rewriter, Operation *op,
|
||||
int32_t index) {
|
||||
return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
|
||||
rewriter.getIntegerType(64),
|
||||
rewriter.getI32IntegerAttr(index));
|
||||
static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
|
||||
return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
|
||||
b.getI32IntegerAttr(index));
|
||||
}
|
||||
|
||||
/// Returns a Value that holds data type enum that is expected by CUDA driver.
|
||||
static Value elementTypeAsLLVMConstant(RewriterBase &rewriter, Operation *op,
|
||||
Type type) {
|
||||
static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
|
||||
// Enum is from CUDA driver API
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
|
||||
enum CUtensorMapDataTypeEnum {
|
||||
@ -1079,25 +1074,25 @@ static Value elementTypeAsLLVMConstant(RewriterBase &rewriter, Operation *op,
|
||||
};
|
||||
|
||||
if (type.isUnsignedInteger(8))
|
||||
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT8);
|
||||
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
|
||||
if (type.isUnsignedInteger(16))
|
||||
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT16);
|
||||
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
|
||||
if (type.isUnsignedInteger(32))
|
||||
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT32);
|
||||
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
|
||||
if (type.isUnsignedInteger(64))
|
||||
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT64);
|
||||
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
|
||||
if (type.isSignlessInteger(32))
|
||||
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_INT32);
|
||||
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
|
||||
if (type.isSignlessInteger(64))
|
||||
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_INT64);
|
||||
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
|
||||
if (type.isF16())
|
||||
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
|
||||
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
|
||||
if (type.isF32())
|
||||
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
|
||||
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
|
||||
if (type.isF64())
|
||||
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
|
||||
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
|
||||
if (type.isBF16())
|
||||
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
|
||||
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
|
||||
|
||||
llvm_unreachable("Not supported data type");
|
||||
}
|
||||
@ -1109,23 +1104,22 @@ struct NVGPUTmaCreateDescriptorOpLowering
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
LLVM::LLVMPointerType llvmPointerType = getTypeConverter()->getPointerType(
|
||||
IntegerType::get(op->getContext(), 8));
|
||||
Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
|
||||
|
||||
Value tensorElementType = elementTypeAsLLVMConstant(
|
||||
rewriter, op, op.getTensor().getType().getElementType());
|
||||
Value tensorElementType =
|
||||
elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
|
||||
auto promotedOperands = getTypeConverter()->promoteOperands(
|
||||
loc, op->getOperands(), adaptor.getOperands(), rewriter);
|
||||
b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
|
||||
|
||||
Value boxArrayPtr = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, llvmPointerType, llvmInt64Type, makeI64Const(rewriter, op, 5));
|
||||
Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
|
||||
makeI64Const(b, 5));
|
||||
for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
|
||||
Value gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, llvmPointerType, llvmPointerType, boxArrayPtr,
|
||||
makeI64Const(rewriter, op, index));
|
||||
rewriter.create<LLVM::StoreOp>(loc, value, gep);
|
||||
Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
|
||||
boxArrayPtr, makeI64Const(b, index));
|
||||
b.create<LLVM::StoreOp>(value, gep);
|
||||
}
|
||||
|
||||
nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
|
||||
@ -1135,12 +1129,10 @@ struct NVGPUTmaCreateDescriptorOpLowering
|
||||
arguments.push_back(promotedOperands[1]); // descriptor
|
||||
arguments.push_back(tensorElementType); // data type
|
||||
arguments.push_back(
|
||||
makeI64Const(rewriter, op, (int)desc.getInterleave())); // interleave
|
||||
arguments.push_back(
|
||||
makeI64Const(rewriter, op, (int)desc.getSwizzle())); // swizzle
|
||||
arguments.push_back(
|
||||
makeI64Const(rewriter, op, (int)desc.getL2promo())); // l2promo
|
||||
arguments.push_back(makeI64Const(rewriter, op, (int)desc.getOob())); // oob
|
||||
makeI64Const(b, (int)desc.getInterleave())); // interleave
|
||||
arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
|
||||
arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
|
||||
arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
|
||||
arguments.push_back(boxArrayPtr); // box dimensions
|
||||
|
||||
// Set data types of the arguments
|
||||
@ -1157,7 +1149,7 @@ struct NVGPUTmaCreateDescriptorOpLowering
|
||||
FunctionCallBuilder hostRegisterCallBuilder = {
|
||||
"mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
|
||||
Value tensorMap =
|
||||
hostRegisterCallBuilder.create(loc, rewriter, arguments).getResult();
|
||||
hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
|
||||
|
||||
rewriter.replaceOp(op, tensorMap);
|
||||
return success();
|
||||
@ -1191,11 +1183,10 @@ struct NVGPUWarpgroupMmaOpLowering
|
||||
return success();
|
||||
}
|
||||
|
||||
Value generateNVVMWgmmaOp(MLIRContext *ctx,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
int m, int n, int k, Type resultStructType,
|
||||
Value inout, Value descriptorA,
|
||||
Value descriptorB) const {
|
||||
Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
|
||||
Type resultStructType, Value inout,
|
||||
Value descriptorA, Value descriptorB) const {
|
||||
MLIRContext *ctx = b.getContext();
|
||||
auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
|
||||
auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
|
||||
auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
|
||||
@ -1205,15 +1196,16 @@ struct NVGPUWarpgroupMmaOpLowering
|
||||
auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
|
||||
auto overflow =
|
||||
NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
|
||||
Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
|
||||
loc, resultStructType, inout, descriptorA, descriptorB, shape, itype,
|
||||
itype, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
|
||||
Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
|
||||
resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
|
||||
scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
|
||||
return res;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
|
||||
int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
|
||||
int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
|
||||
int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
|
||||
@ -1232,13 +1224,11 @@ struct NVGPUWarpgroupMmaOpLowering
|
||||
Value descriptorB = adaptor.getDescriptorB();
|
||||
|
||||
// Generate wgmma group
|
||||
|
||||
auto loc = op->getLoc();
|
||||
MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
|
||||
MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
|
||||
|
||||
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
|
||||
return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
|
||||
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
|
||||
};
|
||||
|
||||
auto iterateDescA = [&](Value desc, int iterM, int iterN,
|
||||
@ -1254,7 +1244,7 @@ struct NVGPUWarpgroupMmaOpLowering
|
||||
<< incrementVal << " | \t ");
|
||||
if (!incrementVal)
|
||||
return desc;
|
||||
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
|
||||
return makeAdd(desc, makeI64Const(b, incrementVal));
|
||||
};
|
||||
|
||||
auto iterateDescB = [&](Value desc, int iterM, int iterN,
|
||||
@ -1266,10 +1256,10 @@ struct NVGPUWarpgroupMmaOpLowering
|
||||
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
|
||||
if (!incrementVal)
|
||||
return desc;
|
||||
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
|
||||
return makeAdd(desc, makeI64Const(b, incrementVal));
|
||||
};
|
||||
|
||||
rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
|
||||
b.create<NVVM::WgmmaFenceAlignedOp>();
|
||||
|
||||
SmallVector<Value> wgmmaResults;
|
||||
for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
|
||||
@ -1291,14 +1281,13 @@ struct NVGPUWarpgroupMmaOpLowering
|
||||
<< " B[" << (iterK * wgmmaShapeK) << ":"
|
||||
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
|
||||
<< ":" << wgmmaShapeN << "])\n");
|
||||
matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
|
||||
wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
|
||||
matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
|
||||
structType, matrixC, descA, descB);
|
||||
}
|
||||
wgmmaResults.push_back(matrixC);
|
||||
}
|
||||
rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
|
||||
rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
|
||||
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
|
||||
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
|
||||
|
||||
ValueRange myres(wgmmaResults);
|
||||
rewriter.replaceOp(op, myres);
|
||||
|
Loading…
Reference in New Issue
Block a user