mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-03-09 05:02:19 +00:00
[mlir][spirv][gpu] Add conversion for load/store/mad coop matrix ops (#66311)
This is plugged in as an alternative lowering path in the gpu to spirv dialect conversion. Add custom op builders for coop matrix ops to make the create functions nicer to work with and less error-prone. The latter is accomplished by following the op syntax and also requiring stride to be a constant op to avoid confusion around the order of arguments. The remaining lowering patterns will be added in a future patch.
This commit is contained in:
parent
f66cd9e955
commit
ed4daeaa13
@ -30,11 +30,21 @@ class MMAMatrixType;
|
||||
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns);
|
||||
|
||||
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
|
||||
/// using the KHR Cooperative Matrix extension.
|
||||
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
|
||||
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
|
||||
|
||||
/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
|
||||
/// using the NV Cooperative Matrix extension.
|
||||
void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
|
||||
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
|
||||
|
||||
/// Returns a KHR cooperative matrix type corresponding to the MMAMatrixType
|
||||
/// `type`.
|
||||
spirv::CooperativeMatrixType
|
||||
convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type);
|
||||
|
||||
/// Returns an NV cooperative matrix type corresponding to the MMAMatrixType
|
||||
/// `type`.
|
||||
spirv::CooperativeMatrixNVType
|
||||
|
@ -567,7 +567,11 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
|
||||
let options = [
|
||||
Option<"use64bitIndex", "use-64bit-index",
|
||||
"bool", /*default=*/"false",
|
||||
"Use 64-bit integers to convert index types">
|
||||
"Use 64-bit integers to convert index types">,
|
||||
Option<"useCoopMatrixNV", "use-coop-matrix-nv",
|
||||
"bool", /*default=*/"true",
|
||||
"Use the NV cooperative matrix extension insted of the KHR extension"
|
||||
" to lower GPU WMMA ops">,
|
||||
];
|
||||
}
|
||||
|
||||
|
@ -146,6 +146,15 @@ def SPIRV_KHRCooperativeMatrixLoadOp : SPIRV_KhrVendorOp<"CooperativeMatrixLoad"
|
||||
let results = (outs
|
||||
SPIRV_AnyCooperativeMatrix:$result
|
||||
);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Type":$result, "Value":$pointer,
|
||||
"spirv::ConstantOp":$stride,
|
||||
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
|
||||
build($_builder, $_state, result, pointer, layout, stride,
|
||||
spirv::MemoryAccessAttr{});
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -226,6 +235,15 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$pointer, "Value":$object,
|
||||
"spirv::ConstantOp":$stride,
|
||||
"spirv::CooperativeMatrixLayoutKHR":$layout), [{
|
||||
build($_builder, $_state, pointer, object, layout, stride,
|
||||
spirv::MemoryAccessAttr{});
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -332,6 +350,13 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
|
||||
let results = (outs
|
||||
SPIRV_AnyCooperativeMatrix:$result
|
||||
);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a, "Value":$b, "Value":$c), [{
|
||||
build($_builder, $_state, a, b, c,
|
||||
spirv::CooperativeMatrixOperandsKHRAttr{});
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -86,13 +86,25 @@ void GPUToSPIRVPass::runOnOperation() {
|
||||
SPIRVConversionOptions options;
|
||||
options.use64bitIndex = this->use64bitIndex;
|
||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||
typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
|
||||
return convertMMAToSPIRVCoopMatrixNVType(type);
|
||||
|
||||
typeConverter.addConversion([useNV = this->useCoopMatrixNV.getValue()](
|
||||
gpu::MMAMatrixType type) -> Type {
|
||||
if (useNV)
|
||||
return convertMMAToSPIRVCoopMatrixNVType(type);
|
||||
|
||||
return convertMMAToSPIRVCoopMatrixType(type);
|
||||
});
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
populateGPUToSPIRVPatterns(typeConverter, patterns);
|
||||
populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
if (this->useCoopMatrixNV) {
|
||||
populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
} else {
|
||||
populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
}
|
||||
|
||||
// TODO: Change SPIR-V conversion to be progressive and remove the following
|
||||
// patterns.
|
||||
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
|
||||
|
@ -18,22 +18,28 @@
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
|
||||
namespace mlir::nv {
|
||||
namespace {
|
||||
#include <cassert>
|
||||
|
||||
namespace mlir {
|
||||
/// Creates a SPIR-V op to replace the given GPU subgroup mma elementwise op
|
||||
/// when the elementwise op directly supports with cooperative matrix type.
|
||||
/// Returns false if cannot.
|
||||
///
|
||||
/// See SPV_NV_cooperative_matrix for supported elementwise ops.
|
||||
static bool createElementwiseOp(ConversionPatternRewriter &builder,
|
||||
gpu::SubgroupMmaElementwiseOp op,
|
||||
spirv::CooperativeMatrixNVType coopType,
|
||||
gpu::SubgroupMmaElementwiseOp op, Type coopType,
|
||||
ValueRange operands) {
|
||||
assert((isa<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType>(
|
||||
coopType)));
|
||||
|
||||
switch (op.getOpType()) {
|
||||
case gpu::MMAElementwiseOp::ADDF:
|
||||
builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
|
||||
@ -71,6 +77,110 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPV_KHR_cooperative_matrix
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace khr {
|
||||
namespace {
|
||||
|
||||
/// Converts the GPU MMA loadOp to KHRCooperativeMatrixLoad op in the SPIRV
|
||||
/// dialect.
|
||||
struct WmmaLoadOpToSPIRVLowering final
|
||||
: OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
|
||||
Location loc = op->getLoc();
|
||||
|
||||
auto retType = cast<gpu::MMAMatrixType>(op.getRes().getType());
|
||||
MemRefType memrefType = op.getSrcMemref().getType();
|
||||
Value bufferPtr =
|
||||
spirv::getElementPtr(typeConverter, memrefType, adaptor.getSrcMemref(),
|
||||
adaptor.getIndices(), loc, rewriter);
|
||||
|
||||
auto coopType =
|
||||
typeConverter.convertType<spirv::CooperativeMatrixType>(retType);
|
||||
if (!coopType)
|
||||
return rewriter.notifyMatchFailure(op, "type conversion failed");
|
||||
|
||||
int64_t stride = op.getLeadDimension().getSExtValue();
|
||||
IntegerType i32Type = rewriter.getI32Type();
|
||||
auto strideValue = rewriter.create<spirv::ConstantOp>(
|
||||
loc, i32Type, IntegerAttr::get(i32Type, stride));
|
||||
|
||||
bool isColMajor = op.getTranspose().value_or(false);
|
||||
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
|
||||
: spirv::CooperativeMatrixLayoutKHR::RowMajor;
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixLoadOp>(
|
||||
op, coopType, bufferPtr, strideValue, layout);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts the GPU MMA StoreOp to KHRCooperativeMatrixStore op in the SPIRV
|
||||
/// dialect.
|
||||
struct WmmaStoreOpToSPIRVLowering final
|
||||
: OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
|
||||
Location loc = op->getLoc();
|
||||
|
||||
auto memrefType = cast<MemRefType>(op.getDstMemref().getType());
|
||||
Value bufferPtr =
|
||||
spirv::getElementPtr(typeConverter, memrefType, adaptor.getDstMemref(),
|
||||
adaptor.getIndices(), loc, rewriter);
|
||||
|
||||
int64_t stride = op.getLeadDimension().getSExtValue();
|
||||
IntegerType i32Type = rewriter.getI32Type();
|
||||
auto strideValue = rewriter.create<spirv::ConstantOp>(
|
||||
loc, i32Type, IntegerAttr::get(i32Type, stride));
|
||||
|
||||
bool isColMajor = op.getTranspose().value_or(false);
|
||||
auto layout = isColMajor ? spirv::CooperativeMatrixLayoutKHR::ColumnMajor
|
||||
: spirv::CooperativeMatrixLayoutKHR::RowMajor;
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixStoreOp>(
|
||||
op, bufferPtr, adaptor.getSrc(), strideValue, layout);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts GPU MMA Compute to KHRCooperativeMatrixMulAdd op in the SPIRV
|
||||
/// dialect.
|
||||
struct WmmaMmaOpToSPIRVLowering final
|
||||
: OpConversionPattern<gpu::SubgroupMmaComputeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<spirv::KHRCooperativeMatrixMulAddOp>(
|
||||
subgroupMmaComputeOp, adaptor.getOpA(), adaptor.getOpB(),
|
||||
adaptor.getOpC());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace khr
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPV_NV_cooperative_matrix
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace nv {
|
||||
namespace {
|
||||
|
||||
/// Converts the GPU MMA loadOp to NVCooperativeMatrixLoad op in the SPIRV
|
||||
/// dialect.
|
||||
struct WmmaLoadOpToSPIRVLowering final
|
||||
@ -247,7 +357,8 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace mlir::nv
|
||||
} // namespace nv
|
||||
} // namespace mlir
|
||||
|
||||
mlir::spirv::CooperativeMatrixNVType
|
||||
mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
|
||||
@ -257,6 +368,30 @@ mlir::convertMMAToSPIRVCoopMatrixNVType(gpu::MMAMatrixType type) {
|
||||
elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
|
||||
}
|
||||
|
||||
mlir::spirv::CooperativeMatrixType
|
||||
mlir::convertMMAToSPIRVCoopMatrixType(gpu::MMAMatrixType type) {
|
||||
ArrayRef<int64_t> retTypeShape = type.getShape();
|
||||
Type elementType = type.getElementType();
|
||||
|
||||
auto use =
|
||||
llvm::StringSwitch<spirv::CooperativeMatrixUseKHR>(type.getOperand())
|
||||
.Case("AOp", spirv::CooperativeMatrixUseKHR::MatrixA)
|
||||
.Case("BOp", spirv::CooperativeMatrixUseKHR::MatrixB)
|
||||
.Default(spirv::CooperativeMatrixUseKHR::MatrixAcc);
|
||||
|
||||
return spirv::CooperativeMatrixType::get(elementType, retTypeShape[0],
|
||||
retTypeShape[1],
|
||||
spirv::Scope::Subgroup, use);
|
||||
}
|
||||
|
||||
void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
|
||||
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
using namespace mlir;
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
|
||||
khr::WmmaStoreOpToSPIRVLowering>(converter, context);
|
||||
}
|
||||
|
||||
void mlir::populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
|
||||
SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
using namespace mlir;
|
||||
|
@ -0,0 +1,80 @@
|
||||
// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=false" --cse \
|
||||
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
|
||||
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.6,
|
||||
[Shader, CooperativeMatrixKHR, Float16],
|
||||
[SPV_KHR_storage_buffer_storage_class, SPV_KHR_cooperative_matrix]>,
|
||||
#spirv.resource_limits<>>} {
|
||||
|
||||
gpu.module @kernels {
|
||||
// CHECK-LABEL: spirv.func @gpu_wmma_load_op
|
||||
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
|
||||
gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
|
||||
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
|
||||
%i = arith.constant 16 : index
|
||||
%j = arith.constant 16 : index
|
||||
// CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32
|
||||
// CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <RowMajor> :
|
||||
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
|
||||
%0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} :
|
||||
memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
|
||||
// CHECK: spirv.KHR.CooperativeMatrixLoad {{%.*}}, %[[STRIDE]], <ColumnMajor> :
|
||||
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
|
||||
%1 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index, transpose} :
|
||||
memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
// CHECK: spirv.Return
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: spirv.func @gpu_wmma_store_op
|
||||
// CHECK-SAME: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer>
|
||||
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
|
||||
gpu.func @gpu_wmma_store_op(%arg0: memref<32x32xf16, #spirv.storage_class<StorageBuffer>>,
|
||||
%arg1: !gpu.mma_matrix<16x16xf16, "COp">) kernel
|
||||
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
|
||||
%i = arith.constant 16 : index
|
||||
%j = arith.constant 16 : index
|
||||
// CHECK: %[[STRIDE:.+]] = spirv.Constant 32 : i32
|
||||
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <RowMajor> :
|
||||
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
|
||||
gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index} :
|
||||
!gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
|
||||
|
||||
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, %[[STRIDE]], <ColumnMajor> :
|
||||
// CHECK-SAME: !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
|
||||
gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension = 32 : index, transpose} :
|
||||
!gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
|
||||
// CHECK: spirv.Return
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: spirv.func @gpu_wmma_mma_op
|
||||
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
|
||||
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
|
||||
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
|
||||
gpu.func @gpu_wmma_mma_op(%A: !gpu.mma_matrix<16x16xf16, "AOp">,
|
||||
%B: !gpu.mma_matrix<16x16xf16, "BOp">,
|
||||
%C: !gpu.mma_matrix<16x16xf16, "COp">,
|
||||
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
|
||||
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
|
||||
// CHECK: %[[MAD:.*]] = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} :
|
||||
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>,
|
||||
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixB>
|
||||
// CHECK-SAME: -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
|
||||
%D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">,
|
||||
!gpu.mma_matrix<16x16xf16, "BOp">
|
||||
-> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
|
||||
%i = arith.constant 0 : index
|
||||
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.+}}, %[[MAD]], %{{.+}}, <RowMajor>
|
||||
gpu.subgroup_mma_store_matrix %D, %ptr[%i,%i] {leadDimension = 32 : index} :
|
||||
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
|
||||
// CHECK: spirv.Return
|
||||
gpu.return
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
// RUN: mlir-opt --convert-gpu-to-spirv --split-input-file --verify-diagnostics %s | FileCheck %s
|
||||
// RUN: mlir-opt --convert-gpu-to-spirv="use-coop-matrix-nv=true" \
|
||||
// RUN: --split-input-file --verify-diagnostics %s | FileCheck %s
|
||||
|
||||
module attributes {
|
||||
gpu.container_module,
|
||||
|
Loading…
x
Reference in New Issue
Block a user