mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-15 12:09:51 +00:00
[mlir][linalg][transform][python] Add mix-in for MapCopyToThreadsOp.
Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D157706
This commit is contained in:
parent
030e315ee7
commit
691a2fab88
@ -187,6 +187,66 @@ class InterchangeOp:
|
||||
)
|
||||
|
||||
|
||||
class MapCopyToThreadsOp:
|
||||
"""Specialization for MapCopyToThreadsOp class."""
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
forall_op_type: Type,
|
||||
tiled_op_type: Type,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
total_num_threads: Union[int, IntegerAttr],
|
||||
desired_bit_alignment: Union[int, IntegerAttr],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
target: Union[Operation, OpView, Value],
|
||||
*,
|
||||
total_num_threads: Union[int, IntegerAttr],
|
||||
desired_bit_alignment: Union[int, IntegerAttr],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
forall_op_type_or_target: Union[Operation, OpView, Type, Value],
|
||||
tiled_op_type_or_none: Optional[Type] = None,
|
||||
target_or_none: Optional[Union[Operation, OpView, Value]] = None,
|
||||
*,
|
||||
total_num_threads: Union[int, IntegerAttr],
|
||||
desired_bit_alignment: Union[int, IntegerAttr],
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(forall_op_type_or_target, Type):
|
||||
forall_op_type = forall_op_type_or_target
|
||||
tiled_op_type = tiled_op_type_or_none
|
||||
target = target_or_none
|
||||
else:
|
||||
forall_op_type = transform.AnyOpType.get()
|
||||
tiled_op_type = transform.AnyOpType.get()
|
||||
target = forall_op_type_or_target
|
||||
|
||||
super().__init__(
|
||||
forall_op_type,
|
||||
tiled_op_type,
|
||||
target,
|
||||
total_num_threads=total_num_threads,
|
||||
desired_bit_alignment=desired_bit_alignment,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class MatchOp:
|
||||
"""Specialization for MatchOp class."""
|
||||
|
||||
|
@ -97,6 +97,44 @@ def testInterchange():
|
||||
# CHECK: iterator_interchange = [1, 0]
|
||||
|
||||
|
||||
@run
|
||||
def testMapCopyToThreadsOpCompact():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.MapCopyToThreadsOp(
|
||||
sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128
|
||||
)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
|
||||
# CHECK: = transform.structured.gpu.map_copy_to_threads
|
||||
# CHECK-SAME: total_num_threads = 32
|
||||
# CHECK-SAME: desired_bit_alignment = 128
|
||||
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
|
||||
|
||||
@run
|
||||
def testMapCopyToThreadsOpTypes():
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], transform.AnyOpType.get()
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
structured.MapCopyToThreadsOp(
|
||||
transform.OperationType.get("test.opA"),
|
||||
transform.OperationType.get("test.opB"),
|
||||
sequence.bodyTarget,
|
||||
total_num_threads=32,
|
||||
desired_bit_alignment=128,
|
||||
)
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
|
||||
# CHECK: = transform.structured.gpu.map_copy_to_threads
|
||||
# CHECK-SAME: total_num_threads = 32
|
||||
# CHECK-SAME: desired_bit_alignment = 128
|
||||
# CHECK-SAME: (!transform.any_op) -> (!transform.op<"test.opA">, !transform.op<"test.opB">)
|
||||
|
||||
|
||||
@run
|
||||
def testMatchOpNamesString():
|
||||
sequence = transform.SequenceOp(
|
||||
|
Loading…
Reference in New Issue
Block a user