[mlir][linalg][transform][python] Refactor TileOp mix-in.

This patch simplifies and improves the mix-in of the `TileOp`. In
particular:

* Accept all types of sizes (static, dynamic, scalable) in a single
  argument `sizes`.
* Use the existing convenience function to dispatch different types of
  sizes instead of repeating the implementation in the mix-in.
* Pass on `None` values as is of optional arguments to the init function
  of the super class.
* Reformat with default indentation width (4 spaces vs 2 spaces).
* Add a a test for providing scalable sizes.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D159417
This commit is contained in:
Ingo Müller 2023-09-04 08:28:57 +00:00
parent d5946fd3ed
commit ea4a5127c4
2 changed files with 47 additions and 61 deletions

View File

@ -571,107 +571,77 @@ class SplitOp:
class TileOp:
"""Specialization for TileOp class."""
"""Specialization for TileOp class."""
@overload
def __init__(
@overload
def __init__(
self,
loop_types: Union[Type, List[Type]],
target: Union[Operation, Value],
*,
sizes: Optional[
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
] = None,
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
scalable_sizes: OptionalBoolList = None,
loc=None,
ip=None,
):
...
...
@overload
def __init__(
@overload
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
sizes: Optional[
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
] = None,
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
scalable_sizes: OptionalBoolList = None,
loc=None,
ip=None,
):
...
...
def __init__(
def __init__(
self,
loop_types_or_target: Union[Type, List[Type], Operation, Value],
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
sizes: Optional[
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
] = None,
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
interchange: OptionalIntList = None,
scalable_sizes: OptionalBoolList = None,
loc=None,
ip=None,
):
if interchange is None:
interchange = []
if sizes is None:
sizes = []
(
dynamic_sizes,
static_sizes,
scalable_sizes,
) = _dispatch_dynamic_index_list(sizes)
static_sizes = []
dynamic_sizes = []
if isinstance(sizes, ArrayAttr):
sizes_attr = sizes
else:
for size in sizes:
if isinstance(size, int):
static_sizes.append(size)
num_loops = sum(v if v == 0 else 1 for v in static_sizes)
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
target = loop_types_or_target
assert target_or_none is None, "Cannot construct TileOp with two targets."
else:
static_sizes.append(ShapedType.get_dynamic_size())
dynamic_sizes.append(_get_op_result_or_value(size))
sizes_attr = DenseI64ArrayAttr.get(static_sizes)
loop_types = (
([loop_types_or_target] * num_loops)
if isinstance(loop_types_or_target, Type)
else loop_types_or_target
)
target = target_or_none
num_loops = sum(
v if v == 0 else 1 for v in self.__extract_values(sizes_attr)
)
if scalable_sizes is None:
scalable_sizes = [False] * len(self.__extract_values(sizes_attr))
target = _get_op_result_or_value(target)
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
target = loop_types_or_target
assert target_or_none is None, "Cannot construct TileOp with two targets."
else:
loop_types = (
([loop_types_or_target] * num_loops)
if isinstance(loop_types_or_target, Type)
else loop_types_or_target
)
target = target_or_none
target = _get_op_result_or_value(target)
super().__init__(
super().__init__(
target.type,
loop_types,
target,
dynamic_sizes=dynamic_sizes,
static_sizes=sizes_attr,
static_sizes=static_sizes,
interchange=interchange,
scalable_sizes=scalable_sizes,
loc=loc,
ip=ip,
)
def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
if not attr:
return []
return [element for element in attr]
class TileToForallOp:
"""Specialization for TileToForallOp class."""

View File

@ -486,6 +486,22 @@ def testTileExplicitLoopTypeAll():
# CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
@run
def testTileScalable():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.TileOp(
sequence.bodyTarget,
sizes=[4, [2]],
)
transform.YieldOp()
# CHECK-LABEL: TEST: testTileScalable
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, [2]]
@run
def testTileToForallCompact():
sequence = transform.SequenceOp(