Make GPU to CUDA transformations independent of CUDA runtime.

The actual transformation from PTX source to a CUDA binary is now factored out,
enabling compiling and testing the transformations independently of a CUDA
runtime.

MLIR has still to be built with NVPTX target support for the conversions to be
built and tested.

PiperOrigin-RevId: 255167139
This commit is contained in:
Stephan Herhut 2019-06-26 05:16:11 -07:00 committed by A. Unique TensorFlower
parent a4c3a6455c
commit c72c6c3907
9 changed files with 132 additions and 116 deletions

View File

@ -39,6 +39,14 @@ function(whole_archive_link target)
set_target_properties(${target} PROPERTIES LINK_FLAGS ${link_flags})
endfunction(whole_archive_link)
# Build the CUDA conversions and run according tests if the NVPTX backend
# is available
if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
set(MLIR_CUDA_CONVERSIONS_ENABLED 1)
else()
set(MLIR_CUDA_CONVERSIONS_ENABLED 0)
endif()
include_directories( "include")
include_directories( ${MLIR_INCLUDE_DIR})

View File

@ -17,20 +17,31 @@
#ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
#define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
#include <functional>
#include <memory>
#include <string>
#include <vector>
namespace mlir {
class ModulePassBase;
class Function;
using OwnedCubin = std::unique_ptr<std::vector<char>>;
using CubinGenerator =
std::function<OwnedCubin(const std::string &, Function &)>;
/// Creates a pass to convert kernel functions into CUBIN blobs.
///
/// This transformation takes the body of each function that is annotated with
/// the 'nvvm.kernel' attribute, copies it to a new LLVM module, compiles the
/// module with help of the nvptx backend and the CUDA driver into a CUDA
/// binary blob (cubin) and attaches such blob as a string attribute named
/// 'nvvm.cubin' to the kernel function.
/// module with help of the nvptx backend to PTX and then invokes the provided
/// cubinGenerator to produce a binary blob (the cubin). Such blob is then
/// attached as a string attribute named 'nvvm.cubin' to the kernel function.
/// After the transformation, the body of the kernel function is removed (i.e.,
/// it is turned into a declaration).
ModulePassBase *createConvertGPUKernelToCubinPass();
ModulePassBase *
createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator);
/// Creates a pass to convert a gpu.launch_func operation into a sequence of
/// CUDA calls.

View File

