mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-13 05:40:59 +00:00
[MLIR][python bindings] implement PyValue
subclassing to enable operator overloading
Differential Revision: https://reviews.llvm.org/D147758
This commit is contained in:
parent
bfa02523b2
commit
69cc3cfb21
@ -453,6 +453,62 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Creates a custom subclass of mlir.ir.Value, implementing a casting
|
||||
/// constructor and type checking methods.
|
||||
class mlir_value_subclass : public pure_subclass {
|
||||
public:
|
||||
using IsAFunctionTy = bool (*)(MlirValue);
|
||||
|
||||
/// Subclasses by looking up the super-class dynamically.
|
||||
mlir_value_subclass(py::handle scope, const char *valueClassName,
|
||||
IsAFunctionTy isaFunction)
|
||||
: mlir_value_subclass(
|
||||
scope, valueClassName, isaFunction,
|
||||
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) {
|
||||
}
|
||||
|
||||
/// Subclasses with a provided mlir.ir.Value super-class. This must
|
||||
/// be used if the subclass is being defined in the same extension module
|
||||
/// as the mlir.ir class (otherwise, it will trigger a recursive
|
||||
/// initialization).
|
||||
mlir_value_subclass(py::handle scope, const char *valueClassName,
|
||||
IsAFunctionTy isaFunction, const py::object &superCls)
|
||||
: pure_subclass(scope, valueClassName, superCls) {
|
||||
// Casting constructor. Note that it hard, if not impossible, to properly
|
||||
// call chain to parent `__init__` in pybind11 due to its special handling
|
||||
// for init functions that don't have a fully constructed self-reference,
|
||||
// which makes it impossible to forward it to `__init__` of a superclass.
|
||||
// Instead, provide a custom `__new__` and call that of a superclass, which
|
||||
// eventually calls `__init__` of the superclass. Since attribute subclasses
|
||||
// have no additional members, we can just return the instance thus created
|
||||
// without amending it.
|
||||
std::string captureValueName(
|
||||
valueClassName); // As string in case if valueClassName is not static.
|
||||
py::cpp_function newCf(
|
||||
[superCls, isaFunction, captureValueName](py::object cls,
|
||||
py::object otherValue) {
|
||||
MlirValue rawValue = py::cast<MlirValue>(otherValue);
|
||||
if (!isaFunction(rawValue)) {
|
||||
auto origRepr = py::repr(otherValue).cast<std::string>();
|
||||
throw std::invalid_argument((llvm::Twine("Cannot cast value to ") +
|
||||
captureValueName + " (from " +
|
||||
origRepr + ")")
|
||||
.str());
|
||||
}
|
||||
py::object self = superCls.attr("__new__")(cls, otherValue);
|
||||
return self;
|
||||
},
|
||||
py::name("__new__"), py::arg("cls"), py::arg("cast_from_value"));
|
||||
thisClass.attr("__new__") = newCf;
|
||||
|
||||
// 'isinstance' method.
|
||||
def_staticmethod(
|
||||
"isinstance",
|
||||
[isaFunction](MlirValue other) { return isaFunction(other); },
|
||||
py::arg("other_value"));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace adaptors
|
||||
} // namespace python
|
||||
} // namespace mlir
|
||||
|
@ -3260,6 +3260,7 @@ void mlir::python::populateIRCore(py::module &m) {
|
||||
// Mapping of Value.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyValue>(m, "Value", py::module_local())
|
||||
.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"))
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
|
||||
.def_property_readonly(
|
||||
|
@ -3,7 +3,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._python_test_ops_gen import *
|
||||
from .._mlir_libs._mlirPythonTest import TestAttr, TestType
|
||||
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue
|
||||
|
||||
def register_python_test_dialect(context, load=True):
|
||||
from .._mlir_libs import _mlirPythonTest
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from mlir.ir import *
|
||||
import mlir.dialects.python_test as test
|
||||
import mlir.dialects.tensor as tensor
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
@ -302,3 +303,30 @@ def testCustomType():
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
@run
|
||||
# CHECK-LABEL: TEST: testTensorValue
|
||||
def testTensorValue():
|
||||
with Context() as ctx, Location.unknown():
|
||||
test.register_python_test_dialect(ctx)
|
||||
|
||||
i8 = IntegerType.get_signless(8)
|
||||
|
||||
class Tensor(test.TestTensorValue):
|
||||
def __str__(self):
|
||||
return super().__str__().replace("Value", "Tensor")
|
||||
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
t = tensor.EmptyOp([10, 10], i8).result
|
||||
|
||||
# CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
|
||||
print(Value(t))
|
||||
|
||||
tt = Tensor(t)
|
||||
# CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
|
||||
print(tt)
|
||||
|
||||
# CHECK: False
|
||||
print(tt.is_null())
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include "PythonTestCAPI.h"
|
||||
#include "PythonTestDialect.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/CAPI/Wrap.h"
|
||||
|
||||
@ -29,3 +30,7 @@ bool mlirTypeIsAPythonTestTestType(MlirType type) {
|
||||
MlirType mlirPythonTestTestTypeGet(MlirContext context) {
|
||||
return wrap(python_test::TestTypeType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) {
|
||||
return mlirTypeIsATensor(wrap(unwrap(value).getType()));
|
||||
}
|
||||
|
@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -40,4 +40,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
|
||||
return cls(mlirPythonTestTestTypeGet(ctx));
|
||||
},
|
||||
py::arg("cls"), py::arg("context") = py::none());
|
||||
mlir_value_subclass(m, "TestTensorValue",
|
||||
mlirTypeIsAPythonTestTestTensorValue)
|
||||
.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user