From 9595f3568ade92509d5f1e0d45066e31def762a6 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sun, 13 Mar 2022 22:09:20 -0700 Subject: [PATCH] [mlir:PDL] Remove the ConstantParams support from native Constraints/Rewrites This support has never really worked well, and is incredibly clunky to use (it effectively creates two argument APIs), and clunky to generate (it isn't clear how we should actually expose this from PDL frontends). Treating these as just attribute arguments is much much cleaner in every aspect of the stack. If we need to optimize lots of constant parameters, it would be better to investigate internal representation optimizations (e.g. batch attribute creation), that do not affect the user (we want a clean external API). Differential Revision: https://reviews.llvm.org/D121569 --- mlir/docs/PDLL.md | 12 ++-- mlir/include/mlir/Dialect/PDL/IR/PDLOps.td | 69 ++++++------------- .../mlir/Dialect/PDLInterp/IR/PDLInterpOps.td | 30 +++----- mlir/include/mlir/IR/PatternMatch.h | 33 ++++----- .../PDLToPDLInterp/PDLToPDLInterp.cpp | 12 ++-- .../lib/Conversion/PDLToPDLInterp/Predicate.h | 23 ++----- .../PDLToPDLInterp/PredicateTree.cpp | 3 +- mlir/lib/Dialect/PDL/IR/PDL.cpp | 4 -- mlir/lib/Rewrite/ByteCode.cpp | 14 ++-- mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp | 2 +- mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp | 16 ++--- mlir/python/mlir/dialects/_pdl_ops_ext.py | 12 +--- .../pdl-to-pdl-interp-matcher.mlir | 8 +-- .../pdl-to-pdl-interp-rewriter.mlir | 8 +-- mlir/test/Dialect/PDL/invalid.mlir | 17 +---- mlir/test/Dialect/PDL/ops.mlir | 19 +---- mlir/test/Rewrite/pdl-bytecode.mlir | 4 +- mlir/test/lib/Rewrite/TestPDLByteCode.cpp | 18 ++--- mlir/test/mlir-pdll/CodeGen/CPP/general.pdll | 4 +- mlir/test/python/dialects/pdl_ops.py | 40 ++--------- 20 files changed, 102 insertions(+), 246 deletions(-) diff --git a/mlir/docs/PDLL.md b/mlir/docs/PDLL.md index 0107e6777339..ab24a680d37b 100644 --- a/mlir/docs/PDLL.md +++ b/mlir/docs/PDLL.md @@ -1007,15 +1007,13 @@ the C++ PDL API. For example, the constraints above may be registered as: ```c++ // TODO: Cleanup when we allow more accessible wrappers around PDL functions. -static LogicalResult hasOneUseImpl(PDLValue pdlValue, ArrayAttr constantParams, - PatternRewriter &rewriter) { +static LogicalResult hasOneUseImpl(PDLValue pdlValue, PatternRewriter &rewriter) { Value value = pdlValue.cast(); return success(value.hasOneUse()); } -static LogicalResult hasSameElementTypeImpl( - ArrayRef pdlValues, ArrayAttr constantParams, - PatternRewriter &rewriter) { +static LogicalResult hasSameElementTypeImpl(ArrayRef pdlValues, + PatternRewriter &rewriter) { Value value1 = pdlValues[0].cast(); Value value2 = pdlValues[1].cast(); @@ -1310,8 +1308,8 @@ the C++ PDL API. For example, the rewrite above may be registered as: ```c++ // TODO: Cleanup when we allow more accessible wrappers around PDL functions. -static void buildOpImpl(ArrayRef args, ArrayAttr constantParams, - PatternRewriter &rewriter, PDLResultList &results) { +static void buildOpImpl(ArrayRef args, PatternRewriter &rewriter, + PDLResultList &results) { Value value = args[0].cast(); // insert special rewrite logic here. diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index eec878ba57b5..7f0253e59a32 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -35,33 +35,18 @@ def PDL_ApplyNativeConstraintOp let description = [{ `pdl.apply_native_constraint` operations apply a native C++ constraint, that has been registered externally with the consumer of PDL, to a given set of - entities. The constraint is permitted to accept any number of constant - valued parameters. + entities. Example: ```mlir - // Apply `myConstraint` to the entities defined by `input`, `attr`, and - // `op`. `42`, `"abc"`, and `i32` are constant parameters passed to the - // constraint. - pdl.apply_native_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) + // Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`. + pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) ``` }]; - let arguments = (ins StrAttr:$name, - Variadic:$args, - OptionalAttr:$constParams); - let assemblyFormat = [{ - $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict - }]; - - let builders = [ - OpBuilder<(ins "StringRef":$name, CArg<"ValueRange", "{}">:$args, - CArg<"ArrayRef", "{}">:$params), [{ - build($_builder, $_state, $_builder.getStringAttr(name), args, - params.empty() ? ArrayAttr() : $_builder.getArrayAttr(params)); - }]>, - ]; + let arguments = (ins StrAttr:$name, Variadic:$args); + let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict"; let hasVerifier = 1; } @@ -76,26 +61,22 @@ def PDL_ApplyNativeRewriteOp `pdl.apply_native_rewrite` operations apply a native C++ function, that has been registered externally with the consumer of PDL, to perform a rewrite and optionally return a number of values. The native function may accept any - number of arguments and constant attribute parameters. This operation is - used within a pdl.rewrite region to enable the interleaving of native - rewrite methods with other pdl constructs. + number of arguments. This operation is used within a pdl.rewrite region to enable + the interleaving of native rewrite methods with other pdl constructs. Example: ```mlir // Apply a native rewrite method that returns an attribute. - %ret = pdl.apply_native_rewrite "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute + %ret = pdl.apply_native_rewrite "myNativeFunc"(%arg0, %arg1) : !pdl.attribute ``` ```c++ // The native rewrite as defined in C++: - static void myNativeFunc(ArrayRef args, ArrayAttr constantParams, - PatternRewriter &rewriter, + static void myNativeFunc(ArrayRef args, PatternRewriter &rewriter, PDLResultList &results) { Value arg0 = args[0].cast(); Value arg1 = args[1].cast(); - IntegerAttr param0 = constantParams[0].cast(); - StringAttr param1 = constantParams[1].cast(); // Just push back the first param attribute. results.push_back(param0); @@ -107,13 +88,10 @@ def PDL_ApplyNativeRewriteOp ``` }]; - let arguments = (ins StrAttr:$name, - Variadic:$args, - OptionalAttr:$constParams); + let arguments = (ins StrAttr:$name, Variadic:$args); let results = (outs Variadic:$results); let assemblyFormat = [{ - $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? - (`:` type($results)^)? attr-dict + $name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict }]; let hasVerifier = 1; } @@ -588,16 +566,15 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [ rewrite is specified either via a string name (`name`) to a native rewrite function, or via the region body. The rewrite region, if specified, must contain a single block. If the rewrite is external it functions - similarly to `pdl.apply_native_rewrite`, and takes a set of constant - parameters and a set of additional positional values defined within the - matcher as arguments. If the rewrite is external, the root operation is - passed to the native function as the leading arguments. The root operation, - if provided, specifies the starting point in the pattern for the subgraph - isomorphism search. Pattern matching will proceed from this node downward - (towards the defining operation) or upward (towards the users) until all - the operations in the pattern have been matched. If the root is omitted, - the pdl_interp lowering will automatically select the best root of the - pdl.rewrite among all the operations in the pattern. + similarly to `pdl.apply_native_rewrite`, and takes a set of additional + positional values defined within the matcher as arguments. If the rewrite is + external, the root operation is passed to the native function as the leading + arguments. The root operation, if provided, specifies the starting point in + the pattern for the subgraph isomorphism search. Pattern matching will proceed + from this node downward (towards the defining operation) or upward + (towards the users) until all the operations in the pattern have been matched. + If the root is omitted, the pdl_interp lowering will automatically select + the best root of the pdl.rewrite among all the operations in the pattern. Example: @@ -623,12 +600,10 @@ def PDL_RewriteOp : PDL_Op<"rewrite", [ let arguments = (ins Optional:$root, OptionalAttr:$name, - Variadic:$externalArgs, - OptionalAttr:$externalConstParams); + Variadic:$externalArgs); let regions = (region AnyRegion:$body); let assemblyFormat = [{ - ($root^)? (`with` $name^ ($externalConstParams^)? - (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? + ($root^)? (`with` $name^ (`(` $externalArgs^ `:` type($externalArgs) `)`)?)? ($body^)? attr-dict-with-keyword }]; diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index a8dbbc082c3c..fbc73c070872 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -89,25 +89,21 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { let description = [{ `pdl_interp.apply_constraint` operations apply a generic constraint, that has been registered with the interpreter, with a given set of positional - values. The constraint may have any number of constant parameters. On - success, this operation branches to the true destination, otherwise the - false destination is taken. + values. On success, this operation branches to the true destination, + otherwise the false destination is taken. Example: ```mlir // Apply `myConstraint` to the entities defined by `input`, `attr`, and // `op`. - pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest + pdl_interp.apply_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest ``` }]; - let arguments = (ins StrAttr:$name, - Variadic:$args, - OptionalAttr:$constParams); + let arguments = (ins StrAttr:$name, Variadic:$args); let assemblyFormat = [{ - $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict `->` - successors + $name `(` $args `:` type($args) `)` attr-dict `->` successors }]; } @@ -120,9 +116,8 @@ def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> { let description = [{ `pdl_interp.apply_rewrite` operations invoke an external rewriter that has been registered with the interpreter to perform the rewrite after a - successful match. The rewrite is passed a set of positional arguments, - and a set of constant parameters. The rewrite function may return any - number of results. + successful match. The rewrite is passed a set of positional arguments. The + rewrite function may return any number of results. Example: @@ -136,19 +131,12 @@ def PDLInterp_ApplyRewriteOp : PDLInterp_Op<"apply_rewrite"> { // Rewriter operating on the root operation along with additional arguments // from the matcher. pdl_interp.apply_rewrite "rewriter"(%root : !pdl.operation, %value : !pdl.value) - - // Rewriter operating on the root operation along with additional arguments - // and constant parameters. - pdl_interp.apply_rewrite "rewriter"[42](%root : !pdl.operation, %value : !pdl.value) ``` }]; - let arguments = (ins StrAttr:$name, - Variadic:$args, - OptionalAttr:$constParams); + let arguments = (ins StrAttr:$name, Variadic:$args); let results = (outs Variadic:$results); let assemblyFormat = [{ - $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? - (`:` type($results)^)? attr-dict + $name (`(` $args^ `:` type($args) `)`)? (`:` type($results)^)? attr-dict }]; } diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 8fd9fa0caaf3..11f85ee38bef 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -584,24 +584,16 @@ protected: // PDLPatternModule /// A generic PDL pattern constraint function. This function applies a -/// constraint to a given set of opaque PDLValue entities. The second parameter -/// is a set of constant value parameters specified in Attribute form. Returns -/// success if the constraint successfully held, failure otherwise. -using PDLConstraintFunction = std::function, ArrayAttr, PatternRewriter &)>; -/// A native PDL rewrite function. This function performs a rewrite on the -/// given set of values and constant parameters. Any results from this rewrite -/// that should be passed back to PDL should be added to the provided result -/// list. This method is only invoked when the corresponding match was -/// successful. -using PDLRewriteFunction = std::function, ArrayAttr, PatternRewriter &, PDLResultList &)>; -/// A generic PDL pattern constraint function. This function applies a -/// constraint to a given opaque PDLValue entity. The second parameter is a set -/// of constant value parameters specified in Attribute form. Returns success if +/// constraint to a given set of opaque PDLValue entities. Returns success if /// the constraint successfully held, failure otherwise. -using PDLSingleEntityConstraintFunction = - std::function; +using PDLConstraintFunction = + std::function, PatternRewriter &)>; +/// A native PDL rewrite function. This function performs a rewrite on the +/// given set of values. Any results from this rewrite that should be passed +/// back to PDL should be added to the provided result list. This method is only +/// invoked when the corresponding match was successful. +using PDLRewriteFunction = + std::function, PatternRewriter &, PDLResultList &)>; /// This class contains all of the necessary data for a set of PDL patterns, or /// pattern rewrites specified in the form of the PDL dialect. This PDL module @@ -630,15 +622,14 @@ public: /// Register a single entity constraint function. template std::enable_if_t, - ArrayAttr, PatternRewriter &>::value> + PatternRewriter &>::value> registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) { registerConstraintFunction( name, [constraintFn = std::forward(constraintFn)]( - ArrayRef values, ArrayAttr constantParams, - PatternRewriter &rewriter) { + ArrayRef values, PatternRewriter &rewriter) { assert(values.size() == 1 && "expected values to have a single entity"); - return constraintFn(values[0], constantParams, rewriter); + return constraintFn(values[0], rewriter); }); } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index 9529c59f5a92..6057a2193946 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -431,9 +431,8 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, } case Predicates::ConstraintQuestion: { auto *cstQuestion = cast(question); - builder.create( - loc, cstQuestion->getName(), args, cstQuestion->getParams(), success, - failure); + builder.create(loc, cstQuestion->getName(), + args, success, failure); break; } default: @@ -644,8 +643,7 @@ SymbolRefAttr PatternLowering::generateRewriter( auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue); args.append(mappedArgs.begin(), mappedArgs.end()); builder.create( - rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args, - rewriter.externalConstParamsAttr()); + rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args); } else { // Otherwise this is a dag rewriter defined using PDL operations. for (Operation &rewriteOp : *rewriter.getBody()) { @@ -678,8 +676,8 @@ void PatternLowering::generateRewriter( arguments.push_back(mapRewriteValue(argument)); auto interpOp = builder.create( rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(), - arguments, rewriteOp.constParamsAttr()); - for (auto it : llvm::zip(rewriteOp.results(), interpOp.getResults())) + arguments); + for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults())) rewriteValues[std::get<0>(it)] = std::get<1>(it); } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h index 1d723996f8c3..81a11529b97c 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -445,10 +445,9 @@ struct AttributeQuestion /// Apply a parameterized constraint to multiple position values. struct ConstraintQuestion - : public PredicateBase< - ConstraintQuestion, Qualifier, - std::tuple, Attribute>, - Predicates::ConstraintQuestion> { + : public PredicateBase>, + Predicates::ConstraintQuestion> { using Base::Base; /// Return the name of the constraint. @@ -457,17 +456,11 @@ struct ConstraintQuestion /// Return the arguments of the constraint. ArrayRef getArgs() const { return std::get<1>(key); } - /// Return the constant parameters of the constraint. - ArrayAttr getParams() const { - return std::get<2>(key).dyn_cast_or_null(); - } - /// Construct an instance with the given storage allocator. static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, KeyTy key) { return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), - alloc.copyInto(std::get<1>(key)), - std::get<2>(key)}); + alloc.copyInto(std::get<1>(key))}); } }; @@ -667,11 +660,9 @@ public: } /// Create a predicate that applies a generic constraint. - Predicate getConstraint(StringRef name, ArrayRef pos, - Attribute params) { - return { - ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, params)), - TrueAnswer::get(uniquer)}; + Predicate getConstraint(StringRef name, ArrayRef pos) { + return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)), + TrueAnswer::get(uniquer)}; } /// Create a predicate comparing a value with null. diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 424c91f2aca1..b77ea40b8725 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -263,7 +263,6 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, PredicateBuilder &builder, DenseMap &inputs) { OperandRange arguments = op.args(); - ArrayAttr parameters = op.constParamsAttr(); std::vector allPositions; allPositions.reserve(arguments.size()); @@ -274,7 +273,7 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), comparePosDepth); PredicateBuilder::Predicate pred = - builder.getConstraint(op.name(), allPositions, parameters); + builder.getConstraint(op.name(), allPositions); predList.emplace_back(pos, pred); } diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp index c11f2f6ca9bc..3a0280c797ab 100644 --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -425,10 +425,6 @@ LogicalResult RewriteOp::verifyRegions() { return emitOpError() << "expected no external arguments when the " "rewrite is specified inline"; } - if (externalConstParams()) { - return emitOpError() << "expected no external constant parameters when " - "the rewrite is specified inline"; - } return success(); } diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index ba5519b39651..fb929a8494bd 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -757,8 +757,7 @@ void Generator::generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer) { assert(constraintToMemIndex.count(op.getName()) && "expected index for constraint function"); - writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()], - op.getConstParamsAttr()); + writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); writer.appendPDLValueList(op.getArgs()); writer.append(op.getSuccessors()); } @@ -766,8 +765,7 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer) { assert(externalRewriterToMemIndex.count(op.getName()) && "expected index for rewrite function"); - writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()], - op.getConstParamsAttr()); + writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]); writer.appendPDLValueList(op.getArgs()); ResultRange results = op.getResults(); @@ -1333,37 +1331,33 @@ public: void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; - ArrayAttr constParams = read(); SmallVector args; readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); - llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; }); // Invoke the constraint and jump to the proper destination. - selectJump(succeeded(constraintFn(args, constParams, rewriter))); + selectJump(succeeded(constraintFn(args, rewriter))); } void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; - ArrayAttr constParams = read(); SmallVector args; readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Arguments: "; llvm::interleaveComma(args, llvm::dbgs()); - llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; }); // Execute the rewrite function. ByteCodeField numResults = read(); ByteCodeRewriteResultList results(numResults); - rewriteFn(args, constParams, rewriter, results); + rewriteFn(args, rewriter, results); assert(results.getResults().size() == numResults && "native PDL rewrite function returned unexpected number of results"); diff --git a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp index d5045ca07cb1..14cdf44628a7 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/CPPGen.cpp @@ -190,7 +190,7 @@ void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint, // what we need as a frontend. os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name << "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, " - "::mlir::ArrayAttr constParams, ::mlir::PatternRewriter &rewriter" + "::mlir::PatternRewriter &rewriter" << (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n"; const char *argumentInitStr = R"( diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp index b4b2bca2071b..334093e576d5 100644 --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -200,9 +200,9 @@ void CodeGen::genImpl(const ast::CompoundStmt *stmt) { static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr, Location loc) { if (isa(builder.getInsertionBlock()->getParentOp())) { - pdl::RewriteOp rewrite = builder.create( - loc, rootExpr, /*name=*/StringAttr(), - /*externalArgs=*/ValueRange(), /*externalConstParams=*/ArrayAttr()); + pdl::RewriteOp rewrite = + builder.create(loc, rootExpr, /*name=*/StringAttr(), + /*externalArgs=*/ValueRange()); builder.createBlock(&rewrite.body()); } } @@ -564,14 +564,8 @@ SmallVector CodeGen::genConstraintOrRewriteCall(const T *decl, } else { resultTypes.push_back(genType(declResultType)); } - - // FIXME: We currently do not have a modeling for the "constant params" - // support PDL provides. We should either figure out a modeling for this, or - // refactor the support within PDL to be something a bit more reasonable for - // what we need as a frontend. - Operation *pdlOp = builder.create(loc, resultTypes, - decl->getName().getName(), inputs, - /*params=*/ArrayAttr()); + Operation *pdlOp = builder.create( + loc, resultTypes, decl->getName().getName(), inputs); return pdlOp->getResults(); } diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py index 364db53854f8..fb5b519c7c02 100644 --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -59,14 +59,12 @@ class ApplyNativeConstraintOp: def __init__(self, name: Union[str, StringAttr], args: Sequence[Union[OpView, Operation, Value]] = [], - params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, *, loc=None, ip=None): name = _get_str_attr(name) args = _get_values(args) - params = params if params is None else _get_array_attr(params) - super().__init__(name, args, params, loc=loc, ip=ip) + super().__init__(name, args, loc=loc, ip=ip) class ApplyNativeRewriteOp: @@ -76,14 +74,12 @@ class ApplyNativeRewriteOp: results: Sequence[Type], name: Union[str, StringAttr], args: Sequence[Union[OpView, Operation, Value]] = [], - params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, *, loc=None, ip=None): name = _get_str_attr(name) args = _get_values(args) - params = params if params is None else _get_array_attr(params) - super().__init__(results, name, args, params, loc=loc, ip=ip) + super().__init__(results, name, args, loc=loc, ip=ip) class AttributeOp: @@ -236,15 +232,13 @@ class RewriteOp: root: Optional[Union[OpView, Operation, Value]] = None, name: Optional[Union[StringAttr, str]] = None, args: Sequence[Union[OpView, Operation, Value]] = [], - params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, *, loc=None, ip=None): root = root if root is None else _get_value(root) name = name if name is None else _get_str_attr(name) args = _get_values(args) - params = params if params is None else _get_array_attr(params) - super().__init__(root, name, args, params, loc=loc, ip=ip) + super().__init__(root, name, args, loc=loc, ip=ip) def add_body(self): """Add body (block) to the rewrite.""" diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index d9a8706471fe..457042767130 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -64,7 +64,7 @@ module @constraints { // CHECK-DAG: %[[INPUT:.*]] = pdl_interp.get_operand 0 of %[[ROOT]] // CHECK-DAG: %[[INPUT1:.*]] = pdl_interp.get_operand 1 of %[[ROOT]] // CHECK-DAG: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]] - // CHECK: pdl_interp.apply_constraint "multi_constraint" [true](%[[INPUT]], %[[INPUT1]], %[[RESULT]] + // CHECK: pdl_interp.apply_constraint "multi_constraint"(%[[INPUT]], %[[INPUT1]], %[[RESULT]] pdl.pattern : benefit(1) { %input0 = operand @@ -72,7 +72,7 @@ module @constraints { %root = operation(%input0, %input1 : !pdl.value, !pdl.value) %result0 = result 0 of %root - pdl.apply_native_constraint "multi_constraint"[true](%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value) + pdl.apply_native_constraint "multi_constraint"(%input0, %input1, %result0 : !pdl.value, !pdl.value, !pdl.value) rewrite %root with "rewriter" } } @@ -393,11 +393,11 @@ module @predicate_ordering { // CHECK: %[[RESULT:.*]] = pdl_interp.get_result 0 of %[[ROOT]] // CHECK-NEXT: pdl_interp.is_not_null %[[RESULT]] // CHECK: %[[RESULT_TYPE:.*]] = pdl_interp.get_value_type of %[[RESULT]] - // CHECK: pdl_interp.apply_constraint "typeConstraint" [](%[[RESULT_TYPE]] + // CHECK: pdl_interp.apply_constraint "typeConstraint"(%[[RESULT_TYPE]] pdl.pattern : benefit(1) { %resultType = type - pdl.apply_native_constraint "typeConstraint"[](%resultType : !pdl.type) + pdl.apply_native_constraint "typeConstraint"(%resultType : !pdl.type) %root = operation -> (%resultType : !pdl.type) rewrite %root with "rewriter" } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir index 8ca771f87f65..4d6d524e89fc 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir @@ -6,11 +6,11 @@ module @external { // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation, %[[INPUT:.*]]: !pdl.value) - // CHECK: pdl_interp.apply_rewrite "rewriter" [true](%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value) + // CHECK: pdl_interp.apply_rewrite "rewriter"(%[[ROOT]], %[[INPUT]] : !pdl.operation, !pdl.value) pdl.pattern : benefit(1) { %input = operand %root = operation "foo.op"(%input : !pdl.value) - rewrite %root with "rewriter"[true](%input : !pdl.value) + rewrite %root with "rewriter"(%input : !pdl.value) } } @@ -191,13 +191,13 @@ module @replace_with_no_results { module @apply_native_rewrite { // CHECK: module @rewriters // CHECK: func @pdl_generated_rewriter(%[[ROOT:.*]]: !pdl.operation) - // CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor" [true](%[[ROOT]] : !pdl.operation) : !pdl.type + // CHECK: %[[TYPE:.*]] = pdl_interp.apply_rewrite "functor"(%[[ROOT]] : !pdl.operation) : !pdl.type // CHECK: pdl_interp.create_operation "foo.op" -> (%[[TYPE]] : !pdl.type) pdl.pattern : benefit(1) { %type = type %root = operation "foo.op" -> (%type : !pdl.type) rewrite %root { - %newType = apply_native_rewrite "functor"[true](%root : !pdl.operation) : !pdl.type + %newType = apply_native_rewrite "functor"(%root : !pdl.operation) : !pdl.type %newOp = operation "foo.op" -> (%newType : !pdl.type) } } diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir index f8f641e8d298..a39d1c5e80f8 100644 --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -8,7 +8,7 @@ pdl.pattern : benefit(1) { %op = operation "foo.op" // expected-error@below {{expected at least one argument}} - "pdl.apply_native_constraint"() {name = "foo", params = []} : () -> () + "pdl.apply_native_constraint"() {name = "foo"} : () -> () rewrite %op with "rewriter" } @@ -22,7 +22,7 @@ pdl.pattern : benefit(1) { %op = operation "foo.op" rewrite %op { // expected-error@below {{expected at least one argument}} - "pdl.apply_native_rewrite"() {name = "foo", params = []} : () -> () + "pdl.apply_native_rewrite"() {name = "foo"} : () -> () } } @@ -264,19 +264,6 @@ pdl.pattern : benefit(1) { // ----- -pdl.pattern : benefit(1) { - %op = operation "foo.op" - - // expected-error@below {{expected no external constant parameters when the rewrite is specified inline}} - "pdl.rewrite"(%op) ({ - ^bb1: - }) { - operand_segment_sizes = dense<[1,0]> : vector<2xi32>, - externalConstParams = []} : (!pdl.operation) -> () -} - -// ----- - pdl.pattern : benefit(1) { %op = operation "foo.op" diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir index 1e2261a3c2a4..472dd250feaa 100644 --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -27,21 +27,6 @@ pdl.pattern @rewrite_with_args : benefit(1) { // ----- -pdl.pattern @rewrite_with_params : benefit(1) { - %root = operation - rewrite %root with "rewriter"["I am param"] -} - -// ----- - -pdl.pattern @rewrite_with_args_and_params : benefit(1) { - %input = operand - %root = operation(%input : !pdl.value) - rewrite %root with "rewriter"["I am param"](%input : !pdl.value) -} - -// ----- - pdl.pattern @rewrite_multi_root_optimal : benefit(2) { %input1 = operand %input2 = operand @@ -52,7 +37,7 @@ pdl.pattern @rewrite_multi_root_optimal : benefit(2) { %op2 = operation(%input2 : !pdl.value) -> (%type : !pdl.type) %val2 = result 0 of %op2 %root2 = operation(%val1, %val2 : !pdl.value, !pdl.value) - rewrite with "rewriter"["I am param"](%root1, %root2 : !pdl.operation, !pdl.operation) + rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation) } // ----- @@ -67,7 +52,7 @@ pdl.pattern @rewrite_multi_root_forced : benefit(2) { %op2 = operation(%input2 : !pdl.value) -> (%type : !pdl.type) %val2 = result 0 of %op2 %root2 = operation(%val1, %val2 : !pdl.value, !pdl.value) - rewrite %root1 with "rewriter"["I am param"](%root2 : !pdl.operation) + rewrite %root1 with "rewriter"(%root2 : !pdl.operation) } // ----- diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir index 88024332bbba..d06c500241b0 100644 --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -90,7 +90,7 @@ module @patterns { module @rewriters { pdl_interp.func @success(%root : !pdl.operation) { %operand = pdl_interp.get_operand 0 of %root - pdl_interp.apply_rewrite "rewriter"[42](%root, %operand : !pdl.operation, !pdl.value) + pdl_interp.apply_rewrite "rewriter"(%root, %operand : !pdl.operation, !pdl.value) pdl_interp.finalize } } @@ -99,7 +99,7 @@ module @patterns { // CHECK-LABEL: test.apply_rewrite_1 // CHECK: %[[INPUT:.*]] = "test.op_input" // CHECK-NOT: "test.op" -// CHECK: "test.success"(%[[INPUT]]) {constantParams = [42]} +// CHECK: "test.success"(%[[INPUT]]) module @ir attributes { test.apply_rewrite_1 } { %input = "test.op_input"() : () -> i32 "test.op"(%input) : (i32) -> () diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp index 748e54822718..ef399a0d7007 100644 --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -15,19 +15,16 @@ using namespace mlir; /// Custom constraint invoked from PDL. static LogicalResult customSingleEntityConstraint(PDLValue value, - ArrayAttr constantParams, PatternRewriter &rewriter) { Operation *rootOp = value.cast(); return success(rootOp->getName().getStringRef() == "test.op"); } static LogicalResult customMultiEntityConstraint(ArrayRef values, - ArrayAttr constantParams, PatternRewriter &rewriter) { - return customSingleEntityConstraint(values[1], constantParams, rewriter); + return customSingleEntityConstraint(values[1], rewriter); } static LogicalResult customMultiEntityVariadicConstraint(ArrayRef values, - ArrayAttr constantParams, PatternRewriter &rewriter) { if (llvm::any_of(values, [](const PDLValue &value) { return !value; })) return failure(); @@ -39,32 +36,29 @@ customMultiEntityVariadicConstraint(ArrayRef values, } // Custom creator invoked from PDL. -static void customCreate(ArrayRef args, ArrayAttr constantParams, - PatternRewriter &rewriter, PDLResultList &results) { +static void customCreate(ArrayRef args, PatternRewriter &rewriter, + PDLResultList &results) { results.push_back(rewriter.createOperation( OperationState(args[0].cast()->getLoc(), "test.success"))); } static void customVariadicResultCreate(ArrayRef args, - ArrayAttr constantParams, PatternRewriter &rewriter, PDLResultList &results) { Operation *root = args[0].cast(); results.push_back(root->getOperands()); results.push_back(root->getOperands().getTypes()); } -static void customCreateType(ArrayRef args, ArrayAttr constantParams, - PatternRewriter &rewriter, +static void customCreateType(ArrayRef args, PatternRewriter &rewriter, PDLResultList &results) { results.push_back(rewriter.getF32Type()); } /// Custom rewriter invoked from PDL. -static void customRewriter(ArrayRef args, ArrayAttr constantParams, - PatternRewriter &rewriter, PDLResultList &results) { +static void customRewriter(ArrayRef args, PatternRewriter &rewriter, + PDLResultList &results) { Operation *root = args[0].cast(); OperationState successOpState(root->getLoc(), "test.success"); successOpState.addOperands(args[1].cast()); - successOpState.addAttribute("constantParams", constantParams); rewriter.createOperation(successOpState); rewriter.eraseOp(root); } diff --git a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll index 5b94c66b4350..9f0ea1386322 100644 --- a/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll +++ b/mlir/test/mlir-pdll/CodeGen/CPP/general.pdll @@ -43,7 +43,7 @@ Pattern => erase op; // Check the generation of native constraints and rewrites. -// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, ::mlir::ArrayAttr constParams, +// CHECK: static ::mlir::LogicalResult TestCstPDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, // CHECK-SAME: ::mlir::PatternRewriter &rewriter) { // CHECK: ::mlir::Attribute attr = {}; // CHECK: if (values[0]) @@ -69,7 +69,7 @@ Pattern => erase op; // CHECK-NOT: TestUnusedCst -// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, ::mlir::ArrayAttr constParams, +// CHECK: static void TestRewritePDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, // CHECK-SAME: ::mlir::PatternRewriter &rewriter, ::mlir::PDLResultList &results) { // CHECK: ::mlir::Attribute attr = {}; // CHECK: ::mlir::Operation * op = {}; diff --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py index 2388ccadab2e..b575f19bb6c4 100644 --- a/mlir/test/python/dialects/pdl_ops.py +++ b/mlir/test/python/dialects/pdl_ops.py @@ -53,34 +53,6 @@ def test_rewrite_with_args(): root = OperationOp(args=[input]) RewriteOp(root, "rewriter", args=[input]) -# CHECK: module { -# CHECK: pdl.pattern @rewrite_with_params : benefit(1) { -# CHECK: %0 = operation -# CHECK: rewrite %0 with "rewriter" ["I am param"] -# CHECK: } -# CHECK: } -@constructAndPrintInModule -def test_rewrite_with_params(): - pattern = PatternOp(1, "rewrite_with_params") - with InsertionPoint(pattern.body): - op = OperationOp() - RewriteOp(op, "rewriter", params=[StringAttr.get("I am param")]) - -# CHECK: module { -# CHECK: pdl.pattern @rewrite_with_args_and_params : benefit(1) { -# CHECK: %0 = operand -# CHECK: %1 = operation(%0 : !pdl.value) -# CHECK: rewrite %1 with "rewriter" ["I am param"](%0 : !pdl.value) -# CHECK: } -# CHECK: } -@constructAndPrintInModule -def test_rewrite_with_args_and_params(): - pattern = PatternOp(1, "rewrite_with_args_and_params") - with InsertionPoint(pattern.body): - input = OperandOp() - root = OperationOp(args=[input]) - RewriteOp(root, "rewriter", params=[StringAttr.get("I am param")], args=[input]) - # CHECK: module { # CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) { # CHECK: %0 = operand @@ -92,7 +64,7 @@ def test_rewrite_with_args_and_params(): # CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) # CHECK: %7 = result 0 of %6 # CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) -# CHECK: rewrite with "rewriter" ["I am param"](%5, %8 : !pdl.operation, !pdl.operation) +# CHECK: rewrite with "rewriter"(%5, %8 : !pdl.operation, !pdl.operation) # CHECK: } # CHECK: } @constructAndPrintInModule @@ -108,7 +80,7 @@ def test_rewrite_multi_root_optimal(): op2 = OperationOp(args=[input2], types=[ty]) val2 = ResultOp(op2, 0) root2 = OperationOp(args=[val1, val2]) - RewriteOp(name="rewriter", params=[StringAttr.get("I am param")], args=[root1, root2]) + RewriteOp(name="rewriter", args=[root1, root2]) # CHECK: module { # CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) { @@ -121,7 +93,7 @@ def test_rewrite_multi_root_optimal(): # CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type) # CHECK: %7 = result 0 of %6 # CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value) -# CHECK: rewrite %5 with "rewriter" ["I am param"](%8 : !pdl.operation) +# CHECK: rewrite %5 with "rewriter"(%8 : !pdl.operation) # CHECK: } # CHECK: } @constructAndPrintInModule @@ -137,7 +109,7 @@ def test_rewrite_multi_root_forced(): op2 = OperationOp(args=[input2], types=[ty]) val2 = ResultOp(op2, 0) root2 = OperationOp(args=[val1, val2]) - RewriteOp(root1, name="rewriter", params=[StringAttr.get("I am param")], args=[root2]) + RewriteOp(root1, name="rewriter", args=[root2]) # CHECK: module { # CHECK: pdl.pattern @rewrite_add_body : benefit(1) { @@ -303,7 +275,7 @@ def test_operation_results(): # CHECK: module { # CHECK: pdl.pattern : benefit(1) { # CHECK: %0 = type -# CHECK: apply_native_constraint "typeConstraint" [](%0 : !pdl.type) +# CHECK: apply_native_constraint "typeConstraint"(%0 : !pdl.type) # CHECK: %1 = operation -> (%0 : !pdl.type) # CHECK: rewrite %1 with "rewrite" # CHECK: } @@ -313,6 +285,6 @@ def test_apply_native_constraint(): pattern = PatternOp(1) with InsertionPoint(pattern.body): resultType = TypeOp() - ApplyNativeConstraintOp("typeConstraint", args=[resultType], params=[]) + ApplyNativeConstraintOp("typeConstraint", args=[resultType]) root = OperationOp(types=[resultType]) RewriteOp(root, name="rewrite")