[mlir][gpu][mlir-cuda-runner] Refactor ConvertKernelFuncToCubin to be generic.

Make ConvertKernelFuncToCubin pass to be generic:

- Rename to ConvertKernelFuncToBlob.
- Allow specifying triple, target chip, target features.
- Initializing LLVM backend is supplied by a callback function.
- Lowering process from MLIR module to LLVM module is via another callback.
- Change mlir-cuda-runner to adopt the revised pass.
- Add new tests for lowering to ROCm HSA code object (HSACO).
- Tests for CUDA and ROCm are kept in separate directories.

Differential Revision: https://reviews.llvm.org/D80142
This commit is contained in:
Wen-Heng (Jack) Chung 2020-05-22 16:25:00 -05:00
parent fdaa391e3d
commit 061fb8eb2d
17 changed files with 359 additions and 275 deletions

View File

@ -31,6 +31,15 @@ endif()
# TODO: we should use a config.h file like LLVM does
add_definitions(-DMLIR_CUDA_CONVERSIONS_ENABLED=${MLIR_CUDA_CONVERSIONS_ENABLED})
# Build the ROCm conversions and run according tests if the AMDGPU backend
# is available
if ("AMDGPU" IN_LIST LLVM_TARGETS_TO_BUILD)
set(MLIR_ROCM_CONVERSIONS_ENABLED 1)
else()
set(MLIR_ROCM_CONVERSIONS_ENABLED 0)
endif()
add_definitions(-DMLIR_ROCM_CONVERSIONS_ENABLED=${MLIR_ROCM_CONVERSIONS_ENABLED})
set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner")
set(MLIR_VULKAN_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir Vulkan runner")

View File

@ -9,19 +9,33 @@
#define MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_
#include "mlir/Support/LLVM.h"
#include <functional>
#include <memory>
#include <string>
#include "llvm/IR/Module.h"
#include <vector>
namespace mlir {
class Location;
class LogicalResult;
class ModuleOp;
class Operation;
template <typename T>
class OperationPass;
namespace gpu {
class GPUModuleOp;
} // namespace gpu
namespace LLVM {
class LLVMDialect;
} // namespace LLVM
using OwnedBlob = std::unique_ptr<std::vector<char>>;
using BlobGenerator =
std::function<OwnedBlob(const std::string &, Location, StringRef)>;
using LoweringCallback =
std::function<std::unique_ptr<llvm::Module>(Operation *)>;
/// Creates a pass to convert a gpu.launch_func operation into a sequence of
/// GPU runtime calls.
///
@ -31,6 +45,34 @@ class OperationPass;
std::unique_ptr<OperationPass<ModuleOp>>
createConvertGpuLaunchFuncToGpuRuntimeCallsPass();
/// Creates a pass to convert kernel functions into GPU target object blobs.
///
/// This transformation takes the body of each function that is annotated with
/// the 'gpu.kernel' attribute, copies it to a new LLVM module, compiles the
/// module with help of the GPU backend to target object and then invokes
/// the provided blobGenerator to produce a binary blob. Such blob is then
/// attached as a string attribute to the kernel function.
///
/// Following callbacks are to be provided by user:
/// - loweringCallback : lower the module to an LLVM module.
/// - blobGenerator : build a blob executable on target GPU.
///
/// Information wrt LLVM backend are to be supplied by user:
/// - triple : target triple to be used.
/// - targetChip : mcpu to be used.
/// - features : target-specific features to be used.
///
/// Information about result attribute is to be specified by user:
/// - gpuBinaryAnnotation : the name of the attribute which contains the blob.
///
/// After the transformation, the body of the kernel function is removed (i.e.,
/// it is turned into a declaration).
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
createConvertGPUKernelToBlobPass(LoweringCallback loweringCallback,
BlobGenerator blobGenerator, StringRef triple,
StringRef targetChip, StringRef features,
StringRef gpuBinaryAnnotation);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_

View File

@ -1,50 +0,0 @@
//===- GPUToCUDAPass.h - MLIR CUDA runtime support --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
#define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
#include "mlir/Support/LLVM.h"
#include <functional>
#include <memory>
#include <string>
#include <vector>
namespace mlir {
class Location;
class ModuleOp;
template <typename T> class OperationPass;
namespace gpu {
class GPUModuleOp;
} // namespace gpu
namespace LLVM {
class LLVMDialect;
} // namespace LLVM
using OwnedCubin = std::unique_ptr<std::vector<char>>;
using CubinGenerator =
std::function<OwnedCubin(const std::string &, Location, StringRef)>;
/// Creates a pass to convert kernel functions into CUBIN blobs.
///
/// This transformation takes the body of each function that is annotated with
/// the 'nvvm.kernel' attribute, copies it to a new LLVM module, compiles the
/// module with help of the nvptx backend to PTX and then invokes the provided
/// cubinGenerator to produce a binary blob (the cubin). Such blob is then
/// attached as a string attribute named 'nvvm.cubin' to the kernel function.
/// After the transformation, the body of the kernel function is removed (i.e.,
/// it is turned into a declaration).
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_

View File

@ -16,7 +16,6 @@
#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"

View File

@ -1,7 +1,6 @@
add_subdirectory(AffineToStandard)
add_subdirectory(AVX512ToLLVM)
add_subdirectory(GPUCommon)
add_subdirectory(GPUToCUDA)
add_subdirectory(GPUToNVVM)
add_subdirectory(GPUToROCDL)
add_subdirectory(GPUToSPIRV)

View File

@ -1,9 +1,6 @@
set(SOURCES
ConvertLaunchFuncToRuntimeCalls.cpp
)
add_mlir_conversion_library(MLIRGPUtoGPURuntimeTransforms
${SOURCES}
ConvertLaunchFuncToRuntimeCalls.cpp
ConvertKernelFuncToBlob.cpp
DEPENDS
MLIRConversionPassIncGen

View File

@ -0,0 +1,168 @@
//===- ConvertKernelFuncToBlob.cpp - MLIR GPU lowering passes -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert gpu kernel functions into a
// corresponding binary blob that can be executed on a GPU. Currently
// only translates the function itself but no dependencies.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
using namespace mlir;
namespace {
/// A pass converting tagged kernel modules to a blob with target instructions.
///
/// If tagged as a kernel module, each contained function is translated to
/// user-specified IR. A user provided BlobGenerator then compiles the IR to
/// GPU binary code, which is then attached as an attribute to the function.
/// The function body is erased.
class GpuKernelToBlobPass
: public PassWrapper<GpuKernelToBlobPass, OperationPass<gpu::GPUModuleOp>> {
public:
GpuKernelToBlobPass(LoweringCallback loweringCallback,
BlobGenerator blobGenerator, StringRef triple,
StringRef targetChip, StringRef features,
StringRef gpuBinaryAnnotation)
: loweringCallback(loweringCallback), blobGenerator(blobGenerator),
triple(triple), targetChip(targetChip), features(features),
blobAnnotation(gpuBinaryAnnotation) {}
void runOnOperation() override {
gpu::GPUModuleOp module = getOperation();
// Lock access to the llvm context.
llvm::sys::SmartScopedLock<true> scopedLock(
module.getContext()
->getRegisteredDialect<LLVM::LLVMDialect>()
->getLLVMContextMutex());
// Lower the module to a llvm module.
std::unique_ptr<llvm::Module> llvmModule = loweringCallback(module);
if (!llvmModule)
return signalPassFailure();
// Translate the llvm module to a target blob and attach the result as
// attribute to the module.
if (auto blobAttr = translateGPUModuleToBinaryAnnotation(
*llvmModule, module.getLoc(), module.getName()))
module.setAttr(blobAnnotation, blobAttr);
else
signalPassFailure();
}
private:
std::string translateModuleToISA(llvm::Module &module,
llvm::TargetMachine &targetMachine);
/// Converts llvmModule to a blob with target instructions using the
/// user-provided generator. Location is used for error reporting and name is
/// forwarded to the blob generator to use in its logging mechanisms.
OwnedBlob convertModuleToBlob(llvm::Module &llvmModule, Location loc,
StringRef name);
/// Translates llvmModule to a blob with target instructions and returns the
/// result as attribute.
StringAttr translateGPUModuleToBinaryAnnotation(llvm::Module &llvmModule,
Location loc, StringRef name);
LoweringCallback loweringCallback;
BlobGenerator blobGenerator;
llvm::Triple triple;
StringRef targetChip;
StringRef features;
StringRef blobAnnotation;
};
} // anonymous namespace
std::string
GpuKernelToBlobPass::translateModuleToISA(llvm::Module &module,
llvm::TargetMachine &targetMachine) {
std::string targetISA;
{
// Clone the llvm module into a new context to enable concurrent compilation
// with multiple threads.
llvm::LLVMContext llvmContext;
auto clone = LLVM::cloneModuleIntoNewContext(&llvmContext, &module);
llvm::raw_string_ostream stream(targetISA);
llvm::buffer_ostream pstream(stream);
llvm::legacy::PassManager codegenPasses;
targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr,
llvm::CGFT_AssemblyFile);
codegenPasses.run(*clone);
}
return targetISA;
}
OwnedBlob GpuKernelToBlobPass::convertModuleToBlob(llvm::Module &llvmModule,
Location loc,
StringRef name) {
std::unique_ptr<llvm::TargetMachine> targetMachine;
{
std::string error;
const llvm::Target *target =
llvm::TargetRegistry::lookupTarget("", triple, error);
if (target == nullptr) {
emitError(loc, "cannot initialize target triple");
return {};
}
targetMachine.reset(target->createTargetMachine(triple.str(), targetChip,
features, {}, {}));
}
llvmModule.setDataLayout(targetMachine->createDataLayout());
auto targetISA = translateModuleToISA(llvmModule, *targetMachine);
return blobGenerator(targetISA, loc, name);
}
StringAttr GpuKernelToBlobPass::translateGPUModuleToBinaryAnnotation(
llvm::Module &llvmModule, Location loc, StringRef name) {
auto blob = convertModuleToBlob(llvmModule, loc, name);
if (!blob)
return {};
return StringAttr::get({blob->data(), blob->size()}, loc->getContext());
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
mlir::createConvertGPUKernelToBlobPass(LoweringCallback loweringCallback,
BlobGenerator blobGenerator,
StringRef triple, StringRef targetChip,
StringRef features,
StringRef gpuBinaryAnnotation) {
return std::make_unique<GpuKernelToBlobPass>(loweringCallback, blobGenerator,
triple, targetChip, features,
gpuBinaryAnnotation);
}

View File

@ -1,35 +0,0 @@
set(LLVM_OPTIONAL_SOURCES
ConvertKernelFuncToCubin.cpp
)
if (MLIR_CUDA_CONVERSIONS_ENABLED)
set(NVPTX_LIBS
MC
NVPTXCodeGen
NVPTXDesc
NVPTXInfo
)
add_mlir_conversion_library(MLIRGPUtoCUDATransforms
ConvertKernelFuncToCubin.cpp
DEPENDS
MLIRConversionPassIncGen
intrinsics_gen
LINK_COMPONENTS
Core
${NVPTX_LIBS}
LINK_LIBS PUBLIC
MLIRGPU
MLIRIR
MLIRLLVMIR
MLIRNVVMIR
MLIRPass
MLIRSupport
MLIRTargetNVVMIR
)
else()
add_library(MLIRGPUtoCUDATransforms INTERFACE IMPORTED GLOBAL)
endif()

View File

@ -1,165 +0,0 @@
//===- ConvertKernelFuncToCubin.cpp - MLIR GPU lowering passes ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert gpu kernel functions into a
// corresponding binary blob that can be executed on a CUDA GPU. Currently
// only translates the function itself but no dependencies.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/NVVMIR.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/Mutex.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
using namespace mlir;
namespace {
// TODO(herhut): Move to shared location.
static constexpr const char *kCubinAnnotation = "nvvm.cubin";
/// A pass converting tagged kernel modules to cubin blobs.
///
/// If tagged as a kernel module, each contained function is translated to NVVM
/// IR and further to PTX. A user provided CubinGenerator compiles the PTX to
/// GPU binary code, which is then attached as an attribute to the function. The
/// function body is erased.
class GpuKernelToCubinPass
: public PassWrapper<GpuKernelToCubinPass,
OperationPass<gpu::GPUModuleOp>> {
public:
GpuKernelToCubinPass(CubinGenerator cubinGenerator)
: cubinGenerator(cubinGenerator) {}
void runOnOperation() override {
gpu::GPUModuleOp module = getOperation();
// Lock access to the llvm context.
llvm::sys::SmartScopedLock<true> scopedLock(
module.getContext()
->getRegisteredDialect<LLVM::LLVMDialect>()
->getLLVMContextMutex());
// Make sure the NVPTX target is initialized.
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
auto llvmModule = translateModuleToNVVMIR(module);
if (!llvmModule)
return signalPassFailure();
// Translate the module to CUBIN and attach the result as attribute to the
// module.
if (auto cubinAttr = translateGPUModuleToCubinAnnotation(
*llvmModule, module.getLoc(), module.getName()))
module.setAttr(kCubinAnnotation, cubinAttr);
else
signalPassFailure();
}
private:
std::string translateModuleToPtx(llvm::Module &module,
llvm::TargetMachine &target_machine);
/// Converts llvmModule to cubin using the user-provided generator. Location
/// is used for error reporting and name is forwarded to the CUBIN generator
/// to use in its logging mechanisms.
OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, Location loc,
StringRef name);
/// Translates llvmModule to cubin and returns the result as attribute.
StringAttr translateGPUModuleToCubinAnnotation(llvm::Module &llvmModule,
Location loc, StringRef name);
CubinGenerator cubinGenerator;
};
} // anonymous namespace
std::string GpuKernelToCubinPass::translateModuleToPtx(
llvm::Module &module, llvm::TargetMachine &target_machine) {
std::string ptx;
{
// Clone the llvm module into a new context to enable concurrent compilation
// with multiple threads.
// TODO(zinenko): Reevaluate model of ownership of LLVMContext in
// LLVMDialect.
llvm::LLVMContext llvmContext;
auto clone = LLVM::cloneModuleIntoNewContext(&llvmContext, &module);
llvm::raw_string_ostream stream(ptx);
llvm::buffer_ostream pstream(stream);
llvm::legacy::PassManager codegen_passes;
target_machine.addPassesToEmitFile(codegen_passes, pstream, nullptr,
llvm::CGFT_AssemblyFile);
codegen_passes.run(*clone);
}
return ptx;
}
OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule,
Location loc,
StringRef name) {
std::unique_ptr<llvm::TargetMachine> targetMachine;
{
std::string error;
// TODO(herhut): Make triple configurable.
constexpr const char *cudaTriple = "nvptx64-nvidia-cuda";
llvm::Triple triple(cudaTriple);
const llvm::Target *target =
llvm::TargetRegistry::lookupTarget("", triple, error);
if (target == nullptr) {
emitError(loc, "cannot initialize target triple");
return {};
}
targetMachine.reset(
target->createTargetMachine(triple.str(), "sm_35", "+ptx60", {}, {}));
}
// Set the data layout of the llvm module to match what the ptx target needs.
llvmModule.setDataLayout(targetMachine->createDataLayout());
auto ptx = translateModuleToPtx(llvmModule, *targetMachine);
return cubinGenerator(ptx, loc, name);
}
StringAttr GpuKernelToCubinPass::translateGPUModuleToCubinAnnotation(
llvm::Module &llvmModule, Location loc, StringRef name) {
auto cubin = convertModuleToCubin(llvmModule, loc, name);
if (!cubin)
return {};
return StringAttr::get({cubin->data(), cubin->size()}, loc->getContext());
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
mlir::createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator) {
return std::make_unique<GpuKernelToCubinPass>(cubinGenerator);
}

View File

@ -0,0 +1,2 @@
if not config.run_rocm_tests:
config.unsupported = True

View File

@ -0,0 +1,26 @@
// RUN: mlir-opt %s --test-kernel-to-hsaco -split-input-file | FileCheck %s
// CHECK: attributes {rocdl.hsaco = "HSACO"}
gpu.module @foo {
llvm.func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">)
// CHECK: attributes {gpu.kernel}
attributes { gpu.kernel } {
llvm.return
}
}
// -----
gpu.module @bar {
// CHECK: func @kernel_a
llvm.func @kernel_a()
attributes { gpu.kernel } {
llvm.return
}
// CHECK: func @kernel_b
llvm.func @kernel_b()
attributes { gpu.kernel } {
llvm.return
}
}

View File

@ -1,3 +1,21 @@
if (MLIR_CUDA_CONVERSIONS_ENABLED)
set(NVPTX_LIBS
MC
NVPTXCodeGen
NVPTXDesc
NVPTXInfo
)
endif()
if (MLIR_ROCM_CONVERSIONS_ENABLED)
set(AMDGPU_LIBS
MC
AMDGPUCodeGen
AMDGPUDesc
AMDGPUInfo
)
endif()
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
TestAllReduceLowering.cpp
@ -5,6 +23,7 @@ add_mlir_library(MLIRTestTransforms
TestCallGraph.cpp
TestConstantFold.cpp
TestConvertGPUKernelToCubin.cpp
TestConvertGPUKernelToHsaco.cpp
TestDominance.cpp
TestLoopFusion.cpp
TestGpuMemoryPromotion.cpp
@ -31,18 +50,26 @@ add_mlir_library(MLIRTestTransforms
MLIRStandardOpsIncGen
MLIRTestVectorTransformPatternsIncGen
LINK_COMPONENTS
${AMDGPU_LIBS}
${NVPTX_LIBS}
LINK_LIBS PUBLIC
MLIRAffineOps
MLIRAnalysis
MLIREDSC
MLIRGPU
MLIRGPUtoCUDATransforms
MLIRGPUtoGPURuntimeTransforms
MLIRLinalgOps
MLIRLinalgTransforms
MLIRNVVMIR
MLIRSCF
MLIRGPU
MLIRPass
MLIRROCDLIR
MLIRStandardOpsTransforms
MLIRTargetNVVMIR
MLIRTargetROCDLIR
MLIRTestDialect
MLIRTransformUtils
MLIRVectorToSCF

View File

@ -6,26 +6,36 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/NVVMIR.h"
#include "llvm/Support/TargetSelect.h"
using namespace mlir;
#if MLIR_CUDA_CONVERSIONS_ENABLED
static OwnedCubin compilePtxToCubinForTesting(const std::string &, Location,
StringRef) {
static OwnedBlob compilePtxToCubinForTesting(const std::string &, Location,
StringRef) {
const char data[] = "CUBIN";
return std::make_unique<std::vector<char>>(data, data + sizeof(data) - 1);
}
namespace mlir {
void registerTestConvertGPUKernelToCubinPass() {
PassPipelineRegistration<>("test-kernel-to-cubin",
"Convert all kernel functions to CUDA cubin blobs",
[](OpPassManager &pm) {
pm.addPass(createConvertGPUKernelToCubinPass(
compilePtxToCubinForTesting));
});
PassPipelineRegistration<>(
"test-kernel-to-cubin",
"Convert all kernel functions to CUDA cubin blobs",
[](OpPassManager &pm) {
// Initialize LLVM NVPTX backend.
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
pm.addPass(createConvertGPUKernelToBlobPass(
translateModuleToNVVMIR, compilePtxToCubinForTesting,
"nvptx64-nvidia-cuda", "sm_35", "+ptx60", "nvvm.cubin"));
});
}
} // namespace mlir
#endif

View File

@ -0,0 +1,41 @@
//===- TestConvertGPUKernelToHsaco.cpp - Test gpu kernel hsaco lowering ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/ROCDLIR.h"
#include "llvm/Support/TargetSelect.h"
using namespace mlir;
#if MLIR_ROCM_CONVERSIONS_ENABLED
static OwnedBlob compileIsaToHsacoForTesting(const std::string &, Location,
StringRef) {
const char data[] = "HSACO";
return std::make_unique<std::vector<char>>(data, data + sizeof(data) - 1);
}
namespace mlir {
void registerTestConvertGPUKernelToHsacoPass() {
PassPipelineRegistration<>(
"test-kernel-to-hsaco",
"Convert all kernel functions to ROCm hsaco blobs",
[](OpPassManager &pm) {
// Initialize LLVM AMDGPU backend.
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUAsmPrinter();
pm.addPass(createConvertGPUKernelToBlobPass(
translateModuleToROCDLIR, compileIsaToHsacoForTesting,
"amdgcn-amd-amdhsa", "gfx900", "-code-object-v3", "rocdl.hsaco"));
});
}
} // namespace mlir
#endif

View File

@ -38,6 +38,7 @@ config.build_examples = @LLVM_BUILD_EXAMPLES@
config.run_cuda_tests = @MLIR_CUDA_CONVERSIONS_ENABLED@
config.cuda_wrapper_library_dir = "@MLIR_CUDA_WRAPPER_LIBRARY_DIR@"
config.enable_cuda_runner = @MLIR_CUDA_RUNNER_ENABLED@
config.run_rocm_tests = @MLIR_ROCM_CONVERSIONS_ENABLED@
config.vulkan_wrapper_library_dir = "@MLIR_VULKAN_WRAPPER_LIBRARY_DIR@"
config.enable_vulkan_runner = @MLIR_VULKAN_RUNNER_ENABLED@

View File

@ -15,7 +15,6 @@
#include "llvm/ADT/STLExtras.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
@ -30,6 +29,7 @@
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/NVVMIR.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/InitLLVM.h"
@ -57,8 +57,8 @@ inline void emit_cuda_error(const llvm::Twine &message, const char *buffer,
} \
}
OwnedCubin compilePtxToCubin(const std::string ptx, Location loc,
StringRef name) {
OwnedBlob compilePtxToCubin(const std::string ptx, Location loc,
StringRef name) {
char jitErrorBuffer[4096] = {0};
RETURN_ON_CUDA_ERROR(cuInit(0), "cuInit");
@ -97,7 +97,7 @@ OwnedCubin compilePtxToCubin(const std::string ptx, Location loc,
"cuLinkComplete");
char *cubinAsChar = static_cast<char *>(cubinData);
OwnedCubin result =
OwnedBlob result =
std::make_unique<std::vector<char>>(cubinAsChar, cubinAsChar + cubinSize);
// This will also destroy the cubin data.
@ -114,7 +114,9 @@ static LogicalResult runMLIRPasses(ModuleOp m) {
auto &kernelPm = pm.nest<gpu::GPUModuleOp>();
kernelPm.addPass(createStripDebugInfoPass());
kernelPm.addPass(createLowerGpuOpsToNVVMOpsPass());
kernelPm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));
kernelPm.addPass(createConvertGPUKernelToBlobPass(
translateModuleToNVVMIR, compilePtxToCubin, "nvptx64-nvidia-cuda",
"sm_35", "+ptx60", "nvvm.cubin"));
pm.addPass(createLowerToLLVMPass());
pm.addPass(createConvertGpuLaunchFuncToGpuRuntimeCallsPass());
@ -127,6 +129,13 @@ int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
// Initialize LLVM NVPTX backend.
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
mlir::initializeLLVMPasses();
return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
}

View File

@ -46,6 +46,7 @@ void registerTestLoopPermutationPass();
void registerTestCallGraphPass();
void registerTestConstantFold();
void registerTestConvertGPUKernelToCubinPass();
void registerTestConvertGPUKernelToHsacoPass();
void registerTestDominancePass();
void registerTestFunc();
void registerTestGpuMemoryPromotionPass();
@ -112,6 +113,9 @@ void registerTestPasses() {
registerTestConstantFold();
#if MLIR_CUDA_CONVERSIONS_ENABLED
registerTestConvertGPUKernelToCubinPass();
#endif
#if MLIR_ROCM_CONVERSIONS_ENABLED
registerTestConvertGPUKernelToHsacoPass();
#endif
registerTestBufferPlacementPreparationPass();
registerTestDominancePass();