[mlir] Introduce Python bindings for the PDL dialect

This change adds full python bindings for PDL, including types and operations
with additional mixins to make operation construction more similar to the PDL
syntax.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D117458
This commit is contained in:
Denys Shabalin 2022-01-13 10:53:21 +01:00
parent a8890995ee
commit ed21c9276a
12 changed files with 975 additions and 2 deletions

View File

@ -49,6 +49,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType);
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
//===---------------------------------------------------------------------===//
// TypeType
//===---------------------------------------------------------------------===//

View File

@ -0,0 +1,102 @@
//===- DialectPDL.cpp - 'pdl' dialect submodule ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/PDL.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
using namespace llvm;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::adaptors;
void populateDialectPDLSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//
// PDLType
//===-------------------------------------------------------------------===//
auto pdlType = mlir_type_subclass(m, "PDLType", mlirTypeIsAPDLType);
//===-------------------------------------------------------------------===//
// AttributeType
//===-------------------------------------------------------------------===//
auto attributeType =
mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType);
attributeType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirPDLAttributeTypeGet(ctx));
},
"Get an instance of AttributeType in given context.", py::arg("cls"),
py::arg("context") = py::none());
//===-------------------------------------------------------------------===//
// OperationType
//===-------------------------------------------------------------------===//
auto operationType =
mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType);
operationType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirPDLOperationTypeGet(ctx));
},
"Get an instance of OperationType in given context.", py::arg("cls"),
py::arg("context") = py::none());
//===-------------------------------------------------------------------===//
// RangeType
//===-------------------------------------------------------------------===//
auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType);
rangeType.def_classmethod(
"get",
[](py::object cls, MlirType elementType) {
return cls(mlirPDLRangeTypeGet(elementType));
},
"Gets an instance of RangeType in the same context as the provided "
"element type.",
py::arg("cls"), py::arg("element_type"));
rangeType.def_property_readonly(
"element_type",
[](MlirType type) { return mlirPDLRangeTypeGetElementType(type); },
"Get the element type.");
//===-------------------------------------------------------------------===//
// TypeType
//===-------------------------------------------------------------------===//
auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType);
typeType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirPDLTypeTypeGet(ctx));
},
"Get an instance of TypeType in given context.", py::arg("cls"),
py::arg("context") = py::none());
//===-------------------------------------------------------------------===//
// ValueType
//===-------------------------------------------------------------------===//
auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType);
valueType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirPDLValueTypeGet(ctx));
},
"Get an instance of TypeType in given context.", py::arg("cls"),
py::arg("context") = py::none());
}
PYBIND11_MODULE(_mlirDialectsPDL, m) {
m.doc() = "MLIR PDL dialect.";
populateDialectPDLSubmodule(m);
}

View File

@ -60,6 +60,10 @@ MlirType mlirPDLRangeTypeGet(MlirType elementType) {
return wrap(pdl::RangeType::get(unwrap(elementType)));
}
MlirType mlirPDLRangeTypeGetElementType(MlirType type) {
return wrap(unwrap(type).cast<pdl::RangeType>().getElementType());
}
//===---------------------------------------------------------------------===//
// TypeType
//===---------------------------------------------------------------------===//

View File

@ -123,6 +123,15 @@ declare_mlir_python_sources(
dialects/quant.py
_mlir_libs/_mlir/dialects/quant.pyi)
declare_mlir_python_sources(
MLIRPythonSources.Dialects.pdl
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
SOURCES
dialects/pdl.py
dialects/_pdl_ops_ext.py
_mlir_libs/_mlir/dialects/pdl.pyi)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@ -243,6 +252,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
MLIRCAPIQuant
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
MODULE_NAME _mlirDialectsPDL
ADD_TO_PARENT MLIRPythonSources.Dialects.pdl
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectPDL.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPIPDL
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
MODULE_NAME _mlirDialectsSparseTensor
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor

View File

