mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-01 18:12:44 +00:00
[mlir][python] Function decorator for capturing a FuncOp from a python function.
* Moves this out of a test case where it was being developed to good effect and generalizes it. * Having tried a number of things like this, I think this balances concerns reasonably well. Differential Revision: https://reviews.llvm.org/D98989
This commit is contained in:
parent
b76c09023d
commit
d9343e6153
@ -1,6 +1,11 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import inspect
|
||||
|
||||
from ..ir import *
|
||||
|
||||
|
||||
@ -93,3 +98,99 @@ class FuncOp:
|
||||
raise IndexError('The function already has an entry block!')
|
||||
self.body.blocks.append(*self.type.inputs)
|
||||
return self.body.blocks[0]
|
||||
|
||||
@classmethod
|
||||
def from_py_func(FuncOp,
|
||||
*inputs: Type,
|
||||
results: Optional[Sequence[Type]] = None,
|
||||
name: Optional[str] = None):
|
||||
"""Decorator to define an MLIR FuncOp specified as a python function.
|
||||
|
||||
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
|
||||
active for the current thread (i.e. established in a `with` block).
|
||||
|
||||
When applied as a decorator to a Python function, an entry block will
|
||||
be constructed for the FuncOp with types as specified in `*inputs`. The
|
||||
block arguments will be passed positionally to the Python function. In
|
||||
addition, if the Python function accepts keyword arguments generally or
|
||||
has a corresponding keyword argument, the following will be passed:
|
||||
* `func_op`: The `func` op being defined.
|
||||
|
||||
By default, the function name will be the Python function `__name__`. This
|
||||
can be overriden by passing the `name` argument to the decorator.
|
||||
|
||||
If `results` is not specified, then the decorator will implicitly
|
||||
insert a `ReturnOp` with the `Value`'s returned from the decorated
|
||||
function. It will also set the `FuncOp` type with the actual return
|
||||
value types. If `results` is specified, then the decorated function
|
||||
must return `None` and no implicit `ReturnOp` is added (nor are the result
|
||||
types updated). The implicit behavior is intended for simple, single-block
|
||||
cases, and users should specify result types explicitly for any complicated
|
||||
cases.
|
||||
|
||||
The decorated function can further be called from Python and will insert
|
||||
a `CallOp` at the then-current insertion point, returning either None (
|
||||
if no return values), a unary Value (for one result), or a list of Values).
|
||||
This mechanism cannot be used to emit recursive calls (by construction).
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
from . import std
|
||||
# Introspect the callable for optional features.
|
||||
sig = inspect.signature(f)
|
||||
has_arg_func_op = False
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == param.VAR_KEYWORD:
|
||||
has_arg_func_op = True
|
||||
if param.name == "func_op" and (param.kind
|
||||
== param.POSITIONAL_OR_KEYWORD or
|
||||
param.kind == param.KEYWORD_ONLY):
|
||||
has_arg_func_op = True
|
||||
|
||||
# Emit the FuncOp.
|
||||
implicit_return = results is None
|
||||
symbol_name = name or f.__name__
|
||||
function_type = FunctionType.get(
|
||||
inputs=inputs, results=[] if implicit_return else results)
|
||||
func_op = FuncOp(name=symbol_name, type=function_type)
|
||||
with InsertionPoint(func_op.add_entry_block()):
|
||||
func_args = func_op.entry_block.arguments
|
||||
func_kwargs = {}
|
||||
if has_arg_func_op:
|
||||
func_kwargs["func_op"] = func_op
|
||||
return_values = f(*func_args, **func_kwargs)
|
||||
if not implicit_return:
|
||||
return_types = list(results)
|
||||
assert return_values is None, (
|
||||
"Capturing a python function with explicit `results=` "
|
||||
"requires that the wrapped function returns None.")
|
||||
else:
|
||||
# Coerce return values, add ReturnOp and rewrite func type.
|
||||
if return_values is None:
|
||||
return_values = []
|
||||
elif isinstance(return_values, Value):
|
||||
return_values = [return_values]
|
||||
else:
|
||||
return_values = list(return_values)
|
||||
std.ReturnOp(return_values)
|
||||
# Recompute the function type.
|
||||
return_types = [v.type for v in return_values]
|
||||
function_type = FunctionType.get(inputs=inputs, results=return_types)
|
||||
func_op.attributes["type"] = TypeAttr.get(function_type)
|
||||
|
||||
def emit_call_op(*call_args):
|
||||
call_op = std.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name),
|
||||
call_args)
|
||||
if return_types is None:
|
||||
return None
|
||||
elif len(return_types) == 1:
|
||||
return call_op.result
|
||||
else:
|
||||
return call_op.results
|
||||
|
||||
wrapped = emit_call_op
|
||||
wrapped.__name__ = f.__name__
|
||||
wrapped.func_op = func_op
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
|
@ -8,9 +8,106 @@ import mlir.dialects.std as std
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
f()
|
||||
return f
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFromPyFunc
|
||||
@run
|
||||
def testFromPyFunc():
|
||||
with Context() as ctx, Location.unknown() as loc:
|
||||
m = builtin.ModuleOp()
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
# CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
|
||||
# CHECK: return %arg0 : f64
|
||||
@builtin.FuncOp.from_py_func(f64)
|
||||
def unary_return(a):
|
||||
return a
|
||||
|
||||
# CHECK-LABEL: func @binary_return(%arg0: f32, %arg1: f64) -> (f32, f64)
|
||||
# CHECK: return %arg0, %arg1 : f32, f64
|
||||
@builtin.FuncOp.from_py_func(f32, f64)
|
||||
def binary_return(a, b):
|
||||
return a, b
|
||||
|
||||
# CHECK-LABEL: func @none_return(%arg0: f32, %arg1: f64)
|
||||
# CHECK: return
|
||||
@builtin.FuncOp.from_py_func(f32, f64)
|
||||
def none_return(a, b):
|
||||
pass
|
||||
|
||||
# CHECK-LABEL: func @call_unary
|
||||
# CHECK: %0 = call @unary_return(%arg0) : (f64) -> f64
|
||||
# CHECK: return %0 : f64
|
||||
@builtin.FuncOp.from_py_func(f64)
|
||||
def call_unary(a):
|
||||
return unary_return(a)
|
||||
|
||||
# CHECK-LABEL: func @call_binary
|
||||
# CHECK: %0:2 = call @binary_return(%arg0, %arg1) : (f32, f64) -> (f32, f64)
|
||||
# CHECK: return %0#0, %0#1 : f32, f64
|
||||
@builtin.FuncOp.from_py_func(f32, f64)
|
||||
def call_binary(a, b):
|
||||
return binary_return(a, b)
|
||||
|
||||
# CHECK-LABEL: func @call_none
|
||||
# CHECK: call @none_return(%arg0, %arg1) : (f32, f64) -> ()
|
||||
# CHECK: return
|
||||
@builtin.FuncOp.from_py_func(f32, f64)
|
||||
def call_none(a, b):
|
||||
return none_return(a, b)
|
||||
|
||||
## Variants and optional feature tests.
|
||||
# CHECK-LABEL: func @from_name_arg
|
||||
@builtin.FuncOp.from_py_func(f32, f64, name="from_name_arg")
|
||||
def explicit_name(a, b):
|
||||
return b
|
||||
|
||||
@builtin.FuncOp.from_py_func(f32, f64)
|
||||
def positional_func_op(a, b, func_op):
|
||||
assert isinstance(func_op, builtin.FuncOp)
|
||||
return b
|
||||
|
||||
@builtin.FuncOp.from_py_func(f32, f64)
|
||||
def kw_func_op(a, b=None, func_op=None):
|
||||
assert isinstance(func_op, builtin.FuncOp)
|
||||
return b
|
||||
|
||||
@builtin.FuncOp.from_py_func(f32, f64)
|
||||
def kwargs_func_op(a, b=None, **kwargs):
|
||||
assert isinstance(kwargs["func_op"], builtin.FuncOp)
|
||||
return b
|
||||
|
||||
# CHECK-LABEL: func @explicit_results(%arg0: f32, %arg1: f64) -> f64
|
||||
# CHECK: return %arg1 : f64
|
||||
@builtin.FuncOp.from_py_func(f32, f64, results=[f64])
|
||||
def explicit_results(a, b):
|
||||
std.ReturnOp([b])
|
||||
|
||||
print(m)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testFromPyFuncErrors
|
||||
@run
|
||||
def testFromPyFuncErrors():
|
||||
with Context() as ctx, Location.unknown() as loc:
|
||||
m = builtin.ModuleOp()
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
try:
|
||||
|
||||
@builtin.FuncOp.from_py_func(f64, results=[f64])
|
||||
def unary_return(a):
|
||||
return a
|
||||
except AssertionError as e:
|
||||
# CHECK: Capturing a python function with explicit `results=` requires that the wrapped function returns None.
|
||||
print(e)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testBuildFuncOp
|
||||
@run
|
||||
def testBuildFuncOp():
|
||||
ctx = Context()
|
||||
with Location.unknown(ctx) as loc:
|
||||
@ -64,6 +161,3 @@ def testBuildFuncOp():
|
||||
# CHECK: return %arg0 : tensor<2x3x4xf32>
|
||||
# CHECK: }
|
||||
print(m)
|
||||
|
||||
|
||||
run(testBuildFuncOp)
|
||||
|
@ -10,46 +10,6 @@ from mlir.dialects import std
|
||||
from mlir.dialects.linalg.opdsl.lang import *
|
||||
|
||||
|
||||
# TODO: Find a home for this quality of life helper.
|
||||
def build_function(*inputs: Type, results: Optional[Sequence[Type]] = None):
|
||||
"""Decorator that emits a function in a more pythonic way.
|
||||
|
||||
If result types are not specified, they are inferred from the function
|
||||
returns. The `ReturnOp` is implicitly added upon the wrapped function return.
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
return_types = results
|
||||
symbol_name = f.__name__
|
||||
function_type = FunctionType.get(inputs=inputs, results=results or [])
|
||||
func_op = builtin.FuncOp(name=symbol_name, type=function_type)
|
||||
with InsertionPoint(func_op.add_entry_block()):
|
||||
func_args = func_op.entry_block.arguments
|
||||
return_values = f(*func_args)
|
||||
if return_values is None:
|
||||
return_values = []
|
||||
elif isinstance(return_values, Value):
|
||||
return_values = [return_values]
|
||||
else:
|
||||
return_values = list(return_values)
|
||||
std.ReturnOp(return_values)
|
||||
if return_types is None:
|
||||
# Recompute the function type.
|
||||
return_types = [v.type for v in return_values]
|
||||
function_type = FunctionType.get(inputs=inputs, results=return_types)
|
||||
# TODO: Have an API or a setter for this.
|
||||
func_op.attributes["type"] = TypeAttr.get(function_type)
|
||||
|
||||
# TODO: When turning this into a real facility, return a function that emits
|
||||
# a `call` to the function instead of doing nothing.
|
||||
wrapped = lambda: None
|
||||
wrapped.__name__ = symbol_name
|
||||
wrapped.func_op = func_op
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@linalg_structured_op
|
||||
def matmul_mono(A=TensorDef(T, S.M, S.K),
|
||||
B=TensorDef(T, S.K, S.N),
|
||||
@ -92,8 +52,8 @@ with Context() as ctx, Location.unknown():
|
||||
# CHECK-SAME: ins(%[[A]], %[[B]]
|
||||
# CHECK-SAME: outs(%[[INITC]]
|
||||
|
||||
@build_function(RankedTensorType.get((4, 16), f32),
|
||||
RankedTensorType.get((16, 8), f32))
|
||||
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f32),
|
||||
RankedTensorType.get((16, 8), f32))
|
||||
def test_matmul_mono(lhs, rhs):
|
||||
# TODO: Enable outs inference and add sugar for InitTensorOp
|
||||
# construction.
|
||||
@ -114,9 +74,9 @@ with Context() as ctx, Location.unknown():
|
||||
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
|
||||
# CHECK-NEXT: linalg.yield %[[ADD]] : i32
|
||||
# CHECK-NEXT: -> tensor<4x8xi32>
|
||||
@build_function(RankedTensorType.get((4, 16), i8),
|
||||
RankedTensorType.get((16, 8), i8),
|
||||
RankedTensorType.get((4, 8), i32))
|
||||
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
|
||||
RankedTensorType.get((16, 8), i8),
|
||||
RankedTensorType.get((4, 8), i32))
|
||||
def test_i8i8i32_matmul(lhs, rhs, init_result):
|
||||
return matmul_poly(lhs, rhs, outs=[init_result])
|
||||
|
||||
@ -128,9 +88,9 @@ with Context() as ctx, Location.unknown():
|
||||
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
|
||||
# CHECK-NEXT: linalg.yield %[[ADD]] : i32
|
||||
# CHECK-NEXT: -> tensor<4x8xi32>
|
||||
@build_function(RankedTensorType.get((4, 16), i8),
|
||||
RankedTensorType.get((16, 8), i16),
|
||||
RankedTensorType.get((4, 8), i32))
|
||||
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
|
||||
RankedTensorType.get((16, 8), i16),
|
||||
RankedTensorType.get((4, 8), i32))
|
||||
def test_i8i16i32_matmul(lhs, rhs, init_result):
|
||||
return matmul_poly(lhs, rhs, outs=[init_result])
|
||||
|
||||
@ -142,9 +102,9 @@ with Context() as ctx, Location.unknown():
|
||||
# CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i16
|
||||
# CHECK-NEXT: linalg.yield %[[ADD]] : i16
|
||||
# CHECK-NEXT: -> tensor<4x8xi16>
|
||||
@build_function(RankedTensorType.get((4, 16), i32),
|
||||
RankedTensorType.get((16, 8), i32),
|
||||
RankedTensorType.get((4, 8), i16))
|
||||
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i32),
|
||||
RankedTensorType.get((16, 8), i32),
|
||||
RankedTensorType.get((4, 8), i16))
|
||||
def test_i32i32i16_matmul(lhs, rhs, init_result):
|
||||
return matmul_poly(lhs, rhs, outs=[init_result])
|
||||
|
||||
@ -156,9 +116,9 @@ with Context() as ctx, Location.unknown():
|
||||
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
|
||||
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
|
||||
# CHECK-NEXT: -> tensor<4x8xf32>
|
||||
@build_function(RankedTensorType.get((4, 16), i8),
|
||||
RankedTensorType.get((16, 8), i8),
|
||||
RankedTensorType.get((4, 8), f32))
|
||||
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), i8),
|
||||
RankedTensorType.get((16, 8), i8),
|
||||
RankedTensorType.get((4, 8), f32))
|
||||
def test_i8i8f32_matmul(lhs, rhs, init_result):
|
||||
return matmul_poly(lhs, rhs, outs=[init_result])
|
||||
|
||||
@ -170,9 +130,9 @@ with Context() as ctx, Location.unknown():
|
||||
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
|
||||
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
|
||||
# CHECK-NEXT: -> tensor<4x8xf32>
|
||||
@build_function(RankedTensorType.get((4, 16), f16),
|
||||
RankedTensorType.get((16, 8), f16),
|
||||
RankedTensorType.get((4, 8), f32))
|
||||
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f16),
|
||||
RankedTensorType.get((16, 8), f16),
|
||||
RankedTensorType.get((4, 8), f32))
|
||||
def test_f16f16f32_matmul(lhs, rhs, init_result):
|
||||
return matmul_poly(lhs, rhs, outs=[init_result])
|
||||
|
||||
@ -184,9 +144,9 @@ with Context() as ctx, Location.unknown():
|
||||
# CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
|
||||
# CHECK-NEXT: linalg.yield %[[ADD]] : f32
|
||||
# CHECK-NEXT: -> tensor<4x8xf32>
|
||||
@build_function(RankedTensorType.get((4, 16), f64),
|
||||
RankedTensorType.get((16, 8), f64),
|
||||
RankedTensorType.get((4, 8), f32))
|
||||
@builtin.FuncOp.from_py_func(RankedTensorType.get((4, 16), f64),
|
||||
RankedTensorType.get((16, 8), f64),
|
||||
RankedTensorType.get((4, 8), f32))
|
||||
def test_f64f64f32_matmul(lhs, rhs, init_result):
|
||||
return matmul_poly(lhs, rhs, outs=[init_result])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user