mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-01 13:20:25 +00:00
[mlir:PDL] Remove the ConstantParams support from native Constraints/Rewrites
This support has never really worked well, and is incredibly clunky to use (it effectively creates two argument APIs), and clunky to generate (it isn't clear how we should actually expose this from PDL frontends). Treating these as just attribute arguments is much much cleaner in every aspect of the stack. If we need to optimize lots of constant parameters, it would be better to investigate internal representation optimizations (e.g. batch attribute creation), that do not affect the user (we want a clean external API). Differential Revision: https://reviews.llvm.org/D121569
This commit is contained in:
parent
469c58944d
commit
9595f3568a
@ -1007,14 +1007,12 @@ the C++ PDL API. For example, the constraints above may be registered as:
|
|||||||
|
|
||||||
```c++
|
```c++
|
||||||
// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
|
// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
|
||||||
static LogicalResult hasOneUseImpl(PDLValue pdlValue, ArrayAttr constantParams,
|
static LogicalResult hasOneUseImpl(PDLValue pdlValue, PatternRewriter &rewriter) {
|
||||||
PatternRewriter &rewriter) {
|
|
||||||
Value value = pdlValue.cast<Value>();
|
Value value = pdlValue.cast<Value>();
|
||||||
|
|
||||||
return success(value.hasOneUse());
|
return success(value.hasOneUse());
|
||||||
}
|
}
|
||||||
static LogicalResult hasSameElementTypeImpl(
|
static LogicalResult hasSameElementTypeImpl(ArrayRef<PDLValue> pdlValues,
|
||||||
ArrayRef<PDLValue> pdlValues, ArrayAttr constantParams,
|
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
Value value1 = pdlValues[0].cast<Value>();
|
Value value1 = pdlValues[0].cast<Value>();
|
||||||
Value value2 = pdlValues[1].cast<Value>();
|
Value value2 = pdlValues[1].cast<Value>();
|
||||||
@ -1310,8 +1308,8 @@ the C++ PDL API. For example, the rewrite above may be registered as:
|
|||||||
|
|
||||||
```c++
|
```c++
|
||||||
// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
|
// TODO: Cleanup when we allow more accessible wrappers around PDL functions.
|
||||||
static void buildOpImpl(ArrayRef<PDLValue> args, ArrayAttr constantParams,
|
static void buildOpImpl(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
|
||||||
PatternRewriter &rewriter, PDLResultList &results) {
|
PDLResultList &results) {
|
||||||
Value value = args[0].cast<Value>();
|
Value value = args[0].cast<Value>();
|
||||||
|
|
||||||
// insert special rewrite logic here.
|
// insert special rewrite logic here.
|
||||||
|
@ -35,33 +35,18 @@ def PDL_ApplyNativeConstraintOp
|
|||||||
let description = [{
|
let description = [{
|
||||||
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
|
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
|
||||||
has been registered externally with the consumer of PDL, to a given set of
|
has been registered externally with the consumer of PDL, to a given set of
|
||||||
entities. The constraint is permitted to accept any number of constant
|
entities.
|
||||||
valued parameters.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```mlir
|
```mlir
|
||||||
// Apply `myConstraint` to the entities defined by `input`, `attr`, and
|
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
|
||||||
// `op`. `42`, `"abc"`, and `i32` are constant parameters passed to the
|
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
|
||||||
// constraint.
|
|
||||||
pdl.apply_native_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
|
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins StrAttr:$name,
|
let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
|
||||||
Variadic<PDL_AnyType>:$args,
|
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
|
||||||
OptionalAttr<ArrayAttr>:$constParams);
|
|
||||||
let assemblyFormat = [{
|
|
||||||
$name ($constParams^)? `(` $args `:` type($args) `)` attr-dict
|
|
||||||
}];
|
|
||||||
|
|
||||||
let builders = [
|
|
||||||
OpBuilder<(ins "StringRef":$name, CArg<"ValueRange", "{}">:$args,
|
|
||||||
CArg<"ArrayRef<Attribute>", "{}">:$params), [{
|
|
||||||
build($_builder, $_state, $_builder.getStringAttr(name), args,
|
|
||||||
params.empty() ? ArrayAttr() : $_builder.getArrayAttr(params));
|
|
||||||
}]>,
|
|
||||||
];
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,26 +61,22 @@ def PDL_ApplyNativeRewriteOp
|
|||||||
`pdl.apply_native_rewrite` operations apply a native C++ function, that has
|
`pdl.apply_native_rewrite` operations apply a native C++ function, that has
|
||||||
been registered externally with the consumer of PDL, to perform a rewrite
|
been registered externally with the consumer of PDL, to perform a rewrite
|
||||||
and optionally return a number of values. The native function may accept any
|
and optionally return a number of values. The native function may accept any
|
||||||
number of arguments and constant attribute parameters. This operation is
|
number of arguments. This operation is used within a pdl.rewrite region to enable
|
||||||
used within a pdl.rewrite region to enable the interleaving of native
|
the interleaving of native rewrite methods with other pdl constructs.
|
||||||
rewrite methods with other pdl constructs.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```mlir
|
```mlir
|
||||||
// Apply a native rewrite method that returns an attribute.
|
// Apply a native rewrite method that returns an attribute.
|
||||||
%ret = pdl.apply_native_rewrite "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute
|
%ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %arg1) : !pdl.attribute
|
||||||
```
|
```
|
||||||
|
|
||||||
```c++
|
```c++
|
||||||
// The native rewrite as defined in C++:
|
// The native rewrite as defined in C++:
|
||||||
static void myNativeFunc(ArrayRef<PDLValue> args, ArrayAttr constantParams,
|
static void myNativeFunc(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
|
||||||
PatternRewriter &rewriter,
|
|
||||||
PDLResultList &results) {
|
PDLResultList &results) {
|
||||||
Value arg0 = args[0].cast<Value>();
|
Value arg0 = args[0].cast<Value>();
|
||||||
Value arg1 = args[1].cast<Value>();
|
Value arg1 = args[1].cast<Value>();
|
||||||
IntegerAttr param0 = constantParams[0].cast<IntegerAttr>();
|
|
||||||
StringAttr param1 = constantParams[1].cast<StringAttr>();
|
|
||||||
|
|
||||||
// Just push back the first param attribute.
|
// Just push back the first param attribute.
|
||||||
results.push_back(param0);
|
results.push_back(param0);
|
||||||
@ -107,13 +88,10 @@ def PDL_ApplyNativeRewriteOp
|
|||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins StrAttr:$name,
|
let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
|
||||||
Variadic<PDL_AnyType>:$args,
|
|
||||||
OptionalAttr<ArrayAttr>:$constParams);
|
|
||||||
let results = (outs Variadic<PDL_AnyType>:$results);
|
let results = (outs Variadic<PDL_AnyType>:$results);
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$name ($constParams^)? (`(` $args^ `:` type($args) `)`)?
|
$name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict
|
||||||
(`:` type($results)^)? attr-dict
|
|
||||||
}];
|
}];
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
@ -588,16 +566,15 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
|
|||||||
rewrite is specified either via a string name (`name`) to a native
|
rewrite is specified either via a string name (`name`) to a native
|
||||||
rewrite function, or via the region body. The rewrite region, if specified,
|
rewrite function, or via the region body. The rewrite region, if specified,
|
||||||
must contain a single block. If the rewrite is external it functions
|
must contain a single block. If the rewrite is external it functions
|
||||||
similarly to `pdl.apply_native_rewrite`, and takes a set of constant
|
similarly to `pdl.apply_native_rewrite`, and takes a set of additional
|
||||||
parameters and a set of additional positional values defined within the
|
positional values defined within the matcher as arguments. If the rewrite is
|
||||||
matcher as arguments. If the rewrite is external, the root operation is
|
external, the root operation is passed to the native function as the leading
|
||||||
passed to the native function as the leading arguments. The root operation,
|
arguments. The root operation, if provided, specifies the starting point in
|
||||||
if provided, specifies the starting point in the pattern for the subgraph
|
the pattern for the subgraph isomorphism search. Pattern matching will proceed
|
||||||
isomorphism search. Pattern matching will proceed from this node downward
|
from this node downward (towards the defining operation) or upward
|
||||||
(towards the defining operation) or upward (towards the users) until all
|
(towards the users) until all the operations in the pattern have been matched.
|
||||||
the operations in the pattern have been matched. If the root is omitted,
|
If the root is omitted, the pdl_interp lowering will automatically select
|
||||||
the pdl_interp lowering will automatically select the best root of the
|
the best root of the pdl.rewrite among all the operations in the pattern.
|
||||||
pdl.rewrite among all the operations in the pattern.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -623,12 +600,10 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [
|
|||||||
|
|
||||||
let arguments = (ins Optional<PDL_Operation>:$root,
|
let arguments = (ins Optional<PDL_Operation>:$root,
|
||||||
OptionalAttr<StrAttr>:$name,
|
OptionalAttr<StrAttr>:$name,
|
||||||
Variadic<PDL_AnyType>:$externalArgs,
|
Variadic<PDL_AnyType>:$externalArgs);
|
||||||
OptionalAttr<ArrayAttr>:$externalConstParams);
|
|
||||||
let regions = (region AnyRegion:$body);
|
let regions = (region AnyRegion:$body);
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
($root^)? (`with` $name^ ($externalConstParams^)?
|
($root^)? (`with` $name^ (`(` $externalArgs^ `:` type($externalArgs) `)`)?)?
|
||||||
(`(` $externalArgs^ `:` type($externalArgs) `)`)?)?
|
|
||||||
($body^)?
|
($body^)?
|
||||||
attr-dict-with-keyword
|
attr-dict-with-keyword
|
||||||
}];
|
}];
|
||||||
|
@ -89,25 +89,21 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
|
|||||||
let description = [{
|
let description = [{
|
||||||
`pdl_interp.apply_constraint` operations apply a generic constraint, that
|
`pdl_interp.apply_constraint` operations apply a generic constraint, that
|
||||||
has been registered with the interpreter, with a given set of positional
|
has been registered with the interpreter, with a given set of positional
|
||||||
values. The constraint may have any number of constant parameters. On
|
values. On success, this operation branches to the true destination,
|
||||||
success, this operation branches to the true destination, otherwise the
|
otherwise the false destination is taken.
|
||||||
false destination is taken.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```mlir
|
```mlir
|
||||||
// Apply `myConstraint` to the entities defined by `input`, `attr`, and
|
// Apply `myConstraint` to the entities defined by `input`, `attr`, and
|
||||||
// `op`.
|
// `op`.
|
||||||
pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest
|
pdl_interp.apply_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins StrAttr:$name,
|
let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
|
||||||
Variadic<PDL_AnyType>:$args,
|
|
||||||
OptionalAttr<ArrayAttr>:$constParams);
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$name ($constParams^)? `(` $args `:` type($args) `)` attr-dict `->`
|
$name `(` $args `:` type($args) `)` attr-dict `->` successors
|
||||||
successors
|
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,9 +116,8 @@ def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> {
|
|||||||
let description = [{
|
let description = [{
|
||||||
`pdl_interp.apply_rewrite` operations invoke an external rewriter that has
|
`pdl_interp.apply_rewrite` operations invoke an external rewriter that has
|
||||||
been registered with the interpreter to perform the rewrite after a
|
been registered with the interpreter to perform the rewrite after a
|
||||||
successful match. The rewrite is passed a set of positional arguments,
|
successful match. The rewrite is passed a set of positional arguments. The
|
||||||
and a set of constant parameters. The rewrite function may return any
|
rewrite function may return any number of results.
|
||||||
number of results.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -136,19 +131,12 @@ def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> {
|
|||||||
// Rewriter operating on the root operation along with additional arguments
|
// Rewriter operating on the root operation along with additional arguments
|
||||||
// from the matcher.
|
// from the matcher.
|
||||||
pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation, %value : !pdl.value)
|
pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation, %value : !pdl.value)
|
||||||
|
|
||||||
// Rewriter operating on the root operation along with additional arguments
|
|
||||||
// and constant parameters.
|
|
||||||
pdl_interp.apply_rewrite "rewriter"[42](%root : !pdl.operation, %value : !pdl.value)
|
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
let arguments = (ins StrAttr:$name,
|
let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args);
|
||||||
Variadic<PDL_AnyType>:$args,
|
|
||||||
OptionalAttr<ArrayAttr>:$constParams);
|
|
||||||
let results = (outs Variadic<PDL_AnyType>:$results);
|
let results = (outs Variadic<PDL_AnyType>:$results);
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$name ($constParams^)? (`(` $args^ `:` type($args) `)`)?
|
$name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict
|
||||||
(`:` type($results)^)? attr-dict
|
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -584,24 +584,16 @@ protected:
|
|||||||
// PDLPatternModule
|
// PDLPatternModule
|
||||||
|
|
||||||
/// A generic PDL pattern constraint function. This function applies a
|
/// A generic PDL pattern constraint function. This function applies a
|
||||||
/// constraint to a given set of opaque PDLValue entities. The second parameter
|
/// constraint to a given set of opaque PDLValue entities. Returns success if
|
||||||
/// is a set of constant value parameters specified in Attribute form. Returns
|
|
||||||
/// success if the constraint successfully held, failure otherwise.
|
|
||||||
using PDLConstraintFunction = std::function<LogicalResult(
|
|
||||||
ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &)>;
|
|
||||||
/// A native PDL rewrite function. This function performs a rewrite on the
|
|
||||||
/// given set of values and constant parameters. Any results from this rewrite
|
|
||||||
/// that should be passed back to PDL should be added to the provided result
|
|
||||||
/// list. This method is only invoked when the corresponding match was
|
|
||||||
/// successful.
|
|
||||||
using PDLRewriteFunction = std::function<void(
|
|
||||||
ArrayRef<PDLValue>, ArrayAttr, PatternRewriter &, PDLResultList &)>;
|
|
||||||
/// A generic PDL pattern constraint function. This function applies a
|
|
||||||
/// constraint to a given opaque PDLValue entity. The second parameter is a set
|
|
||||||
/// of constant value parameters specified in Attribute form. Returns success if
|
|
||||||
/// the constraint successfully held, failure otherwise.
|
/// the constraint successfully held, failure otherwise.
|
||||||
using PDLSingleEntityConstraintFunction =
|
using PDLConstraintFunction =
|
||||||
std::function<LogicalResult(PDLValue, ArrayAttr, PatternRewriter &)>;
|
std::function<LogicalResult(ArrayRef<PDLValue>, PatternRewriter &)>;
|
||||||
|
/// A native PDL rewrite function. This function performs a rewrite on the
|
||||||
|
/// given set of values. Any results from this rewrite that should be passed
|
||||||
|
/// back to PDL should be added to the provided result list. This method is only
|
||||||
|
/// invoked when the corresponding match was successful.
|
||||||
|
using PDLRewriteFunction =
|
||||||
|
std::function<void(ArrayRef<PDLValue>, PatternRewriter &, PDLResultList &)>;
|
||||||
|
|
||||||
/// This class contains all of the necessary data for a set of PDL patterns, or
|
/// This class contains all of the necessary data for a set of PDL patterns, or
|
||||||
/// pattern rewrites specified in the form of the PDL dialect. This PDL module
|
/// pattern rewrites specified in the form of the PDL dialect. This PDL module
|
||||||
@ -630,15 +622,14 @@ public:
|
|||||||
/// Register a single entity constraint function.
|
/// Register a single entity constraint function.
|
||||||
template <typename SingleEntityFn>
|
template <typename SingleEntityFn>
|
||||||
std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>,
|
std::enable_if_t<!llvm::is_invocable<SingleEntityFn, ArrayRef<PDLValue>,
|
||||||
ArrayAttr, PatternRewriter &>::value>
|
PatternRewriter &>::value>
|
||||||
registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) {
|
registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) {
|
||||||
registerConstraintFunction(
|
registerConstraintFunction(
|
||||||
name, [constraintFn = std::forward<SingleEntityFn>(constraintFn)](
|
name, [constraintFn = std::forward<SingleEntityFn>(constraintFn)](
|
||||||
ArrayRef<PDLValue> values, ArrayAttr constantParams,
|
ArrayRef<PDLValue> values, PatternRewriter &rewriter) {
|
||||||
PatternRewriter &rewriter) {
|
|
||||||
assert(values.size() == 1 &&
|
assert(values.size() == 1 &&
|
||||||
"expected values to have a single entity");
|
"expected values to have a single entity");
|
||||||
return constraintFn(values[0], constantParams, rewriter);
|
return constraintFn(values[0], rewriter);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -431,9 +431,8 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
|
|||||||
}
|
}
|
||||||
case Predicates::ConstraintQuestion: {
|
case Predicates::ConstraintQuestion: {
|
||||||
auto *cstQuestion = cast<ConstraintQuestion>(question);
|
auto *cstQuestion = cast<ConstraintQuestion>(question);
|
||||||
builder.create<pdl_interp::ApplyConstraintOp>(
|
builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(),
|
||||||
loc, cstQuestion->getName(), args, cstQuestion->getParams(), success,
|
args, success, failure);
|
||||||
failure);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -644,8 +643,7 @@ SymbolRefAttr PatternLowering::generateRewriter(
|
|||||||
auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
|
auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
|
||||||
args.append(mappedArgs.begin(), mappedArgs.end());
|
args.append(mappedArgs.begin(), mappedArgs.end());
|
||||||
builder.create<pdl_interp::ApplyRewriteOp>(
|
builder.create<pdl_interp::ApplyRewriteOp>(
|
||||||
rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args,
|
rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
|
||||||
rewriter.externalConstParamsAttr());
|
|
||||||
} else {
|
} else {
|
||||||
// Otherwise this is a dag rewriter defined using PDL operations.
|
// Otherwise this is a dag rewriter defined using PDL operations.
|
||||||
for (Operation &rewriteOp : *rewriter.getBody()) {
|
for (Operation &rewriteOp : *rewriter.getBody()) {
|
||||||
@ -678,8 +676,8 @@ void PatternLowering::generateRewriter(
|
|||||||
arguments.push_back(mapRewriteValue(argument));
|
arguments.push_back(mapRewriteValue(argument));
|
||||||
auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
|
auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
|
||||||
rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
|
rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
|
||||||
arguments, rewriteOp.constParamsAttr());
|
arguments);
|
||||||
for (auto it : llvm::zip(rewriteOp.results(), interpOp.getResults()))
|
for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
|
||||||
rewriteValues[std::get<0>(it)] = std::get<1>(it);
|
rewriteValues[std::get<0>(it)] = std::get<1>(it);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -445,9 +445,8 @@ struct AttributeQuestion
|
|||||||
|
|
||||||
/// Apply a parameterized constraint to multiple position values.
|
/// Apply a parameterized constraint to multiple position values.
|
||||||
struct ConstraintQuestion
|
struct ConstraintQuestion
|
||||||
: public PredicateBase<
|
: public PredicateBase<ConstraintQuestion, Qualifier,
|
||||||
ConstraintQuestion, Qualifier,
|
std::tuple<StringRef, ArrayRef<Position *>>,
|
||||||
std::tuple<StringRef, ArrayRef<Position *>, Attribute>,
|
|
||||||
Predicates::ConstraintQuestion> {
|
Predicates::ConstraintQuestion> {
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
@ -457,17 +456,11 @@ struct ConstraintQuestion
|
|||||||
/// Return the arguments of the constraint.
|
/// Return the arguments of the constraint.
|
||||||
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
|
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
|
||||||
|
|
||||||
/// Return the constant parameters of the constraint.
|
|
||||||
ArrayAttr getParams() const {
|
|
||||||
return std::get<2>(key).dyn_cast_or_null<ArrayAttr>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Construct an instance with the given storage allocator.
|
/// Construct an instance with the given storage allocator.
|
||||||
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
|
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
|
||||||
KeyTy key) {
|
KeyTy key) {
|
||||||
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
|
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
|
||||||
alloc.copyInto(std::get<1>(key)),
|
alloc.copyInto(std::get<1>(key))});
|
||||||
std::get<2>(key)});
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -667,10 +660,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a predicate that applies a generic constraint.
|
/// Create a predicate that applies a generic constraint.
|
||||||
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
|
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos) {
|
||||||
Attribute params) {
|
return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)),
|
||||||
return {
|
|
||||||
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, params)),
|
|
||||||
TrueAnswer::get(uniquer)};
|
TrueAnswer::get(uniquer)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,7 +263,6 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
|
|||||||
PredicateBuilder &builder,
|
PredicateBuilder &builder,
|
||||||
DenseMap<Value, Position *> &inputs) {
|
DenseMap<Value, Position *> &inputs) {
|
||||||
OperandRange arguments = op.args();
|
OperandRange arguments = op.args();
|
||||||
ArrayAttr parameters = op.constParamsAttr();
|
|
||||||
|
|
||||||
std::vector<Position *> allPositions;
|
std::vector<Position *> allPositions;
|
||||||
allPositions.reserve(arguments.size());
|
allPositions.reserve(arguments.size());
|
||||||
@ -274,7 +273,7 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
|
|||||||
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
|
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
|
||||||
comparePosDepth);
|
comparePosDepth);
|
||||||
PredicateBuilder::Predicate pred =
|
PredicateBuilder::Predicate pred =
|
||||||
builder.getConstraint(op.name(), allPositions, parameters);
|
builder.getConstraint(op.name(), allPositions);
|
||||||
predList.emplace_back(pos, pred);
|
predList.emplace_back(pos, pred);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -425,10 +425,6 @@ LogicalResult RewriteOp::verifyRegions() {
|
|||||||
return emitOpError() << "expected no external arguments when the "
|
return emitOpError() << "expected no external arguments when the "
|
||||||
"rewrite is specified inline";
|
"rewrite is specified inline";
|
||||||
}
|
}
|
||||||
if (externalConstParams()) {
|
|
||||||
return emitOpError() << "expected no external constant parameters when "
|
|
||||||
"the rewrite is specified inline";
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -757,8 +757,7 @@ void Generator::generate(pdl_interp::ApplyConstraintOp op,
|
|||||||
ByteCodeWriter &writer) {
|
ByteCodeWriter &writer) {
|
||||||
assert(constraintToMemIndex.count(op.getName()) &&
|
assert(constraintToMemIndex.count(op.getName()) &&
|
||||||
"expected index for constraint function");
|
"expected index for constraint function");
|
||||||
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()],
|
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
|
||||||
op.getConstParamsAttr());
|
|
||||||
writer.appendPDLValueList(op.getArgs());
|
writer.appendPDLValueList(op.getArgs());
|
||||||
writer.append(op.getSuccessors());
|
writer.append(op.getSuccessors());
|
||||||
}
|
}
|
||||||
@ -766,8 +765,7 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
|
|||||||
ByteCodeWriter &writer) {
|
ByteCodeWriter &writer) {
|
||||||
assert(externalRewriterToMemIndex.count(op.getName()) &&
|
assert(externalRewriterToMemIndex.count(op.getName()) &&
|
||||||
"expected index for rewrite function");
|
"expected index for rewrite function");
|
||||||
writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()],
|
writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
|
||||||
op.getConstParamsAttr());
|
|
||||||
writer.appendPDLValueList(op.getArgs());
|
writer.appendPDLValueList(op.getArgs());
|
||||||
|
|
||||||
ResultRange results = op.getResults();
|
ResultRange results = op.getResults();
|
||||||
@ -1333,37 +1331,33 @@ public:
|
|||||||
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
|
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
|
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
|
||||||
const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
|
const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
|
||||||
ArrayAttr constParams = read<ArrayAttr>();
|
|
||||||
SmallVector<PDLValue, 16> args;
|
SmallVector<PDLValue, 16> args;
|
||||||
readList<PDLValue>(args);
|
readList<PDLValue>(args);
|
||||||
|
|
||||||
LLVM_DEBUG({
|
LLVM_DEBUG({
|
||||||
llvm::dbgs() << " * Arguments: ";
|
llvm::dbgs() << " * Arguments: ";
|
||||||
llvm::interleaveComma(args, llvm::dbgs());
|
llvm::interleaveComma(args, llvm::dbgs());
|
||||||
llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Invoke the constraint and jump to the proper destination.
|
// Invoke the constraint and jump to the proper destination.
|
||||||
selectJump(succeeded(constraintFn(args, constParams, rewriter)));
|
selectJump(succeeded(constraintFn(args, rewriter)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
|
void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
|
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
|
||||||
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
|
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
|
||||||
ArrayAttr constParams = read<ArrayAttr>();
|
|
||||||
SmallVector<PDLValue, 16> args;
|
SmallVector<PDLValue, 16> args;
|
||||||
readList<PDLValue>(args);
|
readList<PDLValue>(args);
|
||||||
|
|
||||||
LLVM_DEBUG({
|
LLVM_DEBUG({
|
||||||
llvm::dbgs() << " * Arguments: ";
|
llvm::dbgs() << " * Arguments: ";
|
||||||
llvm::interleaveComma(args, llvm::dbgs());
|
llvm::interleaveComma(args, llvm::dbgs());
|
||||||
llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Execute the rewrite function.
|
// Execute the rewrite function.
|
||||||
ByteCodeField numResults = read();
|
ByteCodeField numResults = read();
|
||||||
ByteCodeRewriteResultList results(numResults);
|
ByteCodeRewriteResultList results(numResults);
|
||||||
rewriteFn(args, constParams, rewriter, results);
|
rewriteFn(args, rewriter, results);
|
||||||
|
|
||||||
assert(results.getResults().size() == numResults &&
|
assert(results.getResults().size() == numResults &&
|
||||||
"native PDL rewrite function returned unexpected number of results");
|
"native PDL rewrite function returned unexpected number of results");
|
||||||
|
@ -190,7 +190,7 @@ void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint,
|
|||||||
// what we need as a frontend.
|
// what we need as a frontend.
|
||||||
os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name
|
os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name
|
||||||
<< "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, "
|
<< "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, "
|
||||||
"::mlir::ArrayAttr constParams, ::mlir::PatternRewriter &rewriter"
|
"::mlir::PatternRewriter &rewriter"
|
||||||
<< (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n";
|
<< (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n";
|
||||||
|
|
||||||
const char *argumentInitStr = R"(
|
const char *argumentInitStr = R"(
|
||||||
|
@ -200,9 +200,9 @@ void CodeGen::genImpl(const ast::CompoundStmt *stmt) {
|
|||||||
static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
|
static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
|
if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
|
||||||
pdl::RewriteOp rewrite = builder.create<pdl::RewriteOp>(
|
pdl::RewriteOp rewrite =
|
||||||
loc, rootExpr, /*name=*/StringAttr(),
|
builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(),
|
||||||
/*externalArgs=*/ValueRange(), /*externalConstParams=*/ArrayAttr());
|
/*externalArgs=*/ValueRange());
|
||||||
builder.createBlock(&rewrite.body());
|
builder.createBlock(&rewrite.body());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -564,14 +564,8 @@ SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
|
|||||||
} else {
|
} else {
|
||||||
resultTypes.push_back(genType(declResultType));
|
resultTypes.push_back(genType(declResultType));
|
||||||
}
|
}
|
||||||
|
Operation *pdlOp = builder.create<PDLOpT>(
|
||||||
// FIXME: We currently do not have a modeling for the "constant params"
|
loc, resultTypes, decl->getName().getName(), inputs);
|
||||||
// support PDL provides. We should either figure out a modeling for this, or
|
|
||||||
// refactor the support within PDL to be something a bit more reasonable for
|
|
||||||
// what we need as a frontend.
|
|
||||||
Operation *pdlOp = builder.create<PDLOpT>(loc, resultTypes,
|
|
||||||
decl->getName().getName(), inputs,
|
|
||||||
/*params=*/ArrayAttr());
|
|
||||||
return pdlOp->getResults();
|
return pdlOp->getResults();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,14 +59,12 @@ class ApplyNativeConstraintOp:
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
name: Union[str, StringAttr],
|
name: Union[str, StringAttr],
|
||||||
args: Sequence[Union[OpView, Operation, Value]] = [],
|
args: Sequence[Union[OpView, Operation, Value]] = [],
|
||||||
params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
|
|
||||||
*,
|
*,
|
||||||
loc=None,
|
loc=None,
|
||||||
ip=None):
|
ip=None):
|
||||||
name = _get_str_attr(name)
|
name = _get_str_attr(name)
|
||||||
args = _get_values(args)
|
args = _get_values(args)
|
||||||
params = params if params is None else _get_array_attr(params)
|
super().__init__(name, args, loc=loc, ip=ip)
|
||||||
super().__init__(name, args, params, loc=loc, ip=ip)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplyNativeRewriteOp:
|
class ApplyNativeRewriteOp:
|
||||||
@ -76,14 +74,12 @@ class ApplyNativeRewriteOp:
|
|||||||
results: Sequence[Type],
|
results: Sequence[Type],
|
||||||
name: Union[str, StringAttr],
|
name: Union[str, StringAttr],
|
||||||
args: Sequence[Union[OpView, Operation, Value]] = [],
|
args: Sequence[Union[OpView, Operation, Value]] = [],
|
||||||
params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
|
|
||||||
*,
|
*,
|
||||||
loc=None,
|
loc=None,
|
||||||
ip=None):
|
ip=None):
|
||||||
name = _get_str_attr(name)
|
name = _get_str_attr(name)
|
||||||
args = _get_values(args)
|
args = _get_values(args)
|
||||||
params = params if params is None else _get_array_attr(params)
|
super().__init__(results, name, args, loc=loc, ip=ip)
|
||||||
super().__init__(results, name, args, params, loc=loc, ip=ip)
|
|
||||||
|
|
||||||
|
|
||||||
class AttributeOp:
|
class AttributeOp:
|
||||||
@ -236,15 +232,13 @@ class RewriteOp:
|
|||||||
root: Optional[Union[OpView, Operation, Value]] = None,
|
root: Optional[Union[OpView, Operation, Value]] = None,
|
||||||
name: Optional[Union[StringAttr, str]] = None,
|
name: Optional[Union[StringAttr, str]] = None,
|
||||||
args: Sequence[Union[OpView, Operation, Value]] = [],
|
args: Sequence[Union[OpView, Operation, Value]] = [],
|
||||||
params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None,
|
|
||||||
*,
|
*,
|
||||||
loc=None,
|
loc=None,
|
||||||
ip=None):
|
ip=None):
|
||||||
root = root if root is None else _get_value(root)
|
root = root if root is None else _get_value(root)
|
||||||
name = name if name is None else _get_str_attr(name)
|
name = name if name is None else _get_str_attr(name)
|
||||||
args = _get_values(args)
|
args = _get_values(args)
|
||||||
params = params if params is None else _get_array_attr(params)
|
super().__init__(root, name, args, loc=loc, ip=ip)
|
||||||
super().__init__(root, name, args, params, loc=loc, ip=ip)
|
|
||||||
|
|
||||||
def add_body(self):
|
def add_body(self):
|
||||||
"""Add body (block) to the rewrite."""
|
"""Add body (block) to the rewrite."""
|
||||||
|
@ -64,7 +64,7 @@ module @constraints {
|
|||||||
// CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
|
// CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]]
|
||||||
// CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
|
// CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]]
|
||||||
// CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
|
// CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
|
||||||
// CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]], %[[RESULT]]
|
// CHECK: pdl_interp.apply_constraint "multi_constraint"(%[[INPUT]], %[[INPUT1]], %[[RESULT]]
|
||||||
|
|
||||||
pdl.pattern : benefit(1) {
|
pdl.pattern : benefit(1) {
|
||||||
%input0 = operand
|
%input0 = operand
|
||||||
@ -72,7 +72,7 @@ module @constraints {
|
|||||||
%root = operation(%input0, %input1 : !pdl.value, !pdl.value)
|
%root = operation(%input0, %input1 : !pdl.value, !pdl.value)
|
||||||
%result0 = result 0 of %root
|
%result0 = result 0 of %root
|
||||||
|
|
||||||
pdl.apply_native_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
|
pdl.apply_native_constraint "multi_constraint"(%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value)
|
||||||
rewrite %root with "rewriter"
|
rewrite %root with "rewriter"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -393,11 +393,11 @@ module @predicate_ordering {
|
|||||||
// CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
|
// CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]]
|
||||||
// CHECK-NEXT: pdl_interp.is_not_null %[[RESULT]]
|
// CHECK-NEXT: pdl_interp.is_not_null %[[RESULT]]
|
||||||
// CHECK: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]]
|
// CHECK: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]]
|
||||||
// CHECK: pdl_interp.apply_constraint "typeConstraint" [](%[[RESULT_TYPE]]
|
// CHECK: pdl_interp.apply_constraint "typeConstraint"(%[[RESULT_TYPE]]
|
||||||
|
|
||||||
pdl.pattern : benefit(1) {
|
pdl.pattern : benefit(1) {
|
||||||
%resultType = type
|
%resultType = type
|
||||||
pdl.apply_native_constraint "typeConstraint"[](%resultType : !pdl.type)
|
pdl.apply_native_constraint "typeConstraint"(%resultType : !pdl.type)
|
||||||
%root = operation -> (%resultType : !pdl.type)
|
%root = operation -> (%resultType : !pdl.type)
|
||||||
rewrite %root with "rewriter"
|
rewrite %root with "rewriter"
|
||||||
}
|
}
|
||||||
|
@ -6,11 +6,11 @@
|
|||||||
module @external {
|
module @external {
|
||||||
// CHECK: module @rewriters
|
// CHECK: module @rewriters
|
||||||
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation, %[[INPUT:.*]]: !pdl.value)
|
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation, %[[INPUT:.*]]: !pdl.value)
|
||||||
// CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value)
|
// CHECK: pdl_interp.apply_rewrite "rewriter"(%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value)
|
||||||
pdl.pattern : benefit(1) {
|
pdl.pattern : benefit(1) {
|
||||||
%input = operand
|
%input = operand
|
||||||
%root = operation "foo.op"(%input : !pdl.value)
|
%root = operation "foo.op"(%input : !pdl.value)
|
||||||
rewrite %root with "rewriter"[true](%input : !pdl.value)
|
rewrite %root with "rewriter"(%input : !pdl.value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -191,13 +191,13 @@ module @replace_with_no_results {
|
|||||||
module @apply_native_rewrite {
|
module @apply_native_rewrite {
|
||||||
// CHECK: module @rewriters
|
// CHECK: module @rewriters
|
||||||
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
|
// CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation)
|
||||||
// CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type
|
// CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor"(%[[ROOT]] : !pdl.operation) : !pdl.type
|
||||||
// CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPE]] : !pdl.type)
|
// CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPE]] : !pdl.type)
|
||||||
pdl.pattern : benefit(1) {
|
pdl.pattern : benefit(1) {
|
||||||
%type = type
|
%type = type
|
||||||
%root = operation "foo.op" -> (%type : !pdl.type)
|
%root = operation "foo.op" -> (%type : !pdl.type)
|
||||||
rewrite %root {
|
rewrite %root {
|
||||||
%newType = apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type
|
%newType = apply_native_rewrite "functor"(%root : !pdl.operation) : !pdl.type
|
||||||
%newOp = operation "foo.op" -> (%newType : !pdl.type)
|
%newOp = operation "foo.op" -> (%newType : !pdl.type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ pdl.pattern : benefit(1) {
|
|||||||
%op = operation "foo.op"
|
%op = operation "foo.op"
|
||||||
|
|
||||||
// expected-error@below {{expected at least one argument}}
|
// expected-error@below {{expected at least one argument}}
|
||||||
"pdl.apply_native_constraint"() {name = "foo", params = []} : () -> ()
|
"pdl.apply_native_constraint"() {name = "foo"} : () -> ()
|
||||||
rewrite %op with "rewriter"
|
rewrite %op with "rewriter"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ pdl.pattern : benefit(1) {
|
|||||||
%op = operation "foo.op"
|
%op = operation "foo.op"
|
||||||
rewrite %op {
|
rewrite %op {
|
||||||
// expected-error@below {{expected at least one argument}}
|
// expected-error@below {{expected at least one argument}}
|
||||||
"pdl.apply_native_rewrite"() {name = "foo", params = []} : () -> ()
|
"pdl.apply_native_rewrite"() {name = "foo"} : () -> ()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -264,19 +264,6 @@ pdl.pattern : benefit(1) {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
pdl.pattern : benefit(1) {
|
|
||||||
%op = operation "foo.op"
|
|
||||||
|
|
||||||
// expected-error@below {{expected no external constant parameters when the rewrite is specified inline}}
|
|
||||||
"pdl.rewrite"(%op) ({
|
|
||||||
^bb1:
|
|
||||||
}) {
|
|
||||||
operand_segment_sizes = dense<[1,0]> : vector<2xi32>,
|
|
||||||
externalConstParams = []} : (!pdl.operation) -> ()
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
pdl.pattern : benefit(1) {
|
pdl.pattern : benefit(1) {
|
||||||
%op = operation "foo.op"
|
%op = operation "foo.op"
|
||||||
|
|
||||||
|
@ -27,21 +27,6 @@ pdl.pattern @rewrite_with_args : benefit(1) {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
pdl.pattern @rewrite_with_params : benefit(1) {
|
|
||||||
%root = operation
|
|
||||||
rewrite %root with "rewriter"["I am param"]
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
pdl.pattern @rewrite_with_args_and_params : benefit(1) {
|
|
||||||
%input = operand
|
|
||||||
%root = operation(%input : !pdl.value)
|
|
||||||
rewrite %root with "rewriter"["I am param"](%input : !pdl.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
pdl.pattern @rewrite_multi_root_optimal : benefit(2) {
|
pdl.pattern @rewrite_multi_root_optimal : benefit(2) {
|
||||||
%input1 = operand
|
%input1 = operand
|
||||||
%input2 = operand
|
%input2 = operand
|
||||||
@ -52,7 +37,7 @@ pdl.pattern @rewrite_multi_root_optimal : benefit(2) {
|
|||||||
%op2 = operation(%input2 : !pdl.value) -> (%type : !pdl.type)
|
%op2 = operation(%input2 : !pdl.value) -> (%type : !pdl.type)
|
||||||
%val2 = result 0 of %op2
|
%val2 = result 0 of %op2
|
||||||
%root2 = operation(%val1, %val2 : !pdl.value, !pdl.value)
|
%root2 = operation(%val1, %val2 : !pdl.value, !pdl.value)
|
||||||
rewrite with "rewriter"["I am param"](%root1, %root2 : !pdl.operation, !pdl.operation)
|
rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
@ -67,7 +52,7 @@ pdl.pattern @rewrite_multi_root_forced : benefit(2) {
|
|||||||
%op2 = operation(%input2 : !pdl.value) -> (%type : !pdl.type)
|
%op2 = operation(%input2 : !pdl.value) -> (%type : !pdl.type)
|
||||||
%val2 = result 0 of %op2
|
%val2 = result 0 of %op2
|
||||||
%root2 = operation(%val1, %val2 : !pdl.value, !pdl.value)
|
%root2 = operation(%val1, %val2 : !pdl.value, !pdl.value)
|
||||||
rewrite %root1 with "rewriter"["I am param"](%root2 : !pdl.operation)
|
rewrite %root1 with "rewriter"(%root2 : !pdl.operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -90,7 +90,7 @@ module @patterns {
|
|||||||
module @rewriters {
|
module @rewriters {
|
||||||
pdl_interp.func @success(%root : !pdl.operation) {
|
pdl_interp.func @success(%root : !pdl.operation) {
|
||||||
%operand = pdl_interp.get_operand 0 of %root
|
%operand = pdl_interp.get_operand 0 of %root
|
||||||
pdl_interp.apply_rewrite "rewriter"[42](%root, %operand : !pdl.operation, !pdl.value)
|
pdl_interp.apply_rewrite "rewriter"(%root, %operand : !pdl.operation, !pdl.value)
|
||||||
pdl_interp.finalize
|
pdl_interp.finalize
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -99,7 +99,7 @@ module @patterns {
|
|||||||
// CHECK-LABEL: test.apply_rewrite_1
|
// CHECK-LABEL: test.apply_rewrite_1
|
||||||
// CHECK: %[[INPUT:.*]] = "test.op_input"
|
// CHECK: %[[INPUT:.*]] = "test.op_input"
|
||||||
// CHECK-NOT: "test.op"
|
// CHECK-NOT: "test.op"
|
||||||
// CHECK: "test.success"(%[[INPUT]]) {constantParams = [42]}
|
// CHECK: "test.success"(%[[INPUT]])
|
||||||
module @ir attributes { test.apply_rewrite_1 } {
|
module @ir attributes { test.apply_rewrite_1 } {
|
||||||
%input = "test.op_input"() : () -> i32
|
%input = "test.op_input"() : () -> i32
|
||||||
"test.op"(%input) : (i32) -> ()
|
"test.op"(%input) : (i32) -> ()
|
||||||
|
@ -15,19 +15,16 @@ using namespace mlir;
|
|||||||
|
|
||||||
/// Custom constraint invoked from PDL.
|
/// Custom constraint invoked from PDL.
|
||||||
static LogicalResult customSingleEntityConstraint(PDLValue value,
|
static LogicalResult customSingleEntityConstraint(PDLValue value,
|
||||||
ArrayAttr constantParams,
|
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
Operation *rootOp = value.cast<Operation *>();
|
Operation *rootOp = value.cast<Operation *>();
|
||||||
return success(rootOp->getName().getStringRef() == "test.op");
|
return success(rootOp->getName().getStringRef() == "test.op");
|
||||||
}
|
}
|
||||||
static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
|
static LogicalResult customMultiEntityConstraint(ArrayRef<PDLValue> values,
|
||||||
ArrayAttr constantParams,
|
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
return customSingleEntityConstraint(values[1], constantParams, rewriter);
|
return customSingleEntityConstraint(values[1], rewriter);
|
||||||
}
|
}
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
|
customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
|
||||||
ArrayAttr constantParams,
|
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
|
if (llvm::any_of(values, [](const PDLValue &value) { return !value; }))
|
||||||
return failure();
|
return failure();
|
||||||
@ -39,32 +36,29 @@ customMultiEntityVariadicConstraint(ArrayRef<PDLValue> values,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Custom creator invoked from PDL.
|
// Custom creator invoked from PDL.
|
||||||
static void customCreate(ArrayRef<PDLValue> args, ArrayAttr constantParams,
|
static void customCreate(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
|
||||||
PatternRewriter &rewriter, PDLResultList &results) {
|
PDLResultList &results) {
|
||||||
results.push_back(rewriter.createOperation(
|
results.push_back(rewriter.createOperation(
|
||||||
OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
|
OperationState(args[0].cast<Operation *>()->getLoc(), "test.success")));
|
||||||
}
|
}
|
||||||
static void customVariadicResultCreate(ArrayRef<PDLValue> args,
|
static void customVariadicResultCreate(ArrayRef<PDLValue> args,
|
||||||
ArrayAttr constantParams,
|
|
||||||
PatternRewriter &rewriter,
|
PatternRewriter &rewriter,
|
||||||
PDLResultList &results) {
|
PDLResultList &results) {
|
||||||
Operation *root = args[0].cast<Operation *>();
|
Operation *root = args[0].cast<Operation *>();
|
||||||
results.push_back(root->getOperands());
|
results.push_back(root->getOperands());
|
||||||
results.push_back(root->getOperands().getTypes());
|
results.push_back(root->getOperands().getTypes());
|
||||||
}
|
}
|
||||||
static void customCreateType(ArrayRef<PDLValue> args, ArrayAttr constantParams,
|
static void customCreateType(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
|
||||||
PatternRewriter &rewriter,
|
|
||||||
PDLResultList &results) {
|
PDLResultList &results) {
|
||||||
results.push_back(rewriter.getF32Type());
|
results.push_back(rewriter.getF32Type());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Custom rewriter invoked from PDL.
|
/// Custom rewriter invoked from PDL.
|
||||||
static void customRewriter(ArrayRef<PDLValue> args, ArrayAttr constantParams,
|
static void customRewriter(ArrayRef<PDLValue> args, PatternRewriter &rewriter,
|
||||||
PatternRewriter &rewriter, PDLResultList &results) {
|
PDLResultList &results) {
|
||||||
Operation *root = args[0].cast<Operation *>();
|
Operation *root = args[0].cast<Operation *>();
|
||||||
OperationState successOpState(root->getLoc(), "test.success");
|
OperationState successOpState(root->getLoc(), "test.success");
|
||||||
successOpState.addOperands(args[1].cast<Value>());
|
successOpState.addOperands(args[1].cast<Value>());
|
||||||
successOpState.addAttribute("constantParams", constantParams);
|
|
||||||
rewriter.createOperation(successOpState);
|
rewriter.createOperation(successOpState);
|
||||||
rewriter.eraseOp(root);
|
rewriter.eraseOp(root);
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ Pattern => erase op<test.op3>;
|
|||||||
|
|
||||||
// Check the generation of native constraints and rewrites.
|
// Check the generation of native constraints and rewrites.
|
||||||
|
|
||||||
// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, ::mlir::ArrayAttr constParams,
|
// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values,
|
||||||
// CHECK-SAME: ::mlir::PatternRewriter &rewriter) {
|
// CHECK-SAME: ::mlir::PatternRewriter &rewriter) {
|
||||||
// CHECK: ::mlir::Attribute attr = {};
|
// CHECK: ::mlir::Attribute attr = {};
|
||||||
// CHECK: if (values[0])
|
// CHECK: if (values[0])
|
||||||
@ -69,7 +69,7 @@ Pattern => erase op<test.op3>;
|
|||||||
|
|
||||||
// CHECK-NOT: TestUnusedCst
|
// CHECK-NOT: TestUnusedCst
|
||||||
|
|
||||||
// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, ::mlir::ArrayAttr constParams,
|
// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values,
|
||||||
// CHECK-SAME: ::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results) {
|
// CHECK-SAME: ::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results) {
|
||||||
// CHECK: ::mlir::Attribute attr = {};
|
// CHECK: ::mlir::Attribute attr = {};
|
||||||
// CHECK: ::mlir::Operation * op = {};
|
// CHECK: ::mlir::Operation * op = {};
|
||||||
|
@ -53,34 +53,6 @@ def test_rewrite_with_args():
|
|||||||
root = OperationOp(args=[input])
|
root = OperationOp(args=[input])
|
||||||
RewriteOp(root, "rewriter", args=[input])
|
RewriteOp(root, "rewriter", args=[input])
|
||||||
|
|
||||||
# CHECK: module {
|
|
||||||
# CHECK: pdl.pattern @rewrite_with_params : benefit(1) {
|
|
||||||
# CHECK: %0 = operation
|
|
||||||
# CHECK: rewrite %0 with "rewriter" ["I am param"]
|
|
||||||
# CHECK: }
|
|
||||||
# CHECK: }
|
|
||||||
@constructAndPrintInModule
|
|
||||||
def test_rewrite_with_params():
|
|
||||||
pattern = PatternOp(1, "rewrite_with_params")
|
|
||||||
with InsertionPoint(pattern.body):
|
|
||||||
op = OperationOp()
|
|
||||||
RewriteOp(op, "rewriter", params=[StringAttr.get("I am param")])
|
|
||||||
|
|
||||||
# CHECK: module {
|
|
||||||
# CHECK: pdl.pattern @rewrite_with_args_and_params : benefit(1) {
|
|
||||||
# CHECK: %0 = operand
|
|
||||||
# CHECK: %1 = operation(%0 : !pdl.value)
|
|
||||||
# CHECK: rewrite %1 with "rewriter" ["I am param"](%0 : !pdl.value)
|
|
||||||
# CHECK: }
|
|
||||||
# CHECK: }
|
|
||||||
@constructAndPrintInModule
|
|
||||||
def test_rewrite_with_args_and_params():
|
|
||||||
pattern = PatternOp(1, "rewrite_with_args_and_params")
|
|
||||||
with InsertionPoint(pattern.body):
|
|
||||||
input = OperandOp()
|
|
||||||
root = OperationOp(args=[input])
|
|
||||||
RewriteOp(root, "rewriter", params=[StringAttr.get("I am param")], args=[input])
|
|
||||||
|
|
||||||
# CHECK: module {
|
# CHECK: module {
|
||||||
# CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) {
|
# CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) {
|
||||||
# CHECK: %0 = operand
|
# CHECK: %0 = operand
|
||||||
@ -92,7 +64,7 @@ def test_rewrite_with_args_and_params():
|
|||||||
# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type)
|
# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type)
|
||||||
# CHECK: %7 = result 0 of %6
|
# CHECK: %7 = result 0 of %6
|
||||||
# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value)
|
# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value)
|
||||||
# CHECK: rewrite with "rewriter" ["I am param"](%5, %8 : !pdl.operation, !pdl.operation)
|
# CHECK: rewrite with "rewriter"(%5, %8 : !pdl.operation, !pdl.operation)
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
@constructAndPrintInModule
|
@constructAndPrintInModule
|
||||||
@ -108,7 +80,7 @@ def test_rewrite_multi_root_optimal():
|
|||||||
op2 = OperationOp(args=[input2], types=[ty])
|
op2 = OperationOp(args=[input2], types=[ty])
|
||||||
val2 = ResultOp(op2, 0)
|
val2 = ResultOp(op2, 0)
|
||||||
root2 = OperationOp(args=[val1, val2])
|
root2 = OperationOp(args=[val1, val2])
|
||||||
RewriteOp(name="rewriter", params=[StringAttr.get("I am param")], args=[root1, root2])
|
RewriteOp(name="rewriter", args=[root1, root2])
|
||||||
|
|
||||||
# CHECK: module {
|
# CHECK: module {
|
||||||
# CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) {
|
# CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) {
|
||||||
@ -121,7 +93,7 @@ def test_rewrite_multi_root_optimal():
|
|||||||
# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type)
|
# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type)
|
||||||
# CHECK: %7 = result 0 of %6
|
# CHECK: %7 = result 0 of %6
|
||||||
# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value)
|
# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value)
|
||||||
# CHECK: rewrite %5 with "rewriter" ["I am param"](%8 : !pdl.operation)
|
# CHECK: rewrite %5 with "rewriter"(%8 : !pdl.operation)
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
@constructAndPrintInModule
|
@constructAndPrintInModule
|
||||||
@ -137,7 +109,7 @@ def test_rewrite_multi_root_forced():
|
|||||||
op2 = OperationOp(args=[input2], types=[ty])
|
op2 = OperationOp(args=[input2], types=[ty])
|
||||||
val2 = ResultOp(op2, 0)
|
val2 = ResultOp(op2, 0)
|
||||||
root2 = OperationOp(args=[val1, val2])
|
root2 = OperationOp(args=[val1, val2])
|
||||||
RewriteOp(root1, name="rewriter", params=[StringAttr.get("I am param")], args=[root2])
|
RewriteOp(root1, name="rewriter", args=[root2])
|
||||||
|
|
||||||
# CHECK: module {
|
# CHECK: module {
|
||||||
# CHECK: pdl.pattern @rewrite_add_body : benefit(1) {
|
# CHECK: pdl.pattern @rewrite_add_body : benefit(1) {
|
||||||
@ -303,7 +275,7 @@ def test_operation_results():
|
|||||||
# CHECK: module {
|
# CHECK: module {
|
||||||
# CHECK: pdl.pattern : benefit(1) {
|
# CHECK: pdl.pattern : benefit(1) {
|
||||||
# CHECK: %0 = type
|
# CHECK: %0 = type
|
||||||
# CHECK: apply_native_constraint "typeConstraint" [](%0 : !pdl.type)
|
# CHECK: apply_native_constraint "typeConstraint"(%0 : !pdl.type)
|
||||||
# CHECK: %1 = operation -> (%0 : !pdl.type)
|
# CHECK: %1 = operation -> (%0 : !pdl.type)
|
||||||
# CHECK: rewrite %1 with "rewrite"
|
# CHECK: rewrite %1 with "rewrite"
|
||||||
# CHECK: }
|
# CHECK: }
|
||||||
@ -313,6 +285,6 @@ def test_apply_native_constraint():
|
|||||||
pattern = PatternOp(1)
|
pattern = PatternOp(1)
|
||||||
with InsertionPoint(pattern.body):
|
with InsertionPoint(pattern.body):
|
||||||
resultType = TypeOp()
|
resultType = TypeOp()
|
||||||
ApplyNativeConstraintOp("typeConstraint", args=[resultType], params=[])
|
ApplyNativeConstraintOp("typeConstraint", args=[resultType])
|
||||||
root = OperationOp(types=[resultType])
|
root = OperationOp(types=[resultType])
|
||||||
RewriteOp(root, name="rewrite")
|
RewriteOp(root, name="rewrite")
|
||||||
|
Loading…
Reference in New Issue
Block a user