mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-03-05 00:48:08 +00:00
[MLIR][python bindings] Add TypeCaster for returning refined types from python APIs
depends on D150839 This diff uses `MlirTypeID` to register `TypeCaster`s (i.e., `[](PyType pyType) -> DerivedTy { return pyType; }`) for all concrete types (i.e., `PyConcrete<...>`) that are then queried for (by `MlirTypeID`) and called in `struct type_caster<MlirType>::cast`. The result is that anywhere an `MlirType mlirType` is returned from a python binding, that `mlirType` is automatically cast to the correct concrete type. For example: ``` c0 = arith.ConstantOp(f32, 0.0) # CHECK: F32Type(f32) print(repr(c0.result.type)) unranked_tensor_type = UnrankedTensorType.get(f32) unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result # CHECK: UnrankedTensorType print(type(unranked_tensor.type).__name__) # CHECK: UnrankedTensorType(tensor<*xf32>) print(repr(unranked_tensor.type)) ``` This functionality immediately extends to typed attributes (i.e., `attr.type`). The diff also implements similar functionality for `mlir_type_subclass`es but in a slightly different way - for such types (which have no cpp corresponding `class` or `struct`) the user must provide a type caster in python (similar to how `AttrBuilder` works) or in cpp as a `py::cpp_function`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150927
This commit is contained in:
parent
5310be521d
commit
bfb1ba7526
@ -107,6 +107,23 @@
|
||||
* delineated). */
|
||||
#define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate"
|
||||
|
||||
/** Attribute on MLIR Python objects that expose a function for downcasting the
|
||||
* corresponding Python object to a subclass if the object is in fact a subclass
|
||||
* (Concrete or mlir_type_subclass) of ir.Type. The signature of the function
|
||||
* is: def maybe_downcast(self) -> object where the resulting object will
|
||||
* (possibly) be an instance of the subclass.
|
||||
*/
|
||||
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR "maybe_downcast"
|
||||
|
||||
/** Attribute on main C extension module (_mlir) that corresponds to the
|
||||
* type caster registration binding. The signature of the function is:
|
||||
* def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
|
||||
* bool replace)
|
||||
* where replace indicates the typeCaster should replace any existing registered
|
||||
* type casters (such as those for upstream ConcreteTypes).
|
||||
*/
|
||||
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
|
||||
|
||||
/// Gets a void* from a wrapped struct. Needed because const cast is different
|
||||
/// between C/C++.
|
||||
#ifdef __cplusplus
|
||||
|
@ -33,6 +33,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
|
||||
|
||||
|
@ -825,6 +825,9 @@ MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type);
|
||||
/// Gets the type ID of the type.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type);
|
||||
|
||||
/// Gets the dialect a type belongs to.
|
||||
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type);
|
||||
|
||||
/// Checks whether a type is null.
|
||||
static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; }
|
||||
|
||||
|
@ -28,6 +28,7 @@
|
||||
#include "llvm/ADT/Twine.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
|
||||
// Raw CAPI type casters need to be declared before use, so always include them
|
||||
// first.
|
||||
@ -272,6 +273,7 @@ struct type_caster<MlirType> {
|
||||
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
|
||||
.attr("Type")
|
||||
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
|
||||
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
|
||||
.release();
|
||||
}
|
||||
};
|
||||
@ -424,20 +426,24 @@ public:
|
||||
class mlir_type_subclass : public pure_subclass {
|
||||
public:
|
||||
using IsAFunctionTy = bool (*)(MlirType);
|
||||
using GetTypeIDFunctionTy = MlirTypeID (*)();
|
||||
|
||||
/// Subclasses by looking up the super-class dynamically.
|
||||
mlir_type_subclass(py::handle scope, const char *typeClassName,
|
||||
IsAFunctionTy isaFunction)
|
||||
IsAFunctionTy isaFunction,
|
||||
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
|
||||
: mlir_type_subclass(
|
||||
scope, typeClassName, isaFunction,
|
||||
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type")) {}
|
||||
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"),
|
||||
getTypeIDFunction) {}
|
||||
|
||||
/// Subclasses with a provided mlir.ir.Type super-class. This must
|
||||
/// be used if the subclass is being defined in the same extension module
|
||||
/// as the mlir.ir class (otherwise, it will trigger a recursive
|
||||
/// initialization).
|
||||
mlir_type_subclass(py::handle scope, const char *typeClassName,
|
||||
IsAFunctionTy isaFunction, const py::object &superCls)
|
||||
IsAFunctionTy isaFunction, const py::object &superCls,
|
||||
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
|
||||
: pure_subclass(scope, typeClassName, superCls) {
|
||||
// Casting constructor. Note that it hard, if not impossible, to properly
|
||||
// call chain to parent `__init__` in pybind11 due to its special handling
|
||||
@ -471,6 +477,19 @@ public:
|
||||
"isinstance",
|
||||
[isaFunction](MlirType other) { return isaFunction(other); },
|
||||
py::arg("other_type"));
|
||||
def("__repr__", [superCls, captureTypeName](py::object self) {
|
||||
return py::repr(superCls(self))
|
||||
.attr("replace")(superCls.attr("__name__"), captureTypeName);
|
||||
});
|
||||
if (getTypeIDFunction) {
|
||||
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
|
||||
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
|
||||
getTypeIDFunction(),
|
||||
pybind11::cpp_function(
|
||||
[thisClass = thisClass](const py::object &mlirType) {
|
||||
return thisClass(mlirType);
|
||||
}));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -44,4 +44,25 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
|
||||
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
|
||||
DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
|
||||
|
||||
namespace llvm {
|
||||
|
||||
template <>
|
||||
struct DenseMapInfo<MlirTypeID> {
|
||||
static inline MlirTypeID getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlirTypeIDCreate(pointer);
|
||||
}
|
||||
static inline MlirTypeID getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlirTypeIDCreate(pointer);
|
||||
}
|
||||
static inline unsigned getHashValue(const MlirTypeID &val) {
|
||||
return mlirTypeIDHashValue(val);
|
||||
}
|
||||
static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
|
||||
return mlirTypeIDEqual(lhs, rhs);
|
||||
}
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
#endif // MLIR_CAPI_SUPPORT_H
|
||||
|
@ -36,7 +36,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
|
||||
//===-------------------------------------------------------------------===//
|
||||
|
||||
auto operationType =
|
||||
mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType);
|
||||
mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
|
||||
mlirTransformOperationTypeGetTypeID);
|
||||
operationType.def_classmethod(
|
||||
"get",
|
||||
[](py::object cls, const std::string &operationName, MlirContext ctx) {
|
||||
|
@ -9,12 +9,15 @@
|
||||
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
|
||||
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
|
||||
@ -54,16 +57,18 @@ public:
|
||||
/// entities.
|
||||
void loadDialectModule(llvm::StringRef dialectNamespace);
|
||||
|
||||
/// Decorator for registering a custom Dialect class. The class object must
|
||||
/// have a DIALECT_NAMESPACE attribute.
|
||||
pybind11::object registerDialectDecorator(pybind11::object pyClass);
|
||||
|
||||
/// Adds a user-friendly Attribute builder.
|
||||
/// Raises an exception if the mapping already exists.
|
||||
/// This is intended to be called by implementation code.
|
||||
void registerAttributeBuilder(const std::string &attributeKind,
|
||||
pybind11::function pyFunc);
|
||||
|
||||
/// Adds a user-friendly type caster. Raises an exception if the mapping
|
||||
/// already exists and replace == false. This is intended to be called by
|
||||
/// implementation code.
|
||||
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
|
||||
bool replace = false);
|
||||
|
||||
/// Adds a concrete implementation dialect class.
|
||||
/// Raises an exception if the mapping already exists.
|
||||
/// This is intended to be called by implementation code.
|
||||
@ -80,6 +85,10 @@ public:
|
||||
std::optional<pybind11::function>
|
||||
lookupAttributeBuilder(const std::string &attributeKind);
|
||||
|
||||
/// Returns the custom type caster for MlirTypeID mlirTypeID.
|
||||
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
|
||||
MlirDialect dialect);
|
||||
|
||||
/// Looks up a registered dialect class by namespace. Note that this may
|
||||
/// trigger loading of the defining module and can arbitrarily re-enter.
|
||||
std::optional<pybind11::object>
|
||||
@ -101,6 +110,10 @@ private:
|
||||
llvm::StringMap<pybind11::object> operationClassMap;
|
||||
/// Map of attribute ODS name to custom builder.
|
||||
llvm::StringMap<pybind11::object> attributeBuilderMap;
|
||||
/// Map of MlirTypeID to custom type caster.
|
||||
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
|
||||
/// Cache for map of MlirTypeID to custom type caster.
|
||||
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
|
||||
|
||||
/// Set of dialect namespaces that we have attempted to import implementation
|
||||
/// modules for.
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlir;
|
||||
@ -1023,8 +1024,7 @@ public:
|
||||
py::arg("value"), py::arg("context") = py::none(),
|
||||
"Gets a uniqued Type attribute");
|
||||
c.def_property_readonly("value", [](PyTypeAttribute &self) {
|
||||
return PyType(self.getContext()->getRef(),
|
||||
mlirTypeAttrGetValue(self.get()));
|
||||
return mlirTypeAttrGetValue(self.get());
|
||||
});
|
||||
}
|
||||
};
|
||||
|
@ -25,6 +25,7 @@
|
||||
#include <utility>
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
using namespace mlir;
|
||||
using namespace mlir::python;
|
||||
|
||||
@ -2121,13 +2122,12 @@ public:
|
||||
|
||||
/// Returns the list of types of the values held by container.
|
||||
template <typename Container>
|
||||
static std::vector<PyType> getValueTypes(Container &container,
|
||||
PyMlirContextRef &context) {
|
||||
std::vector<PyType> result;
|
||||
static std::vector<MlirType> getValueTypes(Container &container,
|
||||
PyMlirContextRef &context) {
|
||||
std::vector<MlirType> result;
|
||||
result.reserve(container.size());
|
||||
for (int i = 0, e = container.size(); i < e; ++i) {
|
||||
result.push_back(
|
||||
PyType(context, mlirValueGetType(container.getElement(i).get())));
|
||||
result.push_back(mlirValueGetType(container.getElement(i).get()));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -3148,11 +3148,8 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
"context",
|
||||
[](PyAttribute &self) { return self.getContext().getObject(); },
|
||||
"Context that owns the Attribute")
|
||||
.def_property_readonly("type",
|
||||
[](PyAttribute &self) {
|
||||
return PyType(self.getContext()->getRef(),
|
||||
mlirAttributeGetType(self));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
|
||||
.def(
|
||||
"get_named",
|
||||
[](PyAttribute &self, std::string name) {
|
||||
@ -3247,7 +3244,7 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
|
||||
if (mlirTypeIsNull(type))
|
||||
throw MLIRError("Unable to parse type", errors.take());
|
||||
return PyType(context->getRef(), type);
|
||||
return type;
|
||||
},
|
||||
py::arg("asm"), py::arg("context") = py::none(),
|
||||
kContextParseTypeDocstring)
|
||||
@ -3284,6 +3281,18 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
printAccum.parts.append(")");
|
||||
return printAccum.join();
|
||||
})
|
||||
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
|
||||
[](PyType &self) {
|
||||
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
|
||||
assert(!mlirTypeIDIsNull(mlirTypeID) &&
|
||||
"mlirTypeID was expected to be non-null.");
|
||||
std::optional<pybind11::function> typeCaster =
|
||||
PyGlobals::get().lookupTypeCaster(mlirTypeID,
|
||||
mlirTypeGetDialect(self));
|
||||
if (!typeCaster)
|
||||
return py::cast(self);
|
||||
return typeCaster.value()(self);
|
||||
})
|
||||
.def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
|
||||
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
|
||||
if (!mlirTypeIDIsNull(mlirTypeID))
|
||||
@ -3387,12 +3396,8 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
return printAccum.join();
|
||||
},
|
||||
py::arg("use_local_scope") = false, kGetNameAsOperand)
|
||||
.def_property_readonly("type",
|
||||
[](PyValue &self) {
|
||||
return PyType(
|
||||
self.getParentOperation()->getContext(),
|
||||
mlirValueGetType(self.get()));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"type", [](PyValue &self) { return mlirValueGetType(self.get()); })
|
||||
.def(
|
||||
"replace_all_uses_with",
|
||||
[](PyValue &self, PyValue &with) {
|
||||
|
@ -321,11 +321,7 @@ public:
|
||||
py::module_local())
|
||||
.def_property_readonly(
|
||||
"element_type",
|
||||
[](PyShapedTypeComponents &self) {
|
||||
return PyType(PyMlirContext::forContext(
|
||||
mlirTypeGetContext(self.elementType)),
|
||||
self.elementType);
|
||||
},
|
||||
[](PyShapedTypeComponents &self) { return self.elementType; },
|
||||
"Returns the element type of the shaped type components.")
|
||||
.def_static(
|
||||
"get",
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlir;
|
||||
@ -72,6 +73,15 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
|
||||
found = std::move(pyFunc);
|
||||
}
|
||||
|
||||
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
|
||||
pybind11::function typeCaster,
|
||||
bool replace) {
|
||||
pybind11::object &found = typeCasterMap[mlirTypeID];
|
||||
if (found && !found.is_none() && !replace)
|
||||
throw std::runtime_error("Type caster is already registered");
|
||||
found = std::move(typeCaster);
|
||||
}
|
||||
|
||||
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
|
||||
py::object pyClass) {
|
||||
py::object &found = dialectClassMap[dialectNamespace];
|
||||
@ -110,6 +120,39 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
|
||||
MlirDialect dialect) {
|
||||
{
|
||||
// Fast match against the class map first (common case).
|
||||
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
|
||||
if (foundIt != typeCasterMapCache.end()) {
|
||||
if (foundIt->second.is_none())
|
||||
return std::nullopt;
|
||||
assert(foundIt->second && "py::function is defined");
|
||||
return foundIt->second;
|
||||
}
|
||||
}
|
||||
|
||||
// Not found. Load the dialect namespace.
|
||||
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
|
||||
|
||||
// Attempt to find from the canonical map and cache.
|
||||
{
|
||||
const auto foundIt = typeCasterMap.find(mlirTypeID);
|
||||
if (foundIt != typeCasterMap.end()) {
|
||||
if (foundIt->second.is_none())
|
||||
return std::nullopt;
|
||||
assert(foundIt->second && "py::object is defined");
|
||||
// Positive cache.
|
||||
typeCasterMapCache[mlirTypeID] = foundIt->second;
|
||||
return foundIt->second;
|
||||
}
|
||||
// Negative cache.
|
||||
typeCasterMap[mlirTypeID] = py::none();
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<py::object>
|
||||
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
|
||||
loadDialectModule(dialectNamespace);
|
||||
@ -164,4 +207,5 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
|
||||
void PyGlobals::clearImportCache() {
|
||||
loadedDialectModulesCache.clear();
|
||||
operationClassMapCache.clear();
|
||||
typeCasterMapCache.clear();
|
||||
}
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "Globals.h"
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "mlir-c/AffineExpr.h"
|
||||
@ -868,9 +869,7 @@ public:
|
||||
|
||||
PyConcreteType() = default;
|
||||
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
|
||||
: BaseTy(std::move(contextRef), t) {
|
||||
pybind11::implicitly_convertible<PyType, DerivedTy>();
|
||||
}
|
||||
: BaseTy(std::move(contextRef), t) {}
|
||||
PyConcreteType(PyType &orig)
|
||||
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
|
||||
|
||||
@ -914,6 +913,13 @@ public:
|
||||
return printAccum.join();
|
||||
});
|
||||
|
||||
if (DerivedTy::getTypeIdFunction) {
|
||||
PyGlobals::get().registerTypeCaster(
|
||||
DerivedTy::getTypeIdFunction(),
|
||||
pybind11::cpp_function(
|
||||
[](PyType pyType) -> DerivedTy { return pyType; }));
|
||||
}
|
||||
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
@ -1009,9 +1015,8 @@ public:
|
||||
return DerivedTy::isaFunction(otherAttr);
|
||||
},
|
||||
pybind11::arg("other"));
|
||||
cls.def_property_readonly("type", [](PyAttribute &attr) {
|
||||
return PyType(attr.getContext(), mlirAttributeGetType(attr));
|
||||
});
|
||||
cls.def_property_readonly(
|
||||
"type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
|
@ -334,10 +334,7 @@ public:
|
||||
"Create a complex type");
|
||||
c.def_property_readonly(
|
||||
"element_type",
|
||||
[](PyComplexType &self) -> PyType {
|
||||
MlirType t = mlirComplexTypeGetElementType(self);
|
||||
return PyType(self.getContext(), t);
|
||||
},
|
||||
[](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
|
||||
"Returns element type.");
|
||||
}
|
||||
};
|
||||
@ -351,10 +348,7 @@ public:
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_property_readonly(
|
||||
"element_type",
|
||||
[](PyShapedType &self) {
|
||||
MlirType t = mlirShapedTypeGetElementType(self);
|
||||
return PyType(self.getContext(), t);
|
||||
},
|
||||
[](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
|
||||
"Returns the element type of the shaped type.");
|
||||
c.def_property_readonly(
|
||||
"has_rank",
|
||||
@ -641,9 +635,8 @@ public:
|
||||
"Create a tuple type");
|
||||
c.def(
|
||||
"get_type",
|
||||
[](PyTupleType &self, intptr_t pos) -> PyType {
|
||||
MlirType t = mlirTupleTypeGetType(self, pos);
|
||||
return PyType(self.getContext(), t);
|
||||
[](PyTupleType &self, intptr_t pos) {
|
||||
return mlirTupleTypeGetType(self, pos);
|
||||
},
|
||||
py::arg("pos"), "Returns the pos-th type in the tuple type.");
|
||||
c.def_property_readonly(
|
||||
@ -686,7 +679,7 @@ public:
|
||||
py::list types;
|
||||
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
|
||||
++i) {
|
||||
types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
|
||||
types.append(mlirFunctionTypeGetInput(t, i));
|
||||
}
|
||||
return types;
|
||||
},
|
||||
@ -698,8 +691,7 @@ public:
|
||||
py::list types;
|
||||
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
|
||||
++i) {
|
||||
types.append(
|
||||
PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
|
||||
types.append(mlirFunctionTypeGetResult(self, i));
|
||||
}
|
||||
return types;
|
||||
},
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlir;
|
||||
using namespace py::literals;
|
||||
using namespace mlir::python;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
@ -35,12 +36,12 @@ PYBIND11_MODULE(_mlir, m) {
|
||||
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
|
||||
self.clearImportCache();
|
||||
},
|
||||
py::arg("module_name"))
|
||||
"module_name"_a)
|
||||
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
|
||||
py::arg("dialect_namespace"), py::arg("dialect_class"),
|
||||
"dialect_namespace"_a, "dialect_class"_a,
|
||||
"Testing hook for directly registering a dialect")
|
||||
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
|
||||
py::arg("operation_name"), py::arg("operation_class"),
|
||||
"operation_name"_a, "operation_class"_a,
|
||||
"Testing hook for directly registering an operation");
|
||||
|
||||
// Aside from making the globals accessible to python, having python manage
|
||||
@ -58,11 +59,11 @@ PYBIND11_MODULE(_mlir, m) {
|
||||
PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
|
||||
return pyClass;
|
||||
},
|
||||
py::arg("dialect_class"),
|
||||
"dialect_class"_a,
|
||||
"Class decorator for registering a custom Dialect wrapper");
|
||||
m.def(
|
||||
"register_operation",
|
||||
[](py::object dialectClass) -> py::cpp_function {
|
||||
[](const py::object &dialectClass) -> py::cpp_function {
|
||||
return py::cpp_function(
|
||||
[dialectClass](py::object opClass) -> py::object {
|
||||
std::string operationName =
|
||||
@ -75,9 +76,17 @@ PYBIND11_MODULE(_mlir, m) {
|
||||
return opClass;
|
||||
});
|
||||
},
|
||||
py::arg("dialect_class"),
|
||||
"dialect_class"_a,
|
||||
"Produce a class decorator for registering an Operation class as part of "
|
||||
"a dialect");
|
||||
m.def(
|
||||
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
|
||||
[](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) {
|
||||
PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster),
|
||||
replace);
|
||||
},
|
||||
"typeid"_a, "type_caster"_a, "replace"_a = false,
|
||||
"Register a type caster for casting MLIR types to custom user types.");
|
||||
|
||||
// Define and populate IR submodule.
|
||||
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
|
||||
|
@ -37,6 +37,10 @@ bool mlirTypeIsATransformOperationType(MlirType type) {
|
||||
return isa<transform::OperationType>(unwrap(type));
|
||||
}
|
||||
|
||||
MlirTypeID mlirTransformOperationTypeGetTypeID(void) {
|
||||
return wrap(transform::OperationType::getTypeID());
|
||||
}
|
||||
|
||||
MlirType mlirTransformOperationTypeGet(MlirContext ctx,
|
||||
MlirStringRef operationName) {
|
||||
return wrap(
|
||||
|
@ -324,6 +324,10 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
|
||||
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
|
||||
}
|
||||
|
||||
MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) {
|
||||
return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ranked / Unranked MemRef type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -832,6 +832,10 @@ MlirTypeID mlirTypeGetTypeID(MlirType type) {
|
||||
return wrap(unwrap(type).getTypeID());
|
||||
}
|
||||
|
||||
MlirDialect mlirTypeGetDialect(MlirType type) {
|
||||
return wrap(&unwrap(type).getDialect());
|
||||
}
|
||||
|
||||
bool mlirTypeEqual(MlirType t1, MlirType t2) {
|
||||
return unwrap(t1) == unwrap(t2);
|
||||
}
|
||||
|
@ -23,7 +23,6 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeID API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirTypeID mlirTypeIDCreate(const void *ptr) {
|
||||
assert(reinterpret_cast<uintptr_t>(ptr) % 8 == 0 &&
|
||||
"ptr must be 8 byte aligned");
|
||||
|
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._python_test_ops_gen import *
|
||||
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
|
||||
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType
|
||||
|
||||
|
||||
def register_python_test_dialect(context, load=True):
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
from ._mlir_libs._mlir.ir import *
|
||||
from ._mlir_libs._mlir.ir import _GlobalDebug
|
||||
from ._mlir_libs._mlir import register_type_caster
|
||||
|
||||
|
||||
# Convenience decorator for registering user-friendly Attribute builders.
|
||||
|
@ -369,9 +369,9 @@ def testTensorValue():
|
||||
|
||||
# Classes of custom types that inherit from concrete types should have
|
||||
# static_typeid
|
||||
assert isinstance(test.TestTensorType.static_typeid, TypeID)
|
||||
assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID)
|
||||
# And it should be equal to the in-tree concrete type
|
||||
assert test.TestTensorType.static_typeid == t.type.typeid
|
||||
assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: inferReturnTypeComponents
|
||||
@ -424,3 +424,46 @@ def inferReturnTypeComponents():
|
||||
print("rank:", shaped_type_components.rank)
|
||||
print("element type:", shaped_type_components.element_type)
|
||||
print("shape:", shaped_type_components.shape)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testCustomTypeTypeCaster
|
||||
@run
|
||||
def testCustomTypeTypeCaster():
|
||||
with Context() as ctx, Location.unknown():
|
||||
test.register_python_test_dialect(ctx)
|
||||
|
||||
a = test.TestType.get()
|
||||
assert a.typeid is not None
|
||||
|
||||
b = Type.parse("!python_test.test_type")
|
||||
# CHECK: !python_test.test_type
|
||||
print(b)
|
||||
# CHECK: TestType(!python_test.test_type)
|
||||
print(repr(b))
|
||||
|
||||
c = test.TestIntegerRankedTensorType.get([10, 10], 5)
|
||||
# CHECK: tensor<10x10xi5>
|
||||
print(c)
|
||||
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
|
||||
print(repr(c))
|
||||
|
||||
# CHECK: Type caster is already registered
|
||||
try:
|
||||
|
||||
def type_caster(pytype):
|
||||
return test.TestIntegerRankedTensorType(pytype)
|
||||
|
||||
register_type_caster(c.typeid, type_caster)
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
|
||||
def type_caster(pytype):
|
||||
return test.TestIntegerRankedTensorType(pytype)
|
||||
|
||||
register_type_caster(c.typeid, type_caster, replace=True)
|
||||
|
||||
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
|
||||
# CHECK: tensor<10x10xi5>
|
||||
print(d.type)
|
||||
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
|
||||
print(repr(d.type))
|
||||
|
@ -553,3 +553,42 @@ def testStridedLayoutAttr():
|
||||
print(f"rank: {len(attr.strides)}")
|
||||
# CHECK: strides are dynamic: [True, True, True]
|
||||
print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}")
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
|
||||
@run
|
||||
def testConcreteTypesRoundTrip():
|
||||
with Context(), Location.unknown():
|
||||
|
||||
def print_item(attr):
|
||||
print(repr(attr.type))
|
||||
|
||||
# CHECK: F32Type(f32)
|
||||
print_item(Attribute.parse("42.0 : f32"))
|
||||
# CHECK: F32Type(f32)
|
||||
print_item(FloatAttr.get_f32(42.0))
|
||||
# CHECK: IntegerType(i64)
|
||||
print_item(IntegerAttr.get(IntegerType.get_signless(64), 42))
|
||||
|
||||
def print_container_item(attr_asm):
|
||||
attr = DenseElementsAttr(Attribute.parse(attr_asm))
|
||||
print(repr(attr.type))
|
||||
print(repr(attr.type.element_type))
|
||||
|
||||
# CHECK: RankedTensorType(tensor<i16>)
|
||||
# CHECK: IntegerType(i16)
|
||||
print_container_item("dense<123> : tensor<i16>")
|
||||
|
||||
# CHECK: RankedTensorType(tensor<f64>)
|
||||
# CHECK: F64Type(f64)
|
||||
print_container_item("dense<1.0> : tensor<f64>")
|
||||
|
||||
raw = Attribute.parse("vector<4xf32>")
|
||||
# CHECK: attr: vector<4xf32>
|
||||
print("attr:", raw)
|
||||
type_attr = TypeAttr(raw)
|
||||
|
||||
# CHECK: VectorType(vector<4xf32>)
|
||||
print(repr(type_attr.value))
|
||||
# CHECK: F32Type(f32)
|
||||
print(repr(type_attr.value.element_type))
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
import gc
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import arith, tensor, func, memref
|
||||
|
||||
|
||||
def run(f):
|
||||
@ -382,15 +383,15 @@ def testMemRefType():
|
||||
f32 = F32Type.get()
|
||||
shape = [2, 3]
|
||||
loc = Location.unknown()
|
||||
memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
|
||||
memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
|
||||
# CHECK: memref type: memref<2x3xf32, 2>
|
||||
print("memref type:", memref)
|
||||
print("memref type:", memref_f32)
|
||||
# CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
|
||||
print("memref layout:", memref.layout)
|
||||
print("memref layout:", memref_f32.layout)
|
||||
# CHECK: memref affine map: (d0, d1) -> (d0, d1)
|
||||
print("memref affine map:", memref.affine_map)
|
||||
print("memref affine map:", memref_f32.affine_map)
|
||||
# CHECK: memory space: 2
|
||||
print("memory space:", memref.memory_space)
|
||||
print("memory space:", memref_f32.memory_space)
|
||||
|
||||
layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
|
||||
memref_layout = MemRefType.get(shape, f32, layout=layout)
|
||||
@ -413,7 +414,7 @@ def testMemRefType():
|
||||
else:
|
||||
print("Exception not produced")
|
||||
|
||||
assert memref.shape == shape
|
||||
assert memref_f32.shape == shape
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testUnrankedMemRefType
|
||||
@ -482,9 +483,9 @@ def testFunctionType():
|
||||
input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
|
||||
result_types = [IndexType.get()]
|
||||
func = FunctionType.get(input_types, result_types)
|
||||
# CHECK: INPUTS: [Type(i32), Type(i16)]
|
||||
# CHECK: INPUTS: [IntegerType(i32), IntegerType(i16)]
|
||||
print("INPUTS:", func.inputs)
|
||||
# CHECK: RESULTS: [Type(index)]
|
||||
# CHECK: RESULTS: [IndexType(index)]
|
||||
print("RESULTS:", func.results)
|
||||
|
||||
|
||||
@ -599,3 +600,130 @@ def testTypeIDs():
|
||||
vector_type = Type.parse("vector<2x3xf32>")
|
||||
# CHECK: True
|
||||
print(ShapedType(vector_type).typeid == vector_type.typeid)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
|
||||
@run
|
||||
def testConcreteTypesRoundTrip():
|
||||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
|
||||
def print_downcasted(typ):
|
||||
downcasted = Type(typ).maybe_downcast()
|
||||
print(type(downcasted).__name__)
|
||||
print(repr(downcasted))
|
||||
|
||||
# CHECK: F16Type
|
||||
# CHECK: F16Type(f16)
|
||||
print_downcasted(F16Type.get())
|
||||
# CHECK: F32Type
|
||||
# CHECK: F32Type(f32)
|
||||
print_downcasted(F32Type.get())
|
||||
# CHECK: F64Type
|
||||
# CHECK: F64Type(f64)
|
||||
print_downcasted(F64Type.get())
|
||||
# CHECK: Float8E4M3B11FNUZType
|
||||
# CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
|
||||
print_downcasted(Float8E4M3B11FNUZType.get())
|
||||
# CHECK: Float8E4M3FNType
|
||||
# CHECK: Float8E4M3FNType(f8E4M3FN)
|
||||
print_downcasted(Float8E4M3FNType.get())
|
||||
# CHECK: Float8E4M3FNUZType
|
||||
# CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
|
||||
print_downcasted(Float8E4M3FNUZType.get())
|
||||
# CHECK: Float8E5M2Type
|
||||
# CHECK: Float8E5M2Type(f8E5M2)
|
||||
print_downcasted(Float8E5M2Type.get())
|
||||
# CHECK: Float8E5M2FNUZType
|
||||
# CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
|
||||
print_downcasted(Float8E5M2FNUZType.get())
|
||||
# CHECK: BF16Type
|
||||
# CHECK: BF16Type(bf16)
|
||||
print_downcasted(BF16Type.get())
|
||||
# CHECK: IndexType
|
||||
# CHECK: IndexType(index)
|
||||
print_downcasted(IndexType.get())
|
||||
# CHECK: IntegerType
|
||||
# CHECK: IntegerType(i32)
|
||||
print_downcasted(IntegerType.get_signless(32))
|
||||
|
||||
f32 = F32Type.get()
|
||||
ranked_tensor = tensor.EmptyOp([10, 10], f32).result
|
||||
# CHECK: RankedTensorType
|
||||
print(type(ranked_tensor.type).__name__)
|
||||
# CHECK: RankedTensorType(tensor<10x10xf32>)
|
||||
print(repr(ranked_tensor.type))
|
||||
|
||||
cf32 = ComplexType.get(f32)
|
||||
# CHECK: ComplexType
|
||||
print(type(cf32).__name__)
|
||||
# CHECK: ComplexType(complex<f32>)
|
||||
print(repr(cf32))
|
||||
|
||||
ranked_tensor = tensor.EmptyOp([10, 10], f32).result
|
||||
# CHECK: RankedTensorType
|
||||
print(type(ranked_tensor.type).__name__)
|
||||
# CHECK: RankedTensorType(tensor<10x10xf32>)
|
||||
print(repr(ranked_tensor.type))
|
||||
|
||||
vector = VectorType.get([10, 10], f32)
|
||||
tuple_type = TupleType.get_tuple([f32, vector])
|
||||
# CHECK: TupleType
|
||||
print(type(tuple_type).__name__)
|
||||
# CHECK: TupleType(tuple<f32, vector<10x10xf32>>)
|
||||
print(repr(tuple_type))
|
||||
# CHECK: F32Type(f32)
|
||||
print(repr(tuple_type.get_type(0)))
|
||||
# CHECK: VectorType(vector<10x10xf32>)
|
||||
print(repr(tuple_type.get_type(1)))
|
||||
|
||||
index_type = IndexType.get()
|
||||
|
||||
@func.FuncOp.from_py_func()
|
||||
def default_builder():
|
||||
c0 = arith.ConstantOp(f32, 0.0)
|
||||
unranked_tensor_type = UnrankedTensorType.get(f32)
|
||||
unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result
|
||||
# CHECK: UnrankedTensorType
|
||||
print(type(unranked_tensor.type).__name__)
|
||||
# CHECK: UnrankedTensorType(tensor<*xf32>)
|
||||
print(repr(unranked_tensor.type))
|
||||
|
||||
c10 = arith.ConstantOp(index_type, 10)
|
||||
memref_f32_t = MemRefType.get([10, 10], f32)
|
||||
memref_f32 = memref.AllocOp(memref_f32_t, [c10, c10], []).result
|
||||
# CHECK: MemRefType
|
||||
print(type(memref_f32.type).__name__)
|
||||
# CHECK: MemRefType(memref<10x10xf32>)
|
||||
print(repr(memref_f32.type))
|
||||
|
||||
unranked_memref_t = UnrankedMemRefType.get(f32, Attribute.parse("2"))
|
||||
memref_f32 = memref.AllocOp(unranked_memref_t, [c10, c10], []).result
|
||||
# CHECK: UnrankedMemRefType
|
||||
print(type(memref_f32.type).__name__)
|
||||
# CHECK: UnrankedMemRefType(memref<*xf32, 2>)
|
||||
print(repr(memref_f32.type))
|
||||
|
||||
tuple_type = Operation.parse(
|
||||
f'"test.make_tuple"() : () -> tuple<i32, f32>'
|
||||
).result
|
||||
# CHECK: TupleType
|
||||
print(type(tuple_type.type).__name__)
|
||||
# CHECK: TupleType(tuple<i32, f32>)
|
||||
print(repr(tuple_type.type))
|
||||
|
||||
return c0, c10
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testCustomTypeTypeCaster
|
||||
# This tests being able to materialize a type from a dialect *and* have
|
||||
# the implemented type caster called without explicitly importing the dialect.
|
||||
# I.e., we get a transform.OperationType without explicitly importing the transform dialect.
|
||||
@run
|
||||
def testCustomTypeTypeCaster():
|
||||
with Context() as ctx, Location.unknown():
|
||||
t = Type.parse('!transform.op<"foo.bar">', Context())
|
||||
# CHECK: !transform.op<"foo.bar">
|
||||
print(t)
|
||||
# CHECK: OperationType(!transform.op<"foo.bar">)
|
||||
print(repr(t))
|
||||
|
@ -31,6 +31,10 @@ MlirType mlirPythonTestTestTypeGet(MlirContext context) {
|
||||
return wrap(python_test::TestTypeType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
MlirTypeID mlirPythonTestTestTypeGetTypeID(void) {
|
||||
return wrap(python_test::TestTypeType::getTypeID());
|
||||
}
|
||||
|
||||
bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) {
|
||||
return mlirTypeIsATensor(wrap(unwrap(value).getType()));
|
||||
}
|
||||
|
@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestTypeGetTypeID(void);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -7,11 +7,19 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PythonTestCAPI.h"
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlir::python::adaptors;
|
||||
using namespace pybind11::literals;
|
||||
|
||||
static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
|
||||
return mlirTypeIsARankedTensor(t) &&
|
||||
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_mlirPythonTest, m) {
|
||||
m.def(
|
||||
@ -34,16 +42,38 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
|
||||
return cls(mlirPythonTestTestAttributeGet(ctx));
|
||||
},
|
||||
py::arg("cls"), py::arg("context") = py::none());
|
||||
mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType)
|
||||
mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
|
||||
mlirPythonTestTestTypeGetTypeID)
|
||||
.def_classmethod(
|
||||
"get",
|
||||
[](py::object cls, MlirContext ctx) {
|
||||
return cls(mlirPythonTestTestTypeGet(ctx));
|
||||
},
|
||||
py::arg("cls"), py::arg("context") = py::none());
|
||||
mlir_type_subclass(m, "TestTensorType", mlirTypeIsARankedTensor,
|
||||
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
|
||||
.attr("RankedTensorType"));
|
||||
auto cls =
|
||||
mlir_type_subclass(m, "TestIntegerRankedTensorType",
|
||||
mlirTypeIsARankedIntegerTensor,
|
||||
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
|
||||
.attr("RankedTensorType"))
|
||||
.def_classmethod(
|
||||
"get",
|
||||
[](const py::object &cls, std::vector<int64_t> shape,
|
||||
unsigned width, MlirContext ctx) {
|
||||
MlirAttribute encoding = mlirAttributeGetNull();
|
||||
return cls(mlirRankedTensorTypeGet(
|
||||
shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
|
||||
encoding));
|
||||
},
|
||||
"cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
|
||||
assert(py::hasattr(cls.get_class(), "static_typeid") &&
|
||||
"TestIntegerRankedTensorType has no static_typeid");
|
||||
MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID();
|
||||
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
|
||||
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
|
||||
mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) {
|
||||
return cls.get_class()(mlirType);
|
||||
}),
|
||||
/*replace=*/true);
|
||||
mlir_value_subclass(m, "TestTensorValue",
|
||||
mlirTypeIsAPythonTestTestTensorValue)
|
||||
.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
|
||||
|
Loading…
x
Reference in New Issue
Block a user