[MLIR][python bindings] Add TypeCaster for returning refined types from python APIs

depends on D150839

This diff uses `MlirTypeID` to register `TypeCaster`s (i.e., `[](PyType pyType) -> DerivedTy { return pyType; }`) for all concrete types (i.e., `PyConcrete<...>`) that are then queried for (by `MlirTypeID`) and called in `struct type_caster<MlirType>::cast`. The result is that anywhere an `MlirType mlirType` is returned from a python binding, that `mlirType` is automatically cast to the correct concrete type. For example:

```
      c0 = arith.ConstantOp(f32, 0.0)
      # CHECK: F32Type(f32)
      print(repr(c0.result.type))

      unranked_tensor_type = UnrankedTensorType.get(f32)
      unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result

      # CHECK: UnrankedTensorType
      print(type(unranked_tensor.type).__name__)
      # CHECK: UnrankedTensorType(tensor<*xf32>)
      print(repr(unranked_tensor.type))
```

This functionality immediately extends to typed attributes (i.e., `attr.type`).

The diff also implements similar functionality for `mlir_type_subclass`es but in a slightly different way - for such types (which have no cpp corresponding `class` or `struct`) the user must provide a type caster in python (similar to how `AttrBuilder` works) or in cpp as a `py::cpp_function`.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D150927
This commit is contained in:
max 2023-05-26 10:23:17 -05:00
parent 5310be521d
commit bfb1ba7526
26 changed files with 460 additions and 75 deletions

View File

@ -107,6 +107,23 @@
* delineated). */
#define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate"
/** Attribute on MLIR Python objects that expose a function for downcasting the
* corresponding Python object to a subclass if the object is in fact a subclass
* (Concrete or mlir_type_subclass) of ir.Type. The signature of the function
* is: def maybe_downcast(self) -> object where the resulting object will
* (possibly) be an instance of the subclass.
*/
#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR "maybe_downcast"
/** Attribute on main C extension module (_mlir) that corresponds to the
* type caster registration binding. The signature of the function is:
* def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
* bool replace)
* where replace indicates the typeCaster should replace any existing registered
* type casters (such as those for upstream ConcreteTypes).
*/
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"
/// Gets a void* from a wrapped struct. Needed because const cast is different
/// between C/C++.
#ifdef __cplusplus

View File

@ -33,6 +33,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type);
MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void);
MLIR_CAPI_EXPORTED MlirType
mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);

View File

@ -825,6 +825,9 @@ MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type);
/// Gets the type ID of the type.
MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type);
/// Gets the dialect a type belongs to.
MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type);
/// Checks whether a type is null.
static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; }

View File

@ -28,6 +28,7 @@
#include "llvm/ADT/Twine.h"
namespace py = pybind11;
using namespace py::literals;
// Raw CAPI type casters need to be declared before use, so always include them
// first.
@ -272,6 +273,7 @@ struct type_caster<MlirType> {
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Type")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
.release();
}
};
@ -424,20 +426,24 @@ public:
class mlir_type_subclass : public pure_subclass {
public:
using IsAFunctionTy = bool (*)(MlirType);
using GetTypeIDFunctionTy = MlirTypeID (*)();
/// Subclasses by looking up the super-class dynamically.
mlir_type_subclass(py::handle scope, const char *typeClassName,
IsAFunctionTy isaFunction)
IsAFunctionTy isaFunction,
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: mlir_type_subclass(
scope, typeClassName, isaFunction,
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type")) {}
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"),
getTypeIDFunction) {}
/// Subclasses with a provided mlir.ir.Type super-class. This must
/// be used if the subclass is being defined in the same extension module
/// as the mlir.ir class (otherwise, it will trigger a recursive
/// initialization).
mlir_type_subclass(py::handle scope, const char *typeClassName,
IsAFunctionTy isaFunction, const py::object &superCls)
IsAFunctionTy isaFunction, const py::object &superCls,
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: pure_subclass(scope, typeClassName, superCls) {
// Casting constructor. Note that it hard, if not impossible, to properly
// call chain to parent `__init__` in pybind11 due to its special handling
@ -471,6 +477,19 @@ public:
"isinstance",
[isaFunction](MlirType other) { return isaFunction(other); },
py::arg("other_type"));
def("__repr__", [superCls, captureTypeName](py::object self) {
return py::repr(superCls(self))
.attr("replace")(superCls.attr("__name__"), captureTypeName);
});
if (getTypeIDFunction) {
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
getTypeIDFunction(),
pybind11::cpp_function(
[thisClass = thisClass](const py::object &mlirType) {
return thisClass(mlirType);
}));
}
}
};

