[MLIR][python bindings] implement PyValue subclassing to enable operator overloading

Differential Revision: https://reviews.llvm.org/D147758
This commit is contained in:
max 2023-04-14 14:20:33 -05:00
parent bfa02523b2
commit 69cc3cfb21
7 changed files with 96 additions and 1 deletions

View File

@ -453,6 +453,62 @@ public:
}
};
/// Creates a custom subclass of mlir.ir.Value, implementing a casting
/// constructor and type checking methods.
class mlir_value_subclass : public pure_subclass {
public:
using IsAFunctionTy = bool (*)(MlirValue);
/// Subclasses by looking up the super-class dynamically.
mlir_value_subclass(py::handle scope, const char *valueClassName,
IsAFunctionTy isaFunction)
: mlir_value_subclass(
scope, valueClassName, isaFunction,
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) {
}
/// Subclasses with a provided mlir.ir.Value super-class. This must
/// be used if the subclass is being defined in the same extension module
/// as the mlir.ir class (otherwise, it will trigger a recursive
/// initialization).
mlir_value_subclass(py::handle scope, const char *valueClassName,
IsAFunctionTy isaFunction, const py::object &superCls)
: pure_subclass(scope, valueClassName, superCls) {
// Casting constructor. Note that it hard, if not impossible, to properly
// call chain to parent `__init__` in pybind11 due to its special handling
// for init functions that don't have a fully constructed self-reference,
// which makes it impossible to forward it to `__init__` of a superclass.
// Instead, provide a custom `__new__` and call that of a superclass, which
// eventually calls `__init__` of the superclass. Since attribute subclasses
// have no additional members, we can just return the instance thus created
// without amending it.
std::string captureValueName(
valueClassName); // As string in case if valueClassName is not static.
py::cpp_function newCf(
[superCls, isaFunction, captureValueName](py::object cls,
py::object otherValue) {
MlirValue rawValue = py::cast<MlirValue>(otherValue);
if (!isaFunction(rawValue)) {
auto origRepr = py::repr(otherValue).cast<std::string>();
throw std::invalid_argument((llvm::Twine("Cannot cast value to ") +
captureValueName + " (from " +
origRepr + ")")
.str());
}
py::object self = superCls.attr("__new__")(cls, otherValue);
return self;
},
py::name("__new__"), py::arg("cls"), py::arg("cast_from_value"));
thisClass.attr("__new__") = newCf;
// 'isinstance' method.
def_staticmethod(
"isinstance",
[isaFunction](MlirValue other) { return isaFunction(other); },
py::arg("other_value"));
}
};
} // namespace adaptors
} // namespace python
} // namespace mlir

View File

@ -3260,6 +3260,7 @@ void mlir::python::populateIRCore(py::module &m) {
// Mapping of Value.
//----------------------------------------------------------------------------
py::class_<PyValue>(m, "Value", py::module_local())
.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"))
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
.def_property_readonly(

View File

@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
from .._mlir_libs._mlirPythonTest import TestAttr, TestType
from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue
def register_python_test_dialect(context, load=True):
from .._mlir_libs import _mlirPythonTest

View File

@ -2,6 +2,7 @@
from mlir.ir import *
import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
def run(f):
print("\nTEST:", f.__name__)
@ -302,3 +303,30 @@ def testCustomType():
pass
else:
raise
@run
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
with Context() as ctx, Location.unknown():
test.register_python_test_dialect(ctx)
i8 = IntegerType.get_signless(8)
class Tensor(test.TestTensorValue):
def __str__(self):
return super().__str__().replace("Value", "Tensor")
module = Module.create()
with InsertionPoint(module.body):
t = tensor.EmptyOp([10, 10], i8).result
# CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
print(Value(t))
tt = Tensor(t)
# CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
print(tt)
# CHECK: False
print(tt.is_null())

View File

@ -8,6 +8,7 @@
#include "PythonTestCAPI.h"
#include "PythonTestDialect.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Wrap.h"
@ -29,3 +30,7 @@ bool mlirTypeIsAPythonTestTestType(MlirType type) {
MlirType mlirPythonTestTestTypeGet(MlirContext context) {
return wrap(python_test::TestTypeType::get(unwrap(context)));
}
bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value) {
return mlirTypeIsATensor(wrap(unwrap(value).getType()));
}

View File

@ -27,6 +27,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestTensorValue(MlirValue value);
#ifdef __cplusplus
}
#endif

View File

@ -40,4 +40,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
return cls(mlirPythonTestTestTypeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
mlir_value_subclass(m, "TestTensorValue",
mlirTypeIsAPythonTestTestTensorValue)
.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
}