From 56a698510faef5bf3ef224c229a049bb1e376a56 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 29 Apr 2020 15:08:05 -0700 Subject: [PATCH] [mlir][Pass][NFC] Merge OpToOpPassAdaptor and OpToOpPassAdaptorParallel This moves the threading check to runOnOperation. This produces a much cleaner interface for the adaptor pass, and will allow for the ability to enable/disable threading in a much cleaner way in the future. Differential Revision: https://reviews.llvm.org/D78313 --- mlir/lib/Pass/IRPrinting.cpp | 4 +- mlir/lib/Pass/Pass.cpp | 53 ++++++++---------------- mlir/lib/Pass/PassDetail.h | 71 +++++++++++--------------------- mlir/lib/Pass/PassStatistics.cpp | 25 +++++------ mlir/lib/Pass/PassTiming.cpp | 21 ++++------ 5 files changed, 63 insertions(+), 111 deletions(-) diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 679a9ec27ead..ba9ff989ea79 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -99,7 +99,7 @@ private: /// Returns true if the given pass is hidden from IR printing. static bool isHiddenPass(Pass *pass) { - return isAdaptorPass(pass) || isa(pass); + return isa(pass) || isa(pass); } static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, @@ -173,7 +173,7 @@ void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { } void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { - if (isAdaptorPass(pass)) + if (isa(pass)) return; if (config->shouldPrintAfterOnlyOnChange()) beforePassFingerPrints.erase(pass); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 53ccd4f005a4..b6bef48cb3ec 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -51,7 +51,7 @@ void Pass::copyOptionValuesFrom(const Pass *other) { /// an adaptor pass, print with the op_name(sub_pass,...) format. void Pass::printAsTextualPipeline(raw_ostream &os) { // Special case for adaptors to use the 'op_name(sub_passes)' format. - if (auto *adaptor = getAdaptorPassBase(this)) { + if (auto *adaptor = dyn_cast(this)) { llvm::interleaveComma(adaptor->getPassManagers(), os, [&](OpPassManager &pm) { os << pm.getOpName() << "("; @@ -152,15 +152,15 @@ struct OpPassManagerImpl { void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() { // Bail out early if there are no adaptor passes. if (llvm::none_of(passes, [](std::unique_ptr &pass) { - return isAdaptorPass(pass.get()); + return isa(pass.get()); })) return; // Walk the pass list and merge adjacent adaptors. - OpToOpPassAdaptorBase *lastAdaptor = nullptr; + OpToOpPassAdaptor *lastAdaptor = nullptr; for (auto it = passes.begin(), e = passes.end(); it != e; ++it) { // Check to see if this pass is an adaptor. - if (auto *currentAdaptor = getAdaptorPassBase(it->get())) { + if (auto *currentAdaptor = dyn_cast(it->get())) { // If it is the first adaptor in a possible chain, remember it and // continue. if (!lastAdaptor) { @@ -243,16 +243,7 @@ LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) { /// pass manager. OpPassManager &OpPassManager::nest(const OperationName &nestedName) { OpPassManager nested(nestedName, impl->disableThreads, impl->verifyPasses); - - /// Create an adaptor for this pass. If multi-threading is disabled, then - /// create a synchronous adaptor. - if (impl->disableThreads || !llvm::llvm_is_multithreaded()) { - auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); - addPass(std::unique_ptr(adaptor)); - return adaptor->getPassManagers().front(); - } - - auto *adaptor = new OpToOpPassAdaptorParallel(std::move(nested)); + auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); addPass(std::unique_ptr(adaptor)); return adaptor->getPassManagers().front(); } @@ -330,12 +321,12 @@ static OpPassManager *findPassManagerFor(MutableArrayRef mgrs, return it == mgrs.end() ? nullptr : &*it; } -OpToOpPassAdaptorBase::OpToOpPassAdaptorBase(OpPassManager &&mgr) { +OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) { mgrs.emplace_back(std::move(mgr)); } /// Merge the current pass adaptor into given 'rhs'. -void OpToOpPassAdaptorBase::mergeInto(OpToOpPassAdaptorBase &rhs) { +void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) { for (auto &pm : mgrs) { // If an existing pass manager exists, then merge the given pass manager // into it. @@ -357,7 +348,7 @@ void OpToOpPassAdaptorBase::mergeInto(OpToOpPassAdaptorBase &rhs) { } /// Returns the adaptor pass name. -std::string OpToOpPassAdaptorBase::getName() { +std::string OpToOpPassAdaptor::getAdaptorName() { std::string name = "Pipeline Collection : ["; llvm::raw_string_ostream os(name); llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) { @@ -367,11 +358,16 @@ std::string OpToOpPassAdaptorBase::getName() { return os.str(); } -OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) - : OpToOpPassAdaptorBase(std::move(mgr)) {} - /// Run the held pipeline over all nested operations. void OpToOpPassAdaptor::runOnOperation() { + if (mgrs.front().getImpl().disableThreads || !llvm::llvm_is_multithreaded()) + runOnOperationImpl(); + else + runOnOperationAsyncImpl(); +} + +/// Run this pass adaptor synchronously. +void OpToOpPassAdaptor::runOnOperationImpl() { auto am = getAnalysisManager(); PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(), this}; @@ -397,9 +393,6 @@ void OpToOpPassAdaptor::runOnOperation() { } } -OpToOpPassAdaptorParallel::OpToOpPassAdaptorParallel(OpPassManager &&mgr) - : OpToOpPassAdaptorBase(std::move(mgr)) {} - /// Utility functor that checks if the two ranges of pass managers have a size /// mismatch. static bool hasSizeMismatch(ArrayRef lhs, @@ -409,8 +402,8 @@ static bool hasSizeMismatch(ArrayRef lhs, [&](size_t i) { return lhs[i].size() != rhs[i].size(); }); } -// Run the held pipeline asynchronously across the functions within the module. -void OpToOpPassAdaptorParallel::runOnOperation() { +/// Run this pass adaptor synchronously. +void OpToOpPassAdaptor::runOnOperationAsyncImpl() { AnalysisManager am = getAnalysisManager(); // Create the async executors if they haven't been created, or if the main @@ -491,16 +484,6 @@ void OpToOpPassAdaptorParallel::runOnOperation() { signalPassFailure(); } -/// Utility function to convert the given class to the base adaptor it is an -/// adaptor pass, returns nullptr otherwise. -OpToOpPassAdaptorBase *mlir::detail::getAdaptorPassBase(Pass *pass) { - if (auto *adaptor = dyn_cast(pass)) - return adaptor; - if (auto *adaptor = dyn_cast(pass)) - return adaptor; - return nullptr; -} - //===----------------------------------------------------------------------===// // PassCrashReproducer //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h index 59d9a7a0576f..2342a1a7af97 100644 --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -27,70 +27,45 @@ class VerifierPass : public PassWrapper> { // OpToOpPassAdaptor //===----------------------------------------------------------------------===// -/// A base class for Op-to-Op adaptor passes. -class OpToOpPassAdaptorBase { +/// An adaptor pass used to run operation passes over nested operations. +class OpToOpPassAdaptor + : public PassWrapper> { public: - OpToOpPassAdaptorBase(OpPassManager &&mgr); - OpToOpPassAdaptorBase(const OpToOpPassAdaptorBase &rhs) = default; + OpToOpPassAdaptor(OpPassManager &&mgr); + OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs) = default; + + /// Run the held pipeline over all operations. + void runOnOperation() override; /// Merge the current pass adaptor into given 'rhs'. - void mergeInto(OpToOpPassAdaptorBase &rhs); + void mergeInto(OpToOpPassAdaptor &rhs); /// Returns the pass managers held by this adaptor. MutableArrayRef getPassManagers() { return mgrs; } - /// Returns the adaptor pass name. - std::string getName(); - -protected: - // A set of adaptors to run. - SmallVector mgrs; -}; - -/// An adaptor pass used to run operation passes over nested operations -/// synchronously on a single thread. -class OpToOpPassAdaptor - : public PassWrapper>, - public OpToOpPassAdaptorBase { -public: - OpToOpPassAdaptor(OpPassManager &&mgr); - - /// Run the held pipeline over all operations. - void runOnOperation() override; -}; - -/// An adaptor pass used to run operation passes over nested operations -/// asynchronously across multiple threads. -class OpToOpPassAdaptorParallel - : public PassWrapper>, - public OpToOpPassAdaptorBase { -public: - OpToOpPassAdaptorParallel(OpPassManager &&mgr); - - /// Run the held pipeline over all operations. - void runOnOperation() override; - /// Return the async pass managers held by this parallel adaptor. MutableArrayRef> getParallelPassManagers() { return asyncExecutors; } + /// Returns the adaptor pass name. + std::string getAdaptorName(); + private: - // A set of executors, cloned from the main executor, that run asynchronously - // on different threads. + /// Run this pass adaptor synchronously. + void runOnOperationImpl(); + + /// Run this pass adaptor asynchronously. + void runOnOperationAsyncImpl(); + + /// A set of adaptors to run. + SmallVector mgrs; + + /// A set of executors, cloned from the main executor, that run asynchronously + /// on different threads. This is used when threading is enabled. SmallVector, 8> asyncExecutors; }; -/// Utility function to convert the given class to the base adaptor it is an -/// adaptor pass, returns nullptr otherwise. -OpToOpPassAdaptorBase *getAdaptorPassBase(Pass *pass); - -/// Utility function to return if a pass refers to an adaptor pass. Adaptor -/// passes are those that internally execute a pipeline. -inline bool isAdaptorPass(Pass *pass) { - return isa(pass) || isa(pass); -} - } // end namespace detail } // end namespace mlir #endif // MLIR_PASS_PASSDETAIL_H_ diff --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp index 7ac54f7cf1af..6ef0d3bbea6a 100644 --- a/mlir/lib/Pass/PassStatistics.cpp +++ b/mlir/lib/Pass/PassStatistics.cpp @@ -60,7 +60,7 @@ static void printPassEntry(raw_ostream &os, unsigned indent, StringRef pass, static void printResultsAsList(raw_ostream &os, OpPassManager &pm) { llvm::StringMap> mergedStats; std::function addStats = [&](Pass *pass) { - auto *adaptor = getAdaptorPassBase(pass); + auto *adaptor = dyn_cast(pass); // If this is not an adaptor, add the stats to the list if there are any. if (!adaptor) { @@ -105,13 +105,12 @@ static void printResultsAsList(raw_ostream &os, OpPassManager &pm) { static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) { std::function printPass = [&](unsigned indent, Pass *pass) { - // Handle the case of an adaptor pass. - if (auto *adaptor = getAdaptorPassBase(pass)) { + if (auto *adaptor = dyn_cast(pass)) { // If this adaptor has more than one internal pipeline, print an entry for // it. auto mgrs = adaptor->getPassManagers(); if (mgrs.size() > 1) { - printPassEntry(os, indent, adaptor->getName()); + printPassEntry(os, indent, adaptor->getAdaptorName()); indent += 2; } @@ -195,8 +194,8 @@ void OpPassManager::mergeStatisticsInto(OpPassManager &other) { Pass &pass = std::get<0>(passPair), &otherPass = std::get<1>(passPair); // If this is an adaptor, then recursively merge the pass managers. - if (auto *adaptorPass = getAdaptorPassBase(&pass)) { - auto *otherAdaptorPass = getAdaptorPassBase(&otherPass); + if (auto *adaptorPass = dyn_cast(&pass)) { + auto *otherAdaptorPass = cast(&otherPass); for (auto mgrs : llvm::zip(adaptorPass->getPassManagers(), otherAdaptorPass->getPassManagers())) std::get<0>(mgrs).mergeStatisticsInto(std::get<1>(mgrs)); @@ -217,18 +216,16 @@ void OpPassManager::mergeStatisticsInto(OpPassManager &other) { /// consumption(e.g. dumping). static void prepareStatistics(OpPassManager &pm) { for (Pass &pass : pm.getPasses()) { - OpToOpPassAdaptorBase *adaptor = getAdaptorPassBase(&pass); + OpToOpPassAdaptor *adaptor = dyn_cast(&pass); if (!adaptor) continue; MutableArrayRef nestedPms = adaptor->getPassManagers(); - // If this is a parallel adaptor, merge the statistics from the async - // pass managers into the main nested pass managers. - if (auto *parallelAdaptor = dyn_cast(&pass)) { - for (auto &asyncPM : parallelAdaptor->getParallelPassManagers()) { - for (unsigned i = 0, e = asyncPM.size(); i != e; ++i) - asyncPM[i].mergeStatisticsInto(nestedPms[i]); - } + // Merge the statistics from the async pass managers into the main nested + // pass managers. + for (auto &asyncPM : adaptor->getParallelPassManagers()) { + for (unsigned i = 0, e = asyncPM.size(); i != e; ++i) + asyncPM[i].mergeStatisticsInto(nestedPms[i]); } // Prepare the statistics of each of the nested passes. diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp index c8f0ad8afa50..71bf822a864b 100644 --- a/mlir/lib/Pass/PassTiming.cpp +++ b/mlir/lib/Pass/PassTiming.cpp @@ -277,17 +277,17 @@ void PassTiming::runAfterPipeline(const OperationName &name, /// Start a new timer for the given pass. void PassTiming::startPassTimer(Pass *pass) { - auto kind = isAdaptorPass(pass) ? TimerKind::PipelineCollection - : TimerKind::PassOrAnalysis; + auto kind = isa(pass) ? TimerKind::PipelineCollection + : TimerKind::PassOrAnalysis; Timer *timer = getTimer(pass, kind, [pass]() -> std::string { - if (auto *adaptor = getAdaptorPassBase(pass)) - return adaptor->getName(); + if (auto *adaptor = dyn_cast(pass)) + return adaptor->getAdaptorName(); return std::string(pass->getName()); }); // We don't actually want to time the adaptor passes, they gather their total // from their held passes. - if (!isAdaptorPass(pass)) + if (!isa(pass)) timer->start(); } @@ -302,9 +302,9 @@ void PassTiming::startAnalysisTimer(StringRef name, TypeID id) { void PassTiming::runAfterPass(Pass *pass, Operation *) { Timer *timer = popLastActiveTimer(); - // If this is an OpToOpPassAdaptorParallel, then we need to merge in the - // timing data for the pipelines running on other threads. - if (isa(pass)) { + // If this is a pass adaptor, then we need to merge in the timing data for the + // pipelines running on other threads. + if (isa(pass)) { auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass}); if (toMerge != pipelinesToMerge.end()) { for (auto &it : toMerge->second) @@ -314,10 +314,7 @@ void PassTiming::runAfterPass(Pass *pass, Operation *) { return; } - // Adaptor passes aren't timed directly, so we don't need to stop their - // timers. - if (!isAdaptorPass(pass)) - timer->stop(); + timer->stop(); } /// Stop a timer.