mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-05-14 09:56:33 +00:00

* Extends Context/Operation interning to cover Module as well. * Implements Module.context, Attribute.context, Type.context, and Location.context back-references (facilitated testing and also on the TODO list). * Adds method to create an empty Module. * Discovered missing in npcomp. Differential Revision: https://reviews.llvm.org/D89294
1864 lines
66 KiB
C++
1864 lines
66 KiB
C++
//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "IRModules.h"
|
|
#include "PybindUtils.h"
|
|
|
|
#include "mlir-c/Bindings/Python/Interop.h"
|
|
#include "mlir-c/Registration.h"
|
|
#include "mlir-c/StandardAttributes.h"
|
|
#include "mlir-c/StandardTypes.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include <pybind11/stl.h>
|
|
|
|
namespace py = pybind11;
|
|
using namespace mlir;
|
|
using namespace mlir::python;
|
|
|
|
using llvm::SmallVector;
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Docstrings (trivial, non-duplicated docstrings are included inline).
|
|
//------------------------------------------------------------------------------
|
|
|
|
static const char kContextCreateOperationDocstring[] =
|
|
R"(Creates a new operation.
|
|
|
|
Args:
|
|
name: Operation name (e.g. "dialect.operation").
|
|
location: A Location object.
|
|
results: Sequence of Type representing op result types.
|
|
attributes: Dict of str:Attribute.
|
|
successors: List of Block for the operation's successors.
|
|
regions: Number of regions to create.
|
|
|
|
Returns:
|
|
A new "detached" Operation object. Detached operations can be added
|
|
to blocks, which causes them to become "attached."
|
|
)";
|
|
|
|
static const char kContextParseDocstring[] =
|
|
R"(Parses a module's assembly format from a string.
|
|
|
|
Returns a new MlirModule or raises a ValueError if the parsing fails.
|
|
|
|
See also: https://mlir.llvm.org/docs/LangRef/
|
|
)";
|
|
|
|
static const char kContextParseTypeDocstring[] =
|
|
R"(Parses the assembly form of a type.
|
|
|
|
Returns a Type object or raises a ValueError if the type cannot be parsed.
|
|
|
|
See also: https://mlir.llvm.org/docs/LangRef/#type-system
|
|
)";
|
|
|
|
static const char kContextGetUnknownLocationDocstring[] =
|
|
R"(Gets a Location representing an unknown location)";
|
|
|
|
static const char kContextGetFileLocationDocstring[] =
|
|
R"(Gets a Location representing a file, line and column)";
|
|
|
|
static const char kOperationStrDunderDocstring[] =
|
|
R"(Prints the assembly form of the operation with default options.
|
|
|
|
If more advanced control over the assembly formatting or I/O options is needed,
|
|
use the dedicated print method, which supports keyword arguments to customize
|
|
behavior.
|
|
)";
|
|
|
|
static const char kTypeStrDunderDocstring[] =
|
|
R"(Prints the assembly form of the type.)";
|
|
|
|
static const char kDumpDocstring[] =
|
|
R"(Dumps a debug representation of the object to stderr.)";
|
|
|
|
static const char kAppendBlockDocstring[] =
|
|
R"(Appends a new block, with argument types as positional args.
|
|
|
|
Returns:
|
|
The created block.
|
|
)";
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Conversion utilities.
|
|
//------------------------------------------------------------------------------
|
|
|
|
namespace {
|
|
|
|
/// Accumulates into a python string from a method that accepts an
|
|
/// MlirStringCallback.
|
|
struct PyPrintAccumulator {
|
|
py::list parts;
|
|
|
|
void *getUserData() { return this; }
|
|
|
|
MlirStringCallback getCallback() {
|
|
return [](const char *part, intptr_t size, void *userData) {
|
|
PyPrintAccumulator *printAccum =
|
|
static_cast<PyPrintAccumulator *>(userData);
|
|
py::str pyPart(part, size); // Decodes as UTF-8 by default.
|
|
printAccum->parts.append(std::move(pyPart));
|
|
};
|
|
}
|
|
|
|
py::str join() {
|
|
py::str delim("", 0);
|
|
return delim.attr("join")(parts);
|
|
}
|
|
};
|
|
|
|
/// Accumulates into a python string from a method that is expected to make
|
|
/// one (no more, no less) call to the callback (asserts internally on
|
|
/// violation).
|
|
struct PySinglePartStringAccumulator {
|
|
void *getUserData() { return this; }
|
|
|
|
MlirStringCallback getCallback() {
|
|
return [](const char *part, intptr_t size, void *userData) {
|
|
PySinglePartStringAccumulator *accum =
|
|
static_cast<PySinglePartStringAccumulator *>(userData);
|
|
assert(!accum->invoked &&
|
|
"PySinglePartStringAccumulator called back multiple times");
|
|
accum->invoked = true;
|
|
accum->value = py::str(part, size);
|
|
};
|
|
}
|
|
|
|
py::str takeValue() {
|
|
assert(invoked && "PySinglePartStringAccumulator not called back");
|
|
return std::move(value);
|
|
}
|
|
|
|
private:
|
|
py::str value;
|
|
bool invoked = false;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Type-checking utilities.
|
|
//------------------------------------------------------------------------------
|
|
|
|
namespace {
|
|
|
|
/// Checks whether the given type is an integer or float type.
|
|
int mlirTypeIsAIntegerOrFloat(MlirType type) {
|
|
return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
|
|
mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Collections.
|
|
//------------------------------------------------------------------------------
|
|
|
|
namespace {
|
|
|
|
class PyRegionIterator {
|
|
public:
|
|
PyRegionIterator(PyOperationRef operation)
|
|
: operation(std::move(operation)) {}
|
|
|
|
PyRegionIterator &dunderIter() { return *this; }
|
|
|
|
PyRegion dunderNext() {
|
|
operation->checkValid();
|
|
if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
|
|
throw py::stop_iteration();
|
|
}
|
|
MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
|
|
return PyRegion(operation, region);
|
|
}
|
|
|
|
static void bind(py::module &m) {
|
|
py::class_<PyRegionIterator>(m, "RegionIterator")
|
|
.def("__iter__", &PyRegionIterator::dunderIter)
|
|
.def("__next__", &PyRegionIterator::dunderNext);
|
|
}
|
|
|
|
private:
|
|
PyOperationRef operation;
|
|
int nextIndex = 0;
|
|
};
|
|
|
|
/// Regions of an op are fixed length and indexed numerically so are represented
|
|
/// with a sequence-like container.
|
|
class PyRegionList {
|
|
public:
|
|
PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
|
|
|
|
intptr_t dunderLen() {
|
|
operation->checkValid();
|
|
return mlirOperationGetNumRegions(operation->get());
|
|
}
|
|
|
|
PyRegion dunderGetItem(intptr_t index) {
|
|
// dunderLen checks validity.
|
|
if (index < 0 || index >= dunderLen()) {
|
|
throw SetPyError(PyExc_IndexError,
|
|
"attempt to access out of bounds region");
|
|
}
|
|
MlirRegion region = mlirOperationGetRegion(operation->get(), index);
|
|
return PyRegion(operation, region);
|
|
}
|
|
|
|
static void bind(py::module &m) {
|
|
py::class_<PyRegionList>(m, "ReqionSequence")
|
|
.def("__len__", &PyRegionList::dunderLen)
|
|
.def("__getitem__", &PyRegionList::dunderGetItem);
|
|
}
|
|
|
|
private:
|
|
PyOperationRef operation;
|
|
};
|
|
|
|
class PyBlockIterator {
|
|
public:
|
|
PyBlockIterator(PyOperationRef operation, MlirBlock next)
|
|
: operation(std::move(operation)), next(next) {}
|
|
|
|
PyBlockIterator &dunderIter() { return *this; }
|
|
|
|
PyBlock dunderNext() {
|
|
operation->checkValid();
|
|
if (mlirBlockIsNull(next)) {
|
|
throw py::stop_iteration();
|
|
}
|
|
|
|
PyBlock returnBlock(operation, next);
|
|
next = mlirBlockGetNextInRegion(next);
|
|
return returnBlock;
|
|
}
|
|
|
|
static void bind(py::module &m) {
|
|
py::class_<PyBlockIterator>(m, "BlockIterator")
|
|
.def("__iter__", &PyBlockIterator::dunderIter)
|
|
.def("__next__", &PyBlockIterator::dunderNext);
|
|
}
|
|
|
|
private:
|
|
PyOperationRef operation;
|
|
MlirBlock next;
|
|
};
|
|
|
|
/// Blocks are exposed by the C-API as a forward-only linked list. In Python,
|
|
/// we present them as a more full-featured list-like container but optimzie
|
|
/// it for forward iteration. Blocks are always owned by a region.
|
|
class PyBlockList {
|
|
public:
|
|
PyBlockList(PyOperationRef operation, MlirRegion region)
|
|
: operation(std::move(operation)), region(region) {}
|
|
|
|
PyBlockIterator dunderIter() {
|
|
operation->checkValid();
|
|
return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
|
|
}
|
|
|
|
intptr_t dunderLen() {
|
|
operation->checkValid();
|
|
intptr_t count = 0;
|
|
MlirBlock block = mlirRegionGetFirstBlock(region);
|
|
while (!mlirBlockIsNull(block)) {
|
|
count += 1;
|
|
block = mlirBlockGetNextInRegion(block);
|
|
}
|
|
return count;
|
|
}
|
|
|
|
PyBlock dunderGetItem(intptr_t index) {
|
|
operation->checkValid();
|
|
if (index < 0) {
|
|
throw SetPyError(PyExc_IndexError,
|
|
"attempt to access out of bounds block");
|
|
}
|
|
MlirBlock block = mlirRegionGetFirstBlock(region);
|
|
while (!mlirBlockIsNull(block)) {
|
|
if (index == 0) {
|
|
return PyBlock(operation, block);
|
|
}
|
|
block = mlirBlockGetNextInRegion(block);
|
|
index -= 1;
|
|
}
|
|
throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
|
|
}
|
|
|
|
PyBlock appendBlock(py::args pyArgTypes) {
|
|
operation->checkValid();
|
|
llvm::SmallVector<MlirType, 4> argTypes;
|
|
argTypes.reserve(pyArgTypes.size());
|
|
for (auto &pyArg : pyArgTypes) {
|
|
argTypes.push_back(pyArg.cast<PyType &>().type);
|
|
}
|
|
|
|
MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
|
|
mlirRegionAppendOwnedBlock(region, block);
|
|
return PyBlock(operation, block);
|
|
}
|
|
|
|
static void bind(py::module &m) {
|
|
py::class_<PyBlockList>(m, "BlockList")
|
|
.def("__getitem__", &PyBlockList::dunderGetItem)
|
|
.def("__iter__", &PyBlockList::dunderIter)
|
|
.def("__len__", &PyBlockList::dunderLen)
|
|
.def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
|
|
}
|
|
|
|
private:
|
|
PyOperationRef operation;
|
|
MlirRegion region;
|
|
};
|
|
|
|
class PyOperationIterator {
|
|
public:
|
|
PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
|
|
: parentOperation(std::move(parentOperation)), next(next) {}
|
|
|
|
PyOperationIterator &dunderIter() { return *this; }
|
|
|
|
py::object dunderNext() {
|
|
parentOperation->checkValid();
|
|
if (mlirOperationIsNull(next)) {
|
|
throw py::stop_iteration();
|
|
}
|
|
|
|
PyOperationRef returnOperation =
|
|
PyOperation::forOperation(parentOperation->getContext(), next);
|
|
next = mlirOperationGetNextInBlock(next);
|
|
return returnOperation.releaseObject();
|
|
}
|
|
|
|
static void bind(py::module &m) {
|
|
py::class_<PyOperationIterator>(m, "OperationIterator")
|
|
.def("__iter__", &PyOperationIterator::dunderIter)
|
|
.def("__next__", &PyOperationIterator::dunderNext);
|
|
}
|
|
|
|
private:
|
|
PyOperationRef parentOperation;
|
|
MlirOperation next;
|
|
};
|
|
|
|
/// Operations are exposed by the C-API as a forward-only linked list. In
|
|
/// Python, we present them as a more full-featured list-like container but
|
|
/// optimzie it for forward iteration. Iterable operations are always owned
|
|
/// by a block.
|
|
class PyOperationList {
|
|
public:
|
|
PyOperationList(PyOperationRef parentOperation, MlirBlock block)
|
|
: parentOperation(std::move(parentOperation)), block(block) {}
|
|
|
|
PyOperationIterator dunderIter() {
|
|
parentOperation->checkValid();
|
|
return PyOperationIterator(parentOperation,
|
|
mlirBlockGetFirstOperation(block));
|
|
}
|
|
|
|
intptr_t dunderLen() {
|
|
parentOperation->checkValid();
|
|
intptr_t count = 0;
|
|
MlirOperation childOp = mlirBlockGetFirstOperation(block);
|
|
while (!mlirOperationIsNull(childOp)) {
|
|
count += 1;
|
|
childOp = mlirOperationGetNextInBlock(childOp);
|
|
}
|
|
return count;
|
|
}
|
|
|
|
py::object dunderGetItem(intptr_t index) {
|
|
parentOperation->checkValid();
|
|
if (index < 0) {
|
|
throw SetPyError(PyExc_IndexError,
|
|
"attempt to access out of bounds operation");
|
|
}
|
|
MlirOperation childOp = mlirBlockGetFirstOperation(block);
|
|
while (!mlirOperationIsNull(childOp)) {
|
|
if (index == 0) {
|
|
return PyOperation::forOperation(parentOperation->getContext(), childOp)
|
|
.releaseObject();
|
|
}
|
|
childOp = mlirOperationGetNextInBlock(childOp);
|
|
index -= 1;
|
|
}
|
|
throw SetPyError(PyExc_IndexError,
|
|
"attempt to access out of bounds operation");
|
|
}
|
|
|
|
void insert(int index, PyOperation &newOperation) {
|
|
parentOperation->checkValid();
|
|
newOperation.checkValid();
|
|
if (index < 0) {
|
|
throw SetPyError(
|
|
PyExc_IndexError,
|
|
"only positive insertion indices are supported for operations");
|
|
}
|
|
if (newOperation.isAttached()) {
|
|
throw SetPyError(
|
|
PyExc_ValueError,
|
|
"attempt to insert an operation that has already been inserted");
|
|
}
|
|
// TODO: Needing to do this check is unfortunate, especially since it will
|
|
// be a forward-scan, just like the following call to
|
|
// mlirBlockInsertOwnedOperation. Switch to insert before/after once
|
|
// D88148 lands.
|
|
if (index > dunderLen()) {
|
|
throw SetPyError(PyExc_IndexError,
|
|
"attempt to insert operation past end");
|
|
}
|
|
mlirBlockInsertOwnedOperation(block, index, newOperation.get());
|
|
newOperation.setAttached();
|
|
// TODO: Rework the parentKeepAlive so as to avoid ownership hazards under
|
|
// the new ownership.
|
|
}
|
|
|
|
static void bind(py::module &m) {
|
|
py::class_<PyOperationList>(m, "OperationList")
|
|
.def("__getitem__", &PyOperationList::dunderGetItem)
|
|
.def("__iter__", &PyOperationList::dunderIter)
|
|
.def("__len__", &PyOperationList::dunderLen)
|
|
.def("insert", &PyOperationList::insert, py::arg("index"),
|
|
py::arg("operation"),
|
|
"Inserts an operation at an indexed position");
|
|
}
|
|
|
|
private:
|
|
PyOperationRef parentOperation;
|
|
MlirBlock block;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//------------------------------------------------------------------------------
|
|
// PyMlirContext
|
|
//------------------------------------------------------------------------------
|
|
|
|
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
|
|
py::gil_scoped_acquire acquire;
|
|
auto &liveContexts = getLiveContexts();
|
|
liveContexts[context.ptr] = this;
|
|
}
|
|
|
|
PyMlirContext::~PyMlirContext() {
|
|
// Note that the only public way to construct an instance is via the
|
|
// forContext method, which always puts the associated handle into
|
|
// liveContexts.
|
|
py::gil_scoped_acquire acquire;
|
|
getLiveContexts().erase(context.ptr);
|
|
mlirContextDestroy(context);
|
|
}
|
|
|
|
py::object PyMlirContext::getCapsule() {
|
|
return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
|
|
}
|
|
|
|
py::object PyMlirContext::createFromCapsule(py::object capsule) {
|
|
MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
|
|
if (mlirContextIsNull(rawContext))
|
|
throw py::error_already_set();
|
|
return forContext(rawContext).releaseObject();
|
|
}
|
|
|
|
PyMlirContext *PyMlirContext::createNewContextForInit() {
|
|
MlirContext context = mlirContextCreate();
|
|
mlirRegisterAllDialects(context);
|
|
return new PyMlirContext(context);
|
|
}
|
|
|
|
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
|
|
py::gil_scoped_acquire acquire;
|
|
auto &liveContexts = getLiveContexts();
|
|
auto it = liveContexts.find(context.ptr);
|
|
if (it == liveContexts.end()) {
|
|
// Create.
|
|
PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
|
|
py::object pyRef = py::cast(unownedContextWrapper);
|
|
assert(pyRef && "cast to py::object failed");
|
|
liveContexts[context.ptr] = unownedContextWrapper;
|
|
return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
|
|
}
|
|
// Use existing.
|
|
py::object pyRef = py::cast(it->second);
|
|
return PyMlirContextRef(it->second, std::move(pyRef));
|
|
}
|
|
|
|
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
|
|
static LiveContextMap liveContexts;
|
|
return liveContexts;
|
|
}
|
|
|
|
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
|
|
|
|
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
|
|
|
|
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
|
|
|
|
py::object PyMlirContext::createOperation(
|
|
std::string name, PyLocation location,
|
|
llvm::Optional<std::vector<PyType *>> results,
|
|
llvm::Optional<py::dict> attributes,
|
|
llvm::Optional<std::vector<PyBlock *>> successors, int regions) {
|
|
llvm::SmallVector<MlirType, 4> mlirResults;
|
|
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
|
|
llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
|
|
|
|
// General parameter validation.
|
|
if (regions < 0)
|
|
throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
|
|
|
|
// Unpack/validate results.
|
|
if (results) {
|
|
mlirResults.reserve(results->size());
|
|
for (PyType *result : *results) {
|
|
// TODO: Verify result type originate from the same context.
|
|
if (!result)
|
|
throw SetPyError(PyExc_ValueError, "result type cannot be None");
|
|
mlirResults.push_back(result->type);
|
|
}
|
|
}
|
|
// Unpack/validate attributes.
|
|
if (attributes) {
|
|
mlirAttributes.reserve(attributes->size());
|
|
for (auto &it : *attributes) {
|
|
|
|
auto name = it.first.cast<std::string>();
|
|
auto &attribute = it.second.cast<PyAttribute &>();
|
|
// TODO: Verify attribute originates from the same context.
|
|
mlirAttributes.emplace_back(std::move(name), attribute.attr);
|
|
}
|
|
}
|
|
// Unpack/validate successors.
|
|
if (successors) {
|
|
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
|
|
mlirSuccessors.reserve(successors->size());
|
|
for (auto *successor : *successors) {
|
|
// TODO: Verify successor originate from the same context.
|
|
if (!successor)
|
|
throw SetPyError(PyExc_ValueError, "successor block cannot be None");
|
|
mlirSuccessors.push_back(successor->get());
|
|
}
|
|
}
|
|
|
|
// Apply unpacked/validated to the operation state. Beyond this
|
|
// point, exceptions cannot be thrown or else the state will leak.
|
|
MlirOperationState state = mlirOperationStateGet(name.c_str(), location.loc);
|
|
if (!mlirResults.empty())
|
|
mlirOperationStateAddResults(&state, mlirResults.size(),
|
|
mlirResults.data());
|
|
if (!mlirAttributes.empty()) {
|
|
// Note that the attribute names directly reference bytes in
|
|
// mlirAttributes, so that vector must not be changed from here
|
|
// on.
|
|
llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
|
|
mlirNamedAttributes.reserve(mlirAttributes.size());
|
|
for (auto &it : mlirAttributes)
|
|
mlirNamedAttributes.push_back(
|
|
mlirNamedAttributeGet(it.first.c_str(), it.second));
|
|
mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
|
|
mlirNamedAttributes.data());
|
|
}
|
|
if (!mlirSuccessors.empty())
|
|
mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
|
|
mlirSuccessors.data());
|
|
if (regions) {
|
|
llvm::SmallVector<MlirRegion, 4> mlirRegions;
|
|
mlirRegions.resize(regions);
|
|
for (int i = 0; i < regions; ++i)
|
|
mlirRegions[i] = mlirRegionCreate();
|
|
mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
|
|
mlirRegions.data());
|
|
}
|
|
|
|
// Construct the operation.
|
|
MlirOperation operation = mlirOperationCreate(&state);
|
|
return PyOperation::createDetached(getRef(), operation).releaseObject();
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
// PyModule
|
|
//------------------------------------------------------------------------------
|
|
|
|
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
|
|
: BaseContextObject(std::move(contextRef)), module(module) {}
|
|
|
|
PyModule::~PyModule() {
|
|
py::gil_scoped_acquire acquire;
|
|
auto &liveModules = getContext()->liveModules;
|
|
assert(liveModules.count(module.ptr) == 1 &&
|
|
"destroying module not in live map");
|
|
liveModules.erase(module.ptr);
|
|
mlirModuleDestroy(module);
|
|
}
|
|
|
|
PyModuleRef PyModule::forModule(MlirModule module) {
|
|
MlirContext context = mlirModuleGetContext(module);
|
|
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
|
|
|
|
py::gil_scoped_acquire acquire;
|
|
auto &liveModules = contextRef->liveModules;
|
|
auto it = liveModules.find(module.ptr);
|
|
if (it == liveModules.end()) {
|
|
// Create.
|
|
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
|
|
// Note that the default return value policy on cast is automatic_reference,
|
|
// which does not take ownership (delete will not be called).
|
|
// Just be explicit.
|
|
py::object pyRef =
|
|
py::cast(unownedModule, py::return_value_policy::take_ownership);
|
|
unownedModule->handle = pyRef;
|
|
liveModules[module.ptr] =
|
|
std::make_pair(unownedModule->handle, unownedModule);
|
|
return PyModuleRef(unownedModule, std::move(pyRef));
|
|
}
|
|
// Use existing.
|
|
PyModule *existing = it->second.second;
|
|
py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
|
|
return PyModuleRef(existing, std::move(pyRef));
|
|
}
|
|
|
|
py::object PyModule::createFromCapsule(py::object capsule) {
|
|
MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
|
|
if (mlirModuleIsNull(rawModule))
|
|
throw py::error_already_set();
|
|
return forModule(rawModule).releaseObject();
|
|
}
|
|
|
|
py::object PyModule::getCapsule() {
|
|
return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
// PyOperation
|
|
//------------------------------------------------------------------------------
|
|
|
|
PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
|
|
: BaseContextObject(std::move(contextRef)), operation(operation) {}
|
|
|
|
PyOperation::~PyOperation() {
|
|
auto &liveOperations = getContext()->liveOperations;
|
|
assert(liveOperations.count(operation.ptr) == 1 &&
|
|
"destroying operation not in live map");
|
|
liveOperations.erase(operation.ptr);
|
|
if (!isAttached()) {
|
|
mlirOperationDestroy(operation);
|
|
}
|
|
}
|
|
|
|
PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
|
|
MlirOperation operation,
|
|
py::object parentKeepAlive) {
|
|
auto &liveOperations = contextRef->liveOperations;
|
|
// Create.
|
|
PyOperation *unownedOperation =
|
|
new PyOperation(std::move(contextRef), operation);
|
|
// Note that the default return value policy on cast is automatic_reference,
|
|
// which does not take ownership (delete will not be called).
|
|
// Just be explicit.
|
|
py::object pyRef =
|
|
py::cast(unownedOperation, py::return_value_policy::take_ownership);
|
|
unownedOperation->handle = pyRef;
|
|
if (parentKeepAlive) {
|
|
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
|
|
}
|
|
liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
|
|
return PyOperationRef(unownedOperation, std::move(pyRef));
|
|
}
|
|
|
|
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
|
|
MlirOperation operation,
|
|
py::object parentKeepAlive) {
|
|
auto &liveOperations = contextRef->liveOperations;
|
|
auto it = liveOperations.find(operation.ptr);
|
|
if (it == liveOperations.end()) {
|
|
// Create.
|
|
return createInstance(std::move(contextRef), operation,
|
|
std::move(parentKeepAlive));
|
|
}
|
|
// Use existing.
|
|
PyOperation *existing = it->second.second;
|
|
assert(existing->parentKeepAlive.is(parentKeepAlive));
|
|
py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
|
|
return PyOperationRef(existing, std::move(pyRef));
|
|
}
|
|
|
|
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
|
|
MlirOperation operation,
|
|
py::object parentKeepAlive) {
|
|
auto &liveOperations = contextRef->liveOperations;
|
|
assert(liveOperations.count(operation.ptr) == 0 &&
|
|
"cannot create detached operation that already exists");
|
|
(void)liveOperations;
|
|
|
|
PyOperationRef created = createInstance(std::move(contextRef), operation,
|
|
std::move(parentKeepAlive));
|
|
created->attached = false;
|
|
return created;
|
|
}
|
|
|
|
void PyOperation::checkValid() {
|
|
if (!valid) {
|
|
throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
|
|
}
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
// PyAttribute.
|
|
//------------------------------------------------------------------------------
|
|
|
|
bool PyAttribute::operator==(const PyAttribute &other) {
|
|
return mlirAttributeEqual(attr, other.attr);
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
// PyNamedAttribute.
|
|
//------------------------------------------------------------------------------
|
|
|
|
PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
|
|
: ownedName(new std::string(std::move(ownedName))) {
|
|
namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr);
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
// PyType.
|
|
//------------------------------------------------------------------------------
|
|
|
|
bool PyType::operator==(const PyType &other) {
|
|
return mlirTypeEqual(type, other.type);
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Standard attribute subclasses.
|
|
//------------------------------------------------------------------------------
|
|
|
|
namespace {
|
|
|
|
/// CRTP base classes for Python attributes that subclass Attribute and should
|
|
/// be castable from it (i.e. via something like StringAttr(attr)).
|
|
/// By default, attribute class hierarchies are one level deep (i.e. a
|
|
/// concrete attribute class extends PyAttribute); however, intermediate
|
|
/// python-visible base classes can be modeled by specifying a BaseTy.
|
|
template <typename DerivedTy, typename BaseTy = PyAttribute>
|
|
class PyConcreteAttribute : public BaseTy {
|
|
public:
|
|
// Derived classes must define statics for:
|
|
// IsAFunctionTy isaFunction
|
|
// const char *pyClassName
|
|
using ClassTy = py::class_<DerivedTy, PyAttribute>;
|
|
using IsAFunctionTy = int (*)(MlirAttribute);
|
|
|
|
PyConcreteAttribute() = default;
|
|
PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
|
|
: BaseTy(std::move(contextRef), attr) {}
|
|
PyConcreteAttribute(PyAttribute &orig)
|
|
: PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
|
|
|
|
static MlirAttribute castFrom(PyAttribute &orig) {
|
|
if (!DerivedTy::isaFunction(orig.attr)) {
|
|
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
|
throw SetPyError(PyExc_ValueError,
|
|
llvm::Twine("Cannot cast attribute to ") +
|
|
DerivedTy::pyClassName + " (from " + origRepr + ")");
|
|
}
|
|
return orig.attr;
|
|
}
|
|
|
|
static void bind(py::module &m) {
|
|
auto cls = ClassTy(m, DerivedTy::pyClassName);
|
|
cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
|
|
DerivedTy::bindDerived(cls);
|
|
}
|
|
|
|
/// Implemented by derived classes to add methods to the Python subclass.
|
|
static void bindDerived(ClassTy &m) {}
|
|
};
|
|
|
|
/// Float Point Attribute subclass - FloatAttr.
|
|
class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
|
|
static constexpr const char *pyClassName = "FloatAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
// TODO: Make the location optional and create a default location.
|
|
[](PyType &type, double value, PyLocation &loc) {
|
|
MlirAttribute attr =
|
|
mlirFloatAttrDoubleGetChecked(type.type, value, loc.loc);
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirAttributeIsNull(attr)) {
|
|
throw SetPyError(PyExc_ValueError,
|
|
llvm::Twine("invalid '") +
|
|
py::repr(py::cast(type)).cast<std::string>() +
|
|
"' and expected floating point type.");
|
|
}
|
|
return PyFloatAttribute(type.getContext(), attr);
|
|
},
|
|
py::arg("type"), py::arg("value"), py::arg("loc"),
|
|
"Gets an uniqued float point attribute associated to a type");
|
|
c.def_static(
|
|
"get_f32",
|
|
[](PyMlirContext &context, double value) {
|
|
MlirAttribute attr = mlirFloatAttrDoubleGet(
|
|
context.get(), mlirF32TypeGet(context.get()), value);
|
|
return PyFloatAttribute(context.getRef(), attr);
|
|
},
|
|
py::arg("context"), py::arg("value"),
|
|
"Gets an uniqued float point attribute associated to a f32 type");
|
|
c.def_static(
|
|
"get_f64",
|
|
[](PyMlirContext &context, double value) {
|
|
MlirAttribute attr = mlirFloatAttrDoubleGet(
|
|
context.get(), mlirF64TypeGet(context.get()), value);
|
|
return PyFloatAttribute(context.getRef(), attr);
|
|
},
|
|
py::arg("context"), py::arg("value"),
|
|
"Gets an uniqued float point attribute associated to a f64 type");
|
|
c.def_property_readonly(
|
|
"value",
|
|
[](PyFloatAttribute &self) {
|
|
return mlirFloatAttrGetValueDouble(self.attr);
|
|
},
|
|
"Returns the value of the float point attribute");
|
|
}
|
|
};
|
|
|
|
/// Integer Attribute subclass - IntegerAttr.
|
|
class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
|
|
static constexpr const char *pyClassName = "IntegerAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](PyType &type, int64_t value) {
|
|
MlirAttribute attr = mlirIntegerAttrGet(type.type, value);
|
|
return PyIntegerAttribute(type.getContext(), attr);
|
|
},
|
|
py::arg("type"), py::arg("value"),
|
|
"Gets an uniqued integer attribute associated to a type");
|
|
c.def_property_readonly(
|
|
"value",
|
|
[](PyIntegerAttribute &self) {
|
|
return mlirIntegerAttrGetValueInt(self.attr);
|
|
},
|
|
"Returns the value of the integer attribute");
|
|
}
|
|
};
|
|
|
|
/// Bool Attribute subclass - BoolAttr.
|
|
class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
|
|
static constexpr const char *pyClassName = "BoolAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](PyMlirContext &context, bool value) {
|
|
MlirAttribute attr = mlirBoolAttrGet(context.get(), value);
|
|
return PyBoolAttribute(context.getRef(), attr);
|
|
},
|
|
py::arg("context"), py::arg("value"), "Gets an uniqued bool attribute");
|
|
c.def_property_readonly(
|
|
"value",
|
|
[](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self.attr); },
|
|
"Returns the value of the bool attribute");
|
|
}
|
|
};
|
|
|
|
class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
|
|
static constexpr const char *pyClassName = "StringAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](PyMlirContext &context, std::string value) {
|
|
MlirAttribute attr =
|
|
mlirStringAttrGet(context.get(), value.size(), &value[0]);
|
|
return PyStringAttribute(context.getRef(), attr);
|
|
},
|
|
"Gets a uniqued string attribute");
|
|
c.def_static(
|
|
"get_typed",
|
|
[](PyType &type, std::string value) {
|
|
MlirAttribute attr =
|
|
mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
|
|
return PyStringAttribute(type.getContext(), attr);
|
|
},
|
|
|
|
"Gets a uniqued string attribute associated to a type");
|
|
c.def_property_readonly(
|
|
"value",
|
|
[](PyStringAttribute &self) {
|
|
MlirStringRef stringRef = mlirStringAttrGetValue(self.attr);
|
|
return py::str(stringRef.data, stringRef.length);
|
|
},
|
|
"Returns the value of the string attribute");
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Standard type subclasses.
|
|
//------------------------------------------------------------------------------
|
|
|
|
namespace {
|
|
|
|
/// CRTP base classes for Python types that subclass Type and should be
|
|
/// castable from it (i.e. via something like IntegerType(t)).
|
|
/// By default, type class hierarchies are one level deep (i.e. a
|
|
/// concrete type class extends PyType); however, intermediate python-visible
|
|
/// base classes can be modeled by specifying a BaseTy.
|
|
template <typename DerivedTy, typename BaseTy = PyType>
|
|
class PyConcreteType : public BaseTy {
|
|
public:
|
|
// Derived classes must define statics for:
|
|
// IsAFunctionTy isaFunction
|
|
// const char *pyClassName
|
|
using ClassTy = py::class_<DerivedTy, BaseTy>;
|
|
using IsAFunctionTy = int (*)(MlirType);
|
|
|
|
PyConcreteType() = default;
|
|
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
|
|
: BaseTy(std::move(contextRef), t) {}
|
|
PyConcreteType(PyType &orig)
|
|
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
|
|
|
|
static MlirType castFrom(PyType &orig) {
|
|
if (!DerivedTy::isaFunction(orig.type)) {
|
|
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
|
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
|
|
DerivedTy::pyClassName +
|
|
" (from " + origRepr + ")");
|
|
}
|
|
return orig.type;
|
|
}
|
|
|
|
static void bind(py::module &m) {
|
|
auto cls = ClassTy(m, DerivedTy::pyClassName);
|
|
cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
|
|
DerivedTy::bindDerived(cls);
|
|
}
|
|
|
|
/// Implemented by derived classes to add methods to the Python subclass.
|
|
static void bindDerived(ClassTy &m) {}
|
|
};
|
|
|
|
class PyIntegerType : public PyConcreteType<PyIntegerType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
|
|
static constexpr const char *pyClassName = "IntegerType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get_signless",
|
|
[](PyMlirContext &context, unsigned width) {
|
|
MlirType t = mlirIntegerTypeGet(context.get(), width);
|
|
return PyIntegerType(context.getRef(), t);
|
|
},
|
|
"Create a signless integer type");
|
|
c.def_static(
|
|
"get_signed",
|
|
[](PyMlirContext &context, unsigned width) {
|
|
MlirType t = mlirIntegerTypeSignedGet(context.get(), width);
|
|
return PyIntegerType(context.getRef(), t);
|
|
},
|
|
"Create a signed integer type");
|
|
c.def_static(
|
|
"get_unsigned",
|
|
[](PyMlirContext &context, unsigned width) {
|
|
MlirType t = mlirIntegerTypeUnsignedGet(context.get(), width);
|
|
return PyIntegerType(context.getRef(), t);
|
|
},
|
|
"Create an unsigned integer type");
|
|
c.def_property_readonly(
|
|
"width",
|
|
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
|
|
"Returns the width of the integer type");
|
|
c.def_property_readonly(
|
|
"is_signless",
|
|
[](PyIntegerType &self) -> bool {
|
|
return mlirIntegerTypeIsSignless(self.type);
|
|
},
|
|
"Returns whether this is a signless integer");
|
|
c.def_property_readonly(
|
|
"is_signed",
|
|
[](PyIntegerType &self) -> bool {
|
|
return mlirIntegerTypeIsSigned(self.type);
|
|
},
|
|
"Returns whether this is a signed integer");
|
|
c.def_property_readonly(
|
|
"is_unsigned",
|
|
[](PyIntegerType &self) -> bool {
|
|
return mlirIntegerTypeIsUnsigned(self.type);
|
|
},
|
|
"Returns whether this is an unsigned integer");
|
|
}
|
|
};
|
|
|
|
/// Index Type subclass - IndexType.
|
|
class PyIndexType : public PyConcreteType<PyIndexType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
|
|
static constexpr const char *pyClassName = "IndexType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def(py::init([](PyMlirContext &context) {
|
|
MlirType t = mlirIndexTypeGet(context.get());
|
|
return PyIndexType(context.getRef(), t);
|
|
}),
|
|
"Create a index type.");
|
|
}
|
|
};
|
|
|
|
/// Floating Point Type subclass - BF16Type.
|
|
class PyBF16Type : public PyConcreteType<PyBF16Type> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
|
|
static constexpr const char *pyClassName = "BF16Type";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def(py::init([](PyMlirContext &context) {
|
|
MlirType t = mlirBF16TypeGet(context.get());
|
|
return PyBF16Type(context.getRef(), t);
|
|
}),
|
|
"Create a bf16 type.");
|
|
}
|
|
};
|
|
|
|
/// Floating Point Type subclass - F16Type.
|
|
class PyF16Type : public PyConcreteType<PyF16Type> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
|
|
static constexpr const char *pyClassName = "F16Type";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def(py::init([](PyMlirContext &context) {
|
|
MlirType t = mlirF16TypeGet(context.get());
|
|
return PyF16Type(context.getRef(), t);
|
|
}),
|
|
"Create a f16 type.");
|
|
}
|
|
};
|
|
|
|
/// Floating Point Type subclass - F32Type.
|
|
class PyF32Type : public PyConcreteType<PyF32Type> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
|
|
static constexpr const char *pyClassName = "F32Type";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def(py::init([](PyMlirContext &context) {
|
|
MlirType t = mlirF32TypeGet(context.get());
|
|
return PyF32Type(context.getRef(), t);
|
|
}),
|
|
"Create a f32 type.");
|
|
}
|
|
};
|
|
|
|
/// Floating Point Type subclass - F64Type.
|
|
class PyF64Type : public PyConcreteType<PyF64Type> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
|
|
static constexpr const char *pyClassName = "F64Type";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def(py::init([](PyMlirContext &context) {
|
|
MlirType t = mlirF64TypeGet(context.get());
|
|
return PyF64Type(context.getRef(), t);
|
|
}),
|
|
"Create a f64 type.");
|
|
}
|
|
};
|
|
|
|
/// None Type subclass - NoneType.
|
|
class PyNoneType : public PyConcreteType<PyNoneType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
|
|
static constexpr const char *pyClassName = "NoneType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def(py::init([](PyMlirContext &context) {
|
|
MlirType t = mlirNoneTypeGet(context.get());
|
|
return PyNoneType(context.getRef(), t);
|
|
}),
|
|
"Create a none type.");
|
|
}
|
|
};
|
|
|
|
/// Complex Type subclass - ComplexType.
|
|
class PyComplexType : public PyConcreteType<PyComplexType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
|
|
static constexpr const char *pyClassName = "ComplexType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get_complex",
|
|
[](PyType &elementType) {
|
|
// The element must be a floating point or integer scalar type.
|
|
if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
|
|
MlirType t = mlirComplexTypeGet(elementType.type);
|
|
return PyComplexType(elementType.getContext(), t);
|
|
}
|
|
throw SetPyError(
|
|
PyExc_ValueError,
|
|
llvm::Twine("invalid '") +
|
|
py::repr(py::cast(elementType)).cast<std::string>() +
|
|
"' and expected floating point or integer type.");
|
|
},
|
|
"Create a complex type");
|
|
c.def_property_readonly(
|
|
"element_type",
|
|
[](PyComplexType &self) -> PyType {
|
|
MlirType t = mlirComplexTypeGetElementType(self.type);
|
|
return PyType(self.getContext(), t);
|
|
},
|
|
"Returns element type.");
|
|
}
|
|
};
|
|
|
|
class PyShapedType : public PyConcreteType<PyShapedType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
|
|
static constexpr const char *pyClassName = "ShapedType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_property_readonly(
|
|
"element_type",
|
|
[](PyShapedType &self) {
|
|
MlirType t = mlirShapedTypeGetElementType(self.type);
|
|
return PyType(self.getContext(), t);
|
|
},
|
|
"Returns the element type of the shaped type.");
|
|
c.def_property_readonly(
|
|
"has_rank",
|
|
[](PyShapedType &self) -> bool {
|
|
return mlirShapedTypeHasRank(self.type);
|
|
},
|
|
"Returns whether the given shaped type is ranked.");
|
|
c.def_property_readonly(
|
|
"rank",
|
|
[](PyShapedType &self) {
|
|
self.requireHasRank();
|
|
return mlirShapedTypeGetRank(self.type);
|
|
},
|
|
"Returns the rank of the given ranked shaped type.");
|
|
c.def_property_readonly(
|
|
"has_static_shape",
|
|
[](PyShapedType &self) -> bool {
|
|
return mlirShapedTypeHasStaticShape(self.type);
|
|
},
|
|
"Returns whether the given shaped type has a static shape.");
|
|
c.def(
|
|
"is_dynamic_dim",
|
|
[](PyShapedType &self, intptr_t dim) -> bool {
|
|
self.requireHasRank();
|
|
return mlirShapedTypeIsDynamicDim(self.type, dim);
|
|
},
|
|
"Returns whether the dim-th dimension of the given shaped type is "
|
|
"dynamic.");
|
|
c.def(
|
|
"get_dim_size",
|
|
[](PyShapedType &self, intptr_t dim) {
|
|
self.requireHasRank();
|
|
return mlirShapedTypeGetDimSize(self.type, dim);
|
|
},
|
|
"Returns the dim-th dimension of the given ranked shaped type.");
|
|
c.def_static(
|
|
"is_dynamic_size",
|
|
[](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
|
|
"Returns whether the given dimension size indicates a dynamic "
|
|
"dimension.");
|
|
c.def(
|
|
"is_dynamic_stride_or_offset",
|
|
[](PyShapedType &self, int64_t val) -> bool {
|
|
self.requireHasRank();
|
|
return mlirShapedTypeIsDynamicStrideOrOffset(val);
|
|
},
|
|
"Returns whether the given value is used as a placeholder for dynamic "
|
|
"strides and offsets in shaped types.");
|
|
}
|
|
|
|
private:
|
|
void requireHasRank() {
|
|
if (!mlirShapedTypeHasRank(type)) {
|
|
throw SetPyError(
|
|
PyExc_ValueError,
|
|
"calling this method requires that the type has a rank.");
|
|
}
|
|
}
|
|
};
|
|
|
|
/// Vector Type subclass - VectorType.
|
|
class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
|
|
static constexpr const char *pyClassName = "VectorType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get_vector",
|
|
// TODO: Make the location optional and create a default location.
|
|
[](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
|
|
MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
|
|
elementType.type, loc.loc);
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirTypeIsNull(t)) {
|
|
throw SetPyError(
|
|
PyExc_ValueError,
|
|
llvm::Twine("invalid '") +
|
|
py::repr(py::cast(elementType)).cast<std::string>() +
|
|
"' and expected floating point or integer type.");
|
|
}
|
|
return PyVectorType(elementType.getContext(), t);
|
|
},
|
|
"Create a vector type");
|
|
}
|
|
};
|
|
|
|
/// Ranked Tensor Type subclass - RankedTensorType.
|
|
class PyRankedTensorType
|
|
: public PyConcreteType<PyRankedTensorType, PyShapedType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
|
|
static constexpr const char *pyClassName = "RankedTensorType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get_ranked_tensor",
|
|
// TODO: Make the location optional and create a default location.
|
|
[](std::vector<int64_t> shape, PyType &elementType, PyLocation &loc) {
|
|
MlirType t = mlirRankedTensorTypeGetChecked(
|
|
shape.size(), shape.data(), elementType.type, loc.loc);
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirTypeIsNull(t)) {
|
|
throw SetPyError(
|
|
PyExc_ValueError,
|
|
llvm::Twine("invalid '") +
|
|
py::repr(py::cast(elementType)).cast<std::string>() +
|
|
"' and expected floating point, integer, vector or "
|
|
"complex "
|
|
"type.");
|
|
}
|
|
return PyRankedTensorType(elementType.getContext(), t);
|
|
},
|
|
"Create a ranked tensor type");
|
|
}
|
|
};
|
|
|
|
/// Unranked Tensor Type subclass - UnrankedTensorType.
|
|
class PyUnrankedTensorType
|
|
: public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
|
|
static constexpr const char *pyClassName = "UnrankedTensorType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get_unranked_tensor",
|
|
// TODO: Make the location optional and create a default location.
|
|
[](PyType &elementType, PyLocation &loc) {
|
|
MlirType t =
|
|
mlirUnrankedTensorTypeGetChecked(elementType.type, loc.loc);
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirTypeIsNull(t)) {
|
|
throw SetPyError(
|
|
PyExc_ValueError,
|
|
llvm::Twine("invalid '") +
|
|
py::repr(py::cast(elementType)).cast<std::string>() +
|
|
"' and expected floating point, integer, vector or "
|
|
"complex "
|
|
"type.");
|
|
}
|
|
return PyUnrankedTensorType(elementType.getContext(), t);
|
|
},
|
|
"Create a unranked tensor type");
|
|
}
|
|
};
|
|
|
|
/// Ranked MemRef Type subclass - MemRefType.
|
|
class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
|
|
static constexpr const char *pyClassName = "MemRefType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
// TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
|
|
// once the affine map binding is completed.
|
|
c.def_static(
|
|
"get_contiguous_memref",
|
|
// TODO: Make the location optional and create a default location.
|
|
[](PyType &elementType, std::vector<int64_t> shape,
|
|
unsigned memorySpace, PyLocation &loc) {
|
|
MlirType t = mlirMemRefTypeContiguousGetChecked(
|
|
elementType.type, shape.size(), shape.data(), memorySpace,
|
|
loc.loc);
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirTypeIsNull(t)) {
|
|
throw SetPyError(
|
|
PyExc_ValueError,
|
|
llvm::Twine("invalid '") +
|
|
py::repr(py::cast(elementType)).cast<std::string>() +
|
|
"' and expected floating point, integer, vector or "
|
|
"complex "
|
|
"type.");
|
|
}
|
|
return PyMemRefType(elementType.getContext(), t);
|
|
},
|
|
"Create a memref type")
|
|
.def_property_readonly(
|
|
"num_affine_maps",
|
|
[](PyMemRefType &self) -> intptr_t {
|
|
return mlirMemRefTypeGetNumAffineMaps(self.type);
|
|
},
|
|
"Returns the number of affine layout maps in the given MemRef "
|
|
"type.")
|
|
.def_property_readonly(
|
|
"memory_space",
|
|
[](PyMemRefType &self) -> unsigned {
|
|
return mlirMemRefTypeGetMemorySpace(self.type);
|
|
},
|
|
"Returns the memory space of the given MemRef type.");
|
|
}
|
|
};
|
|
|
|
/// Unranked MemRef Type subclass - UnrankedMemRefType.
|
|
class PyUnrankedMemRefType
|
|
: public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
|
|
static constexpr const char *pyClassName = "UnrankedMemRefType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get_unranked_memref",
|
|
// TODO: Make the location optional and create a default location.
|
|
[](PyType &elementType, unsigned memorySpace, PyLocation &loc) {
|
|
MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType.type,
|
|
memorySpace, loc.loc);
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirTypeIsNull(t)) {
|
|
throw SetPyError(
|
|
PyExc_ValueError,
|
|
llvm::Twine("invalid '") +
|
|
py::repr(py::cast(elementType)).cast<std::string>() +
|
|
"' and expected floating point, integer, vector or "
|
|
"complex "
|
|
"type.");
|
|
}
|
|
return PyUnrankedMemRefType(elementType.getContext(), t);
|
|
},
|
|
"Create a unranked memref type")
|
|
.def_property_readonly(
|
|
"memory_space",
|
|
[](PyUnrankedMemRefType &self) -> unsigned {
|
|
return mlirUnrankedMemrefGetMemorySpace(self.type);
|
|
},
|
|
"Returns the memory space of the given Unranked MemRef type.");
|
|
}
|
|
};
|
|
|
|
/// Tuple Type subclass - TupleType.
|
|
class PyTupleType : public PyConcreteType<PyTupleType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
|
|
static constexpr const char *pyClassName = "TupleType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get_tuple",
|
|
[](PyMlirContext &context, py::list elementList) {
|
|
intptr_t num = py::len(elementList);
|
|
// Mapping py::list to SmallVector.
|
|
SmallVector<MlirType, 4> elements;
|
|
for (auto element : elementList)
|
|
elements.push_back(element.cast<PyType>().type);
|
|
MlirType t = mlirTupleTypeGet(context.get(), num, elements.data());
|
|
return PyTupleType(context.getRef(), t);
|
|
},
|
|
"Create a tuple type");
|
|
c.def(
|
|
"get_type",
|
|
[](PyTupleType &self, intptr_t pos) -> PyType {
|
|
MlirType t = mlirTupleTypeGetType(self.type, pos);
|
|
return PyType(self.getContext(), t);
|
|
},
|
|
"Returns the pos-th type in the tuple type.");
|
|
c.def_property_readonly(
|
|
"num_types",
|
|
[](PyTupleType &self) -> intptr_t {
|
|
return mlirTupleTypeGetNumTypes(self.type);
|
|
},
|
|
"Returns the number of types contained in a tuple.");
|
|
}
|
|
};
|
|
|
|
/// Function type.
|
|
class PyFunctionType : public PyConcreteType<PyFunctionType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
|
|
static constexpr const char *pyClassName = "FunctionType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](PyMlirContext &context, std::vector<PyType> inputs,
|
|
std::vector<PyType> results) {
|
|
SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
|
|
SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
|
|
MlirType t = mlirFunctionTypeGet(context.get(), inputsRaw.size(),
|
|
inputsRaw.data(), resultsRaw.size(),
|
|
resultsRaw.data());
|
|
return PyFunctionType(context.getRef(), t);
|
|
},
|
|
py::arg("context"), py::arg("inputs"), py::arg("results"),
|
|
"Gets a FunctionType from a list of input and result types");
|
|
c.def_property_readonly(
|
|
"inputs",
|
|
[](PyFunctionType &self) {
|
|
MlirType t = self.type;
|
|
auto contextRef = self.getContext();
|
|
py::list types;
|
|
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self.type);
|
|
i < e; ++i) {
|
|
types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
|
|
}
|
|
return types;
|
|
},
|
|
"Returns the list of input types in the FunctionType.");
|
|
c.def_property_readonly(
|
|
"results",
|
|
[](PyFunctionType &self) {
|
|
MlirType t = self.type;
|
|
auto contextRef = self.getContext();
|
|
py::list types;
|
|
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self.type);
|
|
i < e; ++i) {
|
|
types.append(PyType(contextRef, mlirFunctionTypeGetResult(t, i)));
|
|
}
|
|
return types;
|
|
},
|
|
"Returns the list of result types in the FunctionType.");
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Populates the pybind11 IR submodule.
|
|
//------------------------------------------------------------------------------
|
|
|
|
void mlir::python::populateIRSubmodule(py::module &m) {
|
|
// Mapping of MlirContext
|
|
py::class_<PyMlirContext>(m, "Context")
|
|
.def(py::init<>(&PyMlirContext::createNewContextForInit))
|
|
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
|
|
.def("_get_context_again",
|
|
[](PyMlirContext &self) {
|
|
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
|
|
return ref.releaseObject();
|
|
})
|
|
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
|
|
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
|
|
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
|
&PyMlirContext::getCapsule)
|
|
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
|
|
.def_property(
|
|
"allow_unregistered_dialects",
|
|
[](PyMlirContext &self) -> bool {
|
|
return mlirContextGetAllowUnregisteredDialects(self.get());
|
|
},
|
|
[](PyMlirContext &self, bool value) {
|
|
mlirContextSetAllowUnregisteredDialects(self.get(), value);
|
|
})
|
|
.def("create_operation", &PyMlirContext::createOperation, py::arg("name"),
|
|
py::arg("location"), py::arg("results") = py::none(),
|
|
py::arg("attributes") = py::none(),
|
|
py::arg("successors") = py::none(), py::arg("regions") = 0,
|
|
kContextCreateOperationDocstring)
|
|
.def(
|
|
"parse_module",
|
|
[](PyMlirContext &self, const std::string moduleAsm) {
|
|
MlirModule module =
|
|
mlirModuleCreateParse(self.get(), moduleAsm.c_str());
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirModuleIsNull(module)) {
|
|
throw SetPyError(
|
|
PyExc_ValueError,
|
|
"Unable to parse module assembly (see diagnostics)");
|
|
}
|
|
return PyModule::forModule(module).releaseObject();
|
|
},
|
|
kContextParseDocstring)
|
|
.def(
|
|
"create_module",
|
|
[](PyMlirContext &self, PyLocation &loc) {
|
|
MlirModule module = mlirModuleCreateEmpty(loc.loc);
|
|
return PyModule::forModule(module).releaseObject();
|
|
},
|
|
py::arg("loc"), "Creates an empty module")
|
|
.def(
|
|
"parse_attr",
|
|
[](PyMlirContext &self, std::string attrSpec) {
|
|
MlirAttribute type =
|
|
mlirAttributeParseGet(self.get(), attrSpec.c_str());
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirAttributeIsNull(type)) {
|
|
throw SetPyError(PyExc_ValueError,
|
|
llvm::Twine("Unable to parse attribute: '") +
|
|
attrSpec + "'");
|
|
}
|
|
return PyAttribute(self.getRef(), type);
|
|
},
|
|
py::keep_alive<0, 1>())
|
|
.def(
|
|
"parse_type",
|
|
[](PyMlirContext &self, std::string typeSpec) {
|
|
MlirType type = mlirTypeParseGet(self.get(), typeSpec.c_str());
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirTypeIsNull(type)) {
|
|
throw SetPyError(PyExc_ValueError,
|
|
llvm::Twine("Unable to parse type: '") +
|
|
typeSpec + "'");
|
|
}
|
|
return PyType(self.getRef(), type);
|
|
},
|
|
kContextParseTypeDocstring)
|
|
.def(
|
|
"get_unknown_location",
|
|
[](PyMlirContext &self) {
|
|
return PyLocation(self.getRef(),
|
|
mlirLocationUnknownGet(self.get()));
|
|
},
|
|
kContextGetUnknownLocationDocstring)
|
|
.def(
|
|
"get_file_location",
|
|
[](PyMlirContext &self, std::string filename, int line, int col) {
|
|
return PyLocation(self.getRef(),
|
|
mlirLocationFileLineColGet(
|
|
self.get(), filename.c_str(), line, col));
|
|
},
|
|
kContextGetFileLocationDocstring, py::arg("filename"),
|
|
py::arg("line"), py::arg("col"));
|
|
|
|
py::class_<PyLocation>(m, "Location")
|
|
.def_property_readonly(
|
|
"context",
|
|
[](PyLocation &self) { return self.getContext().getObject(); },
|
|
"Context that owns the Location")
|
|
.def("__repr__", [](PyLocation &self) {
|
|
PyPrintAccumulator printAccum;
|
|
mlirLocationPrint(self.loc, printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
return printAccum.join();
|
|
});
|
|
|
|
// Mapping of Module
|
|
py::class_<PyModule>(m, "Module")
|
|
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
|
|
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
|
|
.def_property_readonly(
|
|
"context",
|
|
[](PyModule &self) { return self.getContext().getObject(); },
|
|
"Context that created the Module")
|
|
.def_property_readonly(
|
|
"operation",
|
|
[](PyModule &self) {
|
|
return PyOperation::forOperation(self.getContext(),
|
|
mlirModuleGetOperation(self.get()),
|
|
self.getRef().releaseObject())
|
|
.releaseObject();
|
|
},
|
|
"Accesses the module as an operation")
|
|
.def(
|
|
"dump",
|
|
[](PyModule &self) {
|
|
mlirOperationDump(mlirModuleGetOperation(self.get()));
|
|
},
|
|
kDumpDocstring)
|
|
.def(
|
|
"__str__",
|
|
[](PyModule &self) {
|
|
MlirOperation operation = mlirModuleGetOperation(self.get());
|
|
PyPrintAccumulator printAccum;
|
|
mlirOperationPrint(operation, printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
return printAccum.join();
|
|
},
|
|
kOperationStrDunderDocstring);
|
|
|
|
// Mapping of Operation.
|
|
py::class_<PyOperation>(m, "Operation")
|
|
.def_property_readonly(
|
|
"context",
|
|
[](PyOperation &self) { return self.getContext().getObject(); },
|
|
"Context that owns the Operation")
|
|
.def_property_readonly(
|
|
"regions",
|
|
[](PyOperation &self) { return PyRegionList(self.getRef()); })
|
|
.def("__iter__",
|
|
[](PyOperation &self) { return PyRegionIterator(self.getRef()); })
|
|
.def(
|
|
"__str__",
|
|
[](PyOperation &self) {
|
|
self.checkValid();
|
|
PyPrintAccumulator printAccum;
|
|
mlirOperationPrint(self.get(), printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
return printAccum.join();
|
|
},
|
|
kTypeStrDunderDocstring);
|
|
|
|
// Mapping of PyRegion.
|
|
py::class_<PyRegion>(m, "Region")
|
|
.def_property_readonly(
|
|
"blocks",
|
|
[](PyRegion &self) {
|
|
return PyBlockList(self.getParentOperation(), self.get());
|
|
},
|
|
"Returns a forward-optimized sequence of blocks.")
|
|
.def(
|
|
"__iter__",
|
|
[](PyRegion &self) {
|
|
self.checkValid();
|
|
MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
|
|
return PyBlockIterator(self.getParentOperation(), firstBlock);
|
|
},
|
|
"Iterates over blocks in the region.")
|
|
.def("__eq__", [](PyRegion &self, py::object &other) {
|
|
try {
|
|
PyRegion *otherRegion = other.cast<PyRegion *>();
|
|
return self.get().ptr == otherRegion->get().ptr;
|
|
} catch (std::exception &e) {
|
|
return false;
|
|
}
|
|
});
|
|
|
|
// Mapping of PyBlock.
|
|
py::class_<PyBlock>(m, "Block")
|
|
.def_property_readonly(
|
|
"operations",
|
|
[](PyBlock &self) {
|
|
return PyOperationList(self.getParentOperation(), self.get());
|
|
},
|
|
"Returns a forward-optimized sequence of operations.")
|
|
.def(
|
|
"__iter__",
|
|
[](PyBlock &self) {
|
|
self.checkValid();
|
|
MlirOperation firstOperation =
|
|
mlirBlockGetFirstOperation(self.get());
|
|
return PyOperationIterator(self.getParentOperation(),
|
|
firstOperation);
|
|
},
|
|
"Iterates over operations in the block.")
|
|
.def("__eq__",
|
|
[](PyBlock &self, py::object &other) {
|
|
try {
|
|
PyBlock *otherBlock = other.cast<PyBlock *>();
|
|
return self.get().ptr == otherBlock->get().ptr;
|
|
} catch (std::exception &e) {
|
|
return false;
|
|
}
|
|
})
|
|
.def(
|
|
"__str__",
|
|
[](PyBlock &self) {
|
|
self.checkValid();
|
|
PyPrintAccumulator printAccum;
|
|
mlirBlockPrint(self.get(), printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
return printAccum.join();
|
|
},
|
|
kTypeStrDunderDocstring);
|
|
|
|
// Mapping of Type.
|
|
py::class_<PyAttribute>(m, "Attribute")
|
|
.def_property_readonly(
|
|
"context",
|
|
[](PyAttribute &self) { return self.getContext().getObject(); },
|
|
"Context that owns the Attribute")
|
|
.def(
|
|
"get_named",
|
|
[](PyAttribute &self, std::string name) {
|
|
return PyNamedAttribute(self.attr, std::move(name));
|
|
},
|
|
py::keep_alive<0, 1>(), "Binds a name to the attribute")
|
|
.def("__eq__",
|
|
[](PyAttribute &self, py::object &other) {
|
|
try {
|
|
PyAttribute otherAttribute = other.cast<PyAttribute>();
|
|
return self == otherAttribute;
|
|
} catch (std::exception &e) {
|
|
return false;
|
|
}
|
|
})
|
|
.def(
|
|
"dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
|
|
kDumpDocstring)
|
|
.def(
|
|
"__str__",
|
|
[](PyAttribute &self) {
|
|
PyPrintAccumulator printAccum;
|
|
mlirAttributePrint(self.attr, printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
return printAccum.join();
|
|
},
|
|
kTypeStrDunderDocstring)
|
|
.def("__repr__", [](PyAttribute &self) {
|
|
// Generally, assembly formats are not printed for __repr__ because
|
|
// this can cause exceptionally long debug output and exceptions.
|
|
// However, attribute values are generally considered useful and are
|
|
// printed. This may need to be re-evaluated if debug dumps end up
|
|
// being excessive.
|
|
PyPrintAccumulator printAccum;
|
|
printAccum.parts.append("Attribute(");
|
|
mlirAttributePrint(self.attr, printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
printAccum.parts.append(")");
|
|
return printAccum.join();
|
|
});
|
|
|
|
py::class_<PyNamedAttribute>(m, "NamedAttribute")
|
|
.def("__repr__",
|
|
[](PyNamedAttribute &self) {
|
|
PyPrintAccumulator printAccum;
|
|
printAccum.parts.append("NamedAttribute(");
|
|
printAccum.parts.append(self.namedAttr.name);
|
|
printAccum.parts.append("=");
|
|
mlirAttributePrint(self.namedAttr.attribute,
|
|
printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
printAccum.parts.append(")");
|
|
return printAccum.join();
|
|
})
|
|
.def_property_readonly(
|
|
"name",
|
|
[](PyNamedAttribute &self) {
|
|
return py::str(self.namedAttr.name, strlen(self.namedAttr.name));
|
|
},
|
|
"The name of the NamedAttribute binding")
|
|
.def_property_readonly(
|
|
"attr",
|
|
[](PyNamedAttribute &self) {
|
|
// TODO: When named attribute is removed/refactored, also remove
|
|
// this constructor (it does an inefficient table lookup).
|
|
auto contextRef = PyMlirContext::forContext(
|
|
mlirAttributeGetContext(self.namedAttr.attribute));
|
|
return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
|
|
},
|
|
py::keep_alive<0, 1>(),
|
|
"The underlying generic attribute of the NamedAttribute binding");
|
|
|
|
// Standard attribute bindings.
|
|
PyFloatAttribute::bind(m);
|
|
PyIntegerAttribute::bind(m);
|
|
PyBoolAttribute::bind(m);
|
|
PyStringAttribute::bind(m);
|
|
|
|
// Mapping of Type.
|
|
py::class_<PyType>(m, "Type")
|
|
.def_property_readonly(
|
|
"context", [](PyType &self) { return self.getContext().getObject(); },
|
|
"Context that owns the Type")
|
|
.def("__eq__",
|
|
[](PyType &self, py::object &other) {
|
|
try {
|
|
PyType otherType = other.cast<PyType>();
|
|
return self == otherType;
|
|
} catch (std::exception &e) {
|
|
return false;
|
|
}
|
|
})
|
|
.def(
|
|
"dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
|
|
.def(
|
|
"__str__",
|
|
[](PyType &self) {
|
|
PyPrintAccumulator printAccum;
|
|
mlirTypePrint(self.type, printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
return printAccum.join();
|
|
},
|
|
kTypeStrDunderDocstring)
|
|
.def("__repr__", [](PyType &self) {
|
|
// Generally, assembly formats are not printed for __repr__ because
|
|
// this can cause exceptionally long debug output and exceptions.
|
|
// However, types are an exception as they typically have compact
|
|
// assembly forms and printing them is useful.
|
|
PyPrintAccumulator printAccum;
|
|
printAccum.parts.append("Type(");
|
|
mlirTypePrint(self.type, printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
printAccum.parts.append(")");
|
|
return printAccum.join();
|
|
});
|
|
|
|
// Standard type bindings.
|
|
PyIntegerType::bind(m);
|
|
PyIndexType::bind(m);
|
|
PyBF16Type::bind(m);
|
|
PyF16Type::bind(m);
|
|
PyF32Type::bind(m);
|
|
PyF64Type::bind(m);
|
|
PyNoneType::bind(m);
|
|
PyComplexType::bind(m);
|
|
PyShapedType::bind(m);
|
|
PyVectorType::bind(m);
|
|
PyRankedTensorType::bind(m);
|
|
PyUnrankedTensorType::bind(m);
|
|
PyMemRefType::bind(m);
|
|
PyUnrankedMemRefType::bind(m);
|
|
PyTupleType::bind(m);
|
|
PyFunctionType::bind(m);
|
|
|
|
// Container bindings.
|
|
PyBlockIterator::bind(m);
|
|
PyBlockList::bind(m);
|
|
PyOperationIterator::bind(m);
|
|
PyOperationList::bind(m);
|
|
PyRegionIterator::bind(m);
|
|
PyRegionList::bind(m);
|
|
}
|