mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-05-14 01:46:41 +00:00

* Generic mlir.ir.Attribute class. * First standard attribute (mlir.ir.StringAttr), following the same pattern as generic vs standard types. * NamedAttribute class. Differential Revision: https://reviews.llvm.org/D86250
493 lines
17 KiB
C++
493 lines
17 KiB
C++
//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "IRModules.h"
|
|
#include "PybindUtils.h"
|
|
|
|
#include "mlir-c/StandardAttributes.h"
|
|
#include "mlir-c/StandardTypes.h"
|
|
|
|
namespace py = pybind11;
|
|
using namespace mlir;
|
|
using namespace mlir::python;
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Docstrings (trivial, non-duplicated docstrings are included inline).
|
|
//------------------------------------------------------------------------------
|
|
|
|
static const char kContextParseDocstring[] =
|
|
R"(Parses a module's assembly format from a string.
|
|
|
|
Returns a new MlirModule or raises a ValueError if the parsing fails.
|
|
|
|
See also: https://mlir.llvm.org/docs/LangRef/
|
|
)";
|
|
|
|
static const char kContextParseType[] = R"(Parses the assembly form of a type.
|
|
|
|
Returns a Type object or raises a ValueError if the type cannot be parsed.
|
|
|
|
See also: https://mlir.llvm.org/docs/LangRef/#type-system
|
|
)";
|
|
|
|
static const char kOperationStrDunderDocstring[] =
|
|
R"(Prints the assembly form of the operation with default options.
|
|
|
|
If more advanced control over the assembly formatting or I/O options is needed,
|
|
use the dedicated print method, which supports keyword arguments to customize
|
|
behavior.
|
|
)";
|
|
|
|
static const char kTypeStrDunderDocstring[] =
|
|
R"(Prints the assembly form of the type.)";
|
|
|
|
static const char kDumpDocstring[] =
|
|
R"(Dumps a debug representation of the object to stderr.)";
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Conversion utilities.
|
|
//------------------------------------------------------------------------------
|
|
|
|
namespace {
|
|
|
|
/// Accumulates into a python string from a method that accepts an
|
|
/// MlirStringCallback.
|
|
struct PyPrintAccumulator {
|
|
py::list parts;
|
|
|
|
void *getUserData() { return this; }
|
|
|
|
MlirStringCallback getCallback() {
|
|
return [](const char *part, intptr_t size, void *userData) {
|
|
PyPrintAccumulator *printAccum =
|
|
static_cast<PyPrintAccumulator *>(userData);
|
|
py::str pyPart(part, size); // Decodes as UTF-8 by default.
|
|
printAccum->parts.append(std::move(pyPart));
|
|
};
|
|
}
|
|
|
|
py::str join() {
|
|
py::str delim("", 0);
|
|
return delim.attr("join")(parts);
|
|
}
|
|
};
|
|
|
|
/// Accumulates into a python string from a method that is expected to make
|
|
/// one (no more, no less) call to the callback (asserts internally on
|
|
/// violation).
|
|
struct PySinglePartStringAccumulator {
|
|
void *getUserData() { return this; }
|
|
|
|
MlirStringCallback getCallback() {
|
|
return [](const char *part, intptr_t size, void *userData) {
|
|
PySinglePartStringAccumulator *accum =
|
|
static_cast<PySinglePartStringAccumulator *>(userData);
|
|
assert(!accum->invoked &&
|
|
"PySinglePartStringAccumulator called back multiple times");
|
|
accum->invoked = true;
|
|
accum->value = py::str(part, size);
|
|
};
|
|
}
|
|
|
|
py::str takeValue() {
|
|
assert(invoked && "PySinglePartStringAccumulator not called back");
|
|
return std::move(value);
|
|
}
|
|
|
|
private:
|
|
py::str value;
|
|
bool invoked = false;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//------------------------------------------------------------------------------
|
|
// PyAttribute.
|
|
//------------------------------------------------------------------------------
|
|
|
|
bool PyAttribute::operator==(const PyAttribute &other) {
|
|
return mlirAttributeEqual(attr, other.attr);
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
// PyNamedAttribute.
|
|
//------------------------------------------------------------------------------
|
|
|
|
PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
|
|
: ownedName(new std::string(std::move(ownedName))) {
|
|
namedAttr = mlirNamedAttributeGet(this->ownedName->c_str(), attr);
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
// PyType.
|
|
//------------------------------------------------------------------------------
|
|
|
|
bool PyType::operator==(const PyType &other) {
|
|
return mlirTypeEqual(type, other.type);
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Standard attribute subclasses.
|
|
//------------------------------------------------------------------------------
|
|
|
|
namespace {
|
|
|
|
/// CRTP base classes for Python attributes that subclass Attribute and should
|
|
/// be castable from it (i.e. via something like StringAttr(attr)).
|
|
template <typename T>
|
|
class PyConcreteAttribute : public PyAttribute {
|
|
public:
|
|
// Derived classes must define statics for:
|
|
// IsAFunctionTy isaFunction
|
|
// const char *pyClassName
|
|
using ClassTy = py::class_<T, PyAttribute>;
|
|
using IsAFunctionTy = int (*)(MlirAttribute);
|
|
|
|
PyConcreteAttribute() = default;
|
|
PyConcreteAttribute(MlirAttribute attr) : PyAttribute(attr) {}
|
|
PyConcreteAttribute(PyAttribute &orig)
|
|
: PyConcreteAttribute(castFrom(orig)) {}
|
|
|
|
static MlirAttribute castFrom(PyAttribute &orig) {
|
|
if (!T::isaFunction(orig.attr)) {
|
|
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
|
throw SetPyError(PyExc_ValueError,
|
|
llvm::Twine("Cannot cast attribute to ") +
|
|
T::pyClassName + " (from " + origRepr + ")");
|
|
}
|
|
return orig.attr;
|
|
}
|
|
|
|
static void bind(py::module &m) {
|
|
auto cls = ClassTy(m, T::pyClassName);
|
|
cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
|
|
T::bindDerived(cls);
|
|
}
|
|
|
|
/// Implemented by derived classes to add methods to the Python subclass.
|
|
static void bindDerived(ClassTy &m) {}
|
|
};
|
|
|
|
class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
|
|
static constexpr const char *pyClassName = "StringAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](PyMlirContext &context, std::string value) {
|
|
MlirAttribute attr =
|
|
mlirStringAttrGet(context.context, value.size(), &value[0]);
|
|
return PyStringAttribute(attr);
|
|
},
|
|
py::keep_alive<0, 1>(), "Gets a uniqued string attribute");
|
|
c.def_static(
|
|
"get_typed",
|
|
[](PyType &type, std::string value) {
|
|
MlirAttribute attr =
|
|
mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
|
|
return PyStringAttribute(attr);
|
|
},
|
|
py::keep_alive<0, 1>(),
|
|
"Gets a uniqued string attribute associated to a type");
|
|
c.def_property_readonly(
|
|
"value",
|
|
[](PyStringAttribute &self) {
|
|
PySinglePartStringAccumulator accum;
|
|
mlirStringAttrGetValue(self.attr, accum.getCallback(),
|
|
accum.getUserData());
|
|
return accum.takeValue();
|
|
},
|
|
"Returns the value of the string attribute");
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Standard type subclasses.
|
|
//------------------------------------------------------------------------------
|
|
|
|
namespace {
|
|
|
|
/// CRTP base classes for Python types that subclass Type and should be
|
|
/// castable from it (i.e. via something like IntegerType(t)).
|
|
template <typename T>
|
|
class PyConcreteType : public PyType {
|
|
public:
|
|
// Derived classes must define statics for:
|
|
// IsAFunctionTy isaFunction
|
|
// const char *pyClassName
|
|
using ClassTy = py::class_<T, PyType>;
|
|
using IsAFunctionTy = int (*)(MlirType);
|
|
|
|
PyConcreteType() = default;
|
|
PyConcreteType(MlirType t) : PyType(t) {}
|
|
PyConcreteType(PyType &orig) : PyType(castFrom(orig)) {}
|
|
|
|
static MlirType castFrom(PyType &orig) {
|
|
if (!T::isaFunction(orig.type)) {
|
|
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
|
throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
|
|
T::pyClassName + " (from " +
|
|
origRepr + ")");
|
|
}
|
|
return orig.type;
|
|
}
|
|
|
|
static void bind(py::module &m) {
|
|
auto cls = ClassTy(m, T::pyClassName);
|
|
cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
|
|
T::bindDerived(cls);
|
|
}
|
|
|
|
/// Implemented by derived classes to add methods to the Python subclass.
|
|
static void bindDerived(ClassTy &m) {}
|
|
};
|
|
|
|
class PyIntegerType : public PyConcreteType<PyIntegerType> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
|
|
static constexpr const char *pyClassName = "IntegerType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get_signless",
|
|
[](PyMlirContext &context, unsigned width) {
|
|
MlirType t = mlirIntegerTypeGet(context.context, width);
|
|
return PyIntegerType(t);
|
|
},
|
|
py::keep_alive<0, 1>(), "Create a signless integer type");
|
|
c.def_static(
|
|
"get_signed",
|
|
[](PyMlirContext &context, unsigned width) {
|
|
MlirType t = mlirIntegerTypeSignedGet(context.context, width);
|
|
return PyIntegerType(t);
|
|
},
|
|
py::keep_alive<0, 1>(), "Create a signed integer type");
|
|
c.def_static(
|
|
"get_unsigned",
|
|
[](PyMlirContext &context, unsigned width) {
|
|
MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
|
|
return PyIntegerType(t);
|
|
},
|
|
py::keep_alive<0, 1>(), "Create an unsigned integer type");
|
|
c.def_property_readonly(
|
|
"width",
|
|
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
|
|
"Returns the width of the integer type");
|
|
c.def_property_readonly(
|
|
"is_signless",
|
|
[](PyIntegerType &self) -> bool {
|
|
return mlirIntegerTypeIsSignless(self.type);
|
|
},
|
|
"Returns whether this is a signless integer");
|
|
c.def_property_readonly(
|
|
"is_signed",
|
|
[](PyIntegerType &self) -> bool {
|
|
return mlirIntegerTypeIsSigned(self.type);
|
|
},
|
|
"Returns whether this is a signed integer");
|
|
c.def_property_readonly(
|
|
"is_unsigned",
|
|
[](PyIntegerType &self) -> bool {
|
|
return mlirIntegerTypeIsUnsigned(self.type);
|
|
},
|
|
"Returns whether this is an unsigned integer");
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//------------------------------------------------------------------------------
|
|
// Populates the pybind11 IR submodule.
|
|
//------------------------------------------------------------------------------
|
|
|
|
void mlir::python::populateIRSubmodule(py::module &m) {
|
|
// Mapping of MlirContext
|
|
py::class_<PyMlirContext>(m, "Context")
|
|
.def(py::init<>())
|
|
.def(
|
|
"parse_module",
|
|
[](PyMlirContext &self, const std::string module) {
|
|
auto moduleRef =
|
|
mlirModuleCreateParse(self.context, module.c_str());
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirModuleIsNull(moduleRef)) {
|
|
throw SetPyError(
|
|
PyExc_ValueError,
|
|
"Unable to parse module assembly (see diagnostics)");
|
|
}
|
|
return PyModule(moduleRef);
|
|
},
|
|
py::keep_alive<0, 1>(), kContextParseDocstring)
|
|
.def(
|
|
"parse_attr",
|
|
[](PyMlirContext &self, std::string attrSpec) {
|
|
MlirAttribute type =
|
|
mlirAttributeParseGet(self.context, attrSpec.c_str());
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirAttributeIsNull(type)) {
|
|
throw SetPyError(PyExc_ValueError,
|
|
llvm::Twine("Unable to parse attribute: '") +
|
|
attrSpec + "'");
|
|
}
|
|
return PyAttribute(type);
|
|
},
|
|
py::keep_alive<0, 1>())
|
|
.def(
|
|
"parse_type",
|
|
[](PyMlirContext &self, std::string typeSpec) {
|
|
MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
|
|
// TODO: Rework error reporting once diagnostic engine is exposed
|
|
// in C API.
|
|
if (mlirTypeIsNull(type)) {
|
|
throw SetPyError(PyExc_ValueError,
|
|
llvm::Twine("Unable to parse type: '") +
|
|
typeSpec + "'");
|
|
}
|
|
return PyType(type);
|
|
},
|
|
py::keep_alive<0, 1>(), kContextParseType);
|
|
|
|
// Mapping of Module
|
|
py::class_<PyModule>(m, "Module")
|
|
.def(
|
|
"dump",
|
|
[](PyModule &self) {
|
|
mlirOperationDump(mlirModuleGetOperation(self.module));
|
|
},
|
|
kDumpDocstring)
|
|
.def(
|
|
"__str__",
|
|
[](PyModule &self) {
|
|
auto operation = mlirModuleGetOperation(self.module);
|
|
PyPrintAccumulator printAccum;
|
|
mlirOperationPrint(operation, printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
return printAccum.join();
|
|
},
|
|
kOperationStrDunderDocstring);
|
|
|
|
// Mapping of Type.
|
|
py::class_<PyAttribute>(m, "Attribute")
|
|
.def(
|
|
"get_named",
|
|
[](PyAttribute &self, std::string name) {
|
|
return PyNamedAttribute(self.attr, std::move(name));
|
|
},
|
|
py::keep_alive<0, 1>(), "Binds a name to the attribute")
|
|
.def("__eq__",
|
|
[](PyAttribute &self, py::object &other) {
|
|
try {
|
|
PyAttribute otherAttribute = other.cast<PyAttribute>();
|
|
return self == otherAttribute;
|
|
} catch (std::exception &e) {
|
|
return false;
|
|
}
|
|
})
|
|
.def(
|
|
"dump", [](PyAttribute &self) { mlirAttributeDump(self.attr); },
|
|
kDumpDocstring)
|
|
.def(
|
|
"__str__",
|
|
[](PyAttribute &self) {
|
|
PyPrintAccumulator printAccum;
|
|
mlirAttributePrint(self.attr, printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
return printAccum.join();
|
|
},
|
|
kTypeStrDunderDocstring)
|
|
.def("__repr__", [](PyAttribute &self) {
|
|
// Generally, assembly formats are not printed for __repr__ because
|
|
// this can cause exceptionally long debug output and exceptions.
|
|
// However, attribute values are generally considered useful and are
|
|
// printed. This may need to be re-evaluated if debug dumps end up
|
|
// being excessive.
|
|
PyPrintAccumulator printAccum;
|
|
printAccum.parts.append("Attribute(");
|
|
mlirAttributePrint(self.attr, printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
printAccum.parts.append(")");
|
|
return printAccum.join();
|
|
});
|
|
|
|
py::class_<PyNamedAttribute>(m, "NamedAttribute")
|
|
.def("__repr__",
|
|
[](PyNamedAttribute &self) {
|
|
PyPrintAccumulator printAccum;
|
|
printAccum.parts.append("NamedAttribute(");
|
|
printAccum.parts.append(self.namedAttr.name);
|
|
printAccum.parts.append("=");
|
|
mlirAttributePrint(self.namedAttr.attribute,
|
|
printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
printAccum.parts.append(")");
|
|
return printAccum.join();
|
|
})
|
|
.def_property_readonly(
|
|
"name",
|
|
[](PyNamedAttribute &self) {
|
|
return py::str(self.namedAttr.name, strlen(self.namedAttr.name));
|
|
},
|
|
"The name of the NamedAttribute binding")
|
|
.def_property_readonly(
|
|
"attr",
|
|
[](PyNamedAttribute &self) {
|
|
return PyAttribute(self.namedAttr.attribute);
|
|
},
|
|
py::keep_alive<0, 1>(),
|
|
"The underlying generic attribute of the NamedAttribute binding");
|
|
|
|
// Standard attribute bindings.
|
|
PyStringAttribute::bind(m);
|
|
|
|
// Mapping of Type.
|
|
py::class_<PyType>(m, "Type")
|
|
.def("__eq__",
|
|
[](PyType &self, py::object &other) {
|
|
try {
|
|
PyType otherType = other.cast<PyType>();
|
|
return self == otherType;
|
|
} catch (std::exception &e) {
|
|
return false;
|
|
}
|
|
})
|
|
.def(
|
|
"dump", [](PyType &self) { mlirTypeDump(self.type); }, kDumpDocstring)
|
|
.def(
|
|
"__str__",
|
|
[](PyType &self) {
|
|
PyPrintAccumulator printAccum;
|
|
mlirTypePrint(self.type, printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
return printAccum.join();
|
|
},
|
|
kTypeStrDunderDocstring)
|
|
.def("__repr__", [](PyType &self) {
|
|
// Generally, assembly formats are not printed for __repr__ because
|
|
// this can cause exceptionally long debug output and exceptions.
|
|
// However, types are an exception as they typically have compact
|
|
// assembly forms and printing them is useful.
|
|
PyPrintAccumulator printAccum;
|
|
printAccum.parts.append("Type(");
|
|
mlirTypePrint(self.type, printAccum.getCallback(),
|
|
printAccum.getUserData());
|
|
printAccum.parts.append(")");
|
|
return printAccum.join();
|
|
});
|
|
|
|
// Standard type bindings.
|
|
PyIntegerType::bind(m);
|
|
}
|