Cleanup linalg integration test

This CL performs post-commit cleanups.
    It adds the ability to specify which shared libraries to load dynamically in ExecutionEngine. The linalg integration test is updated to use a shared library.
    Additional minor cleanups related to LLVM lowering of Linalg are also included.

--

PiperOrigin-RevId: 248346589
This commit is contained in:
Nicolas Vasilache 2019-05-15 09:26:27 -07:00 committed by Mehdi Amini
parent 8d5bd823b0
commit 6aa5cc8b06
15 changed files with 99 additions and 83 deletions

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
# RUN: $(dirname %s)/test_edsc %s | FileCheck %s
# RUN: %p/test_edsc %s | FileCheck %s
"""Python2 and 3 test for the MLIR EDSC Python bindings"""
import google_mlir.bindings.python.pybind as E

View File

@ -32,7 +32,7 @@
namespace llvm {
template <typename T> class Expected;
class Module;
}
} // namespace llvm
namespace mlir {
@ -61,16 +61,21 @@ public:
/// runs it on the MLIR module. If `transformer` is
/// provided, it will be called on the LLVM module during JIT-compilation and
/// can be used, e.g., for reporting or optimization.
/// If `sharedLibPaths` are provided, the underlying JIT-compilation will open
/// and link the shared libraries for symbol resolution.
static llvm::Expected<std::unique_ptr<ExecutionEngine>>
create(Module *m, PassManager *pm,
std::function<llvm::Error(llvm::Module *)> transformer = {});
std::function<llvm::Error(llvm::Module *)> transformer = {},
ArrayRef<StringRef> sharedLibPaths = {});
/// Creates an execution engine for the given module. If `transformer` is
/// provided, it will be called on the LLVM module during JIT-compilation and
/// can be used, e.g., for reporting or optimization.
/// If `sharedLibPaths` are provided, the underlying JIT-compilation will open
/// and link the shared libraries for symbol resolution.
static llvm::Expected<std::unique_ptr<ExecutionEngine>>
create(Module *m,
std::function<llvm::Error(llvm::Module *)> transformer = {});
create(Module *m, std::function<llvm::Error(llvm::Module *)> transformer = {},
ArrayRef<StringRef> sharedLibPaths = {});
/// Looks up a packed-argument function with the given name and returns a
/// pointer to it. Propagates errors in case of failure.

View File

@ -28,6 +28,7 @@
namespace llvm {
class IntegerType;
class LLVMContext;
class Module;
class Type;
}
@ -51,6 +52,9 @@ public:
/// to each of the MLIR types converted with `convertType`.
Type packFunctionResults(ArrayRef<Type> types);
/// Returns the LLVM context.
llvm::LLVMContext &getLLVMContext();
protected:
/// Create a set of converters that live in the pass object by passing them a
/// reference to the LLVM IR dialect. Store the module associated with the

View File

@ -2,3 +2,7 @@ set(LLVM_TARGET_DEFINITIONS LinalgOps.td)
mlir_tablegen(LinalgOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRLinalgOpsIncGen)
set(LLVM_TARGET_DEFINITIONS LinalgLibraryOps.td)
mlir_tablegen(LinalgLibraryOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgLibraryOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRLinalgLibraryOpsIncGen)

View File

