[mlir] Add Index Type, Floating Point Type and None Type subclasses to python bindings.

Based on the PyType and PyConcreteType classes, this patch implements the bindings of Index Type, Floating Point Type and None Type subclasses.
These three subclasses share the same binding strategy:
- The function pointer `isaFunction` points to `mlirTypeIsA***`.
- The `mlir***TypeGet` C API is bound with the `***Type` constructor in the python side.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D86466
This commit is contained in:
zhanghb97 2020-08-24 18:54:38 +00:00 committed by Stella Laurenzo
parent 0e6c9a6e79
commit 1f6c4d829c
2 changed files with 132 additions and 0 deletions

View File

@ -305,6 +305,102 @@ public:
}
};
/// Index Type subclass - IndexType.
class PyIndexType : public PyConcreteType<PyIndexType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
static constexpr const char *pyClassName = "IndexType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirIndexTypeGet(context.context);
return PyIndexType(t);
}),
py::keep_alive<0, 1>(), "Create a index type.");
}
};
/// Floating Point Type subclass - BF16Type.
class PyBF16Type : public PyConcreteType<PyBF16Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
static constexpr const char *pyClassName = "BF16Type";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirBF16TypeGet(context.context);
return PyBF16Type(t);
}),
py::keep_alive<0, 1>(), "Create a bf16 type.");
}
};
/// Floating Point Type subclass - F16Type.
class PyF16Type : public PyConcreteType<PyF16Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
static constexpr const char *pyClassName = "F16Type";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirF16TypeGet(context.context);
return PyF16Type(t);
}),
py::keep_alive<0, 1>(), "Create a f16 type.");
}
};
/// Floating Point Type subclass - F32Type.
class PyF32Type : public PyConcreteType<PyF32Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
static constexpr const char *pyClassName = "F32Type";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirF32TypeGet(context.context);
return PyF32Type(t);
}),
py::keep_alive<0, 1>(), "Create a f32 type.");
}
};
/// Floating Point Type subclass - F64Type.
class PyF64Type : public PyConcreteType<PyF64Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
static constexpr const char *pyClassName = "F64Type";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirF64TypeGet(context.context);
return PyF64Type(t);
}),
py::keep_alive<0, 1>(), "Create a f64 type.");
}
};
/// None Type subclass - NoneType.
class PyNoneType : public PyConcreteType<PyNoneType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
static constexpr const char *pyClassName = "NoneType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
MlirType t = mlirNoneTypeGet(context.context);
return PyNoneType(t);
}),
py::keep_alive<0, 1>(), "Create a none type.");
}
};
} // namespace
//------------------------------------------------------------------------------
@ -489,4 +585,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Standard type bindings.
PyIntegerType::bind(m);
PyIndexType::bind(m);
PyBF16Type::bind(m);
PyF16Type::bind(m);
PyF32Type::bind(m);
PyF64Type::bind(m);
PyNoneType::bind(m);
}

View File

@ -124,3 +124,33 @@ def testIntegerType():
print("unsigned:", mlir.ir.IntegerType.get_unsigned(ctx, 64))
run(testIntegerType)
# CHECK-LABEL: TEST: testIndexType
def testIndexType():
ctx = mlir.ir.Context()
# CHECK: index type: index
print("index type:", mlir.ir.IndexType(ctx))
run(testIndexType)
# CHECK-LABEL: TEST: testFloatType
def testFloatType():
ctx = mlir.ir.Context()
# CHECK: float: bf16
print("float:", mlir.ir.BF16Type(ctx))
# CHECK: float: f16
print("float:", mlir.ir.F16Type(ctx))
# CHECK: float: f32
print("float:", mlir.ir.F32Type(ctx))
# CHECK: float: f64
print("float:", mlir.ir.F64Type(ctx))
run(testFloatType)
# CHECK-LABEL: TEST: testNoneType
def testNoneType():
ctx = mlir.ir.Context()
# CHECK: none type: none
print("none type:", mlir.ir.NoneType(ctx))
run(testNoneType)