mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-28 16:11:29 +00:00
[mlir] support interfaces in Python bindings
Introduce the initial support for operation interfaces in C API and Python bindings. Interfaces are a key component of MLIR's extensibility and should be available in bindings to make use of full potential of MLIR. This initial implementation exposes InferTypeOpInterface all the way to the Python bindings since it can be later used to simplify the operation construction methods by inferring their return types instead of requiring the user to do so. The general infrastructure for binding interfaces is defined and InferTypeOpInterface can be used as an example for binding other interfaces. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D111656
This commit is contained in:
parent
1f49b71fe5
commit
14c9207063
@ -123,6 +123,7 @@ add_subdirectory(include/mlir)
|
||||
add_subdirectory(lib)
|
||||
# C API needs all dialects for registration, but should be built before tests.
|
||||
add_subdirectory(lib/CAPI)
|
||||
|
||||
if (MLIR_INCLUDE_TESTS)
|
||||
add_definitions(-DMLIR_INCLUDE_TESTS)
|
||||
add_custom_target(MLIRUnitTests)
|
||||
|
@ -536,6 +536,68 @@ except ValueError:
|
||||
concrete = OpResult(value)
|
||||
```
|
||||
|
||||
#### Interfaces
|
||||
|
||||
MLIR interfaces are a mechanism to interact with the IR without needing to know
|
||||
specific types of operations but only some of their aspects. Operation
|
||||
interfaces are available as Python classes with the same name as their C++
|
||||
counterparts. Objects of these classes can be constructed from either:
|
||||
|
||||
- an object of the `Operation` class or of any `OpView` subclass; in this
|
||||
case, all interface methods are available;
|
||||
- a subclass of `OpView` and a context; in this case, only the *static*
|
||||
interface methods are available as there is no associated operation.
|
||||
|
||||
In both cases, construction of the interface raises a `ValueError` if the
|
||||
operation class does not implement the interface in the given context (or, for
|
||||
operations, in the context that the operation is defined in). Similarly to
|
||||
attributes and types, the MLIR context may be set up by a surrounding context
|
||||
manager.
|
||||
|
||||
```python
|
||||
from mlir.ir import Context, InferTypeOpInterface
|
||||
|
||||
with Context():
|
||||
op = <...>
|
||||
|
||||
# Attempt to cast the operation into an interface.
|
||||
try:
|
||||
iface = InferTypeOpInterface(op)
|
||||
except ValueError:
|
||||
print("Operation does not implement InferTypeOpInterface.")
|
||||
raise
|
||||
|
||||
# All methods are available on interface objects constructed from an Operation
|
||||
# or an OpView.
|
||||
iface.someInstanceMethod()
|
||||
|
||||
# An interface object can also be constructed given an OpView subclass. It
|
||||
# also needs a context in which the interface will be looked up. The context
|
||||
# can be provided explicitly or set up by the surrounding context manager.
|
||||
try:
|
||||
iface = InferTypeOpInterface(some_dialect.SomeOp)
|
||||
except ValueError:
|
||||
print("SomeOp does not implement InferTypeOpInterface.")
|
||||
raise
|
||||
|
||||
# Calling an instance method on an interface object constructed from a class
|
||||
# will raise TypeError.
|
||||
try:
|
||||
iface.someInstanceMethod()
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
# One can still call static interface methods though.
|
||||
iface.inferOpReturnTypes(<...>)
|
||||
```
|
||||
|
||||
If an interface object was constructed from an `Operation` or an `OpView`, they
|
||||
are available as `.operation` and `.opview` properties of the interface object,
|
||||
respectively.
|
||||
|
||||
Only a subset of operation interfaces are currently provided in Python bindings.
|
||||
Attribute and type interfaces are not yet available in Python bindings.
|
||||
|
||||
### Creating IR Objects
|
||||
|
||||
Python bindings also support IR creation and manipulation.
|
||||
|
@ -194,3 +194,23 @@ counterparts. `wrap` converts a C++ class into a C structure and `unwrap` does
|
||||
the inverse conversion. Once the C++ object is available, the API implementation
|
||||
should rely on `isa` to implement `mlirXIsAY` and is expected to use `cast`
|
||||
inside other API calls.
|
||||
|
||||
### Extensions for Interfaces
|
||||
|
||||
Interfaces can follow the example of IR interfaces and should be placed in the
|
||||
appropriate library (e.g., common interfaces in `mlir-c/Interfaces` and
|
||||
dialect-specific interfaces in their dialect library). Similarly to other type
|
||||
hierarchies, interfaces are not expected to have objects of their own type and
|
||||
instead operate on top-level objects: `MlirAttribute`, `MlirOperation` and
|
||||
`MlirType`. Static interface methods are expected to take as leading argument a
|
||||
canonical identifier of the class, `MlirStringRef` with the name for operations
|
||||
and `MlirTypeID` for attributes and types, followed by `MlirContext` in which
|
||||
the interfaces are registered.
|
||||
|
||||
Individual interfaces are expected provide a `mlir<InterfaceName>TypeID()`
|
||||
function that can be used to check whether an object or a class implements this
|
||||
interface using `mlir<Attribute/Operation/Type>ImplementsInterface` or
|
||||
`mlir<Attribute/Operation?Type>ImplementsInterfaceStatic` functions,
|
||||
respectively. Rationale: C++ `isa` only works when an object exists, static
|
||||
methods are usually dispatched to using templates; lookup by `TypeID` in
|
||||
`MLIRContext` works even without an object.
|
||||
|
67
mlir/include/mlir-c/Interfaces.h
Normal file
67
mlir/include/mlir-c/Interfaces.h
Normal file
@ -0,0 +1,67 @@
|
||||
//===-- mlir-c/Interfaces.h - C API to Core MLIR IR interfaces ----*- C -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header declares the C interface to MLIR interface classes. It is
|
||||
// intended to contain interfaces defined in lib/Interfaces.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_C_DIALECT_H
|
||||
#define MLIR_C_DIALECT_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/// Returns `true` if the given operation implements an interface identified by
|
||||
/// its TypeID.
|
||||
MLIR_CAPI_EXPORTED bool
|
||||
mlirOperationImplementsInterface(MlirOperation operation,
|
||||
MlirTypeID interfaceTypeID);
|
||||
|
||||
/// Returns `true` if the operation identified by its canonical string name
|
||||
/// implements the interface identified by its TypeID in the given context.
|
||||
/// Note that interfaces may be attached to operations in some contexts and not
|
||||
/// others.
|
||||
MLIR_CAPI_EXPORTED bool
|
||||
mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
|
||||
MlirContext context,
|
||||
MlirTypeID interfaceTypeID);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InferTypeOpInterface.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the interface TypeID of the InferTypeOpInterface.
|
||||
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID();
|
||||
|
||||
/// These callbacks are used to return multiple types from functions while
|
||||
/// transferring ownerhsip to the caller. The first argument is the number of
|
||||
/// consecutive elements pointed to by the second argument. The third argument
|
||||
/// is an opaque pointer forwarded to the callback by the caller.
|
||||
typedef void (*MlirTypesCallback)(intptr_t, MlirType *, void *);
|
||||
|
||||
/// Infers the return types of the operation identified by its canonical given
|
||||
/// the arguments that will be supplied to its generic builder. Calls `callback`
|
||||
/// with the types of inferred arguments, potentially several times, on success.
|
||||
/// Returns failure otherwise.
|
||||
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
|
||||
MlirStringRef opName, MlirContext context, MlirLocation location,
|
||||
intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
|
||||
intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback,
|
||||
void *userData);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLIR_C_DIALECT_H
|
18
mlir/include/mlir/CAPI/Interfaces.h
Normal file
18
mlir/include/mlir/CAPI/Interfaces.h
Normal file
@ -0,0 +1,18 @@
|
||||
//===- Interfaces.h - C API Utils for MLIR interfaces -----------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains declarations of implementation details of the C API for
|
||||
// MLIR interface classes. This file should not be included from C++ code other
|
||||
// than C API implementation nor from C code.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CAPI_INTERFACES_H
|
||||
#define MLIR_CAPI_INTERFACES_H
|
||||
|
||||
#endif // MLIR_CAPI_INTERFACES_H
|
240
mlir/lib/Bindings/Python/IRInterfaces.cpp
Normal file
240
mlir/lib/Bindings/Python/IRInterfaces.cpp
Normal file
@ -0,0 +1,240 @@
|
||||
//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "IRModule.h"
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/Interfaces.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
|
||||
constexpr static const char *constructorDoc =
|
||||
R"(Creates an interface from a given operation/opview object or from a
|
||||
subclass of OpView. Raises ValueError if the operation does not implement the
|
||||
interface.)";
|
||||
|
||||
constexpr static const char *operationDoc =
|
||||
R"(Returns an Operation for which the interface was constructed.)";
|
||||
|
||||
constexpr static const char *opviewDoc =
|
||||
R"(Returns an OpView subclass _instance_ for which the interface was
|
||||
constructed)";
|
||||
|
||||
constexpr static const char *inferReturnTypesDoc =
|
||||
R"(Given the arguments required to build an operation, attempts to infer
|
||||
its return types. Raises ValueError on failure.)";
|
||||
|
||||
/// CRTP base class for Python classes representing MLIR Op interfaces.
|
||||
/// Interface hierarchies are flat so no base class is expected here. The
|
||||
/// derived class is expected to define the following static fields:
|
||||
/// - `const char *pyClassName` - the name of the Python class to create;
|
||||
/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
|
||||
/// of the interface.
|
||||
/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
|
||||
/// interface-specific methods.
|
||||
///
|
||||
/// An interface class may be constructed from either an Operation/OpView object
|
||||
/// or from a subclass of OpView. In the latter case, only the static interface
|
||||
/// methods are available, similarly to calling ConcereteOp::staticMethod on the
|
||||
/// C++ side. Implementations of concrete interfaces can use the `isStatic`
|
||||
/// method to check whether the interface object was constructed from a class or
|
||||
/// an operation/opview instance. The `getOpName` always succeeds and returns a
|
||||
/// canonical name of the operation suitable for lookups.
|
||||
template <typename ConcreteIface>
|
||||
class PyConcreteOpInterface {
|
||||
protected:
|
||||
using ClassTy = py::class_<ConcreteIface>;
|
||||
using GetTypeIDFunctionTy = MlirTypeID (*)();
|
||||
|
||||
public:
|
||||
/// Constructs an interface instance from an object that is either an
|
||||
/// operation or a subclass of OpView. In the latter case, only the static
|
||||
/// methods of the interface are accessible to the caller.
|
||||
PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
|
||||
: obj(object) {
|
||||
try {
|
||||
operation = &py::cast<PyOperation &>(obj);
|
||||
} catch (py::cast_error &err) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
try {
|
||||
operation = &py::cast<PyOpView &>(obj).getOperation();
|
||||
} catch (py::cast_error &err) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
if (operation != nullptr) {
|
||||
if (!mlirOperationImplementsInterface(*operation,
|
||||
ConcreteIface::getInterfaceID())) {
|
||||
std::string msg = "the operation does not implement ";
|
||||
throw py::value_error(msg + ConcreteIface::pyClassName);
|
||||
}
|
||||
|
||||
MlirIdentifier identifier = mlirOperationGetName(*operation);
|
||||
MlirStringRef stringRef = mlirIdentifierStr(identifier);
|
||||
opName = std::string(stringRef.data, stringRef.length);
|
||||
} else {
|
||||
try {
|
||||
opName = obj.attr("OPERATION_NAME").template cast<std::string>();
|
||||
} catch (py::cast_error &err) {
|
||||
throw py::type_error(
|
||||
"Op interface does not refer to an operation or OpView class");
|
||||
}
|
||||
|
||||
if (!mlirOperationImplementsInterfaceStatic(
|
||||
mlirStringRefCreate(opName.data(), opName.length()),
|
||||
context.resolve().get(), ConcreteIface::getInterfaceID())) {
|
||||
std::string msg = "the operation does not implement ";
|
||||
throw py::value_error(msg + ConcreteIface::pyClassName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates the Python bindings for this class in the given module.
|
||||
static void bind(py::module &m) {
|
||||
py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
|
||||
py::module_local());
|
||||
cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
|
||||
py::arg("context") = py::none(), constructorDoc)
|
||||
.def_property_readonly("operation",
|
||||
&PyConcreteOpInterface::getOperationObject,
|
||||
operationDoc)
|
||||
.def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
|
||||
opviewDoc);
|
||||
ConcreteIface::bindDerived(cls);
|
||||
}
|
||||
|
||||
/// Hook for derived classes to add class-specific bindings.
|
||||
static void bindDerived(ClassTy &cls) {}
|
||||
|
||||
/// Returns `true` if this object was constructed from a subclass of OpView
|
||||
/// rather than from an operation instance.
|
||||
bool isStatic() { return operation == nullptr; }
|
||||
|
||||
/// Returns the operation instance from which this object was constructed.
|
||||
/// Throws a type error if this object was constructed from a subclass of
|
||||
/// OpView.
|
||||
py::object getOperationObject() {
|
||||
if (operation == nullptr) {
|
||||
throw py::type_error("Cannot get an operation from a static interface");
|
||||
}
|
||||
|
||||
return operation->getRef().releaseObject();
|
||||
}
|
||||
|
||||
/// Returns the opview of the operation instance from which this object was
|
||||
/// constructed. Throws a type error if this object was constructed form a
|
||||
/// subclass of OpView.
|
||||
py::object getOpView() {
|
||||
if (operation == nullptr) {
|
||||
throw py::type_error("Cannot get an opview from a static interface");
|
||||
}
|
||||
|
||||
return operation->createOpView();
|
||||
}
|
||||
|
||||
/// Returns the canonical name of the operation this interface is constructed
|
||||
/// from.
|
||||
const std::string &getOpName() { return opName; }
|
||||
|
||||
private:
|
||||
PyOperation *operation = nullptr;
|
||||
std::string opName;
|
||||
py::object obj;
|
||||
};
|
||||
|
||||
/// Python wrapper for InterTypeOpInterface. This interface has only static
|
||||
/// methods.
|
||||
class PyInferTypeOpInterface
|
||||
: public PyConcreteOpInterface<PyInferTypeOpInterface> {
|
||||
public:
|
||||
using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
|
||||
|
||||
constexpr static const char *pyClassName = "InferTypeOpInterface";
|
||||
constexpr static GetTypeIDFunctionTy getInterfaceID =
|
||||
&mlirInferTypeOpInterfaceTypeID;
|
||||
|
||||
/// C-style user-data structure for type appending callback.
|
||||
struct AppendResultsCallbackData {
|
||||
std::vector<PyType> &inferredTypes;
|
||||
PyMlirContext &pyMlirContext;
|
||||
};
|
||||
|
||||
/// Appends the types provided as the two first arguments to the user-data
|
||||
/// structure (expects AppendResultsCallbackData).
|
||||
static void appendResultsCallback(intptr_t nTypes, MlirType *types,
|
||||
void *userData) {
|
||||
auto *data = static_cast<AppendResultsCallbackData *>(userData);
|
||||
data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
|
||||
for (intptr_t i = 0; i < nTypes; ++i) {
|
||||
data->inferredTypes.push_back(
|
||||
PyType(data->pyMlirContext.getRef(), types[i]));
|
||||
}
|
||||
}
|
||||
|
||||
/// Given the arguments required to build an operation, attempts to infer its
|
||||
/// return types. Throws value_error on faliure.
|
||||
std::vector<PyType>
|
||||
inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands,
|
||||
llvm::Optional<PyAttribute> attributes,
|
||||
llvm::Optional<std::vector<PyRegion>> regions,
|
||||
DefaultingPyMlirContext context,
|
||||
DefaultingPyLocation location) {
|
||||
llvm::SmallVector<MlirValue> mlirOperands;
|
||||
llvm::SmallVector<MlirRegion> mlirRegions;
|
||||
|
||||
if (operands) {
|
||||
mlirOperands.reserve(operands->size());
|
||||
for (PyValue &value : *operands) {
|
||||
mlirOperands.push_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
if (regions) {
|
||||
mlirRegions.reserve(regions->size());
|
||||
for (PyRegion ®ion : *regions) {
|
||||
mlirRegions.push_back(region);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<PyType> inferredTypes;
|
||||
PyMlirContext &pyContext = context.resolve();
|
||||
AppendResultsCallbackData data{inferredTypes, pyContext};
|
||||
MlirStringRef opNameRef =
|
||||
mlirStringRefCreate(getOpName().data(), getOpName().length());
|
||||
MlirAttribute attributeDict =
|
||||
attributes ? attributes->get() : mlirAttributeGetNull();
|
||||
|
||||
MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
|
||||
opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
|
||||
mlirOperands.data(), attributeDict, mlirRegions.size(),
|
||||
mlirRegions.data(), &appendResultsCallback, &data);
|
||||
|
||||
if (mlirLogicalResultIsFailure(result)) {
|
||||
throw py::value_error("Failed to infer result types");
|
||||
}
|
||||
|
||||
return inferredTypes;
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &cls) {
|
||||
cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
|
||||
py::arg("operands") = py::none(),
|
||||
py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
|
||||
py::arg("context") = py::none(), py::arg("loc") = py::none(),
|
||||
inferReturnTypesDoc);
|
||||
}
|
||||
};
|
||||
|
||||
void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); }
|
||||
|
||||
} // namespace python
|
||||
} // namespace mlir
|
@ -859,6 +859,7 @@ private:
|
||||
void populateIRAffine(pybind11::module &m);
|
||||
void populateIRAttributes(pybind11::module &m);
|
||||
void populateIRCore(pybind11::module &m);
|
||||
void populateIRInterfaces(pybind11::module &m);
|
||||
void populateIRTypes(pybind11::module &m);
|
||||
|
||||
} // namespace python
|
||||
|
@ -85,6 +85,7 @@ PYBIND11_MODULE(_mlir, m) {
|
||||
populateIRCore(irModule);
|
||||
populateIRAffine(irModule);
|
||||
populateIRAttributes(irModule);
|
||||
populateIRInterfaces(irModule);
|
||||
populateIRTypes(irModule);
|
||||
|
||||
// Define and populate PassManager submodule.
|
||||
|
@ -2,6 +2,7 @@ add_subdirectory(Debug)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(ExecutionEngine)
|
||||
add_subdirectory(Interfaces)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Registration)
|
||||
add_subdirectory(Transforms)
|
||||
|
5
mlir/lib/CAPI/Interfaces/CMakeLists.txt
Normal file
5
mlir/lib/CAPI/Interfaces/CMakeLists.txt
Normal file
@ -0,0 +1,5 @@
|
||||
add_mlir_public_c_api_library(MLIRCAPIInterfaces
|
||||
Interfaces.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRInferTypeOpInterface)
|
82
mlir/lib/CAPI/Interfaces/Interfaces.cpp
Normal file
82
mlir/lib/CAPI/Interfaces/Interfaces.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir-c/Interfaces.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Wrap.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
bool mlirOperationImplementsInterface(MlirOperation operation,
|
||||
MlirTypeID interfaceTypeID) {
|
||||
const AbstractOperation *abstractOp =
|
||||
unwrap(operation)->getAbstractOperation();
|
||||
return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID));
|
||||
}
|
||||
|
||||
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
|
||||
MlirContext context,
|
||||
MlirTypeID interfaceTypeID) {
|
||||
const AbstractOperation *abstractOp = AbstractOperation::lookup(
|
||||
StringRef(operationName.data, operationName.length), unwrap(context));
|
||||
return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID));
|
||||
}
|
||||
|
||||
MlirTypeID mlirInferTypeOpInterfaceTypeID() {
|
||||
return wrap(InferTypeOpInterface::getInterfaceID());
|
||||
}
|
||||
|
||||
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
|
||||
MlirStringRef opName, MlirContext context, MlirLocation location,
|
||||
intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
|
||||
intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback,
|
||||
void *userData) {
|
||||
StringRef name(opName.data, opName.length);
|
||||
const AbstractOperation *abstractOp =
|
||||
AbstractOperation::lookup(name, unwrap(context));
|
||||
if (!abstractOp)
|
||||
return mlirLogicalResultFailure();
|
||||
|
||||
llvm::Optional<Location> maybeLocation = llvm::None;
|
||||
if (!mlirLocationIsNull(location))
|
||||
maybeLocation = unwrap(location);
|
||||
SmallVector<Value> unwrappedOperands;
|
||||
(void)unwrapList(nOperands, operands, unwrappedOperands);
|
||||
DictionaryAttr attributeDict;
|
||||
if (!mlirAttributeIsNull(attributes))
|
||||
attributeDict = unwrap(attributes).cast<DictionaryAttr>();
|
||||
|
||||
// Create a vector of unique pointers to regions and make sure they are not
|
||||
// deleted when exiting the scope. This is a hack caused by C++ API expecting
|
||||
// an list of unique pointers to regions (without ownership transfer
|
||||
// semantics) and C API making ownership transfer explicit.
|
||||
SmallVector<std::unique_ptr<Region>> unwrappedRegions;
|
||||
unwrappedRegions.reserve(nRegions);
|
||||
for (intptr_t i = 0; i < nRegions; ++i)
|
||||
unwrappedRegions.emplace_back(unwrap(*(regions + i)));
|
||||
auto cleaner = llvm::make_scope_exit([&]() {
|
||||
for (auto ®ion : unwrappedRegions)
|
||||
region.release();
|
||||
});
|
||||
|
||||
SmallVector<Type> inferredTypes;
|
||||
if (failed(abstractOp->getInterface<InferTypeOpInterface>()->inferReturnTypes(
|
||||
unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
|
||||
unwrappedRegions, inferredTypes)))
|
||||
return mlirLogicalResultFailure();
|
||||
|
||||
SmallVector<MlirType> wrappedInferredTypes;
|
||||
wrappedInferredTypes.reserve(inferredTypes.size());
|
||||
for (Type t : inferredTypes)
|
||||
wrappedInferredTypes.push_back(wrap(t));
|
||||
callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
|
||||
return mlirLogicalResultSuccess();
|
||||
}
|
@ -113,12 +113,25 @@ declare_mlir_dialect_python_bindings(
|
||||
dialects/_memref_ops_ext.py
|
||||
DIALECT_NAME memref)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonTestSources.Dialects
|
||||
# TODO: this uses a tablegen file from the test directory and should be
|
||||
# decoupled from here.
|
||||
declare_mlir_python_sources(
|
||||
MLIRPythonSources.Dialects.PythonTest
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/PythonTest.td
|
||||
SOURCES dialects/python_test.py
|
||||
DIALECT_NAME python_test)
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
SOURCES dialects/python_test.py)
|
||||
set(LLVM_TARGET_DEFINITIONS
|
||||
"${MLIR_MAIN_SRC_DIR}/test/python/python_test_ops.td")
|
||||
mlir_tablegen(
|
||||
"dialects/_python_test_ops_gen.py"
|
||||
-gen-python-op-bindings
|
||||
-bind-dialect=python_test)
|
||||
add_public_tablegen_target(PythonTestDialectPyIncGen)
|
||||
declare_mlir_python_sources(
|
||||
MLIRPythonSources.Dialects.PythonTest.ops_gen
|
||||
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects.PythonTest
|
||||
SOURCES "dialects/_python_test_ops_gen.py")
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
@ -192,6 +205,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
|
||||
${PYTHON_SOURCE_DIR}/IRAffine.cpp
|
||||
${PYTHON_SOURCE_DIR}/IRAttributes.cpp
|
||||
${PYTHON_SOURCE_DIR}/IRCore.cpp
|
||||
${PYTHON_SOURCE_DIR}/IRInterfaces.cpp
|
||||
${PYTHON_SOURCE_DIR}/IRModule.cpp
|
||||
${PYTHON_SOURCE_DIR}/IRTypes.cpp
|
||||
${PYTHON_SOURCE_DIR}/PybindUtils.cpp
|
||||
@ -201,6 +215,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
|
||||
EMBED_CAPI_LINK_LIBS
|
||||
MLIRCAPIDebug
|
||||
MLIRCAPIIR
|
||||
MLIRCAPIInterfaces
|
||||
MLIRCAPIRegistration # TODO: See about dis-aggregating
|
||||
|
||||
# Dialects
|
||||
@ -297,6 +312,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Transforms
|
||||
MLIRCAPITransforms
|
||||
)
|
||||
|
||||
# TODO: This should not be included in the main Python extension. However,
|
||||
# putting it into MLIRPythonTestSources along with the dialect declaration
|
||||
# above confuses Python module loader when running under lit.
|
||||
declare_mlir_python_extension(MLIRPythonExtension.PythonTest
|
||||
MODULE_NAME _mlirPythonTest
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
SOURCES
|
||||
${MLIR_SOURCE_DIR}/test/python/lib/PythonTestModule.cpp
|
||||
PRIVATE_LINK_LIBS
|
||||
LLVMSupport
|
||||
EMBED_CAPI_LINK_LIBS
|
||||
MLIRCAPIPythonTestDialect
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# Common CAPI dependency DSO.
|
||||
# All python extensions must link through one DSO which exports the CAPI, and
|
||||
@ -336,7 +365,6 @@ add_mlir_python_modules(MLIRPythonModules
|
||||
MLIRPythonCAPI
|
||||
)
|
||||
|
||||
|
||||
add_mlir_python_modules(MLIRPythonTestModules
|
||||
ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_test/mlir"
|
||||
INSTALL_PREFIX "python_packages/mlir_test/mlir"
|
||||
|
@ -1,33 +0,0 @@
|
||||
//===-- python_test_ops.td - Python test Op definitions ----*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef PYTHON_TEST_OPS
|
||||
#define PYTHON_TEST_OPS
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Python_Test_Dialect : Dialect {
|
||||
let name = "python_test";
|
||||
let cppNamespace = "PythonTest";
|
||||
}
|
||||
class TestOp<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<Python_Test_Dialect, mnemonic, traits>;
|
||||
|
||||
def AttributedOp : TestOp<"attributed_op"> {
|
||||
let arguments = (ins I32Attr:$mandatory_i32,
|
||||
OptionalAttr<I32Attr>:$optional_i32,
|
||||
UnitAttr:$unit);
|
||||
}
|
||||
|
||||
def PropertyOp : TestOp<"property_op"> {
|
||||
let arguments = (ins I32Attr:$property,
|
||||
I32:$idx);
|
||||
}
|
||||
|
||||
#endif // PYTHON_TEST_OPS
|
@ -3,3 +3,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._python_test_ops_gen import *
|
||||
|
||||
|
||||
def register_python_test_dialect(context, load=True):
|
||||
from .._mlir_libs import _mlirPythonTest
|
||||
_mlirPythonTest.register_python_test_dialect(context, load)
|
||||
|
@ -1,6 +1,10 @@
|
||||
add_subdirectory(CAPI)
|
||||
add_subdirectory(lib)
|
||||
|
||||
if (MLIR_ENABLE_BINDINGS_PYTHON)
|
||||
add_subdirectory(python)
|
||||
endif()
|
||||
|
||||
# Passed to lit.site.cfg.py.so that the out of tree Standalone dialect test
|
||||
# can find MLIR's CMake configuration
|
||||
set(MLIR_CMAKE_DIR
|
||||
|
8
mlir/test/python/CMakeLists.txt
Normal file
8
mlir/test/python/CMakeLists.txt
Normal file
@ -0,0 +1,8 @@
|
||||
set(LLVM_TARGET_DEFINITIONS python_test_ops.td)
|
||||
mlir_tablegen(lib/PythonTestDialect.h.inc -gen-dialect-decls)
|
||||
mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs)
|
||||
mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(MLIRPythonTestIncGen)
|
||||
|
||||
add_subdirectory(lib)
|
@ -6,8 +6,10 @@ import mlir.dialects.python_test as test
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
return f
|
||||
|
||||
# CHECK-LABEL: TEST: testAttributes
|
||||
@run
|
||||
def testAttributes():
|
||||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
@ -127,4 +129,47 @@ def testAttributes():
|
||||
del op.unit
|
||||
print(f"Unit: {op.unit}")
|
||||
|
||||
run(testAttributes)
|
||||
|
||||
# CHECK-LABEL: TEST: inferReturnTypes
|
||||
@run
|
||||
def inferReturnTypes():
|
||||
with Context() as ctx, Location.unknown(ctx):
|
||||
test.register_python_test_dialect(ctx)
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
op = test.InferResultsOp(
|
||||
IntegerType.get_signless(32), IntegerType.get_signless(64))
|
||||
dummy = test.DummyOp()
|
||||
|
||||
# CHECK: [Type(i32), Type(i64)]
|
||||
iface = InferTypeOpInterface(op)
|
||||
print(iface.inferReturnTypes())
|
||||
|
||||
# CHECK: [Type(i32), Type(i64)]
|
||||
iface_static = InferTypeOpInterface(test.InferResultsOp)
|
||||
print(iface.inferReturnTypes())
|
||||
|
||||
assert isinstance(iface.opview, test.InferResultsOp)
|
||||
assert iface.opview == iface.operation.opview
|
||||
|
||||
try:
|
||||
iface_static.opview
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
assert False, ("not expected to be able to obtain an opview from a static"
|
||||
" interface")
|
||||
|
||||
try:
|
||||
InferTypeOpInterface(dummy)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
assert False, "not expected dummy op to implement the interface"
|
||||
|
||||
try:
|
||||
InferTypeOpInterface(test.DummyOp)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
assert False, "not expected dummy op class to implement the interface"
|
||||
|
33
mlir/test/python/lib/CMakeLists.txt
Normal file
33
mlir/test/python/lib/CMakeLists.txt
Normal file
@ -0,0 +1,33 @@
|
||||
set(LLVM_OPTIONAL_SOURCES
|
||||
PythonTestCAPI.cpp
|
||||
PythonTestDialect.cpp
|
||||
PythonTestModule.cpp
|
||||
)
|
||||
|
||||
add_mlir_library(MLIRPythonTestDialect
|
||||
PythonTestDialect.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
DEPENDS
|
||||
MLIRPythonTestIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
)
|
||||
|
||||
add_mlir_public_c_api_library(MLIRCAPIPythonTestDialect
|
||||
PythonTestCAPI.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRPythonTestIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCAPIInterfaces
|
||||
MLIRCAPIIR
|
||||
MLIRCAPIRegistration
|
||||
MLIRPythonTestDialect
|
||||
)
|
||||
|
14
mlir/test/python/lib/PythonTestCAPI.cpp
Normal file
14
mlir/test/python/lib/PythonTestCAPI.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
//===- PythonTestCAPI.cpp - C API for the PythonTest dialect --------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PythonTestCAPI.h"
|
||||
#include "PythonTestDialect.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test,
|
||||
python_test::PythonTestDialect)
|
24
mlir/test/python/lib/PythonTestCAPI.h
Normal file
24
mlir/test/python/lib/PythonTestCAPI.h
Normal file
@ -0,0 +1,24 @@
|
||||
//===- PythonTestCAPI.h - C API for the PythonTest dialect ------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H
|
||||
#define MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H
|
||||
|
||||
#include "mlir-c/Registration.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H
|
25
mlir/test/python/lib/PythonTestDialect.cpp
Normal file
25
mlir/test/python/lib/PythonTestDialect.cpp
Normal file
@ -0,0 +1,25 @@
|
||||
//===- PythonTestDialect.cpp - PythonTest dialect definition --------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PythonTestDialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
|
||||
#include "PythonTestDialect.cpp.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "PythonTestOps.cpp.inc"
|
||||
|
||||
namespace python_test {
|
||||
void PythonTestDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "PythonTestOps.cpp.inc"
|
||||
>();
|
||||
}
|
||||
} // namespace python_test
|
21
mlir/test/python/lib/PythonTestDialect.h
Normal file
21
mlir/test/python/lib/PythonTestDialect.h
Normal file
@ -0,0 +1,21 @@
|
||||
//===- PythonTestDialect.h - PythonTest dialect definition ------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
|
||||
#define MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
|
||||
#include "PythonTestDialect.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "PythonTestOps.h.inc"
|
||||
|
||||
#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
|
26
mlir/test/python/lib/PythonTestModule.cpp
Normal file
26
mlir/test/python/lib/PythonTestModule.cpp
Normal file
@ -0,0 +1,26 @@
|
||||
//===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PythonTestCAPI.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_mlirPythonTest, m) {
|
||||
m.def(
|
||||
"register_python_test_dialect",
|
||||
[](MlirContext context, bool load) {
|
||||
MlirDialectHandle pythonTestDialect =
|
||||
mlirGetDialectHandle__python_test__();
|
||||
mlirDialectHandleRegisterDialect(pythonTestDialect, context);
|
||||
if (load) {
|
||||
mlirDialectHandleLoadDialect(pythonTestDialect, context);
|
||||
}
|
||||
},
|
||||
py::arg("context"), py::arg("load") = true);
|
||||
}
|
@ -11,10 +11,11 @@
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
|
||||
def Python_Test_Dialect : Dialect {
|
||||
let name = "python_test";
|
||||
let cppNamespace = "PythonTest";
|
||||
let cppNamespace = "python_test";
|
||||
}
|
||||
class TestOp<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<Python_Test_Dialect, mnemonic, traits>;
|
||||
@ -30,4 +31,25 @@ def PropertyOp : TestOp<"property_op"> {
|
||||
I32:$idx);
|
||||
}
|
||||
|
||||
def DummyOp : TestOp<"dummy_op"> {
|
||||
}
|
||||
|
||||
def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> {
|
||||
let arguments = (ins);
|
||||
let results = (outs AnyInteger:$single, AnyInteger:$doubled);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static ::mlir::LogicalResult inferReturnTypes(
|
||||
::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
|
||||
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
|
||||
::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
::mlir::Builder b(context);
|
||||
inferredReturnTypes.push_back(b.getI32Type());
|
||||
inferredReturnTypes.push_back(b.getI64Type());
|
||||
return ::mlir::success();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // PYTHON_TEST_OPS
|
||||
|
@ -333,6 +333,7 @@ cc_library(
|
||||
"include/mlir-c/ExecutionEngine.h",
|
||||
"include/mlir-c/IR.h",
|
||||
"include/mlir-c/IntegerSet.h",
|
||||
"include/mlir-c/Interfaces.h",
|
||||
"include/mlir-c/Pass.h",
|
||||
"include/mlir-c/Registration.h",
|
||||
"include/mlir-c/Support.h",
|
||||
@ -360,6 +361,20 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "CAPIInterfaces",
|
||||
srcs = [
|
||||
"lib/CAPI/Interfaces/Interfaces.cpp",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":CAPIIR",
|
||||
":IR",
|
||||
":InferTypeOpInterface",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "CAPIAsync",
|
||||
srcs = [
|
||||
@ -558,6 +573,7 @@ cc_library(
|
||||
"lib/Bindings/Python/IRAffine.cpp",
|
||||
"lib/Bindings/Python/IRAttributes.cpp",
|
||||
"lib/Bindings/Python/IRCore.cpp",
|
||||
"lib/Bindings/Python/IRInterfaces.cpp",
|
||||
"lib/Bindings/Python/IRModule.cpp",
|
||||
"lib/Bindings/Python/IRTypes.cpp",
|
||||
"lib/Bindings/Python/Pass.cpp",
|
||||
@ -581,6 +597,7 @@ cc_library(
|
||||
":CAPIDebug",
|
||||
":CAPIGPU",
|
||||
":CAPIIR",
|
||||
":CAPIInterfaces",
|
||||
":CAPILinalg",
|
||||
":CAPIRegistration",
|
||||
":CAPISparseTensor",
|
||||
|
Loading…
Reference in New Issue
Block a user