@ -131,7 +131,8 @@ public:
// Setup the object layer to use our custom memory manager in order to
// resolve calls to library functions present in the process.
OrcJIT(llvm::orc::JITTargetMachineBuilder machineBuilder,
llvm::DataLayout layout, IRTransformer transform)
llvm::DataLayout layout, IRTransformer transform,
ArrayRef<StringRef> sharedLibPaths)
: irTransformer(transform),
objectLayer(
session,
@ -144,11 +145,12 @@ public:
threadSafeCtx(llvm::make_unique<llvm::LLVMContext>()) {
session.getMainJITDylib().setGenerator(
SearchGenerator(layout.getGlobalPrefix()));
loadLibraries(sharedLibPaths);
}
// Create a JIT engine for the current host.
static Expected<std::unique_ptr<OrcJIT>>
createDefault(IRTransformer transformer) {
createDefault(IRTransformer transformer, ArrayRef<StringRef> sharedLibPaths) {
auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!machineBuilder)
return machineBuilder.takeError();
@ -158,7 +160,8 @@ public:
return dataLayout.takeError();
return llvm::make_unique<OrcJIT>(std::move(*machineBuilder),
std::move(*dataLayout), transformer);
std::move(*dataLayout), transformer,
sharedLibPaths);
}
// Add an LLVM module to the main library managed by the JIT engine.
@ -190,6 +193,10 @@ private:
};
}
// Iterate over shareLibPaths and load the corresponding libraries for symbol
// resolution.
void loadLibraries(ArrayRef<StringRef> sharedLibPaths);
IRTransformer irTransformer;
llvm::orc::ExecutionSession session;
llvm::orc::RTDyldObjectLinkingLayer objectLayer;
@ -202,6 +209,29 @@ private:
} // end namespace impl
} // namespace mlir
void mlir::impl::OrcJIT::loadLibraries(ArrayRef<StringRef> sharedLibPaths) {
for (auto libPath : sharedLibPaths) {
auto mb = llvm::MemoryBuffer::getFile(libPath);
if (!mb) {
llvm::errs() << "Could not create MemoryBuffer for: " << libPath << " "
<< mb.getError().message() << "\n";
continue;
}
auto &JD = session.createJITDylib(libPath);
auto loaded = llvm::orc::DynamicLibrarySearchGenerator::Load(
libPath.data(), dataLayout.getGlobalPrefix());
if (!loaded) {
llvm::errs() << "Could not load: " << libPath << " " << loaded.takeError()
<< "\n";
continue;
}
JD.setGenerator(loaded.get());
auto res = objectLayer.add(JD, std::move(mb.get()));
if (res)
llvm::errs() << "Could not add: " << libPath << " " << res << "\n";
}
}
// Wrap a string into an llvm::StringError.
static inline Error make_string_error(const llvm::Twine &message) {
return llvm::make_error<llvm::StringError>(message.str(),
@ -318,11 +348,12 @@ void packFunctionArguments(llvm::Module *module) {
// Out of line for PIMPL unique_ptr.
ExecutionEngine::~ExecutionEngine() = default;
Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
Module *m, PassManager *pm,
std::function<llvm::Error(llvm::Module *)> transformer) {
Expected<std::unique_ptr<ExecutionEngine>>
ExecutionEngine::create(Module *m, PassManager *pm,
std::function<llvm::Error(llvm::Module *)> transformer,
ArrayRef<StringRef> sharedLibPaths) {
auto engine = llvm::make_unique<ExecutionEngine>();
auto expectedJIT = impl::OrcJIT::createDefault(transformer);
auto expectedJIT = impl::OrcJIT::createDefault(transformer, sharedLibPaths);
if (!expectedJIT)
return expectedJIT.takeError();
@ -345,12 +376,14 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
return std::move(engine);
}
Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
Module *m, std::function<llvm::Error(llvm::Module *)> transformer) {
Expected<std::unique_ptr<ExecutionEngine>>
ExecutionEngine::create(Module *m,
std::function<llvm::Error(llvm::Module *)> transformer,
ArrayRef<StringRef> sharedLibPaths) {
// Construct and run the default MLIR pipeline.
PassManager manager;
getDefaultPasses(manager, {});
return create(m, &manager, transformer);
return create(m, &manager, transformer, sharedLibPaths);
}
Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {

View File

@ -40,6 +40,11 @@
using namespace mlir;
// Get the LLVM context.
llvm::LLVMContext &LLVMLowering::getLLVMContext() {
return module->getContext();
}
// Wrap the given LLVM IR type into an LLVM IR dialect type.
Type LLVMLowering::wrap(llvm::Type *llvmType) {
return LLVM::LLVMType::get(llvmDialect->getContext(), llvmType);

View File

@ -9,4 +9,4 @@ add_llvm_library(MLIRLinalg
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Linalg
)
add_dependencies(MLIRLinalg MLIRLinalgOpsIncGen)
add_dependencies(MLIRLinalg MLIRLinalgOpsIncGen MLIRLinalgLibraryOpsIncGen)

View File

@ -166,18 +166,15 @@ public:
auto indexType = IndexType::get(op->getContext());
auto voidPtrTy = LLVM::LLVMType::get(
op->getContext(),
lowering.convertType(IntegerType::get(8, op->getContext()))
.cast<LLVM::LLVMType>()
.getUnderlyingType()
->getPointerTo());
llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
auto int64Ty = lowering.convertType(operands[0]->getType());
// Insert the `malloc` declaration if it is not already present.
Function *mallocFunc =
op->getFunction()->getModule()->getNamedFunction("malloc");
auto *module = op->getFunction()->getModule();
Function *mallocFunc = module->getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
op->getFunction()->getModule()->getFunctions().push_back(mallocFunc);
module->getFunctions().push_back(mallocFunc);
}
// Get MLIR types for injecting element pointer.
@ -225,17 +222,14 @@ public:
FuncBuilder &rewriter) const override {
auto voidPtrTy = LLVM::LLVMType::get(
op->getContext(),
lowering.convertType(IntegerType::get(8, op->getContext()))
.cast<LLVM::LLVMType>()
.getUnderlyingType()
->getPointerTo());
llvm::IntegerType::get(lowering.getLLVMContext(), 8)->getPointerTo());
// Insert the `free` declaration if it is not already present.
Function *freeFunc =
op->getFunction()->getModule()->getNamedFunction("free");
auto *module = op->getFunction()->getModule();
Function *freeFunc = module->getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
module->getFunctions().push_back(freeFunc);
}
// Get MLIR types for extracting element pointer.

View File

@ -4,6 +4,10 @@ llvm_canonicalize_cmake_booleans(
LLVM_BUILD_EXAMPLES
)
# Passed to lit.site.cfg.py.in to set up the path where to find the libraries
# for linalg integration tests.
set(MLIR_LINALG_INTEGRATION_TEST_LIB_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py
@ -20,14 +24,13 @@ configure_lit_site_cfg(
set(MLIR_TEST_DEPENDS
FileCheck count not
MLIRUnitTests
mlir-blas-cpu-runner
mlir-cpu-runner
mlir-opt
mlir-tblgen
mlir-translate
sdot
)
if(LLVM_BUILD_EXAMPLES)
list(APPEND MLIR_TEST_DEPENDS
linalg1-opt

View File

@ -62,6 +62,7 @@ tools.extend([
ToolSubst('toy-ch3', unresolved='ignore'),
ToolSubst('toy-ch4', unresolved='ignore'),
ToolSubst('toy-ch5', unresolved='ignore'),
ToolSubst('%linalg_test_lib_dir', config.linalg_test_lib_dir, unresolved='ignore'),
])
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -30,6 +30,7 @@ config.host_arch = "@HOST_ARCH@"
config.mlir_src_root = "@MLIR_SOURCE_DIR@"
config.mlir_obj_root = "@MLIR_BINARY_DIR@"
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
config.linalg_test_lib_dir = "@MLIR_LINALG_INTEGRATION_TEST_LIB_DIR@"
config.build_examples = @LLVM_BUILD_EXAMPLES@
# Support substitution of the tools_dir with user parameters. This is

View File

@ -1,27 +1,2 @@
set(LIBS
MLIRAffineOps
MLIRAnalysis
MLIREDSC
MLIRExecutionEngine
MLIRIR
MLIRLLVMIR
MLIRParser
MLIRTargetLLVMIR
MLIRTransforms
MLIRSupport
MLIRCPURunnerLib
LLVMCore
LLVMSupport
)
add_executable(mlir-blas-cpu-runner
mlir-blas-cpu-runner.cpp
)
llvm_update_compile_flags(mlir-blas-cpu-runner)
whole_archive_link(mlir-blas-cpu-runner
MLIRLLVMIR
MLIRStandardOps
MLIRTargetLLVMIR
MLIRTransforms
MLIRTranslation
)
target_link_libraries(mlir-blas-cpu-runner MLIRIR ${LIBS})
add_llvm_library(sdot SHARED sdot.cpp)

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-blas-cpu-runner -e entry1 -entry-point-result=f32 | FileCheck %s
// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e entry1 -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libsdot.so | FileCheck %s
func @cblas_sdot(!llvm.i64, !llvm<"float*">, !llvm.i64, !llvm<"float*">, !llvm.i64) -> !llvm.float

View File

@ -1,4 +1,4 @@
//===- mlir-blas-cpu-runner.cpp - MLIR CPU Execution Driver + Blas Support ===//
//===- sdot.cpp - Simple sdot Blas Function -------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
@ -15,16 +15,10 @@
// limitations under the License.
// =============================================================================
//
// Main entry point.
// Sdot implementation.
//
//===----------------------------------------------------------------------===//
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/DynamicLibrary.h"
#ifdef WITH_LAPACK
#include "lapack/cblas.h"
#else
extern "C" float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY) {
float res = 0.0f;
@ -32,16 +26,3 @@ extern "C" float cblas_sdot(const int N, const float *X, const int incX,
res += X[i * incX] * Y[i * incY];
return res;
}
#endif
extern int run(int argc, char **argv);
void addSymbols() {
using llvm::sys::DynamicLibrary;
DynamicLibrary::AddSymbol("cblas_sdot", (void *)(&cblas_sdot));
}
int main(int argc, char **argv) {
addSymbols();
return run(argc, argv);
}

View File

@ -79,6 +79,12 @@ static llvm::cl::opt<bool> optO2("O2", llvm::cl::desc("Run opt O2 passes"),
static llvm::cl::opt<bool> optO3("O3", llvm::cl::desc("Run opt O3 passes"),
llvm::cl::cat(optFlags));
static llvm::cl::OptionCategory clOptionsCategory("linking options");
static llvm::cl::list<std::string>
clSharedLibs("shared-libs", llvm::cl::desc("Libraries to link dynamically"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::cat(clOptionsCategory));
static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
MLIRContext *context) {
// Set up the input file.
@ -156,7 +162,9 @@ static Error compileAndExecuteFunctionWithMemRefs(
if (!expectedArguments)
return expectedArguments.takeError();
auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
auto expectedEngine =
mlir::ExecutionEngine::create(module, transformer, libs);
if (!expectedEngine)
return expectedEngine.takeError();
@ -193,7 +201,9 @@ static Error compileAndExecuteSingleFloatReturnFunction(
if (llvmTy != llvmTy->getFloatTy(llvmTy->getContext()))
return make_string_error("only single llvm.f32 function result supported");
auto expectedEngine = mlir::ExecutionEngine::create(module, transformer);
SmallVector<StringRef, 4> libs(clSharedLibs.begin(), clSharedLibs.end());
auto expectedEngine =
mlir::ExecutionEngine::create(module, transformer, libs);
if (!expectedEngine)
return expectedEngine.takeError();