[mlir][python] Function decorator for capturing a FuncOp from a python function.

* Moves this out of a test case where it was being developed to good effect and generalizes it.
* Having tried a number of things like this, I think this balances concerns reasonably well.

Differential Revision: https://reviews.llvm.org/D98989
This commit is contained in:
Stella Laurenzo 2021-03-19 15:43:42 -07:00
parent b76c09023d
commit d9343e6153
3 changed files with 218 additions and 63 deletions

View File

@ -1,6 +1,11 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Optional, Sequence
import inspect
from ..ir import *
@ -93,3 +98,99 @@ class FuncOp:
raise IndexError('The function already has an entry block!')
self.body.blocks.append(*self.type.inputs)
return self.body.blocks[0]
@classmethod
def from_py_func(FuncOp,
*inputs: Type,
results: Optional[Sequence[Type]] = None,
name: Optional[str] = None):
"""Decorator to define an MLIR FuncOp specified as a python function.
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
active for the current thread (i.e. established in a `with` block).
When applied as a decorator to a Python function, an entry block will
be constructed for the FuncOp with types as specified in `*inputs`. The
block arguments will be passed positionally to the Python function. In
addition, if the Python function accepts keyword arguments generally or
has a corresponding keyword argument, the following will be passed:
* `func_op`: The `func` op being defined.
By default, the function name will be the Python function `__name__`. This
can be overriden by passing the `name` argument to the decorator.
If `results` is not specified, then the decorator will implicitly
insert a `ReturnOp` with the `Value`'s returned from the decorated
function. It will also set the `FuncOp` type with the actual return
value types. If `results` is specified, then the decorated function
must return `None` and no implicit `ReturnOp` is added (nor are the result
types updated). The implicit behavior is intended for simple, single-block
cases, and users should specify result types explicitly for any complicated
cases.
The decorated function can further be called from Python and will insert
a `CallOp` at the then-current insertion point, returning either None (
if no return values), a unary Value (for one result), or a list of Values).
This mechanism cannot be used to emit recursive calls (by construction).
"""
def decorator(f):
from . import std
# Introspect the callable for optional features.
sig = inspect.signature(f)
has_arg_func_op = False
for param in sig.parameters.values():
if param.kind == param.VAR_KEYWORD:
has_arg_func_op = True
if param.name == "func_op" and (param.kind
== param.POSITIONAL_OR_KEYWORD or
param.kind == param.KEYWORD_ONLY):
has_arg_func_op = True
# Emit the FuncOp.
implicit_return = results is None
symbol_name = name or f.__name__
function_type = FunctionType.get(
inputs=inputs, results=[] if implicit_return else results)
func_op = FuncOp(name=symbol_name, type=function_type)
with InsertionPoint(func_op.add_entry_block()):
func_args = func_op.entry_block.arguments
func_kwargs = {}
if has_arg_func_op:
func_kwargs["func_op"] = func_op
return_values = f(*func_args, **func_kwargs)
if not implicit_return:
return_types = list(results)
assert return_values is None, (
"Capturing a python function with explicit `results=` "
"requires that the wrapped function returns None.")
else:
# Coerce return values, add ReturnOp and rewrite func type.
if return_values is None:
return_values = []
elif isinstance(return_values, Value):
return_values = [return_values]
else:
return_values = list(return_values)
std.ReturnOp(return_values)
# Recompute the function type.
return_types = [v.type for v in return_values]
function_type = FunctionType.get(inputs=inputs, results=return_types)
func_op.attributes["type"] = TypeAttr.get(function_type)
def emit_call_op(*call_args):
call_op = std.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name),
call_args)
if return_types is None:
return None
elif len(return_types) == 1:
return call_op.result
else:
return call_op.results
wrapped = emit_call_op
wrapped.__name__ = f.__name__
wrapped.func_op = func_op
return wrapped
return decorator

View File

