From 8fd207fd0dcc398c2fcfd953d7e3ebe7cb53f188 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 20 Jul 2023 09:58:41 +0000 Subject: [PATCH] [mlir][transform][structured][python] Allow str arg in match_op_names. Allow the `names` argument in `MatchOp.match_op_names` to be of type `str` in addition to `Sequence[str]`. In this case, the argument is treated as a list with one name, i.e., it is possible to write `MatchOp.match_op_names(..., "test.dummy")` instead of `MatchOp.match_op_names(..., ["test.dummy"])`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155807 --- .../dialects/_structured_transform_ops_ext.py | 11 +++++++---- .../dialects/transform_structured_ext.py | 18 ++++++++++++++++-- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 1936f4b0e0da..9f623efb5001 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -195,7 +195,7 @@ class MatchOp: def match_op_names( cls, target: Union[Operation, Value], - names: Sequence[str], + names: Union[str, Sequence[str]], *, loc=None, ip=None, @@ -208,7 +208,7 @@ class MatchOp: cls, result_type: Type, target: Union[Operation, Value], - names: Sequence[str], + names: Union[str, Sequence[str]], *, loc=None, ip=None, @@ -219,8 +219,8 @@ class MatchOp: def match_op_names( cls, result_type_or_target: Union[Type, Operation, Value], - target_or_names: Union[Operation, Value, Sequence[str]], - names_or_none: Optional[Sequence[str]] = None, + target_or_names: Union[Operation, Value, Sequence[str], str], + names_or_none: Optional[Union[Sequence[str], str]] = None, *, loc=None, ip=None, @@ -234,6 +234,9 @@ class MatchOp: target = result_type_or_target names = target_or_names + if isinstance(names, str): + names = [names] + return cls( result_type, _get_op_result_or_value(target), diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index 0bcfd81d75ff..1da55edf777e 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -97,14 +97,28 @@ def testInterchange(): @run -def testMatchOpNames(): +def testMatchOpNamesString(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() + ) + with InsertionPoint(sequence.body): + structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy") + transform.YieldOp() + # CHECK-LABEL: TEST: testMatchOpNamesString + # CHECK: transform.structured.match ops + # CHECK-SAME: ["test.dummy"] + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +@run +def testMatchOpNamesList(): sequence = transform.SequenceOp( transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get() ) with InsertionPoint(sequence.body): structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"]) transform.YieldOp() - # CHECK-LABEL: TEST: testMatchOpNames + # CHECK-LABEL: TEST: testMatchOpNamesList # CHECK: transform.structured.match ops # CHECK-SAME: ["test.dummy"] # CHECK-SAME: (!transform.any_op) -> !transform.any_op