mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-29 16:41:27 +00:00
[mlir] Add an option to still use bottom-up traversal
GreedyPatternRewriteDriver was changed from bottom-up traversal to top-down traversal. Not all passes work yet with that change for traversal order. To give some time for fixing, add an option to allow to switch back to bottom-up traversal. Use this option in FusionOfTensorOpsPass which fails otherwise. Differential Revision: https://reviews.llvm.org/D99059
This commit is contained in:
parent
82f6e0dde2
commit
c691b9686b
@ -35,26 +35,26 @@ namespace mlir {
|
||||
/// before attempting to match any of the provided patterns.
|
||||
LogicalResult
|
||||
applyPatternsAndFoldGreedily(Operation *op,
|
||||
const FrozenRewritePatternList &patterns);
|
||||
const FrozenRewritePatternList &patterns,
|
||||
bool useTopDownTraversal = true);
|
||||
|
||||
/// Rewrite the regions of the specified operation, with a user-provided limit
|
||||
/// on iterations to attempt before reaching convergence.
|
||||
LogicalResult
|
||||
applyPatternsAndFoldGreedily(Operation *op,
|
||||
const FrozenRewritePatternList &patterns,
|
||||
unsigned maxIterations);
|
||||
LogicalResult applyPatternsAndFoldGreedily(
|
||||
Operation *op, const FrozenRewritePatternList &patterns,
|
||||
unsigned maxIterations, bool useTopDownTraversal = true);
|
||||
|
||||
/// Rewrite the given regions, which must be isolated from above.
|
||||
LogicalResult
|
||||
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
|
||||
const FrozenRewritePatternList &patterns);
|
||||
const FrozenRewritePatternList &patterns,
|
||||
bool useTopDownTraversal = true);
|
||||
|
||||
/// Rewrite the given regions, with a user-provided limit on iterations to
|
||||
/// attempt before reaching convergence.
|
||||
LogicalResult
|
||||
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
|
||||
const FrozenRewritePatternList &patterns,
|
||||
unsigned maxIterations);
|
||||
LogicalResult applyPatternsAndFoldGreedily(
|
||||
MutableArrayRef<Region> regions, const FrozenRewritePatternList &patterns,
|
||||
unsigned maxIterations, bool useTopDownTraversal = true);
|
||||
|
||||
/// Applies the specified patterns on `op` alone while also trying to fold it,
|
||||
/// by selecting the highest benefits patterns in a greedy manner. Returns
|
||||
|
@ -1115,7 +1115,8 @@ struct FusionOfTensorOpsPass
|
||||
Operation *op = getOperation();
|
||||
OwningRewritePatternList patterns(op->getContext());
|
||||
populateLinalgTensorOpsFusionPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
|
||||
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
|
||||
/*useTopDown=*/false);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -37,8 +37,10 @@ namespace {
|
||||
class GreedyPatternRewriteDriver : public PatternRewriter {
|
||||
public:
|
||||
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
|
||||
const FrozenRewritePatternList &patterns)
|
||||
: PatternRewriter(ctx), matcher(patterns), folder(ctx) {
|
||||
const FrozenRewritePatternList &patterns,
|
||||
bool useTopDownTraversal)
|
||||
: PatternRewriter(ctx), matcher(patterns), folder(ctx),
|
||||
useTopDownTraversal(useTopDownTraversal) {
|
||||
worklist.reserve(64);
|
||||
|
||||
// Apply a simple cost model based solely on pattern benefit.
|
||||
@ -134,6 +136,9 @@ private:
|
||||
|
||||
/// Non-pattern based folder for operations.
|
||||
OperationFolder folder;
|
||||
|
||||
// Whether to use top-down or bottom-up traversal order.
|
||||
bool useTopDownTraversal;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
@ -153,14 +158,19 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
|
||||
|
||||
// Add all nested operations to the worklist in preorder.
|
||||
for (auto ®ion : regions)
|
||||
region.walk<WalkOrder::PreOrder>(
|
||||
[this](Operation *op) { worklist.push_back(op); });
|
||||
if (useTopDownTraversal)
|
||||
region.walk<WalkOrder::PreOrder>(
|
||||
[this](Operation *op) { worklist.push_back(op); });
|
||||
else
|
||||
region.walk([this](Operation *op) { addToWorklist(op); });
|
||||
|
||||
// Reverse the list so our pop-back loop processes them in-order.
|
||||
std::reverse(worklist.begin(), worklist.end());
|
||||
// Remember the reverse index.
|
||||
for (unsigned i = 0, e = worklist.size(); i != e; ++i)
|
||||
worklistMap[worklist[i]] = i;
|
||||
if (useTopDownTraversal) {
|
||||
// Reverse the list so our pop-back loop processes them in-order.
|
||||
std::reverse(worklist.begin(), worklist.end());
|
||||
// Remember the reverse index.
|
||||
for (unsigned i = 0, e = worklist.size(); i != e; ++i)
|
||||
worklistMap[worklist[i]] = i;
|
||||
}
|
||||
|
||||
// These are scratch vectors used in the folding loop below.
|
||||
SmallVector<Value, 8> originalOperands, resultValues;
|
||||
@ -231,28 +241,29 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
|
||||
/// top-level operation itself.
|
||||
///
|
||||
LogicalResult
|
||||
mlir::applyPatternsAndFoldGreedily(Operation *op,
|
||||
const FrozenRewritePatternList &patterns) {
|
||||
return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations);
|
||||
}
|
||||
LogicalResult
|
||||
mlir::applyPatternsAndFoldGreedily(Operation *op,
|
||||
const FrozenRewritePatternList &patterns,
|
||||
unsigned maxIterations) {
|
||||
return applyPatternsAndFoldGreedily(op->getRegions(), patterns,
|
||||
maxIterations);
|
||||
bool useTopDownTraversal) {
|
||||
return applyPatternsAndFoldGreedily(op, patterns, maxPatternMatchIterations,
|
||||
useTopDownTraversal);
|
||||
}
|
||||
LogicalResult mlir::applyPatternsAndFoldGreedily(
|
||||
Operation *op, const FrozenRewritePatternList &patterns,
|
||||
unsigned maxIterations, bool useTopDownTraversal) {
|
||||
return applyPatternsAndFoldGreedily(op->getRegions(), patterns, maxIterations,
|
||||
useTopDownTraversal);
|
||||
}
|
||||
/// Rewrite the given regions, which must be isolated from above.
|
||||
LogicalResult
|
||||
mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
|
||||
const FrozenRewritePatternList &patterns) {
|
||||
return applyPatternsAndFoldGreedily(regions, patterns,
|
||||
maxPatternMatchIterations);
|
||||
}
|
||||
LogicalResult
|
||||
mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
|
||||
const FrozenRewritePatternList &patterns,
|
||||
unsigned maxIterations) {
|
||||
bool useTopDownTraversal) {
|
||||
return applyPatternsAndFoldGreedily(
|
||||
regions, patterns, maxPatternMatchIterations, useTopDownTraversal);
|
||||
}
|
||||
LogicalResult mlir::applyPatternsAndFoldGreedily(
|
||||
MutableArrayRef<Region> regions, const FrozenRewritePatternList &patterns,
|
||||
unsigned maxIterations, bool useTopDownTraversal) {
|
||||
if (regions.empty())
|
||||
return success();
|
||||
|
||||
@ -267,7 +278,8 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
|
||||
"patterns can only be applied to operations IsolatedFromAbove");
|
||||
|
||||
// Start the pattern driver.
|
||||
GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns);
|
||||
GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns,
|
||||
useTopDownTraversal);
|
||||
bool converged = driver.simplify(regions, maxIterations);
|
||||
LLVM_DEBUG(if (!converged) {
|
||||
llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
|
||||
|
Loading…
Reference in New Issue
Block a user