[mlir][transform][lingalg][python] Replace pdl.operation => transform.any_op. (#66392)

For some reason, the mix-ins of the Python bindings of this dialect used
the PDL type for "any op". However, PDL isn't involved here, so it makes
more sense to use the corresponding type of the transform dialect. This
PR changes that.
This commit is contained in:
Ingo Müller 2023-09-15 09:06:07 +02:00 committed by GitHub
parent 6d73cca186
commit 5d3489e940
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@
try:
from ..ir import *
from ..dialects import pdl, transform
from ..dialects import transform
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
@ -203,7 +203,8 @@ class DecomposeOp:
"""Specialization for DecomposeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
transformed_type = transform.AnyOpType.get()
super().__init__(transformed_type, target, loc=loc, ip=ip)
class FuseIntoContainingOp:
@ -274,7 +275,8 @@ class GeneralizeOp:
"""Specialization for GeneralizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip)
transformed_type = transform.AnyOpType.get()
super().__init__(transformed_type, target, loc=loc, ip=ip)
class InterchangeOp:
@ -288,9 +290,9 @@ class InterchangeOp:
loc=None,
ip=None,
):
pdl_operation_type = pdl.OperationType.get()
transformed_type = transform.AnyOpType.get()
super().__init__(
pdl_operation_type,
transformed_type,
target,
iterator_interchange=iterator_interchange,
loc=loc,
@ -503,11 +505,11 @@ class PadOp:
):
transpose_paddings = _get_int_array_array_attr(transpose_paddings)
pdl_operation_type = pdl.OperationType.get()
any_op_type = transform.AnyOpType.get()
super().__init__(
pdl_operation_type,
pdl_operation_type,
pdl_operation_type,
any_op_type,
any_op_type,
any_op_type,
target,
padding_values=padding_values,
padding_dimensions=padding_dimensions,
@ -524,8 +526,8 @@ class ScalarizeOp:
"""Specialization for ScalarizeOp class."""
def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None):
pdl_operation_type = pdl.OperationType.get()
super().__init__(pdl_operation_type, target, loc=loc, ip=ip)
result_type = transform.AnyOpType.get()
super().__init__(result_type, target, loc=loc, ip=ip)
class SplitOp:
@ -736,9 +738,9 @@ class VectorizeOp:
loc=None,
ip=None,
):
pdl_operation_type = pdl.OperationType.get()
transformed_type = transform.AnyOpType.get()
super().__init__(
pdl_operation_type,
transformed_type,
target,
disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns,
disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns,