[mlir][Pattern] Refactor the Pattern class into a "metadata only" class

The Pattern class was originally intended to be used for solely matching operations, but that use never materialized. All of the pattern infrastructure uses RewritePattern, and the infrastructure for pure matching(Matchers.h) is implemented inline. This means that this class isn't a useful abstraction at the moment, so this revision refactors it to solely encapsulate the "metadata" of a pattern. The metadata includes the various state describing a pattern; benefit, root operation, etc. The API on PatternApplicator is updated to now operate on `Pattern`s as nothing special from `RewritePattern` is necessary.

This refactoring is also necessary for the upcoming use of PDL patterns alongside C++ rewrite patterns.

Differential Revision: https://reviews.llvm.org/D86258
This commit is contained in:
River Riddle 2020-10-26 17:23:41 -07:00
parent 8a1ca2cd34
commit b99bd77162
6 changed files with 160 additions and 158 deletions

View File

@ -174,10 +174,10 @@ Each driver is responsible for defining its own operation visitation order as
well as pattern cost model, but the final application is performed via a
`PatternApplicator` class. This class takes as input the
`OwningRewritePatternList` and transforms the patterns based upon a provided
cost model. This cost model computes a final benefit for a given rewrite
pattern, using whatever driver specific information necessary. After a cost
model has been computed, the driver may begin to match patterns against
operations using `PatternApplicator::matchAndRewrite`.
cost model. This cost model computes a final benefit for a given pattern, using
whatever driver specific information necessary. After a cost model has been
computed, the driver may begin to match patterns against operations using
`PatternApplicator::matchAndRewrite`.
An example is shown below:
@ -209,7 +209,7 @@ void applyMyPatternDriver(Operation *op,
// Create the applicator and apply our cost model.
PatternApplicator applicator(patterns);
applicator.applyCostModel([](const RewritePattern &pattern) {
applicator.applyCostModel([](const Pattern &pattern) {
// Apply a default cost model.
// Note: This is just for demonstration, if the default cost model is truly
// desired `applicator.applyDefaultCostModel()` should be used

View File

@ -58,15 +58,23 @@ private:
};
//===----------------------------------------------------------------------===//
// Pattern class
// Pattern
//===----------------------------------------------------------------------===//
/// Instances of Pattern can be matched against SSA IR. These matches get used
/// in ways dependent on their subclasses and the driver doing the matching.
/// For example, RewritePatterns implement a rewrite from one matched pattern
/// to a replacement DAG tile.
/// This class contains all of the data related to a pattern, but does not
/// contain any methods or logic for the actual matching. This class is solely
/// used to interface with the metadata of a pattern, such as the benefit or
/// root operation.
class Pattern {
public:
/// Return a list of operations that may be generated when rewriting an
/// operation instance with this pattern.
ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
/// Return the root node that this pattern matches. Patterns that can match
/// multiple root types return None.
Optional<OperationName> getRootKind() const { return rootKind; }
/// Return the benefit (the inverse of "cost") of matching this pattern. The
/// benefit of a Pattern is always static - rewrites that may have dynamic
/// benefit can be instantiated multiple times (different Pattern instances)
@ -74,19 +82,11 @@ public:
/// condition predicates.
PatternBenefit getBenefit() const { return benefit; }
/// Return the root node that this pattern matches. Patterns that can match
/// multiple root types return None.
Optional<OperationName> getRootKind() const { return rootKind; }
//===--------------------------------------------------------------------===//
// Implementation hooks for patterns to implement.
//===--------------------------------------------------------------------===//
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
virtual LogicalResult match(Operation *op) const = 0;
virtual ~Pattern() {}
/// Returns true if this pattern is known to result in recursive application,
/// i.e. this pattern may generate IR that also matches this pattern, but is
/// known to bound the recursion. This signals to a rewrite driver that it is
/// safe to apply this pattern recursively to generated IR.
bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
protected:
/// This class acts as a special tag that makes the desire to match "any"
@ -94,19 +94,38 @@ protected:
/// feature, and ensures that the user is making a conscious decision.
struct MatchAnyOpTypeTag {};
/// This constructor is used for patterns that match against a specific
/// operation type. The `benefit` is the expected benefit of matching this
/// pattern.
/// Construct a pattern with a certain benefit that matches the operation
/// with the given root name.
Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
/// Construct a pattern with a certain benefit that matches any operation
/// type. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
/// always be supplied here.
Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag);
/// Construct a pattern with a certain benefit that matches the operation with
/// the given root name. `generatedNames` contains the names of operations
/// that may be generated during a successful rewrite.
Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context);
/// Construct a pattern that may match any operation type. `generatedNames`
/// contains the names of operations that may be generated during a successful
/// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
/// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
/// always be supplied here.
Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
MLIRContext *context, MatchAnyOpTypeTag tag);
/// This constructor is used when a pattern may match against multiple
/// different types of operations. The `benefit` is the expected benefit of
/// matching this pattern. `MatchAnyOpTypeTag` is just a tag to ensure that
/// the "match any" behavior is what the user actually desired,
/// `MatchAnyOpTypeTag()` should always be supplied here.
Pattern(PatternBenefit benefit, MatchAnyOpTypeTag);
/// Set the flag detailing if this pattern has bounded rewrite recursion or
/// not.
void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
hasBoundedRecursion = hasBoundedRecursionArg;
}
private:
/// A list of the potential operations that may be generated when rewriting
/// an op with this pattern.
SmallVector<OperationName, 2> generatedOps;
/// The root operation of the pattern. If the pattern matches a specific
/// operation, this contains the name of that operation. Contains None
/// otherwise.
@ -115,9 +134,14 @@ private:
/// The expected benefit of matching this pattern.
const PatternBenefit benefit;
virtual void anchor();
/// A boolean flag of whether this pattern has bounded recursion or not.
bool hasBoundedRecursion = false;
};
//===----------------------------------------------------------------------===//
// RewritePattern
//===----------------------------------------------------------------------===//
/// RewritePattern is the common base class for all DAG to DAG replacements.
/// There are two possible usages of this class:
/// * Multi-step RewritePattern with "match" and "rewrite"
@ -129,6 +153,8 @@ private:
///
class RewritePattern : public Pattern {
public:
virtual ~RewritePattern() {}
/// Rewrite the IR rooted at the specified operation with the result of
/// this pattern, generating any new operations with the specified
/// builder. If an unexpected error is encountered (an internal
@ -138,7 +164,7 @@ public:
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind().
LogicalResult match(Operation *op) const override;
virtual LogicalResult match(Operation *op) const;
/// Attempt to match against code rooted at the specified operation,
/// which is the same operation code as getRootKind(). If successful, this
@ -152,44 +178,12 @@ public:
return failure();
}
/// Returns true if this pattern is known to result in recursive application,
/// i.e. this pattern may generate IR that also matches this pattern, but is
/// known to bound the recursion. This signals to a rewriter that it is safe
/// to apply this pattern recursively to generated IR.
virtual bool hasBoundedRewriteRecursion() const { return false; }
/// Return a list of operations that may be generated when rewriting an
/// operation instance with this pattern.
ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
protected:
/// Construct a rewrite pattern with a certain benefit that matches the
/// operation with the given root name.
RewritePattern(StringRef rootName, PatternBenefit benefit,
MLIRContext *context)
: Pattern(rootName, benefit, context) {}
/// Construct a rewrite pattern with a certain benefit that matches any
/// operation type. `MatchAnyOpTypeTag` is just a tag to ensure that the
/// "match any" behavior is what the user actually desired,
/// `MatchAnyOpTypeTag()` should always be supplied here.
RewritePattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
: Pattern(benefit, tag) {}
/// Construct a rewrite pattern with a certain benefit that matches the
/// operation with the given root name. `generatedNames` contains the names of
/// operations that may be generated during a successful rewrite.
RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context);
/// Construct a rewrite pattern that may match any operation type.
/// `generatedNames` contains the names of operations that may be generated
/// during a successful rewrite. `MatchAnyOpTypeTag` is just a tag to ensure
/// that the "match any" behavior is what the user actually desired,
/// `MatchAnyOpTypeTag()` should always be supplied here.
RewritePattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
MLIRContext *context, MatchAnyOpTypeTag tag);
/// Inherit the base constructors from `Pattern`.
using Pattern::Pattern;
/// A list of the potential operations that may be generated when rewriting
/// an op with this pattern.
SmallVector<OperationName, 2> generatedOps;
/// An anchor for the virtual table.
virtual void anchor();
};
/// OpRewritePattern is a wrapper around RewritePattern that allows for
@ -232,7 +226,7 @@ template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
};
//===----------------------------------------------------------------------===//
// PatternRewriter class
// PatternRewriter
//===----------------------------------------------------------------------===//
/// This class coordinates the application of a pattern to the current function,
@ -498,7 +492,7 @@ public:
/// pattern. Users can query contained patterns and pass analysis results to
/// applyCostModel. Patterns to be discarded should have a benefit of
/// `impossibleToMatch`.
using CostModel = function_ref<PatternBenefit(const RewritePattern &)>;
using CostModel = function_ref<PatternBenefit(const Pattern &)>;
explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
: owningPatternList(owningPatternList) {}
@ -512,11 +506,11 @@ public:
/// onFailure: called when a pattern fails to match to perform cleanup.
/// onSuccess: called when a pattern match succeeds; return failure() to
/// invalidate the match and try another pattern.
LogicalResult matchAndRewrite(
Operation *op, PatternRewriter &rewriter,
function_ref<bool(const RewritePattern &)> canApply = {},
function_ref<void(const RewritePattern &)> onFailure = {},
function_ref<LogicalResult(const RewritePattern &)> onSuccess = {});
LogicalResult
matchAndRewrite(Operation *op, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply = {},
function_ref<void(const Pattern &)> onFailure = {},
function_ref<LogicalResult(const Pattern &)> onSuccess = {});
/// Apply a cost model to the patterns within this applicator.
void applyCostModel(CostModel model);
@ -524,22 +518,22 @@ public:
/// Apply the default cost model that solely uses the pattern's static
/// benefit.
void applyDefaultCostModel() {
applyCostModel(
[](const RewritePattern &pattern) { return pattern.getBenefit(); });
applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
}
/// Walk all of the rewrite patterns within the applicator.
void walkAllPatterns(function_ref<void(const RewritePattern &)> walk);
/// Walk all of the patterns within the applicator.
void walkAllPatterns(function_ref<void(const Pattern &)> walk);
private:
/// Attempt to match and rewrite the given op with the given pattern, allowing
/// a predicate to decide if a pattern can be applied or not, and hooks for if
/// the pattern match was a success or failure.
LogicalResult matchAndRewrite(
Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
function_ref<bool(const RewritePattern &)> canApply,
function_ref<void(const RewritePattern &)> onFailure,
function_ref<LogicalResult(const RewritePattern &)> onSuccess);
LogicalResult
matchAndRewrite(Operation *op, const RewritePattern &pattern,
PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess);
/// The list that owns the patterns used within this applicator.
const OwningRewritePatternList &owningPatternList;

