[mlir] Apply py::module_local() to a few more classes.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D109776
This commit is contained in:
Sean Silva 2021-09-14 21:55:54 +00:00
parent 336291e777
commit 8dca953dd3
3 changed files with 6 additions and 4 deletions

View File

@ -20,7 +20,7 @@ void mlir::python::populateDialectSparseTensorSubmodule(
py::module m, const py::module &irModule) { py::module m, const py::module &irModule) {
auto attributeClass = irModule.attr("Attribute"); auto attributeClass = irModule.attr("Attribute");
py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType") py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType", py::module_local())
.value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE)
.value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED)
.value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON); .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON);

View File

@ -678,7 +678,8 @@ public:
} }
static void bind(pybind11::module &m) { static void bind(pybind11::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol()); auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
pybind11::module_local());
cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>()); cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>());
DerivedTy::bindDerived(cls); DerivedTy::bindDerived(cls);
} }
@ -741,7 +742,7 @@ public:
} }
static void bind(pybind11::module &m) { static void bind(pybind11::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName); auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>()); cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>());
cls.def_static("isinstance", [](PyType &otherType) -> bool { cls.def_static("isinstance", [](PyType &otherType) -> bool {
return DerivedTy::isaFunction(otherType); return DerivedTy::isaFunction(otherType);

View File

@ -262,7 +262,8 @@ public:
/// Binds the indexing and length methods in the Python class. /// Binds the indexing and length methods in the Python class.
static void bind(pybind11::module &m) { static void bind(pybind11::module &m) {
auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName) auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
pybind11::module_local())
.def("__len__", &Sliceable::dunderLen) .def("__len__", &Sliceable::dunderLen)
.def("__getitem__", &Sliceable::dunderGetItem) .def("__getitem__", &Sliceable::dunderGetItem)
.def("__getitem__", &Sliceable::dunderGetItemSlice); .def("__getitem__", &Sliceable::dunderGetItemSlice);