mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-13 13:45:16 +00:00
Add Python bindings for the builtin dialect
This includes some minor customization for FuncOp and ModuleOp. Differential Revision: https://reviews.llvm.org/D95022
This commit is contained in:
parent
1deee5cacb
commit
922b26cde4
15
mlir/lib/Bindings/Python/BuiltinOps.td
Normal file
15
mlir/lib/Bindings/Python/BuiltinOps.td
Normal file
@ -0,0 +1,15 @@
|
||||
//===-- BuiltinOps.td - Entry point for builtin bindings ---*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef PYTHON_BINDINGS_BUILTIN_OPS
|
||||
#define PYTHON_BINDINGS_BUILTIN_OPS
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "mlir/IR/BuiltinOps.td"
|
||||
|
||||
#endif
|
@ -11,6 +11,7 @@ set(PY_SRC_FILES
|
||||
mlir/ir.py
|
||||
mlir/dialects/__init__.py
|
||||
mlir/dialects/_linalg.py
|
||||
mlir/dialects/_builtin.py
|
||||
mlir/ir.py
|
||||
mlir/passmanager.py
|
||||
mlir/transforms/__init__.py
|
||||
@ -36,6 +37,11 @@ endforeach()
|
||||
# Generate dialect-specific bindings.
|
||||
################################################################################
|
||||
|
||||
add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps
|
||||
TD_FILE BuiltinOps.td
|
||||
DIALECT_NAME builtin)
|
||||
add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonBuiltinOps)
|
||||
|
||||
add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps
|
||||
TD_FILE LinalgOps.td
|
||||
DIALECT_NAME linalg
|
||||
|
@ -43,7 +43,7 @@ def extend_opview_class(ext_module):
|
||||
except AttributeError:
|
||||
# Try to default resolve it.
|
||||
try:
|
||||
select_mixin = getattr(ext_module, parent_opview_cls.__name__)
|
||||
mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
|
93
mlir/lib/Bindings/Python/mlir/dialects/_builtin.py
Normal file
93
mlir/lib/Bindings/Python/mlir/dialects/_builtin.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
from mlir.ir import *
|
||||
|
||||
|
||||
class ModuleOp:
|
||||
"""Specialization for the module op class."""
|
||||
|
||||
def __init__(self, loc=None, ip=None):
|
||||
super().__init__(
|
||||
self._ods_build_default(operands=[], results=[], loc=loc, ip=ip))
|
||||
body = self.regions[0].blocks.append()
|
||||
with InsertionPoint(body):
|
||||
Operation.create("module_terminator")
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
class FuncOp:
|
||||
"""Specialization for the func op class."""
|
||||
|
||||
def __init__(self,
|
||||
name,
|
||||
type,
|
||||
visibility,
|
||||
body_builder=None,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""
|
||||
Create a FuncOp with the provided `name`, `type`, and `visibility`.
|
||||
- `name` is a string representing the function name.
|
||||
- `type` is either a FunctionType or a pair of list describing inputs and
|
||||
results.
|
||||
- `visibility` is a string matching `public`, `private`, or `nested`. The
|
||||
empty string implies a private visibility.
|
||||
- `body_builder` is an optional callback, when provided a new entry block
|
||||
is created and the callback is invoked with the new op as argument within
|
||||
an InsertionPoint context already set for the block. The callback is
|
||||
expected to insert a terminator in the block.
|
||||
"""
|
||||
sym_name = StringAttr.get(str(name))
|
||||
|
||||
# If the type is passed as a tuple, build a FunctionType on the fly.
|
||||
if isinstance(type, tuple):
|
||||
type = FunctionType.get(inputs=type[0], results=type[1])
|
||||
|
||||
type = TypeAttr.get(type)
|
||||
sym_visibility = StringAttr.get(
|
||||
str(visibility)) if visibility is not None else None
|
||||
super().__init__(sym_name, type, sym_visibility, loc, ip)
|
||||
if body_builder:
|
||||
entry_block = self.add_entry_block()
|
||||
with InsertionPoint(entry_block):
|
||||
body_builder(self)
|
||||
|
||||
@property
|
||||
def is_external(self):
|
||||
return len(self.regions[0].blocks) == 0
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0]
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return FunctionType(TypeAttr(self.attributes["type"]).value)
|
||||
|
||||
@property
|
||||
def visibility(self):
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.attributes["sym_name"]
|
||||
|
||||
@property
|
||||
def entry_block(self):
|
||||
if self.is_external:
|
||||
raise IndexError('External function does not have a body')
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
def add_entry_block(self):
|
||||
'''
|
||||
Add an entry block to the function body using the function signature to infer block arguments
|
||||
Returns the newly created block
|
||||
'''
|
||||
if not self.is_external:
|
||||
raise IndexError('The function already has an entry block!')
|
||||
self.body.blocks.append(*self.type.inputs)
|
||||
return self.body.blocks[0]
|
4
mlir/test/Bindings/Python/.style.yapf
Normal file
4
mlir/test/Bindings/Python/.style.yapf
Normal file
@ -0,0 +1,4 @@
|
||||
[style]
|
||||
based_on_style = google
|
||||
column_limit = 80
|
||||
indent_width = 2
|
69
mlir/test/Bindings/Python/dialects/builtin.py
Normal file
69
mlir/test/Bindings/Python/dialects/builtin.py
Normal file
@ -0,0 +1,69 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.ir import *
|
||||
import mlir.dialects.builtin as builtin
|
||||
import mlir.dialects.std as std
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testBuildFuncOp
|
||||
def testBuildFuncOp():
|
||||
ctx = Context()
|
||||
with Location.unknown(ctx) as loc:
|
||||
m = builtin.ModuleOp()
|
||||
|
||||
f32 = F32Type.get()
|
||||
tensor_type = RankedTensorType.get((2, 3, 4), f32)
|
||||
with InsertionPoint.at_block_begin(m.body):
|
||||
func = builtin.FuncOp(name="some_func",
|
||||
type=FunctionType.get(
|
||||
inputs=[tensor_type, tensor_type],
|
||||
results=[tensor_type]),
|
||||
visibility="nested")
|
||||
# CHECK: Name is: "some_func"
|
||||
print("Name is: ", func.name)
|
||||
|
||||
# CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
|
||||
print("Type is: ", func.type)
|
||||
|
||||
# CHECK: Visibility is: "nested"
|
||||
print("Visibility is: ", func.visibility)
|
||||
|
||||
try:
|
||||
entry_block = func.entry_block
|
||||
except IndexError as e:
|
||||
# CHECK: External function does not have a body
|
||||
print(e)
|
||||
|
||||
with InsertionPoint(func.add_entry_block()):
|
||||
std.ReturnOp([func.entry_block.arguments[0]])
|
||||
pass
|
||||
|
||||
try:
|
||||
func.add_entry_block()
|
||||
except IndexError as e:
|
||||
# CHECK: The function already has an entry block!
|
||||
print(e)
|
||||
|
||||
# Try the callback builder and passing type as tuple.
|
||||
func = builtin.FuncOp(name="some_other_func",
|
||||
type=([tensor_type, tensor_type], [tensor_type]),
|
||||
visibility="nested",
|
||||
body_builder=lambda func: std.ReturnOp(
|
||||
[func.entry_block.arguments[0]]))
|
||||
|
||||
# CHECK: module {
|
||||
# CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||
# CHECK: return %arg0 : tensor<2x3x4xf32>
|
||||
# CHECK: }
|
||||
# CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||
# CHECK: return %arg0 : tensor<2x3x4xf32>
|
||||
# CHECK: }
|
||||
print(m)
|
||||
|
||||
|
||||
run(testBuildFuncOp)
|
@ -716,6 +716,10 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
|
||||
os << llvm::formatv(fileHeader, clDialectName.getValue());
|
||||
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
|
||||
|
||||
if (clDialectName == "builtin")
|
||||
clDialectName = "";
|
||||
|
||||
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
|
||||
Operator op(rec);
|
||||
if (op.getDialectName() == clDialectName.getValue())
|
||||
|
Loading…
x
Reference in New Issue
Block a user