[mlir][nvgpu] Use ImplicitLocOpBuilder in nvgpu-to-nvvm pass (NFC) (#67993)

For the sake of better readability, this PR uses `ImplicitLocOpBuilder`
instead of rewriter+loc
This commit is contained in:
Guray Ozen 2023-10-03 10:52:36 +02:00 committed by GitHub
parent 2b5a6d774c
commit ee49cda7d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -19,6 +19,7 @@
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
@ -44,13 +45,12 @@ constexpr int exclude4LSB = 4;
/// GPU has 32 bit registers, this function truncates values when larger width
/// is not needed.
static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc,
Value value) {
static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
Type type = value.getType();
assert(llvm::isa<IntegerType>(type) && "expected an integer Value");
if (type.getIntOrFloatBitWidth() <= 32)
return value;
return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), value);
return b.create<LLVM::TruncOp>(b.getI32Type(), value);
}
/// Returns the type for the intrinsic given the vectorResultType of the
@ -170,22 +170,21 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
/// `nvvm.mma.sync` op expects these argments to be a given in a long list of
/// scalars of certain types. This function helps unpack the `vector` arguments
/// and cast them to the types expected by `nvvm.mma.sync`.
static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
Location loc, Value operand,
static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
Value operand,
NVVM::MMATypes operandPtxType) {
SmallVector<Value> result;
Type i32Ty = rewriter.getI32Type();
Type f64Ty = rewriter.getF64Type();
Type f32Ty = rewriter.getF32Type();
Type i8Ty = rewriter.getI8Type();
Type i4Ty = rewriter.getIntegerType(4);
Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4);
Type i4x8Ty = LLVM::getFixedVectorType(i4Ty, 8);
Type i32Ty = b.getI32Type();
Type f64Ty = b.getF64Type();
Type f32Ty = b.getF32Type();
Type i64Ty = b.getI64Type();
Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
Value toUse = rewriter.create<LLVM::ExtractValueOp>(loc, operand, i);
Value toUse = b.create<LLVM::ExtractValueOp>(operand, i);
// For 4xi8 vectors, the intrinsic expects these to be provided as i32
// scalar types.
@ -193,8 +192,7 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
arrayTy.getElementType() == i4x8Ty ||
(arrayTy.getElementType() == f32x1Ty &&
operandPtxType == NVVM::MMATypes::tf32)) {
result.push_back(
rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI32Type(), toUse));
result.push_back(b.create<LLVM::BitcastOp>(i32Ty, toUse));
continue;
}
@ -207,10 +205,9 @@ static SmallVector<Value> unpackOperandVector(RewriterBase &rewriter,
innerArrayTy.getElementType() == f32Ty)) {
for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements();
idx < innerSize; idx++) {
result.push_back(rewriter.create<LLVM::ExtractElementOp>(
loc, toUse,
rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx))));
result.push_back(b.create<LLVM::ExtractElementOp>(
toUse,
b.create<LLVM::ConstantOp>(i64Ty, b.getI64IntegerAttr(idx))));
}
continue;
}
@ -256,7 +253,7 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = getContext();
Location loc = op->getLoc();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// The result type of ldmatrix will always be a struct of 32bit integer
// registers if more than one 32bit value is returned. Otherwise, the result
@ -283,10 +280,10 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
auto srcMemrefType = cast<MemRefType>(op.getSrcMemref().getType());
Value srcPtr =
getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(),
getStridedElementPtr(b.getLoc(), srcMemrefType, adaptor.getSrcMemref(),
adaptor.getIndices(), rewriter);
Value ldMatrixResult = rewriter.create<NVVM::LdMatrixOp>(
loc, ldMatrixResultType, srcPtr,
Value ldMatrixResult = b.create<NVVM::LdMatrixOp>(
ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
: NVVM::MMALayout::row);
@ -296,15 +293,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
// actual vector type (still of width 32b) and repack them into a result
// struct.
Type finalResultType = typeConverter->convertType(vectorResultType);
Value result = rewriter.create<LLVM::UndefOp>(loc, finalResultType);
Value result = b.create<LLVM::UndefOp>(finalResultType);
for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) {
Value i32Register =
num32BitRegs > 1
? rewriter.create<LLVM::ExtractValueOp>(loc, ldMatrixResult, i)
: ldMatrixResult;
Value casted =
rewriter.create<LLVM::BitcastOp>(loc, innerVectorType, i32Register);
result = rewriter.create<LLVM::InsertValueOp>(loc, result, casted, i);
num32BitRegs > 1 ? b.create<LLVM::ExtractValueOp>(ldMatrixResult, i)
: ldMatrixResult;
Value casted = b.create<LLVM::BitcastOp>(innerVectorType, i32Register);
result = b.create<LLVM::InsertValueOp>(result, casted, i);
}
rewriter.replaceOp(op, result);
@ -335,7 +330,7 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
LogicalResult
matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// Get the shapes of the MMAMatrix type being used. The shapes will
// choose which intrinsic this op will be lowered to.
VectorType aType = op.getMatrixA().getType();
@ -368,17 +363,17 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
overflow = NVVM::MMAIntOverflow::satfinite;
SmallVector<Value> matA =
unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
SmallVector<Value> matB =
unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
SmallVector<Value> matC =
unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
Type intrinsicResTy = inferIntrinsicResultType(
typeConverter->convertType(op->getResultTypes()[0]));
Value intrinsicResult = rewriter.create<NVVM::MmaOp>(
op.getLoc(), intrinsicResTy, matA, matB, matC,
Value intrinsicResult = b.create<NVVM::MmaOp>(
intrinsicResTy, matA, matB, matC,
/*shape=*/gemmShape,
/*b1Op=*/std::nullopt,
/*intOverflow=*/overflow,
@ -511,14 +506,14 @@ static std::string buildMmaSparseAsmString(
/// Builds an inline assembly operation corresponding to the specified MMA
/// sparse sync operation.
static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
Location loc, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
ImplicitLocOpBuilder &b, NVVM::MMATypes ptxTypeA, NVVM::MMATypes ptxTypeB,
NVVM::MMATypes ptxTypeC, NVVM::MMATypes ptxTypeD,
std::optional<NVVM::MMAIntOverflow> overflow, ArrayRef<Value> unpackedAData,
ArrayRef<Value> unpackedB, ArrayRef<Value> unpackedC, Value indexData,
int64_t metadataSelector, const std::array<int64_t, 3> &shape,
Type intrinsicResultType, ConversionPatternRewriter &rewriter) {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
Type intrinsicResultType) {
auto asmDialectAttr =
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_ATT);
const unsigned matASize = unpackedAData.size();
const unsigned matBSize = unpackedB.size();
@ -536,15 +531,15 @@ static FailureOr<LLVM::InlineAsmOp> emitMmaSparseSyncOpAsm(
llvm::append_range(asmVals, args);
asmVals.push_back(indexData);
return rewriter.create<LLVM::InlineAsmOp>(loc,
/*resultTypes=*/intrinsicResultType,
/*operands=*/asmVals,
/*asm_string=*/asmStr,
/*constraints=*/constraintStr,
/*has_side_effects=*/true,
/*is_align_stack=*/false,
/*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
return b.create<LLVM::InlineAsmOp>(
/*resultTypes=*/intrinsicResultType,
/*operands=*/asmVals,
/*asm_string=*/asmStr,
/*constraints=*/constraintStr,
/*has_side_effects=*/true,
/*is_align_stack=*/false,
/*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
}
/// Lowers `nvgpu.mma.sp.sync` to inline assembly.
@ -555,7 +550,7 @@ struct NVGPUMmaSparseSyncLowering
LogicalResult
matchAndRewrite(nvgpu::MmaSparseSyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
// Get the shapes of the MMAMatrix type being used. The shapes will
// choose which intrinsic this op will be lowered to.
VectorType aType = op.getMatrixA().getType();
@ -586,11 +581,11 @@ struct NVGPUMmaSparseSyncLowering
overflow = NVVM::MMAIntOverflow::satfinite;
SmallVector<Value> matA =
unpackOperandVector(rewriter, loc, adaptor.getMatrixA(), *ptxTypeA);
unpackOperandVector(b, adaptor.getMatrixA(), *ptxTypeA);
SmallVector<Value> matB =
unpackOperandVector(rewriter, loc, adaptor.getMatrixB(), *ptxTypeB);
unpackOperandVector(b, adaptor.getMatrixB(), *ptxTypeB);
SmallVector<Value> matC =
unpackOperandVector(rewriter, loc, adaptor.getMatrixC(), *ptxTypeC);
unpackOperandVector(b, adaptor.getMatrixC(), *ptxTypeC);
Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]);
Type intrinsicResTy = inferIntrinsicResultType(
@ -602,13 +597,13 @@ struct NVGPUMmaSparseSyncLowering
LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
sparseMetadata = rewriter.create<LLVM::BitcastOp>(
loc, rewriter.getI32Type(), sparseMetadata);
sparseMetadata =
b.create<LLVM::BitcastOp>(rewriter.getI32Type(), sparseMetadata);
FailureOr<LLVM::InlineAsmOp> intrinsicResult = emitMmaSparseSyncOpAsm(
loc, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
b, *ptxTypeA, *ptxTypeB, *ptxTypeC, *ptxTypeC, overflow, matA, matB,
matC, sparseMetadata, op.getSparsitySelector(), op.getMmaShapeAsArray(),
intrinsicResTy, rewriter);
intrinsicResTy);
if (failed(intrinsicResult))
return failure();
@ -629,10 +624,12 @@ struct NVGPUAsyncCopyLowering
LogicalResult
matchAndRewrite(nvgpu::DeviceAsyncCopyOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Location loc = op.getLoc();
auto dstMemrefType = cast<MemRefType>(op.getDst().getType());
Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.getDst(),
adaptor.getDstIndices(), rewriter);
Value dstPtr =
getStridedElementPtr(b.getLoc(), dstMemrefType, adaptor.getDst(),
adaptor.getDstIndices(), rewriter);
auto i8Ty = IntegerType::get(op.getContext(), 8);
FailureOr<unsigned> dstAddressSpace =
getTypeConverter()->getMemRefAddressSpace(dstMemrefType);
@ -642,7 +639,7 @@ struct NVGPUAsyncCopyLowering
auto dstPointerType =
getTypeConverter()->getPointerType(i8Ty, *dstAddressSpace);
if (!getTypeConverter()->useOpaquePointers())
dstPtr = rewriter.create<LLVM::BitcastOp>(loc, dstPointerType, dstPtr);
dstPtr = b.create<LLVM::BitcastOp>(dstPointerType, dstPtr);
auto srcMemrefType = cast<MemRefType>(op.getSrc().getType());
FailureOr<unsigned> srcAddressSpace =
@ -656,12 +653,11 @@ struct NVGPUAsyncCopyLowering
auto srcPointerType =
getTypeConverter()->getPointerType(i8Ty, *srcAddressSpace);
if (!getTypeConverter()->useOpaquePointers())
scrPtr = rewriter.create<LLVM::BitcastOp>(loc, srcPointerType, scrPtr);
scrPtr = b.create<LLVM::BitcastOp>(srcPointerType, scrPtr);
// Intrinsics takes a global pointer so we need an address space cast.
auto srcPointerGlobalType = getTypeConverter()->getPointerType(
i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace);
scrPtr = rewriter.create<LLVM::AddrSpaceCastOp>(loc, srcPointerGlobalType,
scrPtr);
scrPtr = b.create<LLVM::AddrSpaceCastOp>(srcPointerGlobalType, scrPtr);
int64_t dstElements = adaptor.getDstElements().getZExtValue();
int64_t sizeInBytes =
(dstMemrefType.getElementTypeBitWidth() * dstElements) / 8;
@ -675,16 +671,14 @@ struct NVGPUAsyncCopyLowering
// memory) of CpAsyncOp is read only for SrcElements number of elements.
// The rest of the DstElements in the destination (shared memory) are
// filled with zeros.
Value c3I32 = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3));
Value bitwidth = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(),
rewriter.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
Value srcElementsI32 =
rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), srcBytes);
srcBytes = rewriter.create<LLVM::LShrOp>(
loc, rewriter.create<LLVM::MulOp>(loc, bitwidth, srcElementsI32),
c3I32);
Value c3I32 =
b.create<LLVM::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(3));
Value bitwidth = b.create<LLVM::ConstantOp>(
b.getI32Type(),
b.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth()));
Value srcElementsI32 = b.create<LLVM::TruncOp>(b.getI32Type(), srcBytes);
srcBytes = b.create<LLVM::LShrOp>(
b.create<LLVM::MulOp>(bitwidth, srcElementsI32), c3I32);
}
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
// 16 dst bytes.
@ -693,15 +687,14 @@ struct NVGPUAsyncCopyLowering
? NVVM::LoadCacheModifierKind::CG
: NVVM::LoadCacheModifierKind::CA;
rewriter.create<NVVM::CpAsyncOp>(
loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
b.create<NVVM::CpAsyncOp>(
dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes),
NVVM::LoadCacheModifierKindAttr::get(op->getContext(), cacheModifier),
srcBytes);
// Drop the result token.
Value zero = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), IntegerType::get(op.getContext(), 32),
rewriter.getI32IntegerAttr(0));
Value zero = b.create<LLVM::ConstantOp>(
IntegerType::get(op.getContext(), 32), rewriter.getI32IntegerAttr(0));
rewriter.replaceOp(op, zero);
return success();
}
@ -790,14 +783,14 @@ struct MBarrierBasePattern : public ConvertOpToLLVMPattern<SourceOp> {
public:
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
/// Returns the base pointer of the mbarrier object.
Value getMbarrierPtr(Operation *op, nvgpu::MBarrierGroupType mbarType,
Value memrefDesc, Value mbarId,
Value getMbarrierPtr(ImplicitLocOpBuilder &b,
nvgpu::MBarrierGroupType mbarType, Value memrefDesc,
Value mbarId,
ConversionPatternRewriter &rewriter) const {
MemRefType mbarrierMemrefType =
nvgpu::getMBarrierMemrefType(rewriter.getContext(), mbarType);
return ConvertToLLVMPattern::getStridedElementPtr(
op->getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
return memrefDesc;
b.getLoc(), mbarrierMemrefType, memrefDesc, {mbarId}, rewriter);
}
};
@ -809,11 +802,12 @@ struct NVGPUMBarrierInitLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierInitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
nvgpu::MBarrierGroupType mbarrierType = op.getBarriers().getType();
rewriter.setInsertionPoint(op);
Value barrier = getMbarrierPtr(op, mbarrierType, adaptor.getBarriers(),
Value barrier = getMbarrierPtr(b, mbarrierType, adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
Value count = truncToI32(b, adaptor.getCount());
if (isMbarrierShared(mbarrierType)) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierInitSharedOp>(op, barrier,
count);
@ -831,8 +825,9 @@ struct NVGPUMBarrierArriveLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierArriveOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
@ -856,12 +851,13 @@ struct NVGPUMBarrierArriveNoCompleteLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Type tokenType = getTypeConverter()->convertType(
nvgpu::MBarrierTokenType::get(op->getContext()));
Value count = truncToI32(rewriter, op->getLoc(), adaptor.getCount());
Value count = truncToI32(b, adaptor.getCount());
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveNocompleteSharedOp>(
op, tokenType, barrier, count);
@ -880,8 +876,9 @@ struct NVGPUMBarrierTestWaitLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierTestWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Type retType = rewriter.getI1Type();
if (isMbarrierShared(op.getBarriers().getType())) {
@ -902,10 +899,11 @@ struct NVGPUMBarrierArriveExpectTxLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(rewriter, op->getLoc(), adaptor.getTxcount());
Value txcount = truncToI32(b, adaptor.getTxcount());
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
@ -926,11 +924,12 @@ struct NVGPUMBarrierTryWaitParityLowering
LogicalResult
matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value barrier =
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value ticks = truncToI32(rewriter, op->getLoc(), adaptor.getTicks());
Value phase = truncToI32(rewriter, op->getLoc(), adaptor.getPhase());
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase = truncToI32(b, adaptor.getPhase());
if (isMbarrierShared(op.getBarriers().getType())) {
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
@ -950,16 +949,17 @@ struct NVGPUTmaAsyncLoadOpLowering
LogicalResult
matchAndRewrite(nvgpu::TmaAsyncLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
Value dest = getStridedElementPtr(op->getLoc(), srcMemrefType,
adaptor.getDst(), {}, rewriter);
Value barrier =
getMbarrierPtr(op, op.getBarriers().getType(), adaptor.getBarriers(),
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
SmallVector<Value> coords = adaptor.getCoordinates();
for (auto [index, value] : llvm::enumerate(coords)) {
coords[index] = truncToI32(rewriter, op->getLoc(), value);
coords[index] = truncToI32(b, value);
}
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
@ -976,7 +976,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
matchAndRewrite(nvgpu::GenerateGmmaDescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
nvgpu::TensorMapSwizzleKind swizzleKind =
op.getTensorMap().getType().getSwizzle();
@ -992,20 +992,18 @@ struct NVGPUGenerateGmmaDescriptorLowering
: (swizzleKind == nvgpu::TensorMapSwizzleKind::SWIZZLE_32B) ? 3
: 0;
auto ti64 = rewriter.getIntegerType(64);
auto ti64 = b.getIntegerType(64);
auto makeConst = [&](uint64_t index) -> Value {
return rewriter.create<LLVM::ConstantOp>(
loc, ti64, rewriter.getI64IntegerAttr(index));
return b.create<LLVM::ConstantOp>(ti64, b.getI64IntegerAttr(index));
};
auto shiftLeft = [&](Value value, unsigned shift) -> Value {
return rewriter.create<LLVM::ShlOp>(loc, ti64, value, makeConst(shift));
return b.create<LLVM::ShlOp>(ti64, value, makeConst(shift));
};
auto shiftRight = [&](Value value, unsigned shift) -> Value {
return rewriter.create<LLVM::LShrOp>(loc, ti64, value, makeConst(shift));
return b.create<LLVM::LShrOp>(ti64, value, makeConst(shift));
};
auto insertBit = [&](Value desc, Value val, int startBit) {
return rewriter.create<LLVM::OrOp>(loc, ti64, desc,
shiftLeft(val, startBit));
return b.create<LLVM::OrOp>(ti64, desc, shiftLeft(val, startBit));
};
int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0);
@ -1019,7 +1017,7 @@ struct NVGPUGenerateGmmaDescriptorLowering
Value baseAddr = getStridedElementPtr(
op->getLoc(), cast<MemRefType>(op.getTensor().getType()),
adaptor.getTensor(), {}, rewriter);
Value basePtr = rewriter.create<LLVM::PtrToIntOp>(loc, ti64, baseAddr);
Value basePtr = b.create<LLVM::PtrToIntOp>(ti64, baseAddr);
// Just use 14 bits for base address
Value basePtr14bit = shiftRight(shiftLeft(basePtr, 46), 50);
@ -1050,16 +1048,13 @@ struct NVGPUGenerateGmmaDescriptorLowering
}
};
static Value makeI64Const(RewriterBase &rewriter, Operation *op,
int32_t index) {
return rewriter.create<LLVM::ConstantOp>(op->getLoc(),
rewriter.getIntegerType(64),
rewriter.getI32IntegerAttr(index));
static Value makeI64Const(ImplicitLocOpBuilder &b, int32_t index) {
return b.create<LLVM::ConstantOp>(b.getIntegerType(64),
b.getI32IntegerAttr(index));
}
/// Returns a Value that holds data type enum that is expected by CUDA driver.
static Value elementTypeAsLLVMConstant(RewriterBase &rewriter, Operation *op,
Type type) {
static Value elementTypeAsLLVMConstant(ImplicitLocOpBuilder &b, Type type) {
// Enum is from CUDA driver API
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
enum CUtensorMapDataTypeEnum {
@ -1079,25 +1074,25 @@ static Value elementTypeAsLLVMConstant(RewriterBase &rewriter, Operation *op,
};
if (type.isUnsignedInteger(8))
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT8);
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT8);
if (type.isUnsignedInteger(16))
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT16);
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT16);
if (type.isUnsignedInteger(32))
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT32);
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT32);
if (type.isUnsignedInteger(64))
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_UINT64);
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_UINT64);
if (type.isSignlessInteger(32))
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_INT32);
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT32);
if (type.isSignlessInteger(64))
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_INT64);
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_INT64);
if (type.isF16())
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT16);
if (type.isF32())
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT32);
if (type.isF64())
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_FLOAT64);
if (type.isBF16())
return makeI64Const(rewriter, op, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
return makeI64Const(b, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16);
llvm_unreachable("Not supported data type");
}
@ -1109,23 +1104,22 @@ struct NVGPUTmaCreateDescriptorOpLowering
LogicalResult
matchAndRewrite(nvgpu::TmaCreateDescriptorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
LLVM::LLVMPointerType llvmPointerType = getTypeConverter()->getPointerType(
IntegerType::get(op->getContext(), 8));
Type llvmInt64Type = IntegerType::get(op->getContext(), 64);
Value tensorElementType = elementTypeAsLLVMConstant(
rewriter, op, op.getTensor().getType().getElementType());
Value tensorElementType =
elementTypeAsLLVMConstant(b, op.getTensor().getType().getElementType());
auto promotedOperands = getTypeConverter()->promoteOperands(
loc, op->getOperands(), adaptor.getOperands(), rewriter);
b.getLoc(), op->getOperands(), adaptor.getOperands(), b);
Value boxArrayPtr = rewriter.create<LLVM::AllocaOp>(
loc, llvmPointerType, llvmInt64Type, makeI64Const(rewriter, op, 5));
Value boxArrayPtr = b.create<LLVM::AllocaOp>(llvmPointerType, llvmInt64Type,
makeI64Const(b, 5));
for (auto [index, value] : llvm::enumerate(adaptor.getBoxDimensions())) {
Value gep = rewriter.create<LLVM::GEPOp>(
loc, llvmPointerType, llvmPointerType, boxArrayPtr,
makeI64Const(rewriter, op, index));
rewriter.create<LLVM::StoreOp>(loc, value, gep);
Value gep = b.create<LLVM::GEPOp>(llvmPointerType, llvmPointerType,
boxArrayPtr, makeI64Const(b, index));
b.create<LLVM::StoreOp>(value, gep);
}
nvgpu::TensorMapDescriptorType desc = op.getTensorMap().getType();
@ -1135,12 +1129,10 @@ struct NVGPUTmaCreateDescriptorOpLowering
arguments.push_back(promotedOperands[1]); // descriptor
arguments.push_back(tensorElementType); // data type
arguments.push_back(
makeI64Const(rewriter, op, (int)desc.getInterleave())); // interleave
arguments.push_back(
makeI64Const(rewriter, op, (int)desc.getSwizzle())); // swizzle
arguments.push_back(
makeI64Const(rewriter, op, (int)desc.getL2promo())); // l2promo
arguments.push_back(makeI64Const(rewriter, op, (int)desc.getOob())); // oob
makeI64Const(b, (int)desc.getInterleave())); // interleave
arguments.push_back(makeI64Const(b, (int)desc.getSwizzle())); // swizzle
arguments.push_back(makeI64Const(b, (int)desc.getL2promo())); // l2promo
arguments.push_back(makeI64Const(b, (int)desc.getOob())); // oob
arguments.push_back(boxArrayPtr); // box dimensions
// Set data types of the arguments
@ -1157,7 +1149,7 @@ struct NVGPUTmaCreateDescriptorOpLowering
FunctionCallBuilder hostRegisterCallBuilder = {
"mgpuTensorMapEncodeTiledMemref", llvmPointerType, argTypes};
Value tensorMap =
hostRegisterCallBuilder.create(loc, rewriter, arguments).getResult();
hostRegisterCallBuilder.create(b.getLoc(), b, arguments).getResult();
rewriter.replaceOp(op, tensorMap);
return success();
@ -1191,11 +1183,10 @@ struct NVGPUWarpgroupMmaOpLowering
return success();
}
Value generateNVVMWgmmaOp(MLIRContext *ctx,
ConversionPatternRewriter &rewriter, Location loc,
int m, int n, int k, Type resultStructType,
Value inout, Value descriptorA,
Value descriptorB) const {
Value generateNVVMWgmmaOp(ImplicitLocOpBuilder &b, int m, int n, int k,
Type resultStructType, Value inout,
Value descriptorA, Value descriptorB) const {
MLIRContext *ctx = b.getContext();
auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k);
auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one);
auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one);
@ -1205,15 +1196,16 @@ struct NVGPUWarpgroupMmaOpLowering
auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16);
auto overflow =
NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped);
Value res = rewriter.create<NVVM::WgmmaMmaAsyncOp>(
loc, resultStructType, inout, descriptorA, descriptorB, shape, itype,
itype, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
Value res = b.create<NVVM::WgmmaMmaAsyncOp>(
resultStructType, inout, descriptorA, descriptorB, shape, itype, itype,
scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
return res;
}
LogicalResult
matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0);
int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1);
int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1);
@ -1232,13 +1224,11 @@ struct NVGPUWarpgroupMmaOpLowering
Value descriptorB = adaptor.getDescriptorB();
// Generate wgmma group
auto loc = op->getLoc();
MemRefType typeTensorA = op.getDescriptorA().getType().getTensor();
MemRefType typeTensorB = op.getDescriptorB().getType().getTensor();
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
return b.create<LLVM::AddOp>(lhs.getType(), lhs, rhs);
};
auto iterateDescA = [&](Value desc, int iterM, int iterN,
@ -1254,7 +1244,7 @@ struct NVGPUWarpgroupMmaOpLowering
<< incrementVal << " | \t ");
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
return makeAdd(desc, makeI64Const(b, incrementVal));
};
auto iterateDescB = [&](Value desc, int iterM, int iterN,
@ -1266,10 +1256,10 @@ struct NVGPUWarpgroupMmaOpLowering
LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(rewriter, op, incrementVal));
return makeAdd(desc, makeI64Const(b, incrementVal));
};
rewriter.create<NVVM::WgmmaFenceAlignedOp>(loc);
b.create<NVVM::WgmmaFenceAlignedOp>();
SmallVector<Value> wgmmaResults;
for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) {
@ -1291,14 +1281,13 @@ struct NVGPUWarpgroupMmaOpLowering
<< " B[" << (iterK * wgmmaShapeK) << ":"
<< (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0
<< ":" << wgmmaShapeN << "])\n");
matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc,
wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
matrixC = generateNVVMWgmmaOp(b, wgmmaShapeM, wgmmaShapeN, wgmmaShapeK,
structType, matrixC, descA, descB);
}
wgmmaResults.push_back(matrixC);
}
rewriter.create<NVVM::WgmmaGroupSyncAlignedOp>(loc);
rewriter.create<NVVM::WgmmaWaitGroupSyncOp>(loc, op.getWaitGroup());
b.create<NVVM::WgmmaGroupSyncAlignedOp>();
b.create<NVVM::WgmmaWaitGroupSyncOp>(op.getWaitGroup());
ValueRange myres(wgmmaResults);
rewriter.replaceOp(op, myres);