[mlir][linalg][transform][python] Add mix-in for MapCopyToThreadsOp.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D157706
This commit is contained in:
Ingo Müller 2023-08-11 12:11:24 +00:00
parent 030e315ee7
commit 691a2fab88
2 changed files with 98 additions and 0 deletions

View File

@ -187,6 +187,66 @@ class InterchangeOp:
)
class MapCopyToThreadsOp:
"""Specialization for MapCopyToThreadsOp class."""
@overload
def __init__(
self,
forall_op_type: Type,
tiled_op_type: Type,
target: Union[Operation, OpView, Value],
*,
total_num_threads: Union[int, IntegerAttr],
desired_bit_alignment: Union[int, IntegerAttr],
loc=None,
ip=None,
):
...
@overload
def __init__(
self,
target: Union[Operation, OpView, Value],
*,
total_num_threads: Union[int, IntegerAttr],
desired_bit_alignment: Union[int, IntegerAttr],
loc=None,
ip=None,
):
...
def __init__(
self,
forall_op_type_or_target: Union[Operation, OpView, Type, Value],
tiled_op_type_or_none: Optional[Type] = None,
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
*,
total_num_threads: Union[int, IntegerAttr],
desired_bit_alignment: Union[int, IntegerAttr],
loc=None,
ip=None,
):
if isinstance(forall_op_type_or_target, Type):
forall_op_type = forall_op_type_or_target
tiled_op_type = tiled_op_type_or_none
target = target_or_none
else:
forall_op_type = transform.AnyOpType.get()
tiled_op_type = transform.AnyOpType.get()
target = forall_op_type_or_target
super().__init__(
forall_op_type,
tiled_op_type,
target,
total_num_threads=total_num_threads,
desired_bit_alignment=desired_bit_alignment,
loc=loc,
ip=ip,
)
class MatchOp:
"""Specialization for MatchOp class."""

View File

@ -97,6 +97,44 @@ def testInterchange():
# CHECK: iterator_interchange = [1, 0]
@run
def testMapCopyToThreadsOpCompact():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.MapCopyToThreadsOp(
sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128
)
transform.YieldOp()
# CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
# CHECK: = transform.structured.gpu.map_copy_to_threads
# CHECK-SAME: total_num_threads = 32
# CHECK-SAME: desired_bit_alignment = 128
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
@run
def testMapCopyToThreadsOpTypes():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.MapCopyToThreadsOp(
transform.OperationType.get("test.opA"),
transform.OperationType.get("test.opB"),
sequence.bodyTarget,
total_num_threads=32,
desired_bit_alignment=128,
)
transform.YieldOp()
# CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
# CHECK: = transform.structured.gpu.map_copy_to_threads
# CHECK-SAME: total_num_threads = 32
# CHECK-SAME: desired_bit_alignment = 128
# CHECK-SAME: (!transform.any_op) -> (!transform.op<"test.opA">, !transform.op<"test.opB">)
@run
def testMatchOpNamesString():
sequence = transform.SequenceOp(