[mlir][Pass][NFC] Merge OpToOpPassAdaptor and OpToOpPassAdaptorParallel

This moves the threading check to runOnOperation. This produces a much cleaner interface for the adaptor pass, and will allow for the ability to enable/disable threading in a much cleaner way in the future.

Differential Revision: https://reviews.llvm.org/D78313
This commit is contained in:
River Riddle 2020-04-29 15:08:05 -07:00
parent 30d17d8852
commit 56a698510f
5 changed files with 63 additions and 111 deletions

View File

@ -99,7 +99,7 @@ private:
/// Returns true if the given pass is hidden from IR printing.
static bool isHiddenPass(Pass *pass) {
return isAdaptorPass(pass) || isa<VerifierPass>(pass);
return isa<OpToOpPassAdaptor>(pass) || isa<VerifierPass>(pass);
}
static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
@ -173,7 +173,7 @@ void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
}
void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
if (isAdaptorPass(pass))
if (isa<OpToOpPassAdaptor>(pass))
return;
if (config->shouldPrintAfterOnlyOnChange())
beforePassFingerPrints.erase(pass);

View File

@ -51,7 +51,7 @@ void Pass::copyOptionValuesFrom(const Pass *other) {
/// an adaptor pass, print with the op_name(sub_pass,...) format.
void Pass::printAsTextualPipeline(raw_ostream &os) {
// Special case for adaptors to use the 'op_name(sub_passes)' format.
if (auto *adaptor = getAdaptorPassBase(this)) {
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
llvm::interleaveComma(adaptor->getPassManagers(), os,
[&](OpPassManager &pm) {
os << pm.getOpName() << "(";
@ -152,15 +152,15 @@ struct OpPassManagerImpl {
void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
// Bail out early if there are no adaptor passes.
if (llvm::none_of(passes, [](std::unique_ptr<Pass> &pass) {
return isAdaptorPass(pass.get());
return isa<OpToOpPassAdaptor>(pass.get());
}))
return;
// Walk the pass list and merge adjacent adaptors.
OpToOpPassAdaptorBase *lastAdaptor = nullptr;
OpToOpPassAdaptor *lastAdaptor = nullptr;
for (auto it = passes.begin(), e = passes.end(); it != e; ++it) {
// Check to see if this pass is an adaptor.
if (auto *currentAdaptor = getAdaptorPassBase(it->get())) {
if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(it->get())) {
// If it is the first adaptor in a possible chain, remember it and
// continue.
if (!lastAdaptor) {
@ -243,16 +243,7 @@ LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) {
/// pass manager.
OpPassManager &OpPassManager::nest(const OperationName &nestedName) {
OpPassManager nested(nestedName, impl->disableThreads, impl->verifyPasses);
/// Create an adaptor for this pass. If multi-threading is disabled, then
/// create a synchronous adaptor.
if (impl->disableThreads || !llvm::llvm_is_multithreaded()) {
auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
addPass(std::unique_ptr<Pass>(adaptor));
return adaptor->getPassManagers().front();
}
auto *adaptor = new OpToOpPassAdaptorParallel(std::move(nested));
auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
addPass(std::unique_ptr<Pass>(adaptor));
return adaptor->getPassManagers().front();
}
@ -330,12 +321,12 @@ static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
return it == mgrs.end() ? nullptr : &*it;
}
OpToOpPassAdaptorBase::OpToOpPassAdaptorBase(OpPassManager &&mgr) {
OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
mgrs.emplace_back(std::move(mgr));
}
/// Merge the current pass adaptor into given 'rhs'.
void OpToOpPassAdaptorBase::mergeInto(OpToOpPassAdaptorBase &rhs) {
void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
for (auto &pm : mgrs) {
// If an existing pass manager exists, then merge the given pass manager
// into it.
@ -357,7 +348,7 @@ void OpToOpPassAdaptorBase::mergeInto(OpToOpPassAdaptorBase &rhs) {
}
/// Returns the adaptor pass name.
std::string OpToOpPassAdaptorBase::getName() {
std::string OpToOpPassAdaptor::getAdaptorName() {
std::string name = "Pipeline Collection : [";
llvm::raw_string_ostream os(name);
llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
@ -367,11 +358,16 @@ std::string OpToOpPassAdaptorBase::getName() {
return os.str();
}
OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr)
: OpToOpPassAdaptorBase(std::move(mgr)) {}
/// Run the held pipeline over all nested operations.
void OpToOpPassAdaptor::runOnOperation() {
if (mgrs.front().getImpl().disableThreads || !llvm::llvm_is_multithreaded())
runOnOperationImpl();
else
runOnOperationAsyncImpl();
}
/// Run this pass adaptor synchronously.
void OpToOpPassAdaptor::runOnOperationImpl() {
auto am = getAnalysisManager();
PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
this};
@ -397,9 +393,6 @@ void OpToOpPassAdaptor::runOnOperation() {
}
}
OpToOpPassAdaptorParallel::OpToOpPassAdaptorParallel(OpPassManager &&mgr)
: OpToOpPassAdaptorBase(std::move(mgr)) {}
/// Utility functor that checks if the two ranges of pass managers have a size
/// mismatch.
static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
@ -409,8 +402,8 @@ static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
[&](size_t i) { return lhs[i].size() != rhs[i].size(); });
}
// Run the held pipeline asynchronously across the functions within the module.
void OpToOpPassAdaptorParallel::runOnOperation() {
/// Run this pass adaptor synchronously.
void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
AnalysisManager am = getAnalysisManager();
// Create the async executors if they haven't been created, or if the main
@ -491,16 +484,6 @@ void OpToOpPassAdaptorParallel::runOnOperation() {
signalPassFailure();
}
/// Utility function to convert the given class to the base adaptor it is an
/// adaptor pass, returns nullptr otherwise.
OpToOpPassAdaptorBase *mlir::detail::getAdaptorPassBase(Pass *pass) {
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
return adaptor;
if (auto *adaptor = dyn_cast<OpToOpPassAdaptorParallel>(pass))
return adaptor;
return nullptr;
}
//===----------------------------------------------------------------------===//
// PassCrashReproducer
//===----------------------------------------------------------------------===//

View File

@ -27,70 +27,45 @@ class VerifierPass : public PassWrapper<VerifierPass, OperationPass<>> {
// OpToOpPassAdaptor
//===----------------------------------------------------------------------===//
/// A base class for Op-to-Op adaptor passes.
class OpToOpPassAdaptorBase {
/// An adaptor pass used to run operation passes over nested operations.
class OpToOpPassAdaptor
: public PassWrapper<OpToOpPassAdaptor, OperationPass<>> {
public:
OpToOpPassAdaptorBase(OpPassManager &&mgr);
OpToOpPassAdaptorBase(const OpToOpPassAdaptorBase &rhs) = default;
OpToOpPassAdaptor(OpPassManager &&mgr);
OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs) = default;
/// Run the held pipeline over all operations.
void runOnOperation() override;
/// Merge the current pass adaptor into given 'rhs'.
void mergeInto(OpToOpPassAdaptorBase &rhs);
void mergeInto(OpToOpPassAdaptor &rhs);
/// Returns the pass managers held by this adaptor.
MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; }
/// Returns the adaptor pass name.
std::string getName();
protected:
// A set of adaptors to run.
SmallVector<OpPassManager, 1> mgrs;
};
/// An adaptor pass used to run operation passes over nested operations
/// synchronously on a single thread.
class OpToOpPassAdaptor
: public PassWrapper<OpToOpPassAdaptor, OperationPass<>>,
public OpToOpPassAdaptorBase {
public:
OpToOpPassAdaptor(OpPassManager &&mgr);
/// Run the held pipeline over all operations.
void runOnOperation() override;
};
/// An adaptor pass used to run operation passes over nested operations
/// asynchronously across multiple threads.
class OpToOpPassAdaptorParallel
: public PassWrapper<OpToOpPassAdaptorParallel, OperationPass<>>,
public OpToOpPassAdaptorBase {
public:
OpToOpPassAdaptorParallel(OpPassManager &&mgr);
/// Run the held pipeline over all operations.
void runOnOperation() override;
/// Return the async pass managers held by this parallel adaptor.
MutableArrayRef<SmallVector<OpPassManager, 1>> getParallelPassManagers() {
return asyncExecutors;
}
/// Returns the adaptor pass name.
std::string getAdaptorName();
private:
// A set of executors, cloned from the main executor, that run asynchronously
// on different threads.
/// Run this pass adaptor synchronously.
void runOnOperationImpl();
/// Run this pass adaptor asynchronously.
void runOnOperationAsyncImpl();
/// A set of adaptors to run.
SmallVector<OpPassManager, 1> mgrs;
/// A set of executors, cloned from the main executor, that run asynchronously
/// on different threads. This is used when threading is enabled.
SmallVector<SmallVector<OpPassManager, 1>, 8> asyncExecutors;
};
/// Utility function to convert the given class to the base adaptor it is an
/// adaptor pass, returns nullptr otherwise.
OpToOpPassAdaptorBase *getAdaptorPassBase(Pass *pass);
/// Utility function to return if a pass refers to an adaptor pass. Adaptor
/// passes are those that internally execute a pipeline.
inline bool isAdaptorPass(Pass *pass) {
return isa<OpToOpPassAdaptorParallel>(pass) || isa<OpToOpPassAdaptor>(pass);
}
} // end namespace detail
} // end namespace mlir
#endif // MLIR_PASS_PASSDETAIL_H_

View File

@ -60,7 +60,7 @@ static void printPassEntry(raw_ostream &os, unsigned indent, StringRef pass,
static void printResultsAsList(raw_ostream &os, OpPassManager &pm) {
llvm::StringMap<std::vector<Statistic>> mergedStats;
std::function<void(Pass *)> addStats = [&](Pass *pass) {
auto *adaptor = getAdaptorPassBase(pass);
auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass);
// If this is not an adaptor, add the stats to the list if there are any.
if (!adaptor) {
@ -105,13 +105,12 @@ static void printResultsAsList(raw_ostream &os, OpPassManager &pm) {
static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) {
std::function<void(unsigned, Pass *)> printPass = [&](unsigned indent,
Pass *pass) {
// Handle the case of an adaptor pass.
if (auto *adaptor = getAdaptorPassBase(pass)) {
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass)) {
// If this adaptor has more than one internal pipeline, print an entry for
// it.
auto mgrs = adaptor->getPassManagers();
if (mgrs.size() > 1) {
printPassEntry(os, indent, adaptor->getName());
printPassEntry(os, indent, adaptor->getAdaptorName());
indent += 2;
}
@ -195,8 +194,8 @@ void OpPassManager::mergeStatisticsInto(OpPassManager &other) {
Pass &pass = std::get<0>(passPair), &otherPass = std::get<1>(passPair);
// If this is an adaptor, then recursively merge the pass managers.
if (auto *adaptorPass = getAdaptorPassBase(&pass)) {
auto *otherAdaptorPass = getAdaptorPassBase(&otherPass);
if (auto *adaptorPass = dyn_cast<OpToOpPassAdaptor>(&pass)) {
auto *otherAdaptorPass = cast<OpToOpPassAdaptor>(&otherPass);
for (auto mgrs : llvm::zip(adaptorPass->getPassManagers(),
otherAdaptorPass->getPassManagers()))
std::get<0>(mgrs).mergeStatisticsInto(std::get<1>(mgrs));
@ -217,18 +216,16 @@ void OpPassManager::mergeStatisticsInto(OpPassManager &other) {
/// consumption(e.g. dumping).
static void prepareStatistics(OpPassManager &pm) {
for (Pass &pass : pm.getPasses()) {
OpToOpPassAdaptorBase *adaptor = getAdaptorPassBase(&pass);
OpToOpPassAdaptor *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
if (!adaptor)
continue;
MutableArrayRef<OpPassManager> nestedPms = adaptor->getPassManagers();
// If this is a parallel adaptor, merge the statistics from the async
// pass managers into the main nested pass managers.
if (auto *parallelAdaptor = dyn_cast<OpToOpPassAdaptorParallel>(&pass)) {
for (auto &asyncPM : parallelAdaptor->getParallelPassManagers()) {
for (unsigned i = 0, e = asyncPM.size(); i != e; ++i)
asyncPM[i].mergeStatisticsInto(nestedPms[i]);
}
// Merge the statistics from the async pass managers into the main nested
// pass managers.
for (auto &asyncPM : adaptor->getParallelPassManagers()) {
for (unsigned i = 0, e = asyncPM.size(); i != e; ++i)
asyncPM[i].mergeStatisticsInto(nestedPms[i]);
}
// Prepare the statistics of each of the nested passes.

View File

@ -277,17 +277,17 @@ void PassTiming::runAfterPipeline(const OperationName &name,
/// Start a new timer for the given pass.
void PassTiming::startPassTimer(Pass *pass) {
auto kind = isAdaptorPass(pass) ? TimerKind::PipelineCollection
: TimerKind::PassOrAnalysis;
auto kind = isa<OpToOpPassAdaptor>(pass) ? TimerKind::PipelineCollection
: TimerKind::PassOrAnalysis;
Timer *timer = getTimer(pass, kind, [pass]() -> std::string {
if (auto *adaptor = getAdaptorPassBase(pass))
return adaptor->getName();
if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
return adaptor->getAdaptorName();
return std::string(pass->getName());
});
// We don't actually want to time the adaptor passes, they gather their total
// from their held passes.
if (!isAdaptorPass(pass))
if (!isa<OpToOpPassAdaptor>(pass))
timer->start();
}
@ -302,9 +302,9 @@ void PassTiming::startAnalysisTimer(StringRef name, TypeID id) {
void PassTiming::runAfterPass(Pass *pass, Operation *) {
Timer *timer = popLastActiveTimer();
// If this is an OpToOpPassAdaptorParallel, then we need to merge in the
// timing data for the pipelines running on other threads.
if (isa<OpToOpPassAdaptorParallel>(pass)) {
// If this is a pass adaptor, then we need to merge in the timing data for the
// pipelines running on other threads.
if (isa<OpToOpPassAdaptor>(pass)) {
auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass});
if (toMerge != pipelinesToMerge.end()) {
for (auto &it : toMerge->second)
@ -314,10 +314,7 @@ void PassTiming::runAfterPass(Pass *pass, Operation *) {
return;
}
// Adaptor passes aren't timed directly, so we don't need to stop their
// timers.
if (!isAdaptorPass(pass))
timer->stop();
timer->stop();
}
/// Stop a timer.