Add an action for each iteration of the GreedyPatternRewriteDriver

Differential Revision: https://reviews.llvm.org/D149101
This commit is contained in:
Mehdi Amini 2023-02-22 22:43:45 -07:00
parent 2d58925362
commit 87e6e490e7
3 changed files with 43 additions and 19 deletions

View File

@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/IR/Action.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
@ -410,6 +411,24 @@ RegionPatternRewriteDriver::RegionPatternRewriteDriver(
}
}
namespace {
class GreedyPatternRewriteIteration
: public tracing::ActionImpl<GreedyPatternRewriteIteration> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GreedyPatternRewriteIteration)
GreedyPatternRewriteIteration(ArrayRef<IRUnit> units, int64_t iteration)
: tracing::ActionImpl<GreedyPatternRewriteIteration>(units),
iteration(iteration) {}
static constexpr StringLiteral tag = "GreedyPatternRewriteIteration";
void print(raw_ostream &os) const override {
os << "GreedyPatternRewriteIteration(" << iteration << ")";
}
private:
int64_t iteration = 0;
};
} // namespace
LogicalResult RegionPatternRewriteDriver::simplify() && {
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
@ -423,6 +442,7 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
bool changed = false;
int64_t iteration = 0;
MLIRContext *ctx = getContext();
do {
// Check if the iteration limit was reached.
if (iteration++ >= config.maxIterations &&
@ -455,12 +475,16 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
worklistMap[worklist[i]] = i;
}
changed = processWorklist();
ctx->executeAction<GreedyPatternRewriteIteration>(
[&] {
changed = processWorklist();
// After applying patterns, make sure that the CFG of each of the regions
// is kept up to date.
if (config.enableRegionSimplification)
changed |= succeeded(simplifyRegions(*this, region));
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
if (config.enableRegionSimplification)
changed |= succeeded(simplifyRegions(*this, region));
},
{&region}, iteration);
} while (changed);
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.

View File

@ -16,24 +16,24 @@ func.func @c() {
////////////////////////////////////
/// 1. All actions should be logged.
// RUN: mlir-opt %s --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s
// RUN: mlir-opt %s --log-actions-to=- -pass-pipeline="builtin.module(func.func(test-stats-pass))" -o %t --mlir-disable-threading | FileCheck %s
// Specify the current file as filter, expect to see all actions.
// RUN: mlir-opt %s --log-mlir-actions-filter=%s --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s
// RUN: mlir-opt %s --log-mlir-actions-filter=%s --log-actions-to=- -pass-pipeline="builtin.module(func.func(test-stats-pass))" -o %t --mlir-disable-threading | FileCheck %s
// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `Canonicalizer` on Operation `func.func` (func.func @a() {...}
// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `{{.*}}TestStatisticPass` on Operation `func.func` (func.func @a() {...}
// CHECK-NEXT: [thread {{.*}}] completed `pass-execution`
// CHECK-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `Canonicalizer` on Operation `func.func` (func.func @b() {...}
// CHECK-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `{{.*}}TestStatisticPass` on Operation `func.func` (func.func @b() {...}
// CHECK-NEXT: [thread {{.*}}] completed `pass-execution`
// CHECK-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `Canonicalizer` on Operation `func.func` (func.func @c() {...}
// CHECK-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `{{.*}}TestStatisticPass` on Operation `func.func` (func.func @c() {...}
// CHECK-NEXT: [thread {{.*}}] completed `pass-execution`
////////////////////////////////////
/// 2. No match
// Specify a non-existing file as filter, expect to see no actions.
// RUN: mlir-opt %s --log-mlir-actions-filter=foo.mlir --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s --check-prefix=CHECK-NONE --allow-empty
// RUN: mlir-opt %s --log-mlir-actions-filter=foo.mlir --log-actions-to=- -pass-pipeline="builtin.module(func.func(test-stats-pass))" -o %t --mlir-disable-threading | FileCheck %s --check-prefix=CHECK-NONE --allow-empty
// Filter on a non-matching line, expect to see no actions.
// RUN: mlir-opt %s --log-mlir-actions-filter=%s:1 --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s --check-prefix=CHECK-NONE --allow-empty
// RUN: mlir-opt %s --log-mlir-actions-filter=%s:1 --log-actions-to=- -pass-pipeline="builtin.module(func.func(test-stats-pass))" -o %t --mlir-disable-threading | FileCheck %s --check-prefix=CHECK-NONE --allow-empty
// Invalid Filter
// CHECK-NONE-NOT: Canonicalizer
@ -42,19 +42,19 @@ func.func @c() {
/// 3. Matching filters
// Filter the second function only
// RUN: mlir-opt %s --log-mlir-actions-filter=%s:8 --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s --check-prefix=CHECK-SECOND
// RUN: mlir-opt %s --log-mlir-actions-filter=%s:8 --log-actions-to=- -pass-pipeline="builtin.module(func.func(test-stats-pass))" -o %t --mlir-disable-threading | FileCheck %s --check-prefix=CHECK-SECOND
// CHECK-SECOND-NOT: @a
// CHECK-SECOND-NOT: @c
// CHECK-SECOND: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `Canonicalizer` on Operation `func.func` (func.func @b() {...}
// CHECK-SECOND: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `{{.*}}TestStatisticPass` on Operation `func.func` (func.func @b() {...}
// CHECK-SECOND-NEXT: [thread {{.*}}] completed `pass-execution`
// Filter the first and third functions
// RUN: mlir-opt %s --log-mlir-actions-filter=%s:4,%s:12 --log-actions-to=- -pass-pipeline="builtin.module(func.func(canonicalize))" -o %t --mlir-disable-threading | FileCheck %s --check-prefix=CHECK-FIRST-THIRD
// RUN: mlir-opt %s --log-mlir-actions-filter=%s:4,%s:12 --log-actions-to=- -pass-pipeline="builtin.module(func.func(test-stats-pass))" -o %t --mlir-disable-threading | FileCheck %s --check-prefix=CHECK-FIRST-THIRD
// CHECK-FIRST-THIRD-NOT: Canonicalizer
// CHECK-FIRST-THIRD: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `Canonicalizer` on Operation `func.func` (func.func @a() {...}
// CHECK-FIRST-THIRD: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `{{.*}}TestStatisticPass` on Operation `func.func` (func.func @a() {...}
// CHECK-FIRST-THIRD-NEXT: [thread {{.*}}] completed `pass-execution`
// CHECK-FIRST-THIRD-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `Canonicalizer` on Operation `func.func` (func.func @c() {...}
// CHECK-FIRST-THIRD-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `{{.*}}TestStatisticPass` on Operation `func.func` (func.func @c() {...}
// CHECK-FIRST-THIRD-NEXT: [thread {{.*}}] completed `pass-execution`
// CHECK-FIRST-THIRD-NOT: Canonicalizer

View File

@ -1,6 +1,6 @@
// RUN: mlir-opt %s --log-actions-to=- -canonicalize -test-module-pass | FileCheck %s
// RUN: mlir-opt %s --log-actions-to=- -test-stats-pass -test-module-pass | FileCheck %s
// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `Canonicalizer` on Operation `builtin.module` (module {...}
// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `{{.*}}TestStatisticPass` on Operation `builtin.module` (module {...})`
// CHECK-NEXT: [thread {{.*}}] completed `pass-execution`
// CHECK-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution` running `{{.*}}TestModulePass` on Operation `builtin.module` (module {...}
// CHECK-NEXT: [thread {{.*}}] completed `pass-execution`