@ -8,9 +8,106 @@ import mlir.dialects.std as std
def run(f):
print("\nTEST:", f.__name__)
f()
return f
# CHECK-LABEL: TEST: testFromPyFunc
@run
def testFromPyFunc():
with Context() as ctx, Location.unknown() as loc:
m = builtin.ModuleOp()
f32 = F32Type.get()
f64 = F64Type.get()
with InsertionPoint.at_block_terminator(m.body):
# CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
# CHECK: return %arg0 : f64
@builtin.FuncOp.from_py_func(f64)
def unary_return(a):
return a
# CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
# CHECK: return %arg0, %arg1 : f32, f64
@builtin.FuncOp.from_py_func(f32, f64)
def binary_return(a, b):
return a, b
# CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
# CHECK: return
@builtin.FuncOp.from_py_func(f32, f64)
def none_return(a, b):
pass
# CHECK-LABEL: func @call_unary
# CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
# CHECK: return %0 : f64
@builtin.FuncOp.from_py_func(f64)
def call_unary(a):
return unary_return(a)
# CHECK-LABEL: func @call_binary
# CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
# CHECK: return %0#0, %0#1 : f32, f64
@builtin.FuncOp.from_py_func(f32, f64)
def call_binary(a, b):
return binary_return(a, b)
# CHECK-LABEL: func @call_none
# CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
# CHECK: return
@builtin.FuncOp.from_py_func(f32, f64)
def call_none(a, b):
return none_return(a, b)
## Variants and optional feature tests.
# CHECK-LABEL: func @from_name_arg
@builtin.FuncOp.from_py_func(f32, f64, name="from_name_arg")
def explicit_name(a, b):
return b
@builtin.FuncOp.from_py_func(f32, f64)
def positional_func_op(a, b, func_op):
assert isinstance(func_op, builtin.FuncOp)
return b
@builtin.FuncOp.from_py_func(f32, f64)
def kw_func_op(a, b=None, func_op=None):
assert isinstance(func_op, builtin.FuncOp)
return b
@builtin.FuncOp.from_py_func(f32, f64)
def kwargs_func_op(a, b=None, **kwargs):
assert isinstance(kwargs["func_op"], builtin.FuncOp)
return b
# CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
# CHECK: return %arg1 : f64
@builtin.FuncOp.from_py_func(f32, f64, results=[f64])
def explicit_results(a, b):
std.ReturnOp([b])
print(m)
# CHECK-LABEL: TEST: testFromPyFuncErrors
@run
def testFromPyFuncErrors():
with Context() as ctx, Location.unknown() as loc:
m = builtin.ModuleOp()
f32 = F32Type.get()
f64 = F64Type.get()
with InsertionPoint.at_block_terminator(m.body):
try:
@builtin.FuncOp.from_py_func(f64, results=[f64])
def unary_return(a):
return a
except AssertionError as e:
# CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
print(e)
# CHECK-LABEL: TEST: testBuildFuncOp
@run
def testBuildFuncOp():
ctx = Context()
with Location.unknown(ctx) as loc:
@ -64,6 +161,3 @@ def testBuildFuncOp():
# CHECK: return %arg0 : tensor<2x3x4xf32>
# CHECK: }
print(m)
run(testBuildFuncOp)

View File