View File

@ -44,4 +44,25 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) {
DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator)
namespace llvm {
template <>
struct DenseMapInfo<MlirTypeID> {
static inline MlirTypeID getEmptyKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlirTypeIDCreate(pointer);
}
static inline MlirTypeID getTombstoneKey() {
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlirTypeIDCreate(pointer);
}
static inline unsigned getHashValue(const MlirTypeID &val) {
return mlirTypeIDHashValue(val);
}
static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) {
return mlirTypeIDEqual(lhs, rhs);
}
};
} // namespace llvm
#endif // MLIR_CAPI_SUPPORT_H

View File

@ -36,7 +36,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//
auto operationType =
mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType);
mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
mlirTransformOperationTypeGetTypeID);
operationType.def_classmethod(
"get",
[](py::object cls, const std::string &operationName, MlirContext ctx) {

View File

@ -9,12 +9,15 @@
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
#include <optional>
#include <string>
#include <vector>
#include <optional>
#include "PybindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
@ -54,16 +57,18 @@ public:
/// entities.
void loadDialectModule(llvm::StringRef dialectNamespace);
/// Decorator for registering a custom Dialect class. The class object must
/// have a DIALECT_NAMESPACE attribute.
pybind11::object registerDialectDecorator(pybind11::object pyClass);
/// Adds a user-friendly Attribute builder.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
pybind11::function pyFunc);
/// Adds a user-friendly type caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
bool replace = false);
/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
@ -80,6 +85,10 @@ public:
std::optional<pybind11::function>
lookupAttributeBuilder(const std::string &attributeKind);
/// Returns the custom type caster for MlirTypeID mlirTypeID.
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
std::optional<pybind11::object>
@ -101,6 +110,10 @@ private:
llvm::StringMap<pybind11::object> operationClassMap;
/// Map of attribute ODS name to custom builder.
llvm::StringMap<pybind11::object> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
/// Cache for map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMapCache;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.

View File

@ -15,6 +15,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
using namespace mlir;
@ -1023,8 +1024,7 @@ public:
py::arg("value"), py::arg("context") = py::none(),
"Gets a uniqued Type attribute");
c.def_property_readonly("value", [](PyTypeAttribute &self) {
return PyType(self.getContext()->getRef(),
mlirTypeAttrGetValue(self.get()));
return mlirTypeAttrGetValue(self.get());
});
}
};

View File

