[mlir] Split out Python bindings entry point into a separate file

This will allow the bindings to be built as a library and reused in out-of-tree
projects that want to provide bindings on top of MLIR bindings.

Reviewed By: stellaraccident, mikeurbach

Differential Revision: https://reviews.llvm.org/D101075
This commit is contained in:
Alex Zinenko 2021-04-22 17:32:10 +02:00
parent 54ee962e47
commit ac0a70f373
3 changed files with 147 additions and 129 deletions

View File

@ -84,6 +84,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
IRAffine.cpp
IRAttributes.cpp
IRCore.cpp
IRModule.cpp
IRTypes.cpp
PybindUtils.cpp
Pass.cpp

View File

@ -0,0 +1,146 @@
//===- IRModule.cpp - IR pybind module ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "IRModule.h"
#include "Globals.h"
#include "PybindUtils.h"
#include <vector>
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
// -----------------------------------------------------------------------------
// PyGlobals
// -----------------------------------------------------------------------------
PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
assert(!instance && "PyGlobals already constructed");
instance = this;
}
PyGlobals::~PyGlobals() { instance = nullptr; }
void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
py::gil_scoped_acquire();
if (loadedDialectModulesCache.contains(dialectNamespace))
return;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
py::object loaded;
for (std::string moduleName : localSearchPrefixes) {
moduleName.push_back('.');
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
try {
py::gil_scoped_release();
loaded = py::module::import(moduleName.c_str());
} catch (py::error_already_set &e) {
if (e.matches(PyExc_ModuleNotFoundError)) {
continue;
} else {
throw;
}
}
break;
}
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
loadedDialectModulesCache.insert(dialectNamespace);
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::gil_scoped_acquire();
py::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
dialectNamespace +
"' is already registered.");
}
found = std::move(pyClass);
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
py::object pyClass,
py::object rawOpViewClass) {
py::gil_scoped_acquire();
py::object &found = operationClassMap[operationName];
if (found) {
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
operationName +
"' is already registered.");
}
found = std::move(pyClass);
rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
}
llvm::Optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
py::gil_scoped_acquire();
loadDialectModule(dialectNamespace);
// Fast match against the class map first (common case).
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
if (foundIt->second.is_none())
return llvm::None;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
// Not found and loading did not yield a registration. Negative cache.
dialectClassMap[dialectNamespace] = py::none();
return llvm::None;
}
llvm::Optional<pybind11::object>
PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
{
py::gil_scoped_acquire();
auto foundIt = rawOpViewClassMapCache.find(operationName);
if (foundIt != rawOpViewClassMapCache.end()) {
if (foundIt->second.is_none())
return llvm::None;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
}
// Not found. Load the dialect namespace.
auto split = operationName.split('.');
llvm::StringRef dialectNamespace = split.first;
loadDialectModule(dialectNamespace);
// Attempt to find from the canonical map and cache.
{
py::gil_scoped_acquire();
auto foundIt = rawOpViewClassMap.find(operationName);
if (foundIt != rawOpViewClassMap.end()) {
if (foundIt->second.is_none())
return llvm::None;
assert(foundIt->second && "py::object is defined");
// Positive cache.
rawOpViewClassMapCache[operationName] = foundIt->second;
return foundIt->second;
} else {
// Negative cache.
rawOpViewClassMap[operationName] = py::none();
return llvm::None;
}
}
}
void PyGlobals::clearImportCache() {
py::gil_scoped_acquire();
loadedDialectModulesCache.clear();
rawOpViewClassMapCache.clear();
}

View File

@ -20,135 +20,6 @@ namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
// -----------------------------------------------------------------------------
// PyGlobals
// -----------------------------------------------------------------------------
PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
assert(!instance && "PyGlobals already constructed");
instance = this;
}
PyGlobals::~PyGlobals() { instance = nullptr; }
void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
py::gil_scoped_acquire();
if (loadedDialectModulesCache.contains(dialectNamespace))
return;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
py::object loaded;
for (std::string moduleName : localSearchPrefixes) {
moduleName.push_back('.');
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
try {
py::gil_scoped_release();
loaded = py::module::import(moduleName.c_str());
} catch (py::error_already_set &e) {
if (e.matches(PyExc_ModuleNotFoundError)) {
continue;
} else {
throw;
}
}
break;
}
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
loadedDialectModulesCache.insert(dialectNamespace);
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::gil_scoped_acquire();
py::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
dialectNamespace +
"' is already registered.");
}
found = std::move(pyClass);
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
py::object pyClass,
py::object rawOpViewClass) {
py::gil_scoped_acquire();
py::object &found = operationClassMap[operationName];
if (found) {
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
operationName +
"' is already registered.");
}
found = std::move(pyClass);
rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
}
llvm::Optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
py::gil_scoped_acquire();
loadDialectModule(dialectNamespace);
// Fast match against the class map first (common case).
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
if (foundIt->second.is_none())
return llvm::None;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
// Not found and loading did not yield a registration. Negative cache.
dialectClassMap[dialectNamespace] = py::none();
return llvm::None;
}
llvm::Optional<pybind11::object>
PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
{
py::gil_scoped_acquire();
auto foundIt = rawOpViewClassMapCache.find(operationName);
if (foundIt != rawOpViewClassMapCache.end()) {
if (foundIt->second.is_none())
return llvm::None;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
}
// Not found. Load the dialect namespace.
auto split = operationName.split('.');
llvm::StringRef dialectNamespace = split.first;
loadDialectModule(dialectNamespace);
// Attempt to find from the canonical map and cache.
{
py::gil_scoped_acquire();
auto foundIt = rawOpViewClassMap.find(operationName);
if (foundIt != rawOpViewClassMap.end()) {
if (foundIt->second.is_none())
return llvm::None;
assert(foundIt->second && "py::object is defined");
// Positive cache.
rawOpViewClassMapCache[operationName] = foundIt->second;
return foundIt->second;
} else {
// Negative cache.
rawOpViewClassMap[operationName] = py::none();
return llvm::None;
}
}
}
void PyGlobals::clearImportCache() {
py::gil_scoped_acquire();
loadedDialectModulesCache.clear();
rawOpViewClassMapCache.clear();
}
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------