@ -1,22 +1,16 @@
# The CUDA conversions are only available if we have a working CUDA install.
include(CheckLanguage)
check_language(CUDA)
if(MLIR_CUDA_CONVERSIONS_ENABLED)
llvm_map_components_to_libnames(nvptx "NVPTX")
if(CMAKE_CUDA_COMPILER)
# Enable the CUDA language so that CMake finds the headers and library for us.
enable_language(CUDA)
add_llvm_library(MLIRGPUtoCUDATransforms
ConvertKernelFuncToCubin.cpp
ConvertLaunchFuncToCudaCalls.cpp
add_llvm_library(MLIRGPUtoCUDATransforms
ConvertKernelFuncToCubin.cpp
ConvertLaunchFuncToCudaCalls.cpp
)
target_include_directories(MLIRGPUtoCUDATransforms
PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
)
target_link_libraries(MLIRGPUtoCUDATransforms
MLIRGPU
MLIRLLVMIR
MLIRNVVMIR
MLIRPass
${CUDART_LIBRARY}
target_link_libraries(MLIRGPUtoCUDATransforms
MLIRGPU
MLIRLLVMIR
MLIRNVVMIR
MLIRPass
MLIRTargetNVVMIR
${nvptx}
)
endif()

View File

@ -43,31 +43,52 @@
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "cuda.h"
using namespace mlir;
namespace mlir {
namespace {
// TODO(herhut): Move to shared location.
constexpr const char *kCubinAnnotation = "nvvm.cubin";
static constexpr const char *kCubinAnnotation = "nvvm.cubin";
inline void emit_cuda_error(const llvm::Twine &message, CUresult error,
Function &function) {
function.emitError(
message.concat(" failed with error code").concat(llvm::Twine{error}));
}
/// A pass converting tagged kernel functions to cubin blobs.
class GpuKernelToCubinPass : public ModulePass<GpuKernelToCubinPass> {
public:
GpuKernelToCubinPass(
CubinGenerator cubinGenerator = compilePtxToCubinForTesting)
: cubinGenerator(cubinGenerator) {}
#define RETURN_ON_CUDA_ERROR(expr, msg) \
do { \
auto _cuda_error = (expr); \
if (_cuda_error != CUDA_SUCCESS) { \
emit_cuda_error(msg, _cuda_error, function); \
return {}; \
} \
} while (0)
// Run the dialect converter on the module.
void runOnModule() override {
// Make sure the NVPTX target is initialized.
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
std::string translateModuleToPtx(llvm::Module &module,
llvm::TargetMachine &target_machine) {
for (auto &function : getModule()) {
if (!gpu::GPUDialect::isKernel(&function) || function.isExternal()) {
continue;
}
if (failed(translateGpuKernelToCubinAnnotation(function)))
signalPassFailure();
}
}
private:
static OwnedCubin compilePtxToCubinForTesting(const std::string &ptx,
Function &function);
std::string translateModuleToPtx(llvm::Module &module,
llvm::TargetMachine &target_machine);
OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, Function &function);
LogicalResult translateGpuKernelToCubinAnnotation(Function &function);
CubinGenerator cubinGenerator;
};
} // anonymous namespace
std::string GpuKernelToCubinPass::translateModuleToPtx(
llvm::Module &module, llvm::TargetMachine &target_machine) {
std::string ptx;
{
llvm::raw_string_ostream stream(ptx);
@ -81,51 +102,15 @@ std::string translateModuleToPtx(llvm::Module &module,
return ptx;
}
using OwnedCubin = std::unique_ptr<std::vector<char>>;
llvm::Optional<OwnedCubin> compilePtxToCubin(std::string &ptx,
Function &function) {
RETURN_ON_CUDA_ERROR(cuInit(0), "cuInit");
// Linking requires a device context.
// TODO(herhut): Figure out why context is required and what it is used for.
CUdevice device;
RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0), "cuDeviceGet");
CUcontext context;
RETURN_ON_CUDA_ERROR(cuCtxCreate(&context, 0, device), "cuCtxCreate");
CUlinkState linkState;
RETURN_ON_CUDA_ERROR(cuLinkCreate(0, /* number of jit options */
nullptr, /* jit options */
nullptr, /* jit option values */
&linkState),
"cuLinkCreate");
RETURN_ON_CUDA_ERROR(
cuLinkAddData(linkState, CUjitInputType::CU_JIT_INPUT_PTX,
const_cast<void *>(static_cast<const void *>(ptx.c_str())),
ptx.length(), function.getName().c_str(), /* kernel name */
0, /* number of jit options */
nullptr, /* jit options */
nullptr /* jit option values */
),
"cuLinkAddData");
void *cubinData;
size_t cubinSize;
RETURN_ON_CUDA_ERROR(cuLinkComplete(linkState, &cubinData, &cubinSize),
"cuLinkComplete");
char *cubinAsChar = static_cast<char *>(cubinData);
OwnedCubin result = llvm::make_unique<std::vector<char>>(
cubinAsChar, cubinAsChar + cubinSize);
// This will also destroy the cubin data.
RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState), "cuLinkDestroy");
return result;
OwnedCubin
GpuKernelToCubinPass::compilePtxToCubinForTesting(const std::string &ptx,
Function &function) {
const char data[] = "CUBIN";
return llvm::make_unique<std::vector<char>>(data, data + sizeof(data) - 1);
}
llvm::Optional<OwnedCubin> convertModuleToCubin(llvm::Module &llvmModule,
Function &function) {
OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule,
Function &function) {
std::unique_ptr<llvm::TargetMachine> targetMachine;
{
std::string error;
@ -147,10 +132,11 @@ llvm::Optional<OwnedCubin> convertModuleToCubin(llvm::Module &llvmModule,
auto ptx = translateModuleToPtx(llvmModule, *targetMachine);
return compilePtxToCubin(ptx, function);
return cubinGenerator(ptx, function);
}
LogicalResult translateGpuKernelToCubinAnnotation(Function &function) {
LogicalResult
GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) {
Builder builder(function.getContext());
std::unique_ptr<Module> module(builder.createModule());
@ -165,8 +151,7 @@ LogicalResult translateGpuKernelToCubinAnnotation(Function &function) {
return function.emitError("Translation to CUDA binary failed.");
function.setAttr(kCubinAnnotation,
builder.getStringAttr(
{cubin.getValue()->data(), cubin.getValue()->size()}));
builder.getStringAttr({cubin->data(), cubin->size()}));
// Remove the body of the kernel function now that it has been translated.
// The main reason to do this is so that the resulting module no longer
@ -177,34 +162,11 @@ LogicalResult translateGpuKernelToCubinAnnotation(Function &function) {
return success();
}
} // anonymous namespace
/// A pass converting tagged kernel functions to cubin blobs.
class GpuKernelToCubinPass : public ModulePass<GpuKernelToCubinPass> {
public:
// Run the dialect converter on the module.
void runOnModule() override {
// Make sure the NVPTX target is initialized.
LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter();
for (auto &function : getModule()) {
if (!gpu::GPUDialect::isKernel(&function) || function.isExternal()) {
continue;
}
if (failed(translateGpuKernelToCubinAnnotation(function)))
signalPassFailure();
}
}
};
ModulePassBase *createConvertGPUKernelToCubinPass() {
return new GpuKernelToCubinPass();
ModulePassBase *
mlir::createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator) {
return new GpuKernelToCubinPass(cubinGenerator);
}
static PassRegistration<GpuKernelToCubinPass>
pass("kernel-to-cubin", "Convert all kernel functions to CUDA cubin blobs");
} // namespace mlir
pass("test-kernel-to-cubin",
"Convert all kernel functions to CUDA cubin blobs");

View File

@ -0,0 +1,2 @@
if not config.run_cuda_tests:
config.unsupported = True

View File

@ -0,0 +1,25 @@
// RUN: mlir-opt %s --launch-func-to-cuda | FileCheck %s
func @cubin_getter() -> !llvm<"i8*">
func @kernel(!llvm.float, !llvm<"float*">)
attributes { gpu.kernel, nvvm.cubingetter = @cubin_getter }
func @foo() {
%0 = "op"() : () -> (!llvm.float)
%1 = "op"() : () -> (!llvm<"float*">)
%cst = constant 8 : index
// CHECK: %5 = llvm.alloca %4 x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
// CHECK: %6 = llvm.call @mcuModuleLoad(%5, %3) : (!llvm<"i8**">, !llvm<"i8*">) -> !llvm.i32
// CHECK: %32 = llvm.alloca %31 x !llvm<"i8*"> : (!llvm.i32) -> !llvm<"i8**">
// CHECK: %33 = llvm.call @mcuModuleGetFunction(%32, %7, %9) : (!llvm<"i8**">, !llvm<"i8*">, !llvm<"i8*">) -> !llvm.i32
// CHECK: %34 = llvm.call @mcuGetStreamHelper() : () -> !llvm<"i8*">
// CHECK: %48 = llvm.call @mcuLaunchKernel(%35, %c8, %c8, %c8, %c8, %c8, %c8, %2, %34, %38, %47) : (!llvm<"i8*">, index, index, index, index, index, index, !llvm.i32, !llvm<"i8*">, !llvm<"i8**">, !llvm<"i8**">) -> !llvm.i32
// CHECK: %49 = llvm.call @mcuStreamSynchronize(%34) : (!llvm<"i8*">) -> !llvm.i32
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernel }
: (index, index, index, index, index, index, !llvm.float, !llvm<"float*">) -> ()
return
}

View File

@ -0,0 +1,8 @@
// RUN: mlir-opt %s --test-kernel-to-cubin | FileCheck %s
func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">)
// CHECK: attributes {gpu.kernel, nvvm.cubin = "CUBIN"}
attributes { gpu.kernel } {
// CHECK-NOT: llvm.return
llvm.return
}

View File

@ -32,6 +32,7 @@ config.mlir_obj_root = "@MLIR_BINARY_DIR@"
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
config.linalg_test_lib_dir = "@MLIR_LINALG_INTEGRATION_TEST_LIB_DIR@"
config.build_examples = @LLVM_BUILD_EXAMPLES@
config.run_cuda_tests = @MLIR_CUDA_CONVERSIONS_ENABLED@
# Support substitution of the tools_dir with user parameters. This is
# used when we can't determine the tool dir at configuration time.

View File

@ -40,6 +40,11 @@ set(LIBS
MLIRSupport
MLIRVectorOps
)
if(MLIR_CUDA_CONVERSIONS_ENABLED)
list(APPEND LIBS
MLIRGPUtoCUDATransforms
)
endif()
add_llvm_executable(mlir-opt
mlir-opt.cpp
)