View File

@ -1042,7 +1042,12 @@ public:
class VectorInsertStridedSliceOpSameRankRewritePattern
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx)
: OpRewritePattern<InsertStridedSliceOp>(ctx) {
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
// bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
@ -1093,9 +1098,6 @@ public:
rewriter.replaceOp(op, res);
return success();
}
/// This pattern creates recursive InsertStridedSliceOp, but the recursion is
/// bounded as the rank is strictly decreasing.
bool hasBoundedRewriteRecursion() const final { return true; }
};
/// Returns the strides if the memory underlying `memRefType` has a contiguous
@ -1505,7 +1507,12 @@ private:
class VectorExtractStridedSliceOpConversion
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
VectorExtractStridedSliceOpConversion(MLIRContext *ctx)
: OpRewritePattern<ExtractStridedSliceOp>(ctx) {
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
// is bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@ -1552,9 +1559,6 @@ public:
rewriter.replaceOp(op, res);
return success();
}
/// This pattern creates recursive ExtractStridedSliceOp, but the recursion is
/// bounded as the rank is strictly decreasing.
bool hasBoundedRewriteRecursion() const final { return true; }
};
} // namespace

View File

@ -16,6 +16,10 @@ using namespace mlir;
#define DEBUG_TYPE "pattern-match"
//===----------------------------------------------------------------------===//
// PatternBenefit
//===----------------------------------------------------------------------===//
PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
"This pattern match benefit is too large to represent");
@ -27,20 +31,35 @@ unsigned short PatternBenefit::getBenefit() const {
}
//===----------------------------------------------------------------------===//
// Pattern implementation
// Pattern
//===----------------------------------------------------------------------===//
Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
MLIRContext *context)
: rootKind(OperationName(rootName, context)), benefit(benefit) {}
Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag)
Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
: benefit(benefit) {}
// Out-of-line vtable anchor.
void Pattern::anchor() {}
Pattern::Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context)
: Pattern(rootName, benefit, context) {
generatedOps.reserve(generatedNames.size());
std::transform(generatedNames.begin(), generatedNames.end(),
std::back_inserter(generatedOps), [context](StringRef name) {
return OperationName(name, context);
});
}
Pattern::Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
MLIRContext *context, MatchAnyOpTypeTag tag)
: Pattern(benefit, tag) {
generatedOps.reserve(generatedNames.size());
std::transform(generatedNames.begin(), generatedNames.end(),
std::back_inserter(generatedOps), [context](StringRef name) {
return OperationName(name, context);
});
}
//===----------------------------------------------------------------------===//
// RewritePattern and PatternRewriter implementation
// RewritePattern
//===----------------------------------------------------------------------===//
void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
@ -52,26 +71,12 @@ LogicalResult RewritePattern::match(Operation *op) const {
llvm_unreachable("need to implement either match or matchAndRewrite!");
}
RewritePattern::RewritePattern(StringRef rootName,
ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context)
: Pattern(rootName, benefit, context) {
generatedOps.reserve(generatedNames.size());
std::transform(generatedNames.begin(), generatedNames.end(),
std::back_inserter(generatedOps), [context](StringRef name) {
return OperationName(name, context);
});
}
RewritePattern::RewritePattern(ArrayRef<StringRef> generatedNames,
PatternBenefit benefit, MLIRContext *context,
MatchAnyOpTypeTag tag)
: Pattern(benefit, tag) {
generatedOps.reserve(generatedNames.size());
std::transform(generatedNames.begin(), generatedNames.end(),
std::back_inserter(generatedOps), [context](StringRef name) {
return OperationName(name, context);
});
}
/// Out-of-line vtable anchor.
void RewritePattern::anchor() {}
//===----------------------------------------------------------------------===//
// PatternRewriter
//===----------------------------------------------------------------------===//
PatternRewriter::~PatternRewriter() {
// Out of line to provide a vtable anchor for the class.
@ -201,7 +206,7 @@ void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
}
//===----------------------------------------------------------------------===//
// PatternMatcher implementation
// PatternApplicator
//===----------------------------------------------------------------------===//
void PatternApplicator::applyCostModel(CostModel model) {
@ -266,16 +271,16 @@ void PatternApplicator::applyCostModel(CostModel model) {
}
void PatternApplicator::walkAllPatterns(
function_ref<void(const RewritePattern &)> walk) {
function_ref<void(const Pattern &)> walk) {
for (auto &it : owningPatternList)
walk(*it);
}
LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, PatternRewriter &rewriter,
function_ref<bool(const RewritePattern &)> canApply,
function_ref<void(const RewritePattern &)> onFailure,
function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess) {
// Check to see if there are patterns matching this specific operation type.
MutableArrayRef<RewritePattern *> opPatterns;
auto patternIt = patterns.find(op->getName());
@ -315,9 +320,9 @@ LogicalResult PatternApplicator::matchAndRewrite(
LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
function_ref<bool(const RewritePattern &)> canApply,
function_ref<void(const RewritePattern &)> onFailure,
function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess) {
// Check that the pattern can be applied.
if (canApply && !canApply(pattern))
return failure();

View File

@ -1452,7 +1452,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
namespace {
/// A set of rewrite patterns that can be used to legalize a given operation.
using LegalizationPatterns = SmallVector<const RewritePattern *, 1>;
using LegalizationPatterns = SmallVector<const Pattern *, 1>;
/// This class defines a recursive operation legalizer.
class OperationLegalizer {
@ -1484,12 +1484,11 @@ private:
/// Return true if the given pattern may be applied to the given operation,
/// false otherwise.
bool canApplyPattern(Operation *op, const RewritePattern &pattern,
bool canApplyPattern(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter);
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op,
const RewritePattern &pattern,
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter,
RewriterState &curState);
@ -1546,7 +1545,7 @@ private:
DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns);
/// The current set of patterns that have been applied.
SmallPtrSet<const RewritePattern *, 8> appliedPatterns;
SmallPtrSet<const Pattern *, 8> appliedPatterns;
/// The legalization information provided by the target.
ConversionTarget &target;
@ -1697,13 +1696,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
auto &rewriterImpl = rewriter.getImpl();
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const RewritePattern &pattern) {
auto canApply = [&](const Pattern &pattern) {
return canApplyPattern(op, pattern, rewriter);
};
// Functor that cleans up the rewriter state after a pattern failed to match.
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const RewritePattern &pattern) {
auto onFailure = [&](const Pattern &pattern) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match"));
rewriterImpl.resetState(curState);
appliedPatterns.erase(&pattern);
@ -1711,7 +1710,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that performs additional legalization when a pattern is
// successfully applied.
auto onSuccess = [&](const RewritePattern &pattern) {
auto onSuccess = [&](const Pattern &pattern) {
auto result = legalizePatternResult(op, pattern, rewriter, curState);
appliedPatterns.erase(&pattern);
if (failed(result))
@ -1724,8 +1723,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
onSuccess);
}
bool OperationLegalizer::canApplyPattern(Operation *op,
const RewritePattern &pattern,
bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter) {
LLVM_DEBUG({
auto &os = rewriter.getImpl().logger;
@ -1747,9 +1745,10 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
return true;
}
LogicalResult OperationLegalizer::legalizePatternResult(
Operation *op, const RewritePattern &pattern,
ConversionPatternRewriter &rewriter, RewriterState &curState) {
LogicalResult
OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter,
RewriterState &curState) {
auto &impl = rewriter.getImpl();
#ifndef NDEBUG
@ -1877,13 +1876,12 @@ void OperationLegalizer::buildLegalizationGraph(
// generate it.
DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps;
// A mapping between an operation and any currently invalid patterns it has.
DenseMap<OperationName, SmallPtrSet<const RewritePattern *, 2>>
invalidPatterns;
DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns;
// A worklist of patterns to consider for legality.
llvm::SetVector<const RewritePattern *> patternWorklist;
llvm::SetVector<const Pattern *> patternWorklist;
// Build the mapping from operations to the parent ops that may generate them.
applicator.walkAllPatterns([&](const RewritePattern &pattern) {
applicator.walkAllPatterns([&](const Pattern &pattern) {
Optional<OperationName> root = pattern.getRootKind();
// If the pattern has no specific root, we can't analyze the relationship
@ -1914,7 +1912,7 @@ void OperationLegalizer::buildLegalizationGraph(
// recurse into itself. It would be better to perform this kind of filtering
// at a higher level than here anyways.
if (!anyOpLegalizerPatterns.empty()) {
for (const RewritePattern *pattern : patternWorklist)
for (const Pattern *pattern : patternWorklist)
legalizerPatterns[*pattern->getRootKind()].push_back(pattern);
return;
}
@ -1964,15 +1962,15 @@ void OperationLegalizer::computeLegalizationGraphBenefit(
// Apply a cost model to the pattern applicator. We order patterns first by
// depth then benefit. `legalizerPatterns` contains per-op patterns by
// decreasing benefit.
applicator.applyCostModel([&](const RewritePattern &p) {
ArrayRef<const RewritePattern *> orderedPatternList;
if (Optional<OperationName> rootName = p.getRootKind())
applicator.applyCostModel([&](const Pattern &pattern) {
ArrayRef<const Pattern *> orderedPatternList;
if (Optional<OperationName> rootName = pattern.getRootKind())
orderedPatternList = legalizerPatterns[*rootName];
else
orderedPatternList = anyOpLegalizerPatterns;
// If the pattern is not found, then it was removed and cannot be matched.
auto it = llvm::find(orderedPatternList, &p);
auto it = llvm::find(orderedPatternList, &pattern);
if (it == orderedPatternList.end())
return PatternBenefit::impossibleToMatch();
@ -2014,9 +2012,9 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
unsigned minDepth = std::numeric_limits<unsigned>::max();
// Compute the depth for each pattern within the set.
SmallVector<std::pair<const RewritePattern *, unsigned>, 4> patternsByDepth;
SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth;
patternsByDepth.reserve(patterns.size());
for (const RewritePattern *pattern : patterns) {
for (const Pattern *pattern : patterns) {
unsigned depth = 0;
for (auto generatedOp : pattern->getGeneratedOps()) {
unsigned generatedOpDepth = computeOpLegalizationDepth(
@ -2037,8 +2035,8 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
// Sort the patterns by those likely to be the most beneficial.
llvm::array_pod_sort(
patternsByDepth.begin(), patternsByDepth.end(),
[](const std::pair<const RewritePattern *, unsigned> *lhs,
const std::pair<const RewritePattern *, unsigned> *rhs) {
[](const std::pair<const Pattern *, unsigned> *lhs,
const std::pair<const Pattern *, unsigned> *rhs) {
// First sort by the smaller pattern legalization depth.
if (lhs->second != rhs->second)
return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,

View File

@ -452,7 +452,11 @@ struct TestNonRootReplacement : public RewritePattern {
/// bounded recursion.
struct TestBoundedRecursiveRewrite
: public OpRewritePattern<TestRecursiveRewriteOp> {
using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern;
TestBoundedRecursiveRewrite(MLIRContext *ctx)
: OpRewritePattern<TestRecursiveRewriteOp>(ctx) {
// The conversion target handles bounding the recursion of this pattern.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
PatternRewriter &rewriter) const final {
@ -462,9 +466,6 @@ struct TestBoundedRecursiveRewrite
});
return success();
}
/// The conversion target handles bounding the recursion of this pattern.
bool hasBoundedRewriteRecursion() const final { return true; }
};
struct TestNestedOpCreationUndoRewrite