[mlir][python] Add Python bindings for ml_program dialect.

Differential Revision: https://reviews.llvm.org/D125852
This commit is contained in:
Stella Laurenzo 2022-05-17 22:42:39 -07:00
parent 2bb252852c
commit 8b7e85f4f8
6 changed files with 216 additions and 0 deletions

View File

@ -132,6 +132,15 @@ declare_mlir_dialect_python_bindings(
dialects/_memref_ops_ext.py
DIALECT_NAME memref)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/MLProgramOps.td
SOURCES
dialects/ml_program.py
dialects/_ml_program_ops_ext.py
DIALECT_NAME ml_program)
declare_mlir_python_sources(
MLIRPythonSources.Dialects.quant
ADD_TO_PARENT MLIRPythonSources.Dialects

View File

@ -0,0 +1,15 @@
//===-- MLProgramOps.td - Entry point for MLProgramOps -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_MLPROGRAM_OPS
#define PYTHON_BINDINGS_MLPROGRAM_OPS
include "mlir/Bindings/Python/Attributes.td"
include "mlir/Dialect/MLProgram/IR/MLProgramOps.td"
#endif

View File

@ -0,0 +1,116 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from typing import Union
from ..ir import *
from ._ods_common import get_default_loc_context as _get_default_loc_context
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from ._ml_program_ops_gen import *
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
RESULT_ATTRIBUTE_NAME = "res_attrs"
class FuncOp:
"""Specialization for the func op class."""
def __init__(self,
name,
type,
*,
visibility=None,
body_builder=None,
loc=None,
ip=None):
"""
Create a FuncOp with the provided `name`, `type`, and `visibility`.
- `name` is a string representing the function name.
- `type` is either a FunctionType or a pair of list describing inputs and
results.
- `visibility` is a string matching `public`, `private`, or `nested`. None
implies private visibility.
- `body_builder` is an optional callback, when provided a new entry block
is created and the callback is invoked with the new op as argument within
an InsertionPoint context already set for the block. The callback is
expected to insert a terminator in the block.
"""
sym_name = StringAttr.get(str(name))
# If the type is passed as a tuple, build a FunctionType on the fly.
if isinstance(type, tuple):
type = FunctionType.get(inputs=type[0], results=type[1])
type = TypeAttr.get(type)
sym_visibility = StringAttr.get(
str(visibility)) if visibility is not None else None
super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip)
if body_builder:
entry_block = self.add_entry_block()
with InsertionPoint(entry_block):
body_builder(self)
@property
def is_external(self):
return len(self.regions[0].blocks) == 0
@property
def body(self):
return self.regions[0]
@property
def type(self):
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
@property
def visibility(self):
return self.attributes["sym_visibility"]
@property
def name(self) -> StringAttr:
return StringAttr(self.attributes["sym_name"])
@property
def entry_block(self):
if self.is_external:
raise IndexError('External function does not have a body')
return self.regions[0].blocks[0]
def add_entry_block(self):
"""
Add an entry block to the function body using the function signature to
infer block arguments.
Returns the newly created block
"""
if not self.is_external:
raise IndexError('The function already has an entry block!')
self.body.blocks.append(*self.type.inputs)
return self.body.blocks[0]
@property
def arg_attrs(self):
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@arg_attrs.setter
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
if isinstance(attribute, ArrayAttr):
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
else:
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
attribute, context=self.context)
@property
def arguments(self):
return self.entry_block.arguments
@property
def result_attrs(self):
return self.attributes[RESULT_ATTRIBUTE_NAME]
@result_attrs.setter
def result_attrs(self, attribute: ArrayAttr):
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute

View File

@ -0,0 +1,5 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._ml_program_ops_gen import *

View File

@ -0,0 +1,28 @@
# RUN: %PYTHON %s | FileCheck %s
# This is just a smoke test that the dialect is functional.
from mlir.ir import *
from mlir.dialects import ml_program
def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
return f
# CHECK-LABEL: testFuncOp
@constructAndPrintInModule
def testFuncOp():
# CHECK: ml_program.func @foobar(%arg0: si32) -> si32
f = ml_program.FuncOp(
name="foobar",
type=([IntegerType.get_signed(32)], [IntegerType.get_signed(32)]))
block = f.add_entry_block()
with InsertionPoint(block):
# CHECK: ml_program.return
ml_program.ReturnOp([block.arguments[0]])

View File

@ -378,6 +378,49 @@ filegroup(
],
)
##---------------------------------------------------------------------------##
# MLProgram dialect.
##---------------------------------------------------------------------------##
td_library(
name = "MLProgramOpsPyTdFiles",
srcs = [
"//mlir:include/mlir/Bindings/Python/Attributes.td",
],
includes = ["../include"],
deps = [
"//mlir:MLProgramOpsTdFiles",
"//mlir:OpBaseTdFiles",
],
)
gentbl_filegroup(
name = "MLProgramOpsPyGen",
tbl_outs = [
(
[
"-gen-python-op-bindings",
"-bind-dialect=ml_program",
],
"mlir/dialects/_ml_program_ops_gen.py",
),
],
tblgen = "//mlir:mlir-tblgen",
td_file = "mlir/dialects/MLProgramOps.td",
deps = [
":MLProgramOpsPyTdFiles",
],
)
filegroup(
name = "MLProgramOpsPyFiles",
srcs = [
"mlir/dialects/_ml_program_ops_ext.py",
"mlir/dialects/ml_program.py",
":MLProgramOpsPyGen",
],
)
##---------------------------------------------------------------------------##
# PDL dialect.
##---------------------------------------------------------------------------##