[mlir][python] Add support for arg_attrs and other attrs to NamedSequenceOp

This commit is contained in:
Nicolas Vasilache 2023-11-08 13:17:25 +00:00
parent d5cfdcaacb
commit 5967375fcf
2 changed files with 9 additions and 2 deletions

View File

@ -172,11 +172,17 @@ class NamedSequenceOp(NamedSequenceOp):
sym_name,
input_types: Sequence[Type],
result_types: Sequence[Type],
sym_visibility=None,
arg_attrs=None,
res_attrs=None
):
function_type = FunctionType.get(input_types, result_types)
super().__init__(
sym_name=sym_name,
function_type=TypeAttr.get(function_type),
sym_visibility=sym_visibility,
arg_attrs=arg_attrs,
res_attrs=res_attrs
)
self.regions[0].blocks.append(*input_types)

View File

@ -153,6 +153,7 @@ def testTransformPDLOps(module: Module):
# CHECK: }
# CHECK: }
@run
def testNamedSequenceOp(module: Module):
module.operation.attributes["transform.with_named_sequence"] = UnitAttr.get()
@ -160,12 +161,12 @@ def testNamedSequenceOp(module: Module):
"__transform_main",
[transform.AnyOpType.get()],
[transform.AnyOpType.get()],
)
arg_attrs = [{"transform.consumed": UnitAttr.get()}])
with InsertionPoint(named_sequence.body):
transform.YieldOp([named_sequence.bodyTarget])
# CHECK-LABEL: TEST: testNamedSequenceOp
# CHECK: module attributes {transform.with_named_sequence} {
# CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op) -> !transform.any_op {
# CHECK: transform.named_sequence @__transform_main(%[[ARG0:.+]]: !transform.any_op {transform.consumed}) -> !transform.any_op {
# CHECK: yield %[[ARG0]] : !transform.any_op