Add Python bindings for the builtin dialect

This includes some minor customization for FuncOp and ModuleOp.

Differential Revision: https://reviews.llvm.org/D95022
This commit is contained in:
Mehdi Amini 2021-01-20 05:53:44 +00:00
parent 1deee5cacb
commit 922b26cde4
7 changed files with 192 additions and 1 deletions

View File

@ -0,0 +1,15 @@
//===-- BuiltinOps.td - Entry point for builtin bindings ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_BUILTIN_OPS
#define PYTHON_BINDINGS_BUILTIN_OPS
include "mlir/Bindings/Python/Attributes.td"
include "mlir/IR/BuiltinOps.td"
#endif

View File

@ -11,6 +11,7 @@ set(PY_SRC_FILES
mlir/ir.py
mlir/dialects/__init__.py
mlir/dialects/_linalg.py
mlir/dialects/_builtin.py
mlir/ir.py
mlir/passmanager.py
mlir/transforms/__init__.py
@ -36,6 +37,11 @@ endforeach()
# Generate dialect-specific bindings.
################################################################################
add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps
TD_FILE BuiltinOps.td
DIALECT_NAME builtin)
add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonBuiltinOps)
add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps
TD_FILE LinalgOps.td
DIALECT_NAME linalg

View File

@ -43,7 +43,7 @@ def extend_opview_class(ext_module):
except AttributeError:
# Try to default resolve it.
try:
select_mixin = getattr(ext_module, parent_opview_cls.__name__)
mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
except AttributeError:
pass
else:

View File

@ -0,0 +1,93 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from mlir.ir import *
class ModuleOp:
"""Specialization for the module op class."""
def __init__(self, loc=None, ip=None):
super().__init__(
self._ods_build_default(operands=[], results=[], loc=loc, ip=ip))
body = self.regions[0].blocks.append()
with InsertionPoint(body):
Operation.create("module_terminator")
@property
def body(self):
return self.regions[0].blocks[0]
class FuncOp:
"""Specialization for the func op class."""
def __init__(self,
name,
type,
visibility,
body_builder=None,
loc=None,
ip=None):
"""
Create a FuncOp with the provided `name`, `type`, and `visibility`.
- `name` is a string representing the function name.
- `type` is either a FunctionType or a pair of list describing inputs and
results.
- `visibility` is a string matching `public`, `private`, or `nested`. The
empty string implies a private visibility.
- `body_builder` is an optional callback, when provided a new entry block
is created and the callback is invoked with the new op as argument within
an InsertionPoint context already set for the block. The callback is
expected to insert a terminator in the block.
"""
sym_name = StringAttr.get(str(name))
# If the type is passed as a tuple, build a FunctionType on the fly.
if isinstance(type, tuple):
type = FunctionType.get(inputs=type[0], results=type[1])
type = TypeAttr.get(type)
sym_visibility = StringAttr.get(
str(visibility)) if visibility is not None else None
super().__init__(sym_name, type, sym_visibility, loc, ip)
if body_builder:
entry_block = self.add_entry_block()
with InsertionPoint(entry_block):
body_builder(self)
@property
def is_external(self):
return len(self.regions[0].blocks) == 0
@property
def body(self):
return self.regions[0]
@property
def type(self):
return FunctionType(TypeAttr(self.attributes["type"]).value)
@property
def visibility(self):
return self.attributes["sym_visibility"]
@property
def name(self):
return self.attributes["sym_name"]
@property
def entry_block(self):
if self.is_external:
raise IndexError('External function does not have a body')
return self.regions[0].blocks[0]
def add_entry_block(self):
'''
Add an entry block to the function body using the function signature to infer block arguments
Returns the newly created block
'''
if not self.is_external:
raise IndexError('The function already has an entry block!')
self.body.blocks.append(*self.type.inputs)
return self.body.blocks[0]

View File

@ -0,0 +1,4 @@
[style]
based_on_style = google
column_limit = 80
indent_width = 2

View File

@ -0,0 +1,69 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
import mlir.dialects.builtin as builtin
import mlir.dialects.std as std
def run(f):
print("\nTEST:", f.__name__)
f()
# CHECK-LABEL: TEST: testBuildFuncOp
def testBuildFuncOp():
ctx = Context()
with Location.unknown(ctx) as loc:
m = builtin.ModuleOp()
f32 = F32Type.get()
tensor_type = RankedTensorType.get((2, 3, 4), f32)
with InsertionPoint.at_block_begin(m.body):
func = builtin.FuncOp(name="some_func",
type=FunctionType.get(
inputs=[tensor_type, tensor_type],
results=[tensor_type]),
visibility="nested")
# CHECK: Name is: "some_func"
print("Name is: ", func.name)
# CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
print("Type is: ", func.type)
# CHECK: Visibility is: "nested"
print("Visibility is: ", func.visibility)
try:
entry_block = func.entry_block
except IndexError as e:
# CHECK: External function does not have a body
print(e)
with InsertionPoint(func.add_entry_block()):
std.ReturnOp([func.entry_block.arguments[0]])
pass
try:
func.add_entry_block()
except IndexError as e:
# CHECK: The function already has an entry block!
print(e)
# Try the callback builder and passing type as tuple.
func = builtin.FuncOp(name="some_other_func",
type=([tensor_type, tensor_type], [tensor_type]),
visibility="nested",
body_builder=lambda func: std.ReturnOp(
[func.entry_block.arguments[0]]))
# CHECK: module {
# CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
# CHECK: return %arg0 : tensor<2x3x4xf32>
# CHECK: }
# CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
# CHECK: return %arg0 : tensor<2x3x4xf32>
# CHECK: }
print(m)
run(testBuildFuncOp)

View File

@ -716,6 +716,10 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
os << llvm::formatv(fileHeader, clDialectName.getValue());
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
if (clDialectName == "builtin")
clDialectName = "";
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);
if (op.getDialectName() == clDialectName.getValue())