@ -0,0 +1,64 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Optional
from mlir.ir import Type, Context
__all__ = [
'PDLType',
'AttributeType',
'OperationType',
'RangeType',
'TypeType',
'ValueType',
]
class PDLType(Type):
@staticmethod
def isinstance(type: Type) -> bool: ...
class AttributeType(Type):
@staticmethod
def isinstance(type: Type) -> bool: ...
@staticmethod
def get(context: Optional[Context] = None) -> AttributeType: ...
class OperationType(Type):
@staticmethod
def isinstance(type: Type) -> bool: ...
@staticmethod
def get(context: Optional[Context] = None) -> OperationType: ...
class RangeType(Type):
@staticmethod
def isinstance(type: Type) -> bool: ...
@staticmethod
def get(element_type: Type) -> RangeType: ...
@property
def element_type(self) -> Type: ...
class TypeType(Type):
@staticmethod
def isinstance(type: Type) -> bool: ...
@staticmethod
def get(context: Optional[Context] = None) -> TypeType: ...
class ValueType(Type):
@staticmethod
def isinstance(type: Type) -> bool: ...
@staticmethod
def get(context: Optional[Context] = None) -> ValueType: ...

View File

@ -0,0 +1,15 @@
//===-- PDLOps.td - Entry point for PDLOps bind ------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_PDL_OPS
#define PYTHON_BINDINGS_PDL_OPS
include "mlir/Bindings/Python/Attributes.td"
include "mlir/Dialect/PDL/IR/PDLOps.td"
#endif

View File