@ -25,6 +25,7 @@
#include <utility>
namespace py = pybind11;
using namespace py::literals;
using namespace mlir;
using namespace mlir::python;
@ -2121,13 +2122,12 @@ public:
/// Returns the list of types of the values held by container.
template <typename Container>
static std::vector<PyType> getValueTypes(Container &container,
PyMlirContextRef &context) {
std::vector<PyType> result;
static std::vector<MlirType> getValueTypes(Container &container,
PyMlirContextRef &context) {
std::vector<MlirType> result;
result.reserve(container.size());
for (int i = 0, e = container.size(); i < e; ++i) {
result.push_back(
PyType(context, mlirValueGetType(container.getElement(i).get())));
result.push_back(mlirValueGetType(container.getElement(i).get()));
}
return result;
}
@ -3148,11 +3148,8 @@ void mlir::python::populateIRCore(py::module &m) {
"context",
[](PyAttribute &self) { return self.getContext().getObject(); },
"Context that owns the Attribute")
.def_property_readonly("type",
[](PyAttribute &self) {
return PyType(self.getContext()->getRef(),
mlirAttributeGetType(self));
})
.def_property_readonly(
"type", [](PyAttribute &self) { return mlirAttributeGetType(self); })
.def(
"get_named",
[](PyAttribute &self, std::string name) {
@ -3247,7 +3244,7 @@ void mlir::python::populateIRCore(py::module &m) {
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
if (mlirTypeIsNull(type))
throw MLIRError("Unable to parse type", errors.take());
return PyType(context->getRef(), type);
return type;
},
py::arg("asm"), py::arg("context") = py::none(),
kContextParseTypeDocstring)
@ -3284,6 +3281,18 @@ void mlir::python::populateIRCore(py::module &m) {
printAccum.parts.append(")");
return printAccum.join();
})
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyType &self) {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
std::optional<pybind11::function> typeCaster =
PyGlobals::get().lookupTypeCaster(mlirTypeID,
mlirTypeGetDialect(self));
if (!typeCaster)
return py::cast(self);
return typeCaster.value()(self);
})
.def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
if (!mlirTypeIDIsNull(mlirTypeID))
@ -3387,12 +3396,8 @@ void mlir::python::populateIRCore(py::module &m) {
return printAccum.join();
},
py::arg("use_local_scope") = false, kGetNameAsOperand)
.def_property_readonly("type",
[](PyValue &self) {
return PyType(
self.getParentOperation()->getContext(),
mlirValueGetType(self.get()));
})
.def_property_readonly(
"type", [](PyValue &self) { return mlirValueGetType(self.get()); })
.def(
"replace_all_uses_with",
[](PyValue &self, PyValue &with) {

View File

@ -321,11 +321,7 @@ public:
py::module_local())
.def_property_readonly(
"element_type",
[](PyShapedTypeComponents &self) {
return PyType(PyMlirContext::forContext(
mlirTypeGetContext(self.elementType)),
self.elementType);
},
[](PyShapedTypeComponents &self) { return self.elementType; },
"Returns the element type of the shaped type components.")
.def_static(
"get",

View File

@ -14,6 +14,7 @@
#include <vector>
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
namespace py = pybind11;
using namespace mlir;
@ -72,6 +73,15 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
found = std::move(pyFunc);
}
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
pybind11::function typeCaster,
bool replace) {
pybind11::object &found = typeCasterMap[mlirTypeID];
if (found && !found.is_none() && !replace)
throw std::runtime_error("Type caster is already registered");
found = std::move(typeCaster);
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
@ -110,6 +120,39 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
return std::nullopt;
}
std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
{
// Fast match against the class map first (common case).
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
if (foundIt != typeCasterMapCache.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::function is defined");
return foundIt->second;
}
}
// Not found. Load the dialect namespace.
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
// Attempt to find from the canonical map and cache.
{
const auto foundIt = typeCasterMap.find(mlirTypeID);
if (foundIt != typeCasterMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
// Positive cache.
typeCasterMapCache[mlirTypeID] = foundIt->second;
return foundIt->second;
}
// Negative cache.
typeCasterMap[mlirTypeID] = py::none();
return std::nullopt;
}
}
std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
loadDialectModule(dialectNamespace);
@ -164,4 +207,5 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
void PyGlobals::clearImportCache() {
loadedDialectModulesCache.clear();
operationClassMapCache.clear();
typeCasterMapCache.clear();
}

View File

@ -13,6 +13,7 @@
#include <utility>
#include <vector>
#include "Globals.h"
#include "PybindUtils.h"
#include "mlir-c/AffineExpr.h"
@ -868,9 +869,7 @@ public:
PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef, MlirType t)
: BaseTy(std::move(contextRef), t) {
pybind11::implicitly_convertible<PyType, DerivedTy>();
}
: BaseTy(std::move(contextRef), t) {}
PyConcreteType(PyType &orig)
: PyConcreteType(orig.getContext(), castFrom(orig)) {}
@ -914,6 +913,13 @@ public:
return printAccum.join();
});
if (DerivedTy::getTypeIdFunction) {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
pybind11::cpp_function(
[](PyType pyType) -> DerivedTy { return pyType; }));
}
DerivedTy::bindDerived(cls);
}
@ -1009,9 +1015,8 @@ public:
return DerivedTy::isaFunction(otherAttr);
},
pybind11::arg("other"));
cls.def_property_readonly("type", [](PyAttribute &attr) {
return PyType(attr.getContext(), mlirAttributeGetType(attr));
});
cls.def_property_readonly(
"type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
DerivedTy::bindDerived(cls);
}

View File

@ -334,10 +334,7 @@ public:
"Create a complex type");
c.def_property_readonly(
"element_type",
[](PyComplexType &self) -> PyType {
MlirType t = mlirComplexTypeGetElementType(self);
return PyType(self.getContext(), t);
},
[](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
"Returns element type.");
}
};
@ -351,10 +348,7 @@ public:
static void bindDerived(ClassTy &c) {
c.def_property_readonly(
"element_type",
[](PyShapedType &self) {
MlirType t = mlirShapedTypeGetElementType(self);
return PyType(self.getContext(), t);
},
[](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
"Returns the element type of the shaped type.");
c.def_property_readonly(
"has_rank",
@ -641,9 +635,8 @@ public:
"Create a tuple type");
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) -> PyType {
MlirType t = mlirTupleTypeGetType(self, pos);
return PyType(self.getContext(), t);
[](PyTupleType &self, intptr_t pos) {
return mlirTupleTypeGetType(self, pos);
},
py::arg("pos"), "Returns the pos-th type in the tuple type.");
c.def_property_readonly(
@ -686,7 +679,7 @@ public:
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
++i) {
types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
types.append(mlirFunctionTypeGetInput(t, i));
}
return types;
},
@ -698,8 +691,7 @@ public:
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
types.append(
PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
types.append(mlirFunctionTypeGetResult(self, i));
}
return types;
},

View File

@ -16,6 +16,7 @@
namespace py = pybind11;
using namespace mlir;
using namespace py::literals;
using namespace mlir::python;
// -----------------------------------------------------------------------------
@ -35,12 +36,12 @@ PYBIND11_MODULE(_mlir, m) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
self.clearImportCache();
},
py::arg("module_name"))
"module_name"_a)
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
py::arg("dialect_namespace"), py::arg("dialect_class"),
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
py::arg("operation_name"), py::arg("operation_class"),
"operation_name"_a, "operation_class"_a,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
@ -58,11 +59,11 @@ PYBIND11_MODULE(_mlir, m) {
PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
return pyClass;
},
py::arg("dialect_class"),
"dialect_class"_a,
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
[](py::object dialectClass) -> py::cpp_function {
[](const py::object &dialectClass) -> py::cpp_function {
return py::cpp_function(
[dialectClass](py::object opClass) -> py::object {
std::string operationName =
@ -75,9 +76,17 @@ PYBIND11_MODULE(_mlir, m) {
return opClass;
});
},
py::arg("dialect_class"),
"dialect_class"_a,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
[](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) {
PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster),
replace);
},
"typeid"_a, "type_caster"_a, "replace"_a = false,
"Register a type caster for casting MLIR types to custom user types.");
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");

View File

@ -37,6 +37,10 @@ bool mlirTypeIsATransformOperationType(MlirType type) {
return isa<transform::OperationType>(unwrap(type));
}
MlirTypeID mlirTransformOperationTypeGetTypeID(void) {
return wrap(transform::OperationType::getTypeID());
}
MlirType mlirTransformOperationTypeGet(MlirContext ctx,
MlirStringRef operationName) {
return wrap(

View File

@ -324,6 +324,10 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
}
MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) {
return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType());
}
//===----------------------------------------------------------------------===//
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//

View File

@ -832,6 +832,10 @@ MlirTypeID mlirTypeGetTypeID(MlirType type) {
return wrap(unwrap(type).getTypeID());
}
MlirDialect mlirTypeGetDialect(MlirType type) {
return wrap(&unwrap(type).getDialect());
}
bool mlirTypeEqual(MlirType t1, MlirType t2) {
return unwrap(t1) == unwrap(t2);
}

View File

