mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-23 22:00:10 +00:00
[mlir][python] Add Python bindings for ml_program dialect.
Differential Revision: https://reviews.llvm.org/D125852
This commit is contained in:
parent
2bb252852c
commit
8b7e85f4f8
@ -132,6 +132,15 @@ declare_mlir_dialect_python_bindings(
|
||||
dialects/_memref_ops_ext.py
|
||||
DIALECT_NAME memref)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/MLProgramOps.td
|
||||
SOURCES
|
||||
dialects/ml_program.py
|
||||
dialects/_ml_program_ops_ext.py
|
||||
DIALECT_NAME ml_program)
|
||||
|
||||
declare_mlir_python_sources(
|
||||
MLIRPythonSources.Dialects.quant
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
|
15
mlir/python/mlir/dialects/MLProgramOps.td
Normal file
15
mlir/python/mlir/dialects/MLProgramOps.td
Normal file
@ -0,0 +1,15 @@
|
||||
//===-- MLProgramOps.td - Entry point for MLProgramOps -----*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef PYTHON_BINDINGS_MLPROGRAM_OPS
|
||||
#define PYTHON_BINDINGS_MLPROGRAM_OPS
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "mlir/Dialect/MLProgram/IR/MLProgramOps.td"
|
||||
|
||||
#endif
|
116
mlir/python/mlir/dialects/_ml_program_ops_ext.py
Normal file
116
mlir/python/mlir/dialects/_ml_program_ops_ext.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from typing import Union
|
||||
from ..ir import *
|
||||
from ._ods_common import get_default_loc_context as _get_default_loc_context
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from ._ml_program_ops_gen import *
|
||||
|
||||
|
||||
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
|
||||
RESULT_ATTRIBUTE_NAME = "res_attrs"
|
||||
|
||||
|
||||
class FuncOp:
|
||||
"""Specialization for the func op class."""
|
||||
|
||||
def __init__(self,
|
||||
name,
|
||||
type,
|
||||
*,
|
||||
visibility=None,
|
||||
body_builder=None,
|
||||
loc=None,
|
||||
ip=None):
|
||||
"""
|
||||
Create a FuncOp with the provided `name`, `type`, and `visibility`.
|
||||
- `name` is a string representing the function name.
|
||||
- `type` is either a FunctionType or a pair of list describing inputs and
|
||||
results.
|
||||
- `visibility` is a string matching `public`, `private`, or `nested`. None
|
||||
implies private visibility.
|
||||
- `body_builder` is an optional callback, when provided a new entry block
|
||||
is created and the callback is invoked with the new op as argument within
|
||||
an InsertionPoint context already set for the block. The callback is
|
||||
expected to insert a terminator in the block.
|
||||
"""
|
||||
sym_name = StringAttr.get(str(name))
|
||||
|
||||
# If the type is passed as a tuple, build a FunctionType on the fly.
|
||||
if isinstance(type, tuple):
|
||||
type = FunctionType.get(inputs=type[0], results=type[1])
|
||||
|
||||
type = TypeAttr.get(type)
|
||||
sym_visibility = StringAttr.get(
|
||||
str(visibility)) if visibility is not None else None
|
||||
super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip)
|
||||
if body_builder:
|
||||
entry_block = self.add_entry_block()
|
||||
with InsertionPoint(entry_block):
|
||||
body_builder(self)
|
||||
|
||||
@property
|
||||
def is_external(self):
|
||||
return len(self.regions[0].blocks) == 0
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0]
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
|
||||
|
||||
@property
|
||||
def visibility(self):
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
@property
|
||||
def name(self) -> StringAttr:
|
||||
return StringAttr(self.attributes["sym_name"])
|
||||
|
||||
@property
|
||||
def entry_block(self):
|
||||
if self.is_external:
|
||||
raise IndexError('External function does not have a body')
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
def add_entry_block(self):
|
||||
"""
|
||||
Add an entry block to the function body using the function signature to
|
||||
infer block arguments.
|
||||
Returns the newly created block
|
||||
"""
|
||||
if not self.is_external:
|
||||
raise IndexError('The function already has an entry block!')
|
||||
self.body.blocks.append(*self.type.inputs)
|
||||
return self.body.blocks[0]
|
||||
|
||||
@property
|
||||
def arg_attrs(self):
|
||||
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
|
||||
if isinstance(attribute, ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
else:
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
|
||||
attribute, context=self.context)
|
||||
|
||||
@property
|
||||
def arguments(self):
|
||||
return self.entry_block.arguments
|
||||
|
||||
@property
|
||||
def result_attrs(self):
|
||||
return self.attributes[RESULT_ATTRIBUTE_NAME]
|
||||
|
||||
@result_attrs.setter
|
||||
def result_attrs(self, attribute: ArrayAttr):
|
||||
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
|
5
mlir/python/mlir/dialects/ml_program.py
Normal file
5
mlir/python/mlir/dialects/ml_program.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._ml_program_ops_gen import *
|
28
mlir/test/python/dialects/ml_program.py
Normal file
28
mlir/test/python/dialects/ml_program.py
Normal file
@ -0,0 +1,28 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
# This is just a smoke test that the dialect is functional.
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import ml_program
|
||||
|
||||
|
||||
def constructAndPrintInModule(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
f()
|
||||
print(module)
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: testFuncOp
|
||||
@constructAndPrintInModule
|
||||
def testFuncOp():
|
||||
# CHECK: ml_program.func @foobar(%arg0: si32) -> si32
|
||||
f = ml_program.FuncOp(
|
||||
name="foobar",
|
||||
type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)]))
|
||||
block = f.add_entry_block()
|
||||
with InsertionPoint(block):
|
||||
# CHECK: ml_program.return
|
||||
ml_program.ReturnOp([block.arguments[0]])
|
@ -378,6 +378,49 @@ filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
##---------------------------------------------------------------------------##
|
||||
# MLProgram dialect.
|
||||
##---------------------------------------------------------------------------##
|
||||
|
||||
td_library(
|
||||
name = "MLProgramOpsPyTdFiles",
|
||||
srcs = [
|
||||
"//mlir:include/mlir/Bindings/Python/Attributes.td",
|
||||
],
|
||||
includes = ["../include"],
|
||||
deps = [
|
||||
"//mlir:MLProgramOpsTdFiles",
|
||||
"//mlir:OpBaseTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl_filegroup(
|
||||
name = "MLProgramOpsPyGen",
|
||||
tbl_outs = [
|
||||
(
|
||||
[
|
||||
"-gen-python-op-bindings",
|
||||
"-bind-dialect=ml_program",
|
||||
],
|
||||
"mlir/dialects/_ml_program_ops_gen.py",
|
||||
),
|
||||
],
|
||||
tblgen = "//mlir:mlir-tblgen",
|
||||
td_file = "mlir/dialects/MLProgramOps.td",
|
||||
deps = [
|
||||
":MLProgramOpsPyTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "MLProgramOpsPyFiles",
|
||||
srcs = [
|
||||
"mlir/dialects/_ml_program_ops_ext.py",
|
||||
"mlir/dialects/ml_program.py",
|
||||
":MLProgramOpsPyGen",
|
||||
],
|
||||
)
|
||||
|
||||
##---------------------------------------------------------------------------##
|
||||
# PDL dialect.
|
||||
##---------------------------------------------------------------------------##
|
||||
|
Loading…
Reference in New Issue
Block a user