@ -144,7 +144,8 @@ def get_op_result_or_value(
def get_op_results_or_values(
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]]
arg: _Union[_cext.ir.OpView, _cext.ir.Operation,
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]]
) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
"""Returns the given sequence of values or the results of the given op.
@ -157,4 +158,4 @@ def get_op_results_or_values(
elif isinstance(arg, _cext.ir.Operation):
return arg.results
else:
return arg
return [get_op_result_or_value(element) for element in arg]

View File

@ -0,0 +1,284 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Union, Optional, Sequence, List, Mapping
from ._ods_common import get_op_result_or_value as _get_value, get_op_results_or_values as _get_values
def _get_int_attr(bits: int, value: Union[IntegerAttr, int]) -> IntegerAttr:
"""Converts the given value to signless integer attribute of given bit width."""
if isinstance(value, int):
ty = IntegerType.get_signless(bits)
return IntegerAttr.get(ty, value)
else:
return value
def _get_array_attr(attrs: Union[ArrayAttr, Sequence[Attribute]]) -> ArrayAttr:
"""Converts the given value to array attribute."""
if isinstance(attrs, ArrayAttr):
return attrs
else:
return ArrayAttr.get(list(attrs))
def _get_str_array_attr(attrs: Union[ArrayAttr, Sequence[str]]) -> ArrayAttr:
"""Converts the given value to string array attribute."""
if isinstance(attrs, ArrayAttr):
return attrs
else:
return ArrayAttr.get([StringAttr.get(s) for s in attrs])
def _get_str_attr(name: Union[StringAttr, str]) -> Optional[StringAttr]:
"""Converts the given value to string attribute."""
if isinstance(name, str):
return StringAttr.get(name)
else:
return name
def _get_type_attr(type: Union[TypeAttr, Type]) -> TypeAttr:
"""Converts the given value to type attribute."""
if isinstance(type, Type):
return TypeAttr.get(type)
else:
return type
class ApplyNativeConstraintOp:
"""Specialization for PDL apply native constraint op class."""
def __init__(self,
name: Union[str, StringAttr],
args: Sequence[Union[OpView, Operation, Value]] = [],
params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
*,
loc=None,
ip=None):
name = _get_str_attr(name)
args = _get_values(args)
params = params if params is None else _get_array_attr(params)
super().__init__(name, args, params, loc=loc, ip=ip)
class ApplyNativeRewriteOp:
"""Specialization for PDL apply native rewrite op class."""
def __init__(self,
results: Sequence[Type],
name: Union[str, StringAttr],
args: Sequence[Union[OpView, Operation, Value]] = [],
params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
*,
loc=None,
ip=None):
name = _get_str_attr(name)
args = _get_values(args)
params = params if params is None else _get_array_attr(params)
super().__init__(results, name, args, params, loc=loc, ip=ip)
class AttributeOp:
"""Specialization for PDL attribute op class."""
def __init__(self,
type: Optional[Union[OpView, Operation, Value]] = None,
value: Optional[Attribute] = None,
*,
loc=None,
ip=None):
type = type if type is None else _get_value(type)
result = pdl.AttributeType.get()
super().__init__(result, type, value, loc=loc, ip=ip)
class EraseOp:
"""Specialization for PDL erase op class."""
def __init__(self,
operation: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None):
operation = _get_value(operation)
super().__init__(operation, loc=loc, ip=ip)
class OperandOp:
"""Specialization for PDL operand op class."""
def __init__(self,
type: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None):
type = type if type is None else _get_value(type)
result = pdl.ValueType.get()
super().__init__(result, type, loc=loc, ip=ip)
class OperandsOp:
"""Specialization for PDL operands op class."""
def __init__(self,
types: Optional[Union[OpView, Operation, Value]] = None,
*,
loc=None,
ip=None):
types = types if types is None else _get_value(types)
result = pdl.RangeType.get(pdl.ValueType.get())
super().__init__(result, types, loc=loc, ip=ip)
class OperationOp:
"""Specialization for PDL operand op class."""
def __init__(self,
name: Optional[Union[str, StringAttr]] = None,
args: Sequence[Union[OpView, Operation, Value]] = [],
attributes: Mapping[str, Union[OpView, Operation, Value]] = {},
types: Sequence[Union[OpView, Operation, Value]] = [],
*,
loc=None,
ip=None):
name = name if name is None else _get_str_attr(name)
args = _get_values(args)
attributeNames = []
attributeValues = []
for attrName, attrValue in attributes.items():
attributeNames.append(StringAttr.get(attrName))
attributeValues.append(_get_value(attrValue))
attributeNames = ArrayAttr.get(attributeNames)
types = _get_values(types)
result = pdl.OperationType.get()
super().__init__(result, name, args, attributeValues, attributeNames, types, loc=loc, ip=ip)
class PatternOp:
"""Specialization for PDL pattern op class."""
def __init__(self,
benefit: Union[IntegerAttr, int],
name: Optional[Union[StringAttr, str]] = None,
*,
loc=None,
ip=None):
"""Creates an PDL `pattern` operation."""
name_attr = None if name is None else _get_str_attr(name)
benefit_attr = _get_int_attr(16, benefit)
super().__init__(benefit_attr, name_attr, loc=loc, ip=ip)
self.regions[0].blocks.append()
@property
def body(self):
"""Return the body (block) of the pattern."""
return self.regions[0].blocks[0]
class ReplaceOp:
"""Specialization for PDL replace op class."""
def __init__(self,
op: Union[OpView, Operation, Value],
*,
with_op: Optional[Union[OpView, Operation, Value]] = None,
with_values: Sequence[Union[OpView, Operation, Value]] = [],
loc=None,
ip=None):
op = _get_value(op)
with_op = with_op if with_op is None else _get_value(with_op)
with_values = _get_values(with_values)
super().__init__(op, with_op, with_values, loc=loc, ip=ip)
class ResultOp:
"""Specialization for PDL result op class."""
def __init__(self,
parent: Union[OpView, Operation, Value],
index: Union[IntegerAttr, int],
*,
loc=None,
ip=None):
index = _get_int_attr(32, index)
parent = _get_value(parent)
result = pdl.ValueType.get()
super().__init__(result, parent, index, loc=loc, ip=ip)
class ResultsOp:
"""Specialization for PDL results op class."""
def __init__(self,
result: Type,
parent: Union[OpView, Operation, Value],
index: Optional[Union[IntegerAttr, int]] = None,
*,
loc=None,
ip=None):
parent = _get_value(parent)
index = index if index is None else _get_int_attr(32, index)
super().__init__(result, parent, index, loc=loc, ip=ip)
class RewriteOp:
"""Specialization for PDL rewrite op class."""
def __init__(self,
root: Optional[Union[OpView, Operation, Value]] = None,
name: Optional[Union[StringAttr, str]] = None,
args: Sequence[Union[OpView, Operation, Value]] = [],
params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
*,
loc=None,
ip=None):
root = root if root is None else _get_value(root)
name = name if name is None else _get_str_attr(name)
args = _get_values(args)
params = params if params is None else _get_array_attr(params)
super().__init__(root, name, args, params, loc=loc, ip=ip)
def add_body(self):
"""Add body (block) to the rewrite."""
self.regions[0].blocks.append()
return self.body
@property
def body(self):
"""Return the body (block) of the rewrite."""
return self.regions[0].blocks[0]
class TypeOp:
"""Specialization for PDL type op class."""
def __init__(self,
type: Optional[Union[TypeAttr, Type]] = None,
*,
loc=None,
ip=None):
type = type if type is None else _get_type_attr(type)
result = pdl.TypeType.get()
super().__init__(result, type, loc=loc, ip=ip)
class TypesOp:
"""Specialization for PDL types op class."""
def __init__(self,
types: Sequence[Union[TypeAttr, Type]] = [],
*,
loc=None,
ip=None):
types = _get_array_attr([_get_type_attr(ty) for ty in types])
types = None if not types else types
result = pdl.RangeType.get(pdl.TypeType.get())
super().__init__(result, types, loc=loc, ip=ip)

View File

@ -0,0 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._pdl_ops_gen import *
from .._mlir_libs._mlirDialectsPDL import *

View File

@ -146,6 +146,7 @@ void testRangeType(MlirContext ctx) {
MlirType parsedType = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString("!pdl.range<type>"));
MlirType constructedType = mlirPDLRangeTypeGet(typeType);
MlirType elementType = mlirPDLRangeTypeGetElementType(constructedType);
assert(!mlirTypeIsNull(typeType) && "couldn't get PDLTypeType");
assert(!mlirTypeIsNull(parsedType) && "couldn't parse PDLAttributeType");
@ -191,11 +192,15 @@ void testRangeType(MlirContext ctx) {
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType));
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirTypeEqual(typeType, elementType));
// CHECK: !pdl.range<type>
mlirTypeDump(parsedType);
// CHECK: !pdl.range<type>
mlirTypeDump(constructedType);
// CHECK: !pdl.type
mlirTypeDump(elementType);
fprintf(stderr, "\n\n");
}

View File

@ -0,0 +1,318 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects.pdl import *
def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
return f
# CHECK: module {
# CHECK: pdl.pattern @operations : benefit(1) {
# CHECK: %0 = pdl.attribute
# CHECK: %1 = pdl.type
# CHECK: %2 = pdl.operation {"attr" = %0} -> (%1 : !pdl.type)
# CHECK: %3 = pdl.result 0 of %2
# CHECK: %4 = pdl.operand
# CHECK: %5 = pdl.operation(%3, %4 : !pdl.value, !pdl.value)
# CHECK: pdl.rewrite %5 with "rewriter"
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_operations():
pattern = PatternOp(1, "operations")
with InsertionPoint(pattern.body):
attr = AttributeOp()
ty = TypeOp()
op0 = OperationOp(attributes={"attr": attr}, types=[ty])
op0_result = ResultOp(op0, 0)
input = OperandOp()
root = OperationOp(args=[op0_result, input])
RewriteOp(root, "rewriter")
# CHECK: module {
# CHECK: pdl.pattern @rewrite_with_args : benefit(1) {
# CHECK: %0 = pdl.operand
# CHECK: %1 = pdl.operation(%0 : !pdl.value)
# CHECK: pdl.rewrite %1 with "rewriter"(%0 : !pdl.value)
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_with_args():
pattern = PatternOp(1, "rewrite_with_args")
with InsertionPoint(pattern.body):
input = OperandOp()
root = OperationOp(args=[input])
RewriteOp(root, "rewriter", args=[input])
# CHECK: module {
# CHECK: pdl.pattern @rewrite_with_params : benefit(1) {
# CHECK: %0 = pdl.operation
# CHECK: pdl.rewrite %0 with "rewriter" ["I am param"]
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_with_params():
pattern = PatternOp(1, "rewrite_with_params")
with InsertionPoint(pattern.body):
op = OperationOp()
RewriteOp(op, "rewriter", params=[StringAttr.get("I am param")])
# CHECK: module {
# CHECK: pdl.pattern @rewrite_with_args_and_params : benefit(1) {
# CHECK: %0 = pdl.operand
# CHECK: %1 = pdl.operation(%0 : !pdl.value)
# CHECK: pdl.rewrite %1 with "rewriter" ["I am param"](%0 : !pdl.value)
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_with_args_and_params():
pattern = PatternOp(1, "rewrite_with_args_and_params")
with InsertionPoint(pattern.body):
input = OperandOp()
root = OperationOp(args=[input])
RewriteOp(root, "rewriter", params=[StringAttr.get("I am param")], args=[input])
# CHECK: module {
# CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) {
# CHECK: %0 = pdl.operand
# CHECK: %1 = pdl.operand
# CHECK: %2 = pdl.type
# CHECK: %3 = pdl.operation(%0 : !pdl.value) -> (%2 : !pdl.type)
# CHECK: %4 = pdl.result 0 of %3
# CHECK: %5 = pdl.operation(%4 : !pdl.value)
# CHECK: %6 = pdl.operation(%1 : !pdl.value) -> (%2 : !pdl.type)
# CHECK: %7 = pdl.result 0 of %6
# CHECK: %8 = pdl.operation(%4, %7 : !pdl.value, !pdl.value)
# CHECK: pdl.rewrite with "rewriter" ["I am param"](%5, %8 : !pdl.operation, !pdl.operation)
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_multi_root_optimal():
pattern = PatternOp(1, "rewrite_multi_root_optimal")
with InsertionPoint(pattern.body):
input1 = OperandOp()
input2 = OperandOp()
ty = TypeOp()
op1 = OperationOp(args=[input1], types=[ty])
val1 = ResultOp(op1, 0)
root1 = OperationOp(args=[val1])
op2 = OperationOp(args=[input2], types=[ty])
val2 = ResultOp(op2, 0)
root2 = OperationOp(args=[val1, val2])
RewriteOp(name="rewriter", params=[StringAttr.get("I am param")], args=[root1, root2])
# CHECK: module {
# CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) {
# CHECK: %0 = pdl.operand
# CHECK: %1 = pdl.operand
# CHECK: %2 = pdl.type
# CHECK: %3 = pdl.operation(%0 : !pdl.value) -> (%2 : !pdl.type)
# CHECK: %4 = pdl.result 0 of %3
# CHECK: %5 = pdl.operation(%4 : !pdl.value)
# CHECK: %6 = pdl.operation(%1 : !pdl.value) -> (%2 : !pdl.type)
# CHECK: %7 = pdl.result 0 of %6
# CHECK: %8 = pdl.operation(%4, %7 : !pdl.value, !pdl.value)
# CHECK: pdl.rewrite %5 with "rewriter" ["I am param"](%8 : !pdl.operation)
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_multi_root_forced():
pattern = PatternOp(1, "rewrite_multi_root_forced")
with InsertionPoint(pattern.body):
input1 = OperandOp()
input2 = OperandOp()
ty = TypeOp()
op1 = OperationOp(args=[input1], types=[ty])
val1 = ResultOp(op1, 0)
root1 = OperationOp(args=[val1])
op2 = OperationOp(args=[input2], types=[ty])
val2 = ResultOp(op2, 0)
root2 = OperationOp(args=[val1, val2])
RewriteOp(root1, name="rewriter", params=[StringAttr.get("I am param")], args=[root2])
# CHECK: module {
# CHECK: pdl.pattern @rewrite_add_body : benefit(1) {
# CHECK: %0 = pdl.type : i32
# CHECK: %1 = pdl.type
# CHECK: %2 = pdl.operation -> (%0, %1 : !pdl.type, !pdl.type)
# CHECK: pdl.rewrite %2 {
# CHECK: %3 = pdl.type
# CHECK: %4 = pdl.operation "foo.op" -> (%0, %3 : !pdl.type, !pdl.type)
# CHECK: pdl.replace %2 with %4
# CHECK: }
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_add_body():
pattern = PatternOp(1, "rewrite_add_body")
with InsertionPoint(pattern.body):
ty1 = TypeOp(IntegerType.get_signless(32))
ty2 = TypeOp()
root = OperationOp(types=[ty1, ty2])
rewrite = RewriteOp(root)
with InsertionPoint(rewrite.add_body()):
ty3 = TypeOp()
newOp = OperationOp(name="foo.op", types=[ty1, ty3])
ReplaceOp(root, with_op=newOp)
# CHECK: module {
# CHECK: pdl.pattern @rewrite_type : benefit(1) {
# CHECK: %0 = pdl.type : i32
# CHECK: %1 = pdl.type
# CHECK: %2 = pdl.operation -> (%0, %1 : !pdl.type, !pdl.type)
# CHECK: pdl.rewrite %2 {
# CHECK: %3 = pdl.operation "foo.op" -> (%0, %1 : !pdl.type, !pdl.type)
# CHECK: }
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_type():
pattern = PatternOp(1, "rewrite_type")
with InsertionPoint(pattern.body):
ty1 = TypeOp(IntegerType.get_signless(32))
ty2 = TypeOp()
root = OperationOp(types=[ty1, ty2])
rewrite = RewriteOp(root)
with InsertionPoint(rewrite.add_body()):
newOp = OperationOp(name="foo.op", types=[ty1, ty2])
# CHECK: module {
# CHECK: pdl.pattern @rewrite_types : benefit(1) {
# CHECK: %0 = pdl.types
# CHECK: %1 = pdl.operation -> (%0 : !pdl.range<type>)
# CHECK: pdl.rewrite %1 {
# CHECK: %2 = pdl.types : [i32, i64]
# CHECK: %3 = pdl.operation "foo.op" -> (%0, %2 : !pdl.range<type>, !pdl.range<type>)
# CHECK: }
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_types():
pattern = PatternOp(1, "rewrite_types")
with InsertionPoint(pattern.body):
types = TypesOp()
root = OperationOp(types=[types])
rewrite = RewriteOp(root)
with InsertionPoint(rewrite.add_body()):
otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)])
newOp = OperationOp(name="foo.op", types=[types, otherTypes])
# CHECK: module {
# CHECK: pdl.pattern @rewrite_operands : benefit(1) {
# CHECK: %0 = pdl.types
# CHECK: %1 = pdl.operands : %0
# CHECK: %2 = pdl.operation(%1 : !pdl.range<value>)
# CHECK: pdl.rewrite %2 {
# CHECK: %3 = pdl.operation "foo.op" -> (%0 : !pdl.range<type>)
# CHECK: }
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_rewrite_operands():
pattern = PatternOp(1, "rewrite_operands")
with InsertionPoint(pattern.body):
types = TypesOp()
operands = OperandsOp(types)
root = OperationOp(args=[operands])
rewrite = RewriteOp(root)
with InsertionPoint(rewrite.add_body()):
newOp = OperationOp(name="foo.op", types=[types])
# CHECK: module {
# CHECK: pdl.pattern @native_rewrite : benefit(1) {
# CHECK: %0 = pdl.operation
# CHECK: pdl.rewrite %0 {
# CHECK: pdl.apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation)
# CHECK: }
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_native_rewrite():
pattern = PatternOp(1, "native_rewrite")
with InsertionPoint(pattern.body):
root = OperationOp()
rewrite = RewriteOp(root)
with InsertionPoint(rewrite.add_body()):
ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
# CHECK: module {
# CHECK: pdl.pattern @attribute_with_value : benefit(1) {
# CHECK: %0 = pdl.operation
# CHECK: pdl.rewrite %0 {
# CHECK: %1 = pdl.attribute "value"
# CHECK: pdl.apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute)
# CHECK: }
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_attribute_with_value():
pattern = PatternOp(1, "attribute_with_value")
with InsertionPoint(pattern.body):
root = OperationOp()
rewrite = RewriteOp(root)
with InsertionPoint(rewrite.add_body()):
attr = AttributeOp(value=Attribute.parse('"value"'))
ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
# CHECK: module {
# CHECK: pdl.pattern @erase : benefit(1) {
# CHECK: %0 = pdl.operation
# CHECK: pdl.rewrite %0 {
# CHECK: pdl.erase %0
# CHECK: }
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_erase():
pattern = PatternOp(1, "erase")
with InsertionPoint(pattern.body):
root = OperationOp()
rewrite = RewriteOp(root)
with InsertionPoint(rewrite.add_body()):
EraseOp(root)
# CHECK: module {
# CHECK: pdl.pattern @operation_results : benefit(1) {
# CHECK: %0 = pdl.types
# CHECK: %1 = pdl.operation -> (%0 : !pdl.range<type>)
# CHECK: %2 = pdl.results of %1
# CHECK: %3 = pdl.operation(%2 : !pdl.range<value>)
# CHECK: pdl.rewrite %3 with "rewriter"
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_operation_results():
valueRange = RangeType.get(ValueType.get())
pattern = PatternOp(1, "operation_results")
with InsertionPoint(pattern.body):
types = TypesOp()
inputOp = OperationOp(types=[types])
results = ResultsOp(valueRange, inputOp)
root = OperationOp(args=[results])
RewriteOp(root, name="rewriter")
# CHECK: module {
# CHECK: pdl.pattern : benefit(1) {
# CHECK: %0 = pdl.type
# CHECK: pdl.apply_native_constraint "typeConstraint" [](%0 : !pdl.type)
# CHECK: %1 = pdl.operation -> (%0 : !pdl.type)
# CHECK: pdl.rewrite %1 with "rewrite"
# CHECK: }
# CHECK: }
@constructAndPrintInModule
def test_apply_native_constraint():
pattern = PatternOp(1)
with InsertionPoint(pattern.body):
resultType = TypeOp()
ApplyNativeConstraintOp("typeConstraint", args=[resultType], params=[])
root = OperationOp(types=[resultType])
RewriteOp(root, name="rewrite")

View File

@ -0,0 +1,150 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import pdl
def run(f):
print("\nTEST:", f.__name__)
f()
return f
# CHECK-LABEL: TEST: test_attribute_type
@run
def test_attribute_type():
with Context():
parsedType = Type.parse("!pdl.attribute")
constructedType = pdl.AttributeType.get()
assert pdl.AttributeType.isinstance(parsedType)
assert not pdl.OperationType.isinstance(parsedType)
assert not pdl.RangeType.isinstance(parsedType)
assert not pdl.TypeType.isinstance(parsedType)
assert not pdl.ValueType.isinstance(parsedType)
assert pdl.AttributeType.isinstance(constructedType)
assert not pdl.OperationType.isinstance(constructedType)
assert not pdl.RangeType.isinstance(constructedType)
assert not pdl.TypeType.isinstance(constructedType)
assert not pdl.ValueType.isinstance(constructedType)
assert parsedType == constructedType
# CHECK: !pdl.attribute
print(parsedType)
# CHECK: !pdl.attribute
print(constructedType)
# CHECK-LABEL: TEST: test_operation_type
@run
def test_operation_type():
with Context():
parsedType = Type.parse("!pdl.operation")
constructedType = pdl.OperationType.get()
assert not pdl.AttributeType.isinstance(parsedType)
assert pdl.OperationType.isinstance(parsedType)
assert not pdl.RangeType.isinstance(parsedType)
assert not pdl.TypeType.isinstance(parsedType)
assert not pdl.ValueType.isinstance(parsedType)
assert not pdl.AttributeType.isinstance(constructedType)
assert pdl.OperationType.isinstance(constructedType)
assert not pdl.RangeType.isinstance(constructedType)
assert not pdl.TypeType.isinstance(constructedType)
assert not pdl.ValueType.isinstance(constructedType)
assert parsedType == constructedType
# CHECK: !pdl.operation
print(parsedType)
# CHECK: !pdl.operation
print(constructedType)
# CHECK-LABEL: TEST: test_range_type
@run
def test_range_type():
with Context():
typeType = Type.parse("!pdl.type")
parsedType = Type.parse("!pdl.range<type>")
constructedType = pdl.RangeType.get(typeType)
elementType = constructedType.element_type
assert not pdl.AttributeType.isinstance(parsedType)
assert not pdl.OperationType.isinstance(parsedType)
assert pdl.RangeType.isinstance(parsedType)
assert not pdl.TypeType.isinstance(parsedType)
assert not pdl.ValueType.isinstance(parsedType)
assert not pdl.AttributeType.isinstance(constructedType)
assert not pdl.OperationType.isinstance(constructedType)
assert pdl.RangeType.isinstance(constructedType)
assert not pdl.TypeType.isinstance(constructedType)
assert not pdl.ValueType.isinstance(constructedType)
assert parsedType == constructedType
assert elementType == typeType
# CHECK: !pdl.range<type>
print(parsedType)
# CHECK: !pdl.range<type>
print(constructedType)
# CHECK: !pdl.type
print(elementType)
# CHECK-LABEL: TEST: test_type_type
@run
def test_type_type():
with Context():
parsedType = Type.parse("!pdl.type")
constructedType = pdl.TypeType.get()
assert not pdl.AttributeType.isinstance(parsedType)
assert not pdl.OperationType.isinstance(parsedType)
assert not pdl.RangeType.isinstance(parsedType)
assert pdl.TypeType.isinstance(parsedType)
assert not pdl.ValueType.isinstance(parsedType)
assert not pdl.AttributeType.isinstance(constructedType)
assert not pdl.OperationType.isinstance(constructedType)
assert not pdl.RangeType.isinstance(constructedType)
assert pdl.TypeType.isinstance(constructedType)
assert not pdl.ValueType.isinstance(constructedType)
assert parsedType == constructedType
# CHECK: !pdl.type
print(parsedType)
# CHECK: !pdl.type
print(constructedType)
# CHECK-LABEL: TEST: test_value_type
@run
def test_value_type():
with Context():
parsedType = Type.parse("!pdl.value")
constructedType = pdl.ValueType.get()
assert not pdl.AttributeType.isinstance(parsedType)
assert not pdl.OperationType.isinstance(parsedType)
assert not pdl.RangeType.isinstance(parsedType)
assert not pdl.TypeType.isinstance(parsedType)
assert pdl.ValueType.isinstance(parsedType)
assert not pdl.AttributeType.isinstance(constructedType)
assert not pdl.OperationType.isinstance(constructedType)
assert not pdl.RangeType.isinstance(constructedType)
assert not pdl.TypeType.isinstance(constructedType)
assert pdl.ValueType.isinstance(constructedType)
assert parsedType == constructedType
# CHECK: !pdl.value
print(parsedType)
# CHECK: !pdl.value
print(constructedType)