[mlir] Quality of life improvements to python API types. (#66723)

* Moves several orphaned methods from Operation/OpView -> _OperationBase
so that both hierarchies share them (whether unknown or known to ODS).
* Adds typing information for missing `MLIRError` exception.
* Adds `DiagnosticInfo` typing.
* Adds `DenseResourceElementsAttr` typing that was missing.
This commit is contained in:
Stella Laurenzo 2023-09-18 21:30:41 -07:00 committed by GitHub
parent acfb99d9fd
commit 33df617dfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 67 additions and 39 deletions

View File

@ -2768,6 +2768,24 @@ void mlir::python::populateIRCore(py::module &m) {
return PyOpAttributeMap(
self.getOperation().getRef());
})
.def_property_readonly(
"context",
[](PyOperationBase &self) {
PyOperation &concreteOperation = self.getOperation();
concreteOperation.checkValid();
return concreteOperation.getContext().getObject();
},
"Context that owns the Operation")
.def_property_readonly("name",
[](PyOperationBase &self) {
auto &concreteOperation = self.getOperation();
concreteOperation.checkValid();
MlirOperation operation =
concreteOperation.get();
MlirStringRef name = mlirIdentifierStr(
mlirOperationGetName(operation));
return py::str(name.data, name.length);
})
.def_property_readonly("operands",
[](PyOperationBase &self) {
return PyOpOperandList(
@ -2813,6 +2831,14 @@ void mlir::python::populateIRCore(py::module &m) {
},
"Returns the source location the operation was defined or derived "
"from.")
.def_property_readonly("parent",
[](PyOperationBase &self) -> py::object {
auto parent =
self.getOperation().getParentOperation();
if (parent)
return parent->getObject();
return py::none();
})
.def(
"__str__",
[](PyOperationBase &self) {
@ -2855,6 +2881,12 @@ void mlir::python::populateIRCore(py::module &m) {
.def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
"Puts self immediately before the other operation in its parent "
"block.")
.def(
"clone",
[](PyOperationBase &self, py::object ip) {
return self.getOperation().clone(ip);
},
py::arg("ip") = py::none())
.def(
"detach_from_parent",
[](PyOperationBase &self) {
@ -2866,7 +2898,8 @@ void mlir::python::populateIRCore(py::module &m) {
operation.detachFromParent();
return operation.createOpView();
},
"Detaches the operation from its parent block.");
"Detaches the operation from its parent block.")
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); });
py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
.def_static("create", &PyOperation::create, py::arg("name"),
@ -2887,45 +2920,17 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("context") = py::none(),
"Parses an operation. Supports both text assembly format and binary "
"bytecode format.")
.def_property_readonly("parent",
[](PyOperation &self) -> py::object {
auto parent = self.getParentOperation();
if (parent)
return parent->getObject();
return py::none();
})
.def("erase", &PyOperation::erase)
.def("clone", &PyOperation::clone, py::arg("ip") = py::none())
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyOperation::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
.def_property_readonly("name",
[](PyOperation &self) {
self.checkValid();
MlirOperation operation = self.get();
MlirStringRef name = mlirIdentifierStr(
mlirOperationGetName(operation));
return py::str(name.data, name.length);
})
.def_property_readonly(
"context",
[](PyOperation &self) {
self.checkValid();
return self.getContext().getObject();
},
"Context that owns the Operation")
.def_property_readonly("operation", [](py::object self) { return self; })
.def_property_readonly("opview", &PyOperation::createOpView);
auto opViewClass =
py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
.def(py::init<py::object>(), py::arg("operation"))
.def_property_readonly("operation", &PyOpView::getOperationObject)
.def_property_readonly(
"context",
[](PyOpView &self) {
return self.getOperation().getContext().getObject();
},
"Context that owns the Operation")
.def_property_readonly("opview", [](py::object self) { return self; })
.def("__str__", [](PyOpView &self) {
return py::str(self.getOperationObject());
});

View File

@ -43,11 +43,13 @@ __all__ = [
"DenseElementsAttr",
"DenseFPElementsAttr",
"DenseIntElementsAttr",
"DenseResourceElementsAttr",
"Dialect",
"DialectDescriptor",
"Dialects",
"Diagnostic",
"DiagnosticHandler",
"DiagnosticInfo",
"DiagnosticSeverity",
"DictAttr",
"Float8E4M3FNType",
@ -74,6 +76,7 @@ __all__ = [
"Location",
"MemRefType",
"Module",
"MLIRError",
"NamedAttribute",
"NoneType",
"OpaqueType",
@ -123,10 +126,16 @@ class _OperationBase:
@property
def attributes(self) -> OpAttributeMap: ...
@property
def context(self) -> Context: ...
@property
def location(self) -> Location: ...
@property
def name(self) -> str: ...
@property
def operands(self) -> OpOperandList: ...
@property
@property
def parent(self) -> Optional[_OperationBase]: ...
def regions(self) -> RegionSequence: ...
@property
def result(self) -> OpResult: ...
@ -530,6 +539,10 @@ class DenseIntElementsAttr(DenseElementsAttr):
@property
def type(self) -> Type: ...
class DenseResourceElementsAttr(Attribute):
@staticmethod
def get_from_buffer(array: Any, name: str, type: Type, alignment: Optional[int] = None, is_mutable: bool = False, context: Optional[Context] = None) -> None: ...
class Dialect:
def __init__(self, descriptor: DialectDescriptor) -> None: ...
@property
@ -563,6 +576,17 @@ class DiagnosticHandler:
def __enter__(self) -> DiagnosticHandler: ...
def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
class DiagnosticInfo:
def __init__(self, diag: Diagnostic) -> None: ...
@property
def severity(self) -> "DiagnosticSeverity": ...
@property
def location(self) -> "Location": ...
@property
def message(self) -> str: ...
@property
def notes(self) -> Sequence["DiagnosticInfo"]: ...
class DiagnosticSeverity:
ERROR: DiagnosticSeverity
WARNING: DiagnosticSeverity
@ -871,6 +895,9 @@ class Module:
@property
def operation(self) -> Operation: ...
class MLIRError(Exception):
def __init__(self, message: str, error_diagnostics: List[DiagnosticInfo]) -> None: ...
class NamedAttribute:
@property
def attr(self) -> Attribute: ...
@ -950,9 +977,9 @@ class OpView(_OperationBase):
loc: Optional[Location] = None,
ip: Optional[InsertionPoint] = None) -> _TOperation: ...
@property
def context(self) -> Context: ...
@property
def operation(self) -> Operation: ...
@property
def opview(self) -> "OpView": ...
class Operation(_OperationBase):
def _CAPICreate(self) -> object: ...
@ -968,13 +995,9 @@ class Operation(_OperationBase):
@property
def _CAPIPtr(self) -> object: ...
@property
def context(self) -> Context: ...
@property
def name(self) -> str: ...
def operation(self) -> "Operation": ...
@property
def opview(self) -> OpView: ...
@property
def parent(self) -> Optional[_OperationBase]: ...
class OperationIterator:
def __iter__(self) -> OperationIterator: ...

View File

@ -20,6 +20,6 @@ class PassManager:
def enable_verifier(self, enable: bool) -> None: ...
@staticmethod
def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ...
def run(self, module: _ir.Module) -> None: ...
def run(self, module: _ir._OperationBase) -> None: ...
@property
def _CAPIPtr(self) -> object: ...