From 1f6c4d829c2dad147e30dcb0611eb9886dae9155 Mon Sep 17 00:00:00 2001 From: zhanghb97 Date: Mon, 24 Aug 2020 18:54:38 +0000 Subject: [PATCH] [mlir] Add Index Type, Floating Point Type and None Type subclasses to python bindings. Based on the PyType and PyConcreteType classes, this patch implements the bindings of Index Type, Floating Point Type and None Type subclasses. These three subclasses share the same binding strategy: - The function pointer `isaFunction` points to `mlirTypeIsA***`. - The `mlir***TypeGet` C API is bound with the `***Type` constructor in the python side. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D86466 --- mlir/lib/Bindings/Python/IRModules.cpp | 102 +++++++++++++++++++++++++ mlir/test/Bindings/Python/ir_types.py | 30 ++++++++ 2 files changed, 132 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index ae48e33d3530..2f5735f83975 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -305,6 +305,102 @@ public: } }; +/// Index Type subclass - IndexType. +class PyIndexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr const char *pyClassName = "IndexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirIndexTypeGet(context.context); + return PyIndexType(t); + }), + py::keep_alive<0, 1>(), "Create a index type."); + } +}; + +/// Floating Point Type subclass - BF16Type. +class PyBF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr const char *pyClassName = "BF16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirBF16TypeGet(context.context); + return PyBF16Type(t); + }), + py::keep_alive<0, 1>(), "Create a bf16 type."); + } +}; + +/// Floating Point Type subclass - F16Type. +class PyF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr const char *pyClassName = "F16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirF16TypeGet(context.context); + return PyF16Type(t); + }), + py::keep_alive<0, 1>(), "Create a f16 type."); + } +}; + +/// Floating Point Type subclass - F32Type. +class PyF32Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr const char *pyClassName = "F32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirF32TypeGet(context.context); + return PyF32Type(t); + }), + py::keep_alive<0, 1>(), "Create a f32 type."); + } +}; + +/// Floating Point Type subclass - F64Type. +class PyF64Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr const char *pyClassName = "F64Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirF64TypeGet(context.context); + return PyF64Type(t); + }), + py::keep_alive<0, 1>(), "Create a f64 type."); + } +}; + +/// None Type subclass - NoneType. +class PyNoneType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr const char *pyClassName = "NoneType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def(py::init([](PyMlirContext &context) { + MlirType t = mlirNoneTypeGet(context.context); + return PyNoneType(t); + }), + py::keep_alive<0, 1>(), "Create a none type."); + } +}; + } // namespace //------------------------------------------------------------------------------ @@ -489,4 +585,10 @@ void mlir::python::populateIRSubmodule(py::module &m) { // Standard type bindings. PyIntegerType::bind(m); + PyIndexType::bind(m); + PyBF16Type::bind(m); + PyF16Type::bind(m); + PyF32Type::bind(m); + PyF64Type::bind(m); + PyNoneType::bind(m); } diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py index 1dce0a95c812..32e26c57518a 100644 --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -124,3 +124,33 @@ def testIntegerType(): print("unsigned:", mlir.ir.IntegerType.get_unsigned(ctx, 64)) run(testIntegerType) + +# CHECK-LABEL: TEST: testIndexType +def testIndexType(): + ctx = mlir.ir.Context() + # CHECK: index type: index + print("index type:", mlir.ir.IndexType(ctx)) + +run(testIndexType) + +# CHECK-LABEL: TEST: testFloatType +def testFloatType(): + ctx = mlir.ir.Context() + # CHECK: float: bf16 + print("float:", mlir.ir.BF16Type(ctx)) + # CHECK: float: f16 + print("float:", mlir.ir.F16Type(ctx)) + # CHECK: float: f32 + print("float:", mlir.ir.F32Type(ctx)) + # CHECK: float: f64 + print("float:", mlir.ir.F64Type(ctx)) + +run(testFloatType) + +# CHECK-LABEL: TEST: testNoneType +def testNoneType(): + ctx = mlir.ir.Context() + # CHECK: none type: none + print("none type:", mlir.ir.NoneType(ctx)) + +run(testNoneType)