Add async_funcs_only option to AsyncToAsyncRuntime pass

This change adds async_funcs_only option to AsyncToAsyncRuntimePass. The goal is to convert async functions to regular functions in early stages of compilation pipeline.

Differential Revision: https://reviews.llvm.org/D138611
This commit is contained in:
yijiagu 2022-11-30 10:14:56 -08:00 committed by Eugene Zhulenev
parent a4c466766d
commit 6cca6b9ab9
5 changed files with 130 additions and 50 deletions

View File

@ -17,6 +17,7 @@
namespace mlir {
class ModuleOp;
class ConversionTarget;
#define GEN_PASS_DECL
#include "mlir/Dialect/Async/Passes.h.inc"
@ -27,6 +28,11 @@ std::unique_ptr<Pass> createAsyncParallelForPass(bool asyncDispatch,
int32_t numWorkerThreads,
int32_t minTaskSize);
void populateAsyncFuncToAsyncRuntimeConversionPatterns(
RewritePatternSet &patterns, ConversionTarget &target);
std::unique_ptr<OperationPass<ModuleOp>> createAsyncFuncToAsyncRuntimePass();
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();

View File

@ -41,12 +41,19 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
}
def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
let summary = "Lower high level async operations (e.g. async.execute) to the"
"explicit async.runtime and async.coro operations";
let summary = "Lower all high level async operations (e.g. async.execute) to"
"the explicit async.runtime and async.coro operations";
let constructor = "mlir::createAsyncToAsyncRuntimePass()";
let dependentDialects = ["async::AsyncDialect", "func::FuncDialect"];
}
def AsyncFuncToAsyncRuntime : Pass<"async-func-to-async-runtime", "ModuleOp"> {
let summary = "Lower async.func operations to the explicit async.runtime and"
"async.coro operations";
let constructor = "mlir::createAsyncFuncToAsyncRuntimePass()";
let dependentDialects = ["async::AsyncDialect", "func::FuncDialect"];
}
def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> {
let summary = "Automatic reference counting for Async runtime operations";
let description = [{

View File

@ -30,6 +30,7 @@
namespace mlir {
#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
#include "mlir/Dialect/Async/Passes.h.inc"
} // namespace mlir
@ -51,6 +52,17 @@ public:
} // namespace
namespace {
class AsyncFuncToAsyncRuntimePass
: public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
public:
AsyncFuncToAsyncRuntimePass() = default;
void runOnOperation() override;
};
} // namespace
/// Function targeted for coroutine transformation has two additional blocks at
/// the end: coroutine cleanup and coroutine suspension.
///
@ -84,6 +96,9 @@ struct CoroMachinery {
};
} // namespace
using FuncCoroMapPtr =
std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
/// Utility to partially update the regular function CFG to the coroutine CFG
/// compatible with LLVM coroutines switched-resume lowering using
/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
@ -399,9 +414,8 @@ namespace {
class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
public:
AsyncFuncOpLowering(MLIRContext *ctx,
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
: OpConversionPattern<async::FuncOp>(ctx), coros(coros) {}
AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
: OpConversionPattern<async::FuncOp>(ctx), coros_(coros) {}
LogicalResult
matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
@ -423,7 +437,7 @@ public:
newFuncOp.end());
CoroMachinery coro = setupCoroMachinery(newFuncOp);
coros[newFuncOp] = coro;
(*coros_)[newFuncOp] = coro;
// no initial suspend, we should hot-start
rewriter.eraseOp(op);
@ -431,7 +445,7 @@ public:
}
private:
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
FuncCoroMapPtr coros_;
};
//===----------------------------------------------------------------------===//
@ -458,16 +472,15 @@ public:
class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
public:
AsyncReturnOpLowering(MLIRContext *ctx,
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
: OpConversionPattern<async::ReturnOp>(ctx), coros(coros) {}
AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
: OpConversionPattern<async::ReturnOp>(ctx), coros_(coros) {}
LogicalResult
matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto func = op->template getParentOfType<func::FuncOp>();
auto funcCoro = coros.find(func);
if (funcCoro == coros.end())
auto funcCoro = coros_->find(func);
if (funcCoro == coros_->end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the async coroutine function");
@ -494,7 +507,7 @@ public:
}
private:
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
FuncCoroMapPtr coros_;
};
} // namespace
@ -509,9 +522,10 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
using AwaitAdaptor = typename AwaitType::Adaptor;
public:
AwaitOpLoweringBase(MLIRContext *ctx,
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
: OpConversionPattern<AwaitType>(ctx), coros(coros) {}
AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
bool should_lower_blocking_wait)
: OpConversionPattern<AwaitType>(ctx), coros_(coros),
should_lower_blocking_wait_(should_lower_blocking_wait) {}
LogicalResult
matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
@ -521,16 +535,20 @@ public:
if (!op.getOperand().getType().template isa<AwaitableType>())
return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
// Check if await operation is inside the outlined coroutine function.
// Check if await operation is inside the coroutine function.
auto func = op->template getParentOfType<func::FuncOp>();
auto funcCoro = coros.find(func);
const bool isInCoroutine = funcCoro != coros.end();
auto funcCoro = coros_->find(func);
const bool isInCoroutine = funcCoro != coros_->end();
Location loc = op->getLoc();
Value operand = adaptor.getOperand();
Type i1 = rewriter.getI1Type();
// Delay lowering to block wait in case await op is inside async.execute
if (!isInCoroutine && !should_lower_blocking_wait_)
return failure();
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
if (!isInCoroutine) {
@ -602,7 +620,8 @@ public:
}
private:
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
FuncCoroMapPtr coros_;
bool should_lower_blocking_wait_;
};
/// Lowering for `async.await` with a token operand.
@ -645,17 +664,16 @@ public:
class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
public:
YieldOpLowering(MLIRContext *ctx,
const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
: OpConversionPattern<async::YieldOp>(ctx), coros(coros) {}
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
: OpConversionPattern<async::YieldOp>(ctx), coros_(coros) {}
LogicalResult
matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if yield operation is inside the async coroutine function.
auto func = op->template getParentOfType<func::FuncOp>();
auto funcCoro = coros.find(func);
if (funcCoro == coros.end())
auto funcCoro = coros_->find(func);
if (funcCoro == coros_->end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the async coroutine function");
@ -682,7 +700,7 @@ public:
}
private:
const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
FuncCoroMapPtr coros_;
};
//===----------------------------------------------------------------------===//
@ -691,17 +709,16 @@ private:
class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
public:
AssertOpLowering(MLIRContext *ctx,
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
: OpConversionPattern<cf::AssertOp>(ctx), coros(coros) {}
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
: OpConversionPattern<cf::AssertOp>(ctx), coros_(coros) {}
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Check if assert operation is inside the async coroutine function.
auto func = op->template getParentOfType<func::FuncOp>();
auto funcCoro = coros.find(func);
if (funcCoro == coros.end())
auto funcCoro = coros_->find(func);
if (funcCoro == coros_->end())
return rewriter.notifyMatchFailure(
op, "operation is not inside the async coroutine function");
@ -721,7 +738,7 @@ public:
}
private:
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
FuncCoroMapPtr coros_;
};
//===----------------------------------------------------------------------===//
@ -730,22 +747,23 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
SymbolTable symbolTable(module);
// Functions with coroutine CFG setups, which are results of outlining
// `async.execute` body regions and converting async.func.
llvm::DenseMap<func::FuncOp, CoroMachinery> coros;
// `async.execute` body regions
FuncCoroMapPtr coros =
std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
module.walk([&](ExecuteOp execute) {
coros.insert(outlineExecuteOp(symbolTable, execute));
coros->insert(outlineExecuteOp(symbolTable, execute));
});
LLVM_DEBUG({
llvm::dbgs() << "Outlined " << coros.size()
llvm::dbgs() << "Outlined " << coros->size()
<< " functions built from async.execute operations\n";
});
// Returns true if operation is inside the coroutine.
auto isInCoroutine = [&](Operation *op) -> bool {
auto parentFunc = op->getParentOfType<func::FuncOp>();
return coros.find(parentFunc) != coros.end();
return coros->find(parentFunc) != coros->end();
};
// Lower async operations to async.runtime operations.
@ -762,22 +780,18 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
// types for async.runtime operations.
asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
// Lower async.func to func.func with coroutine cfg.
asyncPatterns.add<AsyncCallOpLowering>(ctx);
asyncPatterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
AwaitAllOpLowering, YieldOpLowering>(ctx, coros);
asyncPatterns
.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
ctx, coros, /*should_lower_blocking_wait=*/true);
// Lower assertions to conditional branches into error blocks.
asyncPatterns.add<AssertOpLowering>(ctx, coros);
asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
// All high level async operations must be lowered to the runtime operations.
ConversionTarget runtimeTarget(*ctx);
runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp,
async::FuncOp, async::CallOp, async::ReturnOp>();
runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
// Decide if structured control flow has to be lowered to branch-based CFG.
runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
@ -795,7 +809,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
[&](cf::AssertOp op) -> bool {
auto func = op->getParentOfType<func::FuncOp>();
return coros.find(func) == coros.end();
return coros->find(func) == coros->end();
});
if (failed(applyPartialConversion(module, runtimeTarget,
@ -805,6 +819,59 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
}
}
//===----------------------------------------------------------------------===//
void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
RewritePatternSet &patterns, ConversionTarget &target) {
// Functions with coroutine CFG setups, which are results of converting
// async.func.
FuncCoroMapPtr coros =
std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
MLIRContext *ctx = patterns.getContext();
// Lower async.func to func.func with coroutine cfg.
patterns.add<AsyncCallOpLowering>(ctx);
patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
ctx, coros, /*should_lower_blocking_wait=*/false);
patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
[coros](Operation *op) {
auto func = op->getParentOfType<func::FuncOp>();
return coros->find(func) == coros->end();
});
}
void AsyncFuncToAsyncRuntimePass::runOnOperation() {
ModuleOp module = getOperation();
// Lower async operations to async.runtime operations.
MLIRContext *ctx = module->getContext();
RewritePatternSet asyncPatterns(ctx);
ConversionTarget runtimeTarget(*ctx);
// Lower async.func to func.func with coroutine cfg.
populateAsyncFuncToAsyncRuntimeConversionPatterns(asyncPatterns,
runtimeTarget);
runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
cf::BranchOp, cf::CondBranchOp>();
if (failed(applyPartialConversion(module, runtimeTarget,
std::move(asyncPatterns)))) {
signalPassFailure();
return;
}
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
return std::make_unique<AsyncToAsyncRuntimePass>();
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createAsyncFuncToAsyncRuntimePass() {
return std::make_unique<AsyncFuncToAsyncRuntimePass>();
}

View File

@ -1,5 +1,5 @@
// RUN: mlir-opt %s -split-input-file -async-to-async-runtime \
// RUN: | FileCheck %s --dump-input=always
// RUN: mlir-opt %s -split-input-file -async-func-to-async-runtime \
// RUN: -async-to-async-runtime | FileCheck %s --dump-input=always
// CHECK-LABEL: @execute_no_async_args
func.func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -pass-pipeline="builtin.module(async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
// RUN: mlir-opt %s -pass-pipeline="builtin.module(async-func-to-async-runtime,async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
// RUN: | mlir-cpu-runner \
// RUN: -e main -entry-point-result=void -O0 \
// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext \