mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-13 10:42:05 +00:00
[mlir][linalg][transform][python] Refactor TileOp mix-in.
This patch simplifies and improves the mix-in of the `TileOp`. In particular: * Accept all types of sizes (static, dynamic, scalable) in a single argument `sizes`. * Use the existing convenience function to dispatch different types of sizes instead of repeating the implementation in the mix-in. * Pass on `None` values as is of optional arguments to the init function of the super class. * Reformat with default indentation width (4 spaces vs 2 spaces). * Add a a test for providing scalable sizes. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D159417
This commit is contained in:
parent
d5946fd3ed
commit
ea4a5127c4
@ -571,107 +571,77 @@ class SplitOp:
|
||||
|
||||
|
||||
class TileOp:
|
||||
"""Specialization for TileOp class."""
|
||||
"""Specialization for TileOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
loop_types: Union[Type, List[Type]],
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
sizes: Optional[
|
||||
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
|
||||
] = None,
|
||||
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
scalable_sizes: OptionalBoolList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, Value, OpView],
|
||||
*,
|
||||
sizes: Optional[
|
||||
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
|
||||
] = None,
|
||||
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
scalable_sizes: OptionalBoolList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
...
|
||||
|
||||
def __init__(
|
||||
def __init__(
|
||||
self,
|
||||
loop_types_or_target: Union[Type, List[Type], Operation, Value],
|
||||
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
|
||||
*,
|
||||
sizes: Optional[
|
||||
Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]
|
||||
] = None,
|
||||
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
scalable_sizes: OptionalBoolList = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if interchange is None:
|
||||
interchange = []
|
||||
if sizes is None:
|
||||
sizes = []
|
||||
(
|
||||
dynamic_sizes,
|
||||
static_sizes,
|
||||
scalable_sizes,
|
||||
) = _dispatch_dynamic_index_list(sizes)
|
||||
|
||||
static_sizes = []
|
||||
dynamic_sizes = []
|
||||
if isinstance(sizes, ArrayAttr):
|
||||
sizes_attr = sizes
|
||||
else:
|
||||
for size in sizes:
|
||||
if isinstance(size, int):
|
||||
static_sizes.append(size)
|
||||
num_loops = sum(v if v == 0 else 1 for v in static_sizes)
|
||||
|
||||
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
|
||||
loop_types = [transform.AnyOpType.get()] * num_loops
|
||||
target = loop_types_or_target
|
||||
assert target_or_none is None, "Cannot construct TileOp with two targets."
|
||||
else:
|
||||
static_sizes.append(ShapedType.get_dynamic_size())
|
||||
dynamic_sizes.append(_get_op_result_or_value(size))
|
||||
sizes_attr = DenseI64ArrayAttr.get(static_sizes)
|
||||
loop_types = (
|
||||
([loop_types_or_target] * num_loops)
|
||||
if isinstance(loop_types_or_target, Type)
|
||||
else loop_types_or_target
|
||||
)
|
||||
target = target_or_none
|
||||
|
||||
num_loops = sum(
|
||||
v if v == 0 else 1 for v in self.__extract_values(sizes_attr)
|
||||
)
|
||||
if scalable_sizes is None:
|
||||
scalable_sizes = [False] * len(self.__extract_values(sizes_attr))
|
||||
target = _get_op_result_or_value(target)
|
||||
|
||||
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
|
||||
loop_types = [transform.AnyOpType.get()] * num_loops
|
||||
target = loop_types_or_target
|
||||
assert target_or_none is None, "Cannot construct TileOp with two targets."
|
||||
else:
|
||||
loop_types = (
|
||||
([loop_types_or_target] * num_loops)
|
||||
if isinstance(loop_types_or_target, Type)
|
||||
else loop_types_or_target
|
||||
)
|
||||
target = target_or_none
|
||||
|
||||
target = _get_op_result_or_value(target)
|
||||
|
||||
super().__init__(
|
||||
super().__init__(
|
||||
target.type,
|
||||
loop_types,
|
||||
target,
|
||||
dynamic_sizes=dynamic_sizes,
|
||||
static_sizes=sizes_attr,
|
||||
static_sizes=static_sizes,
|
||||
interchange=interchange,
|
||||
scalable_sizes=scalable_sizes,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]:
|
||||
if not attr:
|
||||
return []
|
||||
return [element for element in attr]
|
||||
|
||||
|
||||
class TileToForallOp:
|
||||
"""Specialization for TileToForallOp class."""
|
||||
|
@ -486,6 +486,22 @@ def testTileExplicitLoopTypeAll():
|
||||
# CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
|
||||
|
||||
|
||||
@run
|
||||
def testTileScalable():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.TileOp(
|
||||
sequence.bodyTarget,
|
||||
sizes=[4, [2]],
|
||||
)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testTileScalable
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, [2]]
|
||||
|
||||
|
||||
@run
|
||||
def testTileToForallCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
|
Loading…
Reference in New Issue
Block a user