[mlir] support interfaces in Python bindings

Introduce the initial support for operation interfaces in C API and Python
bindings. Interfaces are a key component of MLIR's extensibility and should be
available in bindings to make use of full potential of MLIR.

This initial implementation exposes InferTypeOpInterface all the way to the
Python bindings since it can be later used to simplify the operation
construction methods by inferring their return types instead of requiring the
user to do so. The general infrastructure for binding interfaces is defined and
InferTypeOpInterface can be used as an example for binding other interfaces.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D111656
This commit is contained in:
Alex Zinenko 2021-10-14 17:18:28 +02:00
parent 1f49b71fe5
commit 14c9207063
25 changed files with 778 additions and 41 deletions

View File

@ -123,6 +123,7 @@ add_subdirectory(include/mlir)
add_subdirectory(lib)
# C API needs all dialects for registration, but should be built before tests.
add_subdirectory(lib/CAPI)
if (MLIR_INCLUDE_TESTS)
add_definitions(-DMLIR_INCLUDE_TESTS)
add_custom_target(MLIRUnitTests)

View File

@ -536,6 +536,68 @@ except ValueError:
concrete = OpResult(value)
```
#### Interfaces
MLIR interfaces are a mechanism to interact with the IR without needing to know
specific types of operations but only some of their aspects. Operation
interfaces are available as Python classes with the same name as their C++
counterparts. Objects of these classes can be constructed from either:
- an object of the `Operation` class or of any `OpView` subclass; in this
case, all interface methods are available;
- a subclass of `OpView` and a context; in this case, only the *static*
interface methods are available as there is no associated operation.
In both cases, construction of the interface raises a `ValueError` if the
operation class does not implement the interface in the given context (or, for
operations, in the context that the operation is defined in). Similarly to
attributes and types, the MLIR context may be set up by a surrounding context
manager.
```python
from mlir.ir import Context, InferTypeOpInterface
with Context():
op = <...>
# Attempt to cast the operation into an interface.
try:
iface = InferTypeOpInterface(op)
except ValueError:
print("Operation does not implement InferTypeOpInterface.")
raise
# All methods are available on interface objects constructed from an Operation
# or an OpView.
iface.someInstanceMethod()
# An interface object can also be constructed given an OpView subclass. It
# also needs a context in which the interface will be looked up. The context
# can be provided explicitly or set up by the surrounding context manager.
try:
iface = InferTypeOpInterface(some_dialect.SomeOp)
except ValueError:
print("SomeOp does not implement InferTypeOpInterface.")
raise
# Calling an instance method on an interface object constructed from a class
# will raise TypeError.
try:
iface.someInstanceMethod()
except TypeError:
pass
# One can still call static interface methods though.
iface.inferOpReturnTypes(<...>)
```
If an interface object was constructed from an `Operation` or an `OpView`, they
are available as `.operation` and `.opview` properties of the interface object,
respectively.
Only a subset of operation interfaces are currently provided in Python bindings.
Attribute and type interfaces are not yet available in Python bindings.
### Creating IR Objects
Python bindings also support IR creation and manipulation.

View File

@ -194,3 +194,23 @@ counterparts. `wrap` converts a C++ class into a C structure and `unwrap` does
the inverse conversion. Once the C++ object is available, the API implementation
should rely on `isa` to implement `mlirXIsAY` and is expected to use `cast`
inside other API calls.
### Extensions for Interfaces
Interfaces can follow the example of IR interfaces and should be placed in the
appropriate library (e.g., common interfaces in `mlir-c/Interfaces` and
dialect-specific interfaces in their dialect library). Similarly to other type
hierarchies, interfaces are not expected to have objects of their own type and
instead operate on top-level objects: `MlirAttribute`, `MlirOperation` and
`MlirType`. Static interface methods are expected to take as leading argument a
canonical identifier of the class, `MlirStringRef` with the name for operations
and `MlirTypeID` for attributes and types, followed by `MlirContext` in which
the interfaces are registered.
Individual interfaces are expected provide a `mlir<InterfaceName>TypeID()`
function that can be used to check whether an object or a class implements this
interface using `mlir<Attribute/Operation/Type>ImplementsInterface` or
`mlir<Attribute/Operation?Type>ImplementsInterfaceStatic` functions,
respectively. Rationale: C++ `isa` only works when an object exists, static
methods are usually dispatched to using templates; lookup by `TypeID` in
`MLIRContext` works even without an object.

View File

@ -0,0 +1,67 @@
//===-- mlir-c/Interfaces.h - C API to Core MLIR IR interfaces ----*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header declares the C interface to MLIR interface classes. It is
// intended to contain interfaces defined in lib/Interfaces.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_C_DIALECT_H
#define MLIR_C_DIALECT_H
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
#endif
/// Returns `true` if the given operation implements an interface identified by
/// its TypeID.
MLIR_CAPI_EXPORTED bool
mlirOperationImplementsInterface(MlirOperation operation,
MlirTypeID interfaceTypeID);
/// Returns `true` if the operation identified by its canonical string name
/// implements the interface identified by its TypeID in the given context.
/// Note that interfaces may be attached to operations in some contexts and not
/// others.
MLIR_CAPI_EXPORTED bool
mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
MlirContext context,
MlirTypeID interfaceTypeID);
//===----------------------------------------------------------------------===//
// InferTypeOpInterface.
//===----------------------------------------------------------------------===//
/// Returns the interface TypeID of the InferTypeOpInterface.
MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID();
/// These callbacks are used to return multiple types from functions while
/// transferring ownerhsip to the caller. The first argument is the number of
/// consecutive elements pointed to by the second argument. The third argument
/// is an opaque pointer forwarded to the callback by the caller.
typedef void (*MlirTypesCallback)(intptr_t, MlirType *, void *);
/// Infers the return types of the operation identified by its canonical given
/// the arguments that will be supplied to its generic builder. Calls `callback`
/// with the types of inferred arguments, potentially several times, on success.
/// Returns failure otherwise.
MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
MlirStringRef opName, MlirContext context, MlirLocation location,
intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback,
void *userData);
#ifdef __cplusplus
}
#endif
#endif // MLIR_C_DIALECT_H

View File

@ -0,0 +1,18 @@
//===- Interfaces.h - C API Utils for MLIR interfaces -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains declarations of implementation details of the C API for
// MLIR interface classes. This file should not be included from C++ code other
// than C API implementation nor from C code.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CAPI_INTERFACES_H
#define MLIR_CAPI_INTERFACES_H
#endif // MLIR_CAPI_INTERFACES_H

View File

@ -0,0 +1,240 @@
//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "IRModule.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Interfaces.h"
namespace py = pybind11;
namespace mlir {
namespace python {
constexpr static const char *constructorDoc =
R"(Creates an interface from a given operation/opview object or from a
subclass of OpView. Raises ValueError if the operation does not implement the
interface.)";
constexpr static const char *operationDoc =
R"(Returns an Operation for which the interface was constructed.)";
constexpr static const char *opviewDoc =
R"(Returns an OpView subclass _instance_ for which the interface was
constructed)";
constexpr static const char *inferReturnTypesDoc =
R"(Given the arguments required to build an operation, attempts to infer
its return types. Raises ValueError on failure.)";
/// CRTP base class for Python classes representing MLIR Op interfaces.
/// Interface hierarchies are flat so no base class is expected here. The
/// derived class is expected to define the following static fields:
/// - `const char *pyClassName` - the name of the Python class to create;
/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
/// of the interface.
/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
/// interface-specific methods.
///
/// An interface class may be constructed from either an Operation/OpView object
/// or from a subclass of OpView. In the latter case, only the static interface
/// methods are available, similarly to calling ConcereteOp::staticMethod on the
/// C++ side. Implementations of concrete interfaces can use the `isStatic`
/// method to check whether the interface object was constructed from a class or
/// an operation/opview instance. The `getOpName` always succeeds and returns a
/// canonical name of the operation suitable for lookups.
template <typename ConcreteIface>
class PyConcreteOpInterface {
protected:
using ClassTy = py::class_<ConcreteIface>;
using GetTypeIDFunctionTy = MlirTypeID (*)();
public:
/// Constructs an interface instance from an object that is either an
/// operation or a subclass of OpView. In the latter case, only the static
/// methods of the interface are accessible to the caller.
PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
: obj(object) {
try {
operation = &py::cast<PyOperation &>(obj);
} catch (py::cast_error &err) {
// Do nothing.
}
try {
operation = &py::cast<PyOpView &>(obj).getOperation();
} catch (py::cast_error &err) {
// Do nothing.
}
if (operation != nullptr) {
if (!mlirOperationImplementsInterface(*operation,
ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
throw py::value_error(msg + ConcreteIface::pyClassName);
}
MlirIdentifier identifier = mlirOperationGetName(*operation);
MlirStringRef stringRef = mlirIdentifierStr(identifier);
opName = std::string(stringRef.data, stringRef.length);
} else {
try {
opName = obj.attr("OPERATION_NAME").template cast<std::string>();
} catch (py::cast_error &err) {
throw py::type_error(
"Op interface does not refer to an operation or OpView class");
}
if (!mlirOperationImplementsInterfaceStatic(
mlirStringRefCreate(opName.data(), opName.length()),
context.resolve().get(), ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
throw py::value_error(msg + ConcreteIface::pyClassName);
}
}
}
/// Creates the Python bindings for this class in the given module.
static void bind(py::module &m) {
py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
py::module_local());
cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
py::arg("context") = py::none(), constructorDoc)
.def_property_readonly("operation",
&PyConcreteOpInterface::getOperationObject,
operationDoc)
.def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
opviewDoc);
ConcreteIface::bindDerived(cls);
}
/// Hook for derived classes to add class-specific bindings.
static void bindDerived(ClassTy &cls) {}
/// Returns `true` if this object was constructed from a subclass of OpView
/// rather than from an operation instance.
bool isStatic() { return operation == nullptr; }
/// Returns the operation instance from which this object was constructed.
/// Throws a type error if this object was constructed from a subclass of
/// OpView.
py::object getOperationObject() {
if (operation == nullptr) {
throw py::type_error("Cannot get an operation from a static interface");
}
return operation->getRef().releaseObject();
}
/// Returns the opview of the operation instance from which this object was
/// constructed. Throws a type error if this object was constructed form a
/// subclass of OpView.
py::object getOpView() {
if (operation == nullptr) {
throw py::type_error("Cannot get an opview from a static interface");
}
return operation->createOpView();
}
/// Returns the canonical name of the operation this interface is constructed
/// from.
const std::string &getOpName() { return opName; }
private:
PyOperation *operation = nullptr;
std::string opName;
py::object obj;
};
/// Python wrapper for InterTypeOpInterface. This interface has only static
/// methods.
class PyInferTypeOpInterface
: public PyConcreteOpInterface<PyInferTypeOpInterface> {
public:
using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
constexpr static const char *pyClassName = "InferTypeOpInterface";
constexpr static GetTypeIDFunctionTy getInterfaceID =
&mlirInferTypeOpInterfaceTypeID;
/// C-style user-data structure for type appending callback.
struct AppendResultsCallbackData {
std::vector<PyType> &inferredTypes;
PyMlirContext &pyMlirContext;
};
/// Appends the types provided as the two first arguments to the user-data
/// structure (expects AppendResultsCallbackData).
static void appendResultsCallback(intptr_t nTypes, MlirType *types,
void *userData) {
auto *data = static_cast<AppendResultsCallbackData *>(userData);
data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
for (intptr_t i = 0; i < nTypes; ++i) {
data->inferredTypes.push_back(
PyType(data->pyMlirContext.getRef(), types[i]));
}
}
/// Given the arguments required to build an operation, attempts to infer its
/// return types. Throws value_error on faliure.
std::vector<PyType>
inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands,
llvm::Optional<PyAttribute> attributes,
llvm::Optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context,
DefaultingPyLocation location) {
llvm::SmallVector<MlirValue> mlirOperands;
llvm::SmallVector<MlirRegion> mlirRegions;
if (operands) {
mlirOperands.reserve(operands->size());
for (PyValue &value : *operands) {
mlirOperands.push_back(value);
}
}
if (regions) {
mlirRegions.reserve(regions->size());
for (PyRegion &region : *regions) {
mlirRegions.push_back(region);
}
}
std::vector<PyType> inferredTypes;
PyMlirContext &pyContext = context.resolve();
AppendResultsCallbackData data{inferredTypes, pyContext};
MlirStringRef opNameRef =
mlirStringRefCreate(getOpName().data(), getOpName().length());
MlirAttribute attributeDict =
attributes ? attributes->get() : mlirAttributeGetNull();
MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
mlirOperands.data(), attributeDict, mlirRegions.size(),
mlirRegions.data(), &appendResultsCallback, &data);
if (mlirLogicalResultIsFailure(result)) {
throw py::value_error("Failed to infer result types");
}
return inferredTypes;
}
static void bindDerived(ClassTy &cls) {
cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
py::arg("operands") = py::none(),
py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
py::arg("context") = py::none(), py::arg("loc") = py::none(),
inferReturnTypesDoc);
}
};
void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); }
} // namespace python
} // namespace mlir

View File

@ -859,6 +859,7 @@ private:
void populateIRAffine(pybind11::module &m);
void populateIRAttributes(pybind11::module &m);
void populateIRCore(pybind11::module &m);
void populateIRInterfaces(pybind11::module &m);
void populateIRTypes(pybind11::module &m);
} // namespace python

View File

@ -85,6 +85,7 @@ PYBIND11_MODULE(_mlir, m) {
populateIRCore(irModule);
populateIRAffine(irModule);
populateIRAttributes(irModule);
populateIRInterfaces(irModule);
populateIRTypes(irModule);
// Define and populate PassManager submodule.

View File

@ -2,6 +2,7 @@ add_subdirectory(Debug)
add_subdirectory(Dialect)
add_subdirectory(Conversion)
add_subdirectory(ExecutionEngine)
add_subdirectory(Interfaces)
add_subdirectory(IR)
add_subdirectory(Registration)
add_subdirectory(Transforms)

View File

@ -0,0 +1,5 @@
add_mlir_public_c_api_library(MLIRCAPIInterfaces
Interfaces.cpp
LINK_LIBS PUBLIC
MLIRInferTypeOpInterface)

View File

@ -0,0 +1,82 @@
//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Interfaces.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/ScopeExit.h"
using namespace mlir;
bool mlirOperationImplementsInterface(MlirOperation operation,
MlirTypeID interfaceTypeID) {
const AbstractOperation *abstractOp =
unwrap(operation)->getAbstractOperation();
return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID));
}
bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
MlirContext context,
MlirTypeID interfaceTypeID) {
const AbstractOperation *abstractOp = AbstractOperation::lookup(
StringRef(operationName.data, operationName.length), unwrap(context));
return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID));
}
MlirTypeID mlirInferTypeOpInterfaceTypeID() {
return wrap(InferTypeOpInterface::getInterfaceID());
}
MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
MlirStringRef opName, MlirContext context, MlirLocation location,
intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback,
void *userData) {
StringRef name(opName.data, opName.length);
const AbstractOperation *abstractOp =
AbstractOperation::lookup(name, unwrap(context));
if (!abstractOp)
return mlirLogicalResultFailure();
llvm::Optional<Location> maybeLocation = llvm::None;
if (!mlirLocationIsNull(location))
maybeLocation = unwrap(location);
SmallVector<Value> unwrappedOperands;
(void)unwrapList(nOperands, operands, unwrappedOperands);
DictionaryAttr attributeDict;
if (!mlirAttributeIsNull(attributes))
attributeDict = unwrap(attributes).cast<DictionaryAttr>();
// Create a vector of unique pointers to regions and make sure they are not
// deleted when exiting the scope. This is a hack caused by C++ API expecting
// an list of unique pointers to regions (without ownership transfer
// semantics) and C API making ownership transfer explicit.
SmallVector<std::unique_ptr<Region>> unwrappedRegions;
unwrappedRegions.reserve(nRegions);
for (intptr_t i = 0; i < nRegions; ++i)
unwrappedRegions.emplace_back(unwrap(*(regions + i)));
auto cleaner = llvm::make_scope_exit([&]() {
for (auto &region : unwrappedRegions)
region.release();
});
SmallVector<Type> inferredTypes;
if (failed(abstractOp->getInterface<InferTypeOpInterface>()->inferReturnTypes(
unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
unwrappedRegions, inferredTypes)))
return mlirLogicalResultFailure();
SmallVector<MlirType> wrappedInferredTypes;
wrappedInferredTypes.reserve(inferredTypes.size());
for (Type t : inferredTypes)
wrappedInferredTypes.push_back(wrap(t));
callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
return mlirLogicalResultSuccess();
}

View File

@ -113,12 +113,25 @@ declare_mlir_dialect_python_bindings(
dialects/_memref_ops_ext.py
DIALECT_NAME memref)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonTestSources.Dialects
# TODO: this uses a tablegen file from the test directory and should be
# decoupled from here.
declare_mlir_python_sources(
MLIRPythonSources.Dialects.PythonTest
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/PythonTest.td
SOURCES dialects/python_test.py
DIALECT_NAME python_test)
ADD_TO_PARENT MLIRPythonSources.Dialects
SOURCES dialects/python_test.py)
set(LLVM_TARGET_DEFINITIONS
"${MLIR_MAIN_SRC_DIR}/test/python/python_test_ops.td")
mlir_tablegen(
"dialects/_python_test_ops_gen.py"
-gen-python-op-bindings
-bind-dialect=python_test)
add_public_tablegen_target(PythonTestDialectPyIncGen)
declare_mlir_python_sources(
MLIRPythonSources.Dialects.PythonTest.ops_gen
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
ADD_TO_PARENT MLIRPythonSources.Dialects.PythonTest
SOURCES "dialects/_python_test_ops_gen.py")
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
@ -192,6 +205,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
${PYTHON_SOURCE_DIR}/IRAffine.cpp
${PYTHON_SOURCE_DIR}/IRAttributes.cpp
${PYTHON_SOURCE_DIR}/IRCore.cpp
${PYTHON_SOURCE_DIR}/IRInterfaces.cpp
${PYTHON_SOURCE_DIR}/IRModule.cpp
${PYTHON_SOURCE_DIR}/IRTypes.cpp
${PYTHON_SOURCE_DIR}/PybindUtils.cpp
@ -201,6 +215,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
EMBED_CAPI_LINK_LIBS
MLIRCAPIDebug
MLIRCAPIIR
MLIRCAPIInterfaces
MLIRCAPIRegistration # TODO: See about dis-aggregating
# Dialects
@ -297,6 +312,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Transforms
MLIRCAPITransforms
)
# TODO: This should not be included in the main Python extension. However,
# putting it into MLIRPythonTestSources along with the dialect declaration
# above confuses Python module loader when running under lit.
declare_mlir_python_extension(MLIRPythonExtension.PythonTest
MODULE_NAME _mlirPythonTest
ADD_TO_PARENT MLIRPythonSources.Dialects
SOURCES
${MLIR_SOURCE_DIR}/test/python/lib/PythonTestModule.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIPythonTestDialect
)
################################################################################
# Common CAPI dependency DSO.
# All python extensions must link through one DSO which exports the CAPI, and
@ -336,7 +365,6 @@ add_mlir_python_modules(MLIRPythonModules
MLIRPythonCAPI
)
add_mlir_python_modules(MLIRPythonTestModules
ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_test/mlir"
INSTALL_PREFIX "python_packages/mlir_test/mlir"

View File

@ -1,33 +0,0 @@
//===-- python_test_ops.td - Python test Op definitions ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_TEST_OPS
#define PYTHON_TEST_OPS
include "mlir/Bindings/Python/Attributes.td"
include "mlir/IR/OpBase.td"
def Python_Test_Dialect : Dialect {
let name = "python_test";
let cppNamespace = "PythonTest";
}
class TestOp<string mnemonic, list<OpTrait> traits = []>
: Op<Python_Test_Dialect, mnemonic, traits>;
def AttributedOp : TestOp<"attributed_op"> {
let arguments = (ins I32Attr:$mandatory_i32,
OptionalAttr<I32Attr>:$optional_i32,
UnitAttr:$unit);
}
def PropertyOp : TestOp<"property_op"> {
let arguments = (ins I32Attr:$property,
I32:$idx);
}
#endif // PYTHON_TEST_OPS

View File

@ -3,3 +3,8 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._python_test_ops_gen import *
def register_python_test_dialect(context, load=True):
from .._mlir_libs import _mlirPythonTest
_mlirPythonTest.register_python_test_dialect(context, load)

View File

@ -1,6 +1,10 @@
add_subdirectory(CAPI)
add_subdirectory(lib)
if (MLIR_ENABLE_BINDINGS_PYTHON)
add_subdirectory(python)
endif()
# Passed to lit.site.cfg.py.so that the out of tree Standalone dialect test
# can find MLIR's CMake configuration
set(MLIR_CMAKE_DIR

View File

@ -0,0 +1,8 @@
set(LLVM_TARGET_DEFINITIONS python_test_ops.td)
mlir_tablegen(lib/PythonTestDialect.h.inc -gen-dialect-decls)
mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls)
mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRPythonTestIncGen)
add_subdirectory(lib)

View File

@ -6,8 +6,10 @@ import mlir.dialects.python_test as test
def run(f):
print("\nTEST:", f.__name__)
f()
return f
# CHECK-LABEL: TEST: testAttributes
@run
def testAttributes():
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
@ -127,4 +129,47 @@ def testAttributes():
del op.unit
print(f"Unit: {op.unit}")
run(testAttributes)
# CHECK-LABEL: TEST: inferReturnTypes
@run
def inferReturnTypes():
with Context() as ctx, Location.unknown(ctx):
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
op = test.InferResultsOp(
IntegerType.get_signless(32), IntegerType.get_signless(64))
dummy = test.DummyOp()
# CHECK: [Type(i32), Type(i64)]
iface = InferTypeOpInterface(op)
print(iface.inferReturnTypes())
# CHECK: [Type(i32), Type(i64)]
iface_static = InferTypeOpInterface(test.InferResultsOp)
print(iface.inferReturnTypes())
assert isinstance(iface.opview, test.InferResultsOp)
assert iface.opview == iface.operation.opview
try:
iface_static.opview
except TypeError:
pass
else:
assert False, ("not expected to be able to obtain an opview from a static"
" interface")
try:
InferTypeOpInterface(dummy)
except ValueError:
pass
else:
assert False, "not expected dummy op to implement the interface"
try:
InferTypeOpInterface(test.DummyOp)
except ValueError:
pass
else:
assert False, "not expected dummy op class to implement the interface"

View File

@ -0,0 +1,33 @@
set(LLVM_OPTIONAL_SOURCES
PythonTestCAPI.cpp
PythonTestDialect.cpp
PythonTestModule.cpp
)
add_mlir_library(MLIRPythonTestDialect
PythonTestDialect.cpp
EXCLUDE_FROM_LIBMLIR
DEPENDS
MLIRPythonTestIncGen
LINK_LIBS PUBLIC
MLIRInferTypeOpInterface
MLIRIR
MLIRSupport
)
add_mlir_public_c_api_library(MLIRCAPIPythonTestDialect
PythonTestCAPI.cpp
DEPENDS
MLIRPythonTestIncGen
LINK_LIBS PUBLIC
MLIRCAPIInterfaces
MLIRCAPIIR
MLIRCAPIRegistration
MLIRPythonTestDialect
)

View File

@ -0,0 +1,14 @@
//===- PythonTestCAPI.cpp - C API for the PythonTest dialect --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PythonTestCAPI.h"
#include "PythonTestDialect.h"
#include "mlir/CAPI/Registration.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test,
python_test::PythonTestDialect)

View File

@ -0,0 +1,24 @@
//===- PythonTestCAPI.h - C API for the PythonTest dialect ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H
#define MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H
#include "mlir-c/Registration.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test);
#ifdef __cplusplus
}
#endif
#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H

View File

@ -0,0 +1,25 @@
//===- PythonTestDialect.cpp - PythonTest dialect definition --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PythonTestDialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "PythonTestDialect.cpp.inc"
#define GET_OP_CLASSES
#include "PythonTestOps.cpp.inc"
namespace python_test {
void PythonTestDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "PythonTestOps.cpp.inc"
>();
}
} // namespace python_test

View File

@ -0,0 +1,21 @@
//===- PythonTestDialect.h - PythonTest dialect definition ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
#define MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "PythonTestDialect.h.inc"
#define GET_OP_CLASSES
#include "PythonTestOps.h.inc"
#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H

View File

@ -0,0 +1,26 @@
//===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PythonTestCAPI.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
PYBIND11_MODULE(_mlirPythonTest, m) {
m.def(
"register_python_test_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
mlirDialectHandleRegisterDialect(pythonTestDialect, context);
if (load) {
mlirDialectHandleLoadDialect(pythonTestDialect, context);
}
},
py::arg("context"), py::arg("load") = true);
}

View File

@ -11,10 +11,11 @@
include "mlir/Bindings/Python/Attributes.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
def Python_Test_Dialect : Dialect {
let name = "python_test";
let cppNamespace = "PythonTest";
let cppNamespace = "python_test";
}
class TestOp<string mnemonic, list<OpTrait> traits = []>
: Op<Python_Test_Dialect, mnemonic, traits>;
@ -30,4 +31,25 @@ def PropertyOp : TestOp<"property_op"> {
I32:$idx);
}
def DummyOp : TestOp<"dummy_op"> {
}
def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> {
let arguments = (ins);
let results = (outs AnyInteger:$single, AnyInteger:$doubled);
let extraClassDeclaration = [{
static ::mlir::LogicalResult inferReturnTypes(
::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
::mlir::Builder b(context);
inferredReturnTypes.push_back(b.getI32Type());
inferredReturnTypes.push_back(b.getI64Type());
return ::mlir::success();
}
}];
}
#endif // PYTHON_TEST_OPS

View File

@ -333,6 +333,7 @@ cc_library(
"include/mlir-c/ExecutionEngine.h",
"include/mlir-c/IR.h",
"include/mlir-c/IntegerSet.h",
"include/mlir-c/Interfaces.h",
"include/mlir-c/Pass.h",
"include/mlir-c/Registration.h",
"include/mlir-c/Support.h",
@ -360,6 +361,20 @@ cc_library(
],
)
cc_library(
name = "CAPIInterfaces",
srcs = [
"lib/CAPI/Interfaces/Interfaces.cpp",
],
includes = ["include"],
deps = [
":CAPIIR",
":IR",
":InferTypeOpInterface",
"//llvm:Support",
],
)
cc_library(
name = "CAPIAsync",
srcs = [
@ -558,6 +573,7 @@ cc_library(
"lib/Bindings/Python/IRAffine.cpp",
"lib/Bindings/Python/IRAttributes.cpp",
"lib/Bindings/Python/IRCore.cpp",
"lib/Bindings/Python/IRInterfaces.cpp",
"lib/Bindings/Python/IRModule.cpp",
"lib/Bindings/Python/IRTypes.cpp",
"lib/Bindings/Python/Pass.cpp",
@ -581,6 +597,7 @@ cc_library(
":CAPIDebug",
":CAPIGPU",
":CAPIIR",
":CAPIInterfaces",
":CAPILinalg",
":CAPIRegistration",
":CAPISparseTensor",