@ -10,46 +10,6 @@ from mlir.dialects import std
from mlir.dialects.linalg.opdsl.lang import *
# TODO: Find a home for this quality of life helper.
def build_function(*inputs: Type, results: Optional[Sequence[Type]] = None):
"""Decorator that emits a function in a more pythonic way.
If result types are not specified, they are inferred from the function
returns. The `ReturnOp` is implicitly added upon the wrapped function return.
"""
def decorator(f):
return_types = results
symbol_name = f.__name__
function_type = FunctionType.get(inputs=inputs, results=results or [])
func_op = builtin.FuncOp(name=symbol_name, type=function_type)
with InsertionPoint(func_op.add_entry_block()):
func_args = func_op.entry_block.arguments
return_values = f(*func_args)
if return_values is None:
return_values = []
elif isinstance(return_values, Value):
return_values = [return_values]
else:
return_values = list(return_values)
std.ReturnOp(return_values)
if return_types is None:
# Recompute the function type.
return_types = [v.type for v in return_values]
function_type = FunctionType.get(inputs=inputs, results=return_types)
# TODO: Have an API or a setter for this.
func_op.attributes["type"] = TypeAttr.get(function_type)
# TODO: When turning this into a real facility, return a function that emits
# a `call` to the function instead of doing nothing.
wrapped = lambda: None
wrapped.__name__ = symbol_name
wrapped.func_op = func_op
return wrapped
return decorator
@linalg_structured_op
def matmul_mono(A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N),
@ -92,8 +52,8 @@ with Context() as ctx, Location.unknown():
# CHECK-SAME: ins(%[[A]], %[[B]]
# CHECK-SAME: outs(%[[INITC]]
@build_function(RankedTensorType.get((4, 16), f32),
RankedTensorType.get((16, 8), f32))
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
RankedTensorType.get((16, 8), f32))
def test_matmul_mono(lhs, rhs):
# TODO: Enable outs inference and add sugar for InitTensorOp
# construction.
@ -114,9 +74,9 @@ with Context() as ctx, Location.unknown():
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
# CHECK-NEXT: linalg.yield %[[ADD]] : i32
# CHECK-NEXT: -> tensor<4x8xi32>
@build_function(RankedTensorType.get((4, 16), i8),
RankedTensorType.get((16, 8), i8),
RankedTensorType.get((4, 8), i32))
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
RankedTensorType.get((16, 8), i8),
RankedTensorType.get((4, 8), i32))
def test_i8i8i32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@ -128,9 +88,9 @@ with Context() as ctx, Location.unknown():
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
# CHECK-NEXT: linalg.yield %[[ADD]] : i32
# CHECK-NEXT: -> tensor<4x8xi32>
@build_function(RankedTensorType.get((4, 16), i8),
RankedTensorType.get((16, 8), i16),
RankedTensorType.get((4, 8), i32))
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
RankedTensorType.get((16, 8), i16),
RankedTensorType.get((4, 8), i32))
def test_i8i16i32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@ -142,9 +102,9 @@ with Context() as ctx, Location.unknown():
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
# CHECK-NEXT: linalg.yield %[[ADD]] : i16
# CHECK-NEXT: -> tensor<4x8xi16>
@build_function(RankedTensorType.get((4, 16), i32),
RankedTensorType.get((16, 8), i32),
RankedTensorType.get((4, 8), i16))
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32),
RankedTensorType.get((16, 8), i32),
RankedTensorType.get((4, 8), i16))
def test_i32i32i16_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@ -156,9 +116,9 @@ with Context() as ctx, Location.unknown():
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
# CHECK-NEXT: -> tensor<4x8xf32>
@build_function(RankedTensorType.get((4, 16), i8),
RankedTensorType.get((16, 8), i8),
RankedTensorType.get((4, 8), f32))
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
RankedTensorType.get((16, 8), i8),
RankedTensorType.get((4, 8), f32))
def test_i8i8f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@ -170,9 +130,9 @@ with Context() as ctx, Location.unknown():
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
# CHECK-NEXT: -> tensor<4x8xf32>
@build_function(RankedTensorType.get((4, 16), f16),
RankedTensorType.get((16, 8), f16),
RankedTensorType.get((4, 8), f32))
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f16),
RankedTensorType.get((16, 8), f16),
RankedTensorType.get((4, 8), f32))
def test_f16f16f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])
@ -184,9 +144,9 @@ with Context() as ctx, Location.unknown():
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
# CHECK-NEXT: -> tensor<4x8xf32>
@build_function(RankedTensorType.get((4, 16), f64),
RankedTensorType.get((16, 8), f64),
RankedTensorType.get((4, 8), f32))
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f64),
RankedTensorType.get((16, 8), f64),
RankedTensorType.get((4, 8), f32))
def test_f64f64f32_matmul(lhs, rhs, init_result):
return matmul_poly(lhs, rhs, outs=[init_result])