mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-08 18:37:11 +00:00
[mlir][gpu][sparse] adding cusparse sddmm support
Differential Revision: https://reviews.llvm.org/D151279
This commit is contained in:
parent
35d7fa45bd
commit
cf44847b4d
@ -2047,4 +2047,109 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def GPU_SDDMMBufferSizeOp : GPU_Op<"sddmm_buffer_size", [GPU_AsyncOpInterface]> {
|
||||
let summary = "Precompute buffersize for SDDMM operation";
|
||||
let description = [{
|
||||
The `gpu.sddmm_buffer_size` operation returns the buffer size required
|
||||
to perform the SDDMM operation on the given sparse and dense matrices.
|
||||
The operation expects handles returned by previous sparse operations
|
||||
to construct an environment and the operands for SDDMM.
|
||||
|
||||
If the `async` keyword is present, the op is executed asynchronously (i.e.
|
||||
it does not block until the execution has finished on the device). In
|
||||
that case, it returns a !gpu.async.token in addition to the environment.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%buffersz, %token = gpu.sddmm_buffer_size async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC
|
||||
```
|
||||
|
||||
The matrix arguments can also be associated with one of the following
|
||||
operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
|
||||
is NON_TRANSPOSE.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
|
||||
GPU_SparseEnvHandle:$env,
|
||||
GPU_TransposeModeAttr:$modeA,
|
||||
GPU_TransposeModeAttr:$modeB,
|
||||
GPU_SparseDnMatHandle:$dnmatA,
|
||||
GPU_SparseDnMatHandle:$dnmatB,
|
||||
GPU_SparseSpMatHandle:$spmatC);
|
||||
let results = (outs Res<Index>:$bufferSz, Optional<GPU_AsyncToken>:$asyncToken);
|
||||
|
||||
let builders = [OpBuilder<(ins
|
||||
"::mlir::Type":$bufferSz,
|
||||
"::mlir::Type":$asyncToken,
|
||||
"::mlir::ValueRange":$asyncDependencies,
|
||||
"::mlir::Value":$env,
|
||||
"::mlir::Value":$dnmatA,
|
||||
"::mlir::Value":$dnmatB,
|
||||
"::mlir::Value":$spmatC), [{
|
||||
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
|
||||
auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
|
||||
return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies,
|
||||
env, modeA, modeB, dnmatA, dnmatB, spmatC);}]>
|
||||
];
|
||||
|
||||
let assemblyFormat = [{
|
||||
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
|
||||
$env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
|
||||
let summary = "SDDMM operation";
|
||||
let description = [{
|
||||
The `gpu.sddmm` operation performs the SDDMM operation on the given sparse and
|
||||
dense matrices, and buffer. The operation expects handles returned by previous
|
||||
sparse operations to construct an environment and the operands for SDDMM. The
|
||||
buffer must have been allocated on the device.
|
||||
|
||||
If the `async` keyword is present, the op is executed asynchronously (i.e.
|
||||
it does not block until the execution has finished on the device). In
|
||||
that case, it returns a !gpu.async.token in addition to the environment.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%token = gpu.sddmm async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC, %buffer
|
||||
```
|
||||
|
||||
The matrix arguments can also be associated with one of the following
|
||||
operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
|
||||
is NON_TRANSPOSE.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
|
||||
GPU_SparseEnvHandle:$env,
|
||||
GPU_TransposeModeAttr:$modeA,
|
||||
GPU_TransposeModeAttr:$modeB,
|
||||
GPU_SparseDnMatHandle:$dnmatA,
|
||||
GPU_SparseDnMatHandle:$dnmatB,
|
||||
GPU_SparseSpMatHandle:$spmatC,
|
||||
AnyMemRef:$buffer);
|
||||
let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
|
||||
|
||||
let builders = [OpBuilder<(ins
|
||||
"::mlir::Type":$asyncToken,
|
||||
"::mlir::ValueRange":$asyncDependencies,
|
||||
"::mlir::Value":$env,
|
||||
"::mlir::Value":$dnmatA,
|
||||
"::mlir::Value":$dnmatB,
|
||||
"::mlir::Value":$spmatC,
|
||||
"::mlir::Value":$buffer), [{
|
||||
auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
|
||||
auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
|
||||
return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA,
|
||||
modeB, dnmatA, dnmatB, spmatC, buffer);}]>
|
||||
];
|
||||
|
||||
let assemblyFormat = [{
|
||||
custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
|
||||
$env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // GPU_OPS
|
||||
|
@ -257,6 +257,18 @@ protected:
|
||||
{llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
|
||||
llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType,
|
||||
llvmPointerType /* void *stream */}};
|
||||
FunctionCallBuilder SDDMMBufferSizeCallBuilder = {
|
||||
"mgpuSDDMMBufferSize",
|
||||
llvmIntPtrType,
|
||||
{llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
|
||||
llvmPointerType, llvmPointerType, llvmInt32Type,
|
||||
llvmPointerType /* void *stream */}};
|
||||
FunctionCallBuilder SDDMMCallBuilder = {
|
||||
"mgpuSDDMM",
|
||||
llvmVoidType,
|
||||
{llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
|
||||
llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType,
|
||||
llvmPointerType /* void *stream */}};
|
||||
};
|
||||
|
||||
/// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
|
||||
@ -599,6 +611,20 @@ private:
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
class ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern
|
||||
: public ConvertOpToGpuRuntimeCallPattern<gpu::SDDMMBufferSizeOp> {
|
||||
public:
|
||||
ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern(
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertOpToGpuRuntimeCallPattern<gpu::SDDMMBufferSizeOp>(
|
||||
typeConverter) {}
|
||||
|
||||
private:
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
class ConvertSpMMOpToGpuRuntimeCallPattern
|
||||
: public ConvertOpToGpuRuntimeCallPattern<gpu::SpMMOp> {
|
||||
public:
|
||||
@ -611,6 +637,18 @@ private:
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
class ConvertSDDMMOpToGpuRuntimeCallPattern
|
||||
: public ConvertOpToGpuRuntimeCallPattern<gpu::SDDMMOp> {
|
||||
public:
|
||||
ConvertSDDMMOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
|
||||
: ConvertOpToGpuRuntimeCallPattern<gpu::SDDMMOp>(typeConverter) {}
|
||||
|
||||
private:
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::SDDMMOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void GpuToLLVMConversionPass::runOnOperation() {
|
||||
@ -1245,7 +1283,8 @@ LogicalResult ConvertCreateDnVecOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
|
||||
if (!getTypeConverter()->useOpaquePointers())
|
||||
pVec = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVec);
|
||||
Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
|
||||
Type dType =
|
||||
llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
|
||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
dType.getIntOrFloatBitWidth());
|
||||
auto handle =
|
||||
@ -1281,7 +1320,8 @@ LogicalResult ConvertCreateDnMatOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
|
||||
if (!getTypeConverter()->useOpaquePointers())
|
||||
pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
|
||||
Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
|
||||
Type dType =
|
||||
llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
|
||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
dType.getIntOrFloatBitWidth());
|
||||
auto handle =
|
||||
@ -1325,8 +1365,10 @@ LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
|
||||
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
|
||||
}
|
||||
Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
|
||||
Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||
Type iType =
|
||||
llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
|
||||
Type dType =
|
||||
llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||
auto iw = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
|
||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
@ -1360,9 +1402,12 @@ LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
|
||||
pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
|
||||
}
|
||||
Type pType = llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
|
||||
Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
|
||||
Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||
Type pType =
|
||||
llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
|
||||
Type iType =
|
||||
llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
|
||||
Type dType =
|
||||
llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
|
||||
auto pw = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth());
|
||||
auto iw = rewriter.create<LLVM::ConstantOp>(
|
||||
@ -1445,9 +1490,9 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
failed(isAsyncWithOneDependency(rewriter, op)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
Type dType = getSpMatElemType(op.getSpmatA());
|
||||
auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
|
||||
auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
|
||||
Type dType = getSpMatElemType(op.getSpmatA());
|
||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
dType.getIntOrFloatBitWidth());
|
||||
auto stream = adaptor.getAsyncDependencies().front();
|
||||
@ -1461,6 +1506,29 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
|
||||
failed(isAsyncWithOneDependency(rewriter, op)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
|
||||
auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
|
||||
Type dType = getSpMatElemType(op.getSpmatC());
|
||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
dType.getIntOrFloatBitWidth());
|
||||
auto stream = adaptor.getAsyncDependencies().front();
|
||||
auto bufferSize =
|
||||
SDDMMBufferSizeCallBuilder
|
||||
.create(loc, rewriter,
|
||||
{adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(),
|
||||
adaptor.getDnmatB(), adaptor.getSpmatC(), dw, stream})
|
||||
.getResult();
|
||||
rewriter.replaceOp(op, {bufferSize, stream});
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
gpu::SpMMOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
@ -1468,11 +1536,11 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
failed(isAsyncWithOneDependency(rewriter, op)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
|
||||
auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
|
||||
Type dType = getSpMatElemType(op.getSpmatA());
|
||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
dType.getIntOrFloatBitWidth());
|
||||
auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
|
||||
auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
|
||||
auto stream = adaptor.getAsyncDependencies().front();
|
||||
Value pBuf =
|
||||
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
|
||||
@ -1494,6 +1562,31 @@ static void addOpaquePointerConversion(LLVMTypeConverter &converter) {
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
gpu::SDDMMOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
|
||||
failed(isAsyncWithOneDependency(rewriter, op)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
Type dType = getSpMatElemType(op.getSpmatC());
|
||||
auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
dType.getIntOrFloatBitWidth());
|
||||
auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
|
||||
auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
|
||||
auto stream = adaptor.getAsyncDependencies().front();
|
||||
Value pBuf =
|
||||
MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
|
||||
if (!getTypeConverter()->useOpaquePointers())
|
||||
pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
|
||||
SDDMMCallBuilder.create(loc, rewriter,
|
||||
{adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(),
|
||||
adaptor.getDnmatB(), adaptor.getSpmatC(), dw, pBuf,
|
||||
stream});
|
||||
rewriter.replaceOp(op, {stream});
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns,
|
||||
StringRef gpuBinaryAnnotation,
|
||||
@ -1526,7 +1619,9 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
|
||||
ConvertSpMVOpToGpuRuntimeCallPattern,
|
||||
ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
|
||||
ConvertSpMMOpToGpuRuntimeCallPattern>(converter);
|
||||
ConvertSpMMOpToGpuRuntimeCallPattern,
|
||||
ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
|
||||
ConvertSDDMMOpToGpuRuntimeCallPattern>(converter);
|
||||
patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
|
||||
converter, gpuBinaryAnnotation, kernelBarePtrCallConv);
|
||||
patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
|
||||
|
@ -404,3 +404,37 @@ mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t dw,
|
||||
matB, betap, matC, dtp,
|
||||
CUSPARSE_SPMM_ALG_DEFAULT, buf))
|
||||
}
|
||||
|
||||
extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
|
||||
mgpuSDDMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
|
||||
int32_t dw, CUstream /*stream*/) {
|
||||
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
|
||||
cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
|
||||
cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
|
||||
cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
|
||||
cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
|
||||
cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
|
||||
cudaDataType_t dtp = dataTp(dw);
|
||||
ALPHABETA(dw, alpha, beta)
|
||||
size_t bufferSize = 0;
|
||||
CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM_bufferSize(
|
||||
handle, modeA, modeB, &alpha, matA, matB, &beta, matC, dtp,
|
||||
CUSPARSE_SDDMM_ALG_DEFAULT, &bufferSize))
|
||||
return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
|
||||
}
|
||||
|
||||
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
|
||||
mgpuSDDMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
|
||||
int32_t dw, void *buf, CUstream /*stream*/) {
|
||||
cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
|
||||
cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
|
||||
cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
|
||||
cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
|
||||
cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
|
||||
cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
|
||||
cudaDataType_t dtp = dataTp(dw);
|
||||
ALPHABETA(dw, alpha, beta)
|
||||
CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM(handle, modeA, modeB, &alpha, matA,
|
||||
matB, &beta, matC, dtp,
|
||||
CUSPARSE_SDDMM_ALG_DEFAULT, buf))
|
||||
}
|
@ -62,6 +62,36 @@ module attributes {gpu.container_module} {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sddmm
|
||||
// CHECK: llvm.call @mgpuStreamCreate
|
||||
// CHECK: llvm.call @mgpuMemAlloc
|
||||
// CHECK: llvm.call @mgpuMemAlloc
|
||||
// CHECK: llvm.call @mgpuCreateSparseEnv
|
||||
// CHECK: llvm.call @mgpuCreateCsr
|
||||
// CHECK: llvm.call @mgpuCreateDnMat
|
||||
// CHECK: llvm.call @mgpuSDDMMBufferSize
|
||||
// CHECK: llvm.call @mgpuSDDMM
|
||||
// CHECK: llvm.call @mgpuDestroySpMat
|
||||
// CHECK: llvm.call @mgpuDestroyDnMat
|
||||
// CHECK: llvm.call @mgpuDestroySparseEnv
|
||||
// CHECK: llvm.call @mgpuStreamSynchronize
|
||||
// CHECK: llvm.call @mgpuStreamDestroy
|
||||
func.func @sddmm(%arg0: index) {
|
||||
%token0 = gpu.wait async
|
||||
%mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
|
||||
%mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
|
||||
%env, %token3 = gpu.create_sparse_env async [%token2]
|
||||
%spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
|
||||
%dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref<?xf64>
|
||||
%bufferSz, %token6 = gpu.sddmm_buffer_size async [%token5] %env, %dnmat, %dnmat, %spmat
|
||||
%token7 = gpu.sddmm async [%token6] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64>
|
||||
%token8 = gpu.destroy_sp_mat async [%token7] %spmat
|
||||
%token9 = gpu.destroy_dn_mat async [%token8] %dnmat
|
||||
%token10 = gpu.destroy_sparse_env async [%token9] %env
|
||||
gpu.wait [%token10]
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -344,16 +344,20 @@ module attributes {gpu.container_module} {
|
||||
%bufferSz2, %token10 = gpu.spmm_buffer_size async [%token9] %env, %spmat, %dnmat, %dnmat
|
||||
// CHECK: gpu.spmm async
|
||||
%token11 = gpu.spmm async [%token10] %env, %spmat, %dnmat, %dnmat, %mem2 : memref<?xf64>
|
||||
// CHECK: gpu.sddmm_buffer_size async
|
||||
%bufferSz3, %token12 = gpu.sddmm_buffer_size async [%token11] %env, %dnmat, %dnmat, %spmat
|
||||
// CHECK: gpu.sddmm async
|
||||
%token13 = gpu.sddmm async [%token12] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64>
|
||||
// CHECK: gpu.destroy_dn_mat async
|
||||
%token12 = gpu.destroy_dn_mat async [%token11] %dnmat
|
||||
%token14 = gpu.destroy_dn_mat async [%token13] %dnmat
|
||||
// CHECK: gpu.destroy_sp_mat async
|
||||
%token13 = gpu.destroy_sp_mat async [%token12] %spmat
|
||||
%token15 = gpu.destroy_sp_mat async [%token14] %spmat
|
||||
// CHECK: gpu.destroy_dn_vec async
|
||||
%token14 = gpu.destroy_dn_vec async [%token13] %dnvec
|
||||
%token16 = gpu.destroy_dn_vec async [%token15] %dnvec
|
||||
// CHECK: gpu.destroy_sparse_env async
|
||||
%token15 = gpu.destroy_sparse_env async [%token14] %env
|
||||
%token17 = gpu.destroy_sparse_env async [%token16] %env
|
||||
// CHECK: gpu.wait
|
||||
gpu.wait [%token15]
|
||||
gpu.wait [%token17]
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user