@ -23,7 +23,6 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) {
//===----------------------------------------------------------------------===//
// TypeID API.
//===----------------------------------------------------------------------===//
MlirTypeID mlirTypeIDCreate(const void *ptr) {
assert(reinterpret_cast<uintptr_t>(ptr) % 8 == 0 &&
"ptr must be 8 byte aligned");

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType
def register_python_test_dialect(context, load=True):

View File

@ -4,6 +4,7 @@
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster
# Convenience decorator for registering user-friendly Attribute builders.

View File

@ -369,9 +369,9 @@ def testTensorValue():
# Classes of custom types that inherit from concrete types should have
# static_typeid
assert isinstance(test.TestTensorType.static_typeid, TypeID)
assert isinstance(test.TestIntegerRankedTensorType.static_typeid, TypeID)
# And it should be equal to the in-tree concrete type
assert test.TestTensorType.static_typeid == t.type.typeid
assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid
# CHECK-LABEL: TEST: inferReturnTypeComponents
@ -424,3 +424,46 @@ def inferReturnTypeComponents():
print("rank:", shaped_type_components.rank)
print("element type:", shaped_type_components.element_type)
print("shape:", shaped_type_components.shape)
# CHECK-LABEL: TEST: testCustomTypeTypeCaster
@run
def testCustomTypeTypeCaster():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)
a = test.TestType.get()
assert a.typeid is not None
b = Type.parse("!python_test.test_type")
# CHECK: !python_test.test_type
print(b)
# CHECK: TestType(!python_test.test_type)
print(repr(b))
c = test.TestIntegerRankedTensorType.get([10, 10], 5)
# CHECK: tensor<10x10xi5>
print(c)
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
print(repr(c))
# CHECK: Type caster is already registered
try:
def type_caster(pytype):
return test.TestIntegerRankedTensorType(pytype)
register_type_caster(c.typeid, type_caster)
except RuntimeError as e:
print(e)
def type_caster(pytype):
return test.TestIntegerRankedTensorType(pytype)
register_type_caster(c.typeid, type_caster, replace=True)
d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
# CHECK: tensor<10x10xi5>
print(d.type)
# CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
print(repr(d.type))

View File

@ -553,3 +553,42 @@ def testStridedLayoutAttr():
print(f"rank: {len(attr.strides)}")
# CHECK: strides are dynamic: [True, True, True]
print(f"strides are dynamic: {[s == dynamic for s in attr.strides]}")
# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
@run
def testConcreteTypesRoundTrip():
with Context(), Location.unknown():
def print_item(attr):
print(repr(attr.type))
# CHECK: F32Type(f32)
print_item(Attribute.parse("42.0 : f32"))
# CHECK: F32Type(f32)
print_item(FloatAttr.get_f32(42.0))
# CHECK: IntegerType(i64)
print_item(IntegerAttr.get(IntegerType.get_signless(64), 42))
def print_container_item(attr_asm):
attr = DenseElementsAttr(Attribute.parse(attr_asm))
print(repr(attr.type))
print(repr(attr.type.element_type))
# CHECK: RankedTensorType(tensor<i16>)
# CHECK: IntegerType(i16)
print_container_item("dense<123> : tensor<i16>")
# CHECK: RankedTensorType(tensor<f64>)
# CHECK: F64Type(f64)
print_container_item("dense<1.0> : tensor<f64>")
raw = Attribute.parse("vector<4xf32>")
# CHECK: attr: vector<4xf32>
print("attr:", raw)
type_attr = TypeAttr(raw)
# CHECK: VectorType(vector<4xf32>)
print(repr(type_attr.value))
# CHECK: F32Type(f32)
print(repr(type_attr.value.element_type))

View File

@ -2,6 +2,7 @@
import gc
from mlir.ir import *
from mlir.dialects import arith, tensor, func, memref
def run(f):
@ -382,15 +383,15 @@ def testMemRefType():
f32 = F32Type.get()
shape = [2, 3]
loc = Location.unknown()
memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2"))
# CHECK: memref type: memref<2x3xf32, 2>
print("memref type:", memref)
print("memref type:", memref_f32)
# CHECK: memref layout: affine_map<(d0, d1) -> (d0, d1)>
print("memref layout:", memref.layout)
print("memref layout:", memref_f32.layout)
# CHECK: memref affine map: (d0, d1) -> (d0, d1)
print("memref affine map:", memref.affine_map)
print("memref affine map:", memref_f32.affine_map)
# CHECK: memory space: 2
print("memory space:", memref.memory_space)
print("memory space:", memref_f32.memory_space)
layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0]))
memref_layout = MemRefType.get(shape, f32, layout=layout)
@ -413,7 +414,7 @@ def testMemRefType():
else:
print("Exception not produced")
assert memref.shape == shape
assert memref_f32.shape == shape
# CHECK-LABEL: TEST: testUnrankedMemRefType
@ -482,9 +483,9 @@ def testFunctionType():
input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)]
result_types = [IndexType.get()]
func = FunctionType.get(input_types, result_types)
# CHECK: INPUTS: [Type(i32), Type(i16)]
# CHECK: INPUTS: [IntegerType(i32), IntegerType(i16)]
print("INPUTS:", func.inputs)
# CHECK: RESULTS: [Type(index)]
# CHECK: RESULTS: [IndexType(index)]
print("RESULTS:", func.results)
@ -599,3 +600,130 @@ def testTypeIDs():
vector_type = Type.parse("vector<2x3xf32>")
# CHECK: True
print(ShapedType(vector_type).typeid == vector_type.typeid)
# CHECK-LABEL: TEST: testConcreteTypesRoundTrip
@run
def testConcreteTypesRoundTrip():
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
def print_downcasted(typ):
downcasted = Type(typ).maybe_downcast()
print(type(downcasted).__name__)
print(repr(downcasted))
# CHECK: F16Type
# CHECK: F16Type(f16)
print_downcasted(F16Type.get())
# CHECK: F32Type
# CHECK: F32Type(f32)
print_downcasted(F32Type.get())
# CHECK: F64Type
# CHECK: F64Type(f64)
print_downcasted(F64Type.get())
# CHECK: Float8E4M3B11FNUZType
# CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ)
print_downcasted(Float8E4M3B11FNUZType.get())
# CHECK: Float8E4M3FNType
# CHECK: Float8E4M3FNType(f8E4M3FN)
print_downcasted(Float8E4M3FNType.get())
# CHECK: Float8E4M3FNUZType
# CHECK: Float8E4M3FNUZType(f8E4M3FNUZ)
print_downcasted(Float8E4M3FNUZType.get())
# CHECK: Float8E5M2Type
# CHECK: Float8E5M2Type(f8E5M2)
print_downcasted(Float8E5M2Type.get())
# CHECK: Float8E5M2FNUZType
# CHECK: Float8E5M2FNUZType(f8E5M2FNUZ)
print_downcasted(Float8E5M2FNUZType.get())
# CHECK: BF16Type
# CHECK: BF16Type(bf16)
print_downcasted(BF16Type.get())
# CHECK: IndexType
# CHECK: IndexType(index)
print_downcasted(IndexType.get())
# CHECK: IntegerType
# CHECK: IntegerType(i32)
print_downcasted(IntegerType.get_signless(32))
f32 = F32Type.get()
ranked_tensor = tensor.EmptyOp([10, 10], f32).result
# CHECK: RankedTensorType
print(type(ranked_tensor.type).__name__)
# CHECK: RankedTensorType(tensor<10x10xf32>)
print(repr(ranked_tensor.type))
cf32 = ComplexType.get(f32)
# CHECK: ComplexType
print(type(cf32).__name__)
# CHECK: ComplexType(complex<f32>)
print(repr(cf32))
ranked_tensor = tensor.EmptyOp([10, 10], f32).result
# CHECK: RankedTensorType
print(type(ranked_tensor.type).__name__)
# CHECK: RankedTensorType(tensor<10x10xf32>)
print(repr(ranked_tensor.type))
vector = VectorType.get([10, 10], f32)
tuple_type = TupleType.get_tuple([f32, vector])
# CHECK: TupleType
print(type(tuple_type).__name__)
# CHECK: TupleType(tuple<f32, vector<10x10xf32>>)
print(repr(tuple_type))
# CHECK: F32Type(f32)
print(repr(tuple_type.get_type(0)))
# CHECK: VectorType(vector<10x10xf32>)
print(repr(tuple_type.get_type(1)))
index_type = IndexType.get()
@func.FuncOp.from_py_func()
def default_builder():
c0 = arith.ConstantOp(f32, 0.0)
unranked_tensor_type = UnrankedTensorType.get(f32)
unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result
# CHECK: UnrankedTensorType
print(type(unranked_tensor.type).__name__)
# CHECK: UnrankedTensorType(tensor<*xf32>)
print(repr(unranked_tensor.type))
c10 = arith.ConstantOp(index_type, 10)
memref_f32_t = MemRefType.get([10, 10], f32)
memref_f32 = memref.AllocOp(memref_f32_t, [c10, c10], []).result
# CHECK: MemRefType
print(type(memref_f32.type).__name__)
# CHECK: MemRefType(memref<10x10xf32>)
print(repr(memref_f32.type))
unranked_memref_t = UnrankedMemRefType.get(f32, Attribute.parse("2"))
memref_f32 = memref.AllocOp(unranked_memref_t, [c10, c10], []).result
# CHECK: UnrankedMemRefType
print(type(memref_f32.type).__name__)
# CHECK: UnrankedMemRefType(memref<*xf32, 2>)
print(repr(memref_f32.type))
tuple_type = Operation.parse(
f'"test.make_tuple"() : () -> tuple<i32, f32>'
).result
# CHECK: TupleType
print(type(tuple_type.type).__name__)
# CHECK: TupleType(tuple<i32, f32>)
print(repr(tuple_type.type))
return c0, c10
# CHECK-LABEL: TEST: testCustomTypeTypeCaster
# This tests being able to materialize a type from a dialect *and* have
# the implemented type caster called without explicitly importing the dialect.
# I.e., we get a transform.OperationType without explicitly importing the transform dialect.
@run
def testCustomTypeTypeCaster():
with Context() as ctx, Location.unknown():
t = Type.parse('!transform.op<"foo.bar">', Context())
# CHECK: !transform.op<"foo.bar">
print(t)
# CHECK: OperationType(!transform.op<"foo.bar">)
print(repr(t))

View File

@ -31,6 +31,10 @@ MlirType mlirPythonTestTestTypeGet(MlirContext context) {
return wrap(python_test::TestTypeType::get(unwrap(context)));
}
MlirTypeID mlirPythonTestTestTypeGetTypeID(void) {
return wrap(python_test::TestTypeType::getTypeID());
}
bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) {
return mlirTypeIsATensor(wrap(unwrap(value).getType()));
}

View File

@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestTypeGetTypeID(void);
MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value);
#ifdef __cplusplus

View File

@ -7,11 +7,19 @@
//===----------------------------------------------------------------------===//
#include "PythonTestCAPI.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
using namespace mlir::python::adaptors;
using namespace pybind11::literals;
static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
return mlirTypeIsARankedTensor(t) &&
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
PYBIND11_MODULE(_mlirPythonTest, m) {
m.def(
@ -34,16 +42,38 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
return cls(mlirPythonTestTestAttributeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType)
mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
mlirPythonTestTestTypeGetTypeID)
.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirPythonTestTestTypeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
mlir_type_subclass(m, "TestTensorType", mlirTypeIsARankedTensor,
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("RankedTensorType"));
auto cls =
mlir_type_subclass(m, "TestIntegerRankedTensorType",
mlirTypeIsARankedIntegerTensor,
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("RankedTensorType"))
.def_classmethod(
"get",
[](const py::object &cls, std::vector<int64_t> shape,
unsigned width, MlirContext ctx) {
MlirAttribute encoding = mlirAttributeGetNull();
return cls(mlirRankedTensorTypeGet(
shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
encoding));
},
"cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
assert(py::hasattr(cls.get_class(), "static_typeid") &&
"TestIntegerRankedTensorType has no static_typeid");
MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID();
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) {
return cls.get_class()(mlirType);
}),
/*replace=*/true);
mlir_value_subclass(m, "TestTensorValue",
mlirTypeIsAPythonTestTestTensorValue)
.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });