mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-11 17:08:42 +00:00
Add async_funcs_only option to AsyncToAsyncRuntime pass
This change adds async_funcs_only option to AsyncToAsyncRuntimePass. The goal is to convert async functions to regular functions in early stages of compilation pipeline. Differential Revision: https://reviews.llvm.org/D138611
This commit is contained in:
parent
a4c466766d
commit
6cca6b9ab9
@ -17,6 +17,7 @@
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
class ConversionTarget;
|
||||
|
||||
#define GEN_PASS_DECL
|
||||
#include "mlir/Dialect/Async/Passes.h.inc"
|
||||
@ -27,6 +28,11 @@ std::unique_ptr<Pass> createAsyncParallelForPass(bool asyncDispatch,
|
||||
int32_t numWorkerThreads,
|
||||
int32_t minTaskSize);
|
||||
|
||||
void populateAsyncFuncToAsyncRuntimeConversionPatterns(
|
||||
RewritePatternSet &patterns, ConversionTarget &target);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAsyncFuncToAsyncRuntimePass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createAsyncToAsyncRuntimePass();
|
||||
|
||||
std::unique_ptr<Pass> createAsyncRuntimeRefCountingPass();
|
||||
|
@ -41,12 +41,19 @@ def AsyncParallelFor : Pass<"async-parallel-for", "ModuleOp"> {
|
||||
}
|
||||
|
||||
def AsyncToAsyncRuntime : Pass<"async-to-async-runtime", "ModuleOp"> {
|
||||
let summary = "Lower high level async operations (e.g. async.execute) to the"
|
||||
"explicit async.runtime and async.coro operations";
|
||||
let summary = "Lower all high level async operations (e.g. async.execute) to"
|
||||
"the explicit async.runtime and async.coro operations";
|
||||
let constructor = "mlir::createAsyncToAsyncRuntimePass()";
|
||||
let dependentDialects = ["async::AsyncDialect", "func::FuncDialect"];
|
||||
}
|
||||
|
||||
def AsyncFuncToAsyncRuntime : Pass<"async-func-to-async-runtime", "ModuleOp"> {
|
||||
let summary = "Lower async.func operations to the explicit async.runtime and"
|
||||
"async.coro operations";
|
||||
let constructor = "mlir::createAsyncFuncToAsyncRuntimePass()";
|
||||
let dependentDialects = ["async::AsyncDialect", "func::FuncDialect"];
|
||||
}
|
||||
|
||||
def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> {
|
||||
let summary = "Automatic reference counting for Async runtime operations";
|
||||
let description = [{
|
||||
|
@ -30,6 +30,7 @@
|
||||
|
||||
namespace mlir {
|
||||
#define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME
|
||||
#define GEN_PASS_DEF_ASYNCFUNCTOASYNCRUNTIME
|
||||
#include "mlir/Dialect/Async/Passes.h.inc"
|
||||
} // namespace mlir
|
||||
|
||||
@ -51,6 +52,17 @@ public:
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
class AsyncFuncToAsyncRuntimePass
|
||||
: public impl::AsyncFuncToAsyncRuntimeBase<AsyncFuncToAsyncRuntimePass> {
|
||||
public:
|
||||
AsyncFuncToAsyncRuntimePass() = default;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Function targeted for coroutine transformation has two additional blocks at
|
||||
/// the end: coroutine cleanup and coroutine suspension.
|
||||
///
|
||||
@ -84,6 +96,9 @@ struct CoroMachinery {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
using FuncCoroMapPtr =
|
||||
std::shared_ptr<llvm::DenseMap<func::FuncOp, CoroMachinery>>;
|
||||
|
||||
/// Utility to partially update the regular function CFG to the coroutine CFG
|
||||
/// compatible with LLVM coroutines switched-resume lowering using
|
||||
/// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block
|
||||
@ -399,9 +414,8 @@ namespace {
|
||||
|
||||
class AsyncFuncOpLowering : public OpConversionPattern<async::FuncOp> {
|
||||
public:
|
||||
AsyncFuncOpLowering(MLIRContext *ctx,
|
||||
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
|
||||
: OpConversionPattern<async::FuncOp>(ctx), coros(coros) {}
|
||||
AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
|
||||
: OpConversionPattern<async::FuncOp>(ctx), coros_(coros) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(async::FuncOp op, OpAdaptor adaptor,
|
||||
@ -423,7 +437,7 @@ public:
|
||||
newFuncOp.end());
|
||||
|
||||
CoroMachinery coro = setupCoroMachinery(newFuncOp);
|
||||
coros[newFuncOp] = coro;
|
||||
(*coros_)[newFuncOp] = coro;
|
||||
// no initial suspend, we should hot-start
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
@ -431,7 +445,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
|
||||
FuncCoroMapPtr coros_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -458,16 +472,15 @@ public:
|
||||
|
||||
class AsyncReturnOpLowering : public OpConversionPattern<async::ReturnOp> {
|
||||
public:
|
||||
AsyncReturnOpLowering(MLIRContext *ctx,
|
||||
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
|
||||
: OpConversionPattern<async::ReturnOp>(ctx), coros(coros) {}
|
||||
AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
|
||||
: OpConversionPattern<async::ReturnOp>(ctx), coros_(coros) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto func = op->template getParentOfType<func::FuncOp>();
|
||||
auto funcCoro = coros.find(func);
|
||||
if (funcCoro == coros.end())
|
||||
auto funcCoro = coros_->find(func);
|
||||
if (funcCoro == coros_->end())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "operation is not inside the async coroutine function");
|
||||
|
||||
@ -494,7 +507,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
|
||||
FuncCoroMapPtr coros_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -509,9 +522,10 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
|
||||
using AwaitAdaptor = typename AwaitType::Adaptor;
|
||||
|
||||
public:
|
||||
AwaitOpLoweringBase(MLIRContext *ctx,
|
||||
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
|
||||
: OpConversionPattern<AwaitType>(ctx), coros(coros) {}
|
||||
AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros,
|
||||
bool should_lower_blocking_wait)
|
||||
: OpConversionPattern<AwaitType>(ctx), coros_(coros),
|
||||
should_lower_blocking_wait_(should_lower_blocking_wait) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor,
|
||||
@ -521,16 +535,20 @@ public:
|
||||
if (!op.getOperand().getType().template isa<AwaitableType>())
|
||||
return rewriter.notifyMatchFailure(op, "unsupported awaitable type");
|
||||
|
||||
// Check if await operation is inside the outlined coroutine function.
|
||||
// Check if await operation is inside the coroutine function.
|
||||
auto func = op->template getParentOfType<func::FuncOp>();
|
||||
auto funcCoro = coros.find(func);
|
||||
const bool isInCoroutine = funcCoro != coros.end();
|
||||
auto funcCoro = coros_->find(func);
|
||||
const bool isInCoroutine = funcCoro != coros_->end();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
Value operand = adaptor.getOperand();
|
||||
|
||||
Type i1 = rewriter.getI1Type();
|
||||
|
||||
// Delay lowering to block wait in case await op is inside async.execute
|
||||
if (!isInCoroutine && !should_lower_blocking_wait_)
|
||||
return failure();
|
||||
|
||||
// Inside regular functions we use the blocking wait operation to wait for
|
||||
// the async object (token, value or group) to become available.
|
||||
if (!isInCoroutine) {
|
||||
@ -602,7 +620,8 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
|
||||
FuncCoroMapPtr coros_;
|
||||
bool should_lower_blocking_wait_;
|
||||
};
|
||||
|
||||
/// Lowering for `async.await` with a token operand.
|
||||
@ -645,17 +664,16 @@ public:
|
||||
|
||||
class YieldOpLowering : public OpConversionPattern<async::YieldOp> {
|
||||
public:
|
||||
YieldOpLowering(MLIRContext *ctx,
|
||||
const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
|
||||
: OpConversionPattern<async::YieldOp>(ctx), coros(coros) {}
|
||||
YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
|
||||
: OpConversionPattern<async::YieldOp>(ctx), coros_(coros) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(async::YieldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Check if yield operation is inside the async coroutine function.
|
||||
auto func = op->template getParentOfType<func::FuncOp>();
|
||||
auto funcCoro = coros.find(func);
|
||||
if (funcCoro == coros.end())
|
||||
auto funcCoro = coros_->find(func);
|
||||
if (funcCoro == coros_->end())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "operation is not inside the async coroutine function");
|
||||
|
||||
@ -682,7 +700,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
const llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
|
||||
FuncCoroMapPtr coros_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -691,17 +709,16 @@ private:
|
||||
|
||||
class AssertOpLowering : public OpConversionPattern<cf::AssertOp> {
|
||||
public:
|
||||
AssertOpLowering(MLIRContext *ctx,
|
||||
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros)
|
||||
: OpConversionPattern<cf::AssertOp>(ctx), coros(coros) {}
|
||||
AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros)
|
||||
: OpConversionPattern<cf::AssertOp>(ctx), coros_(coros) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Check if assert operation is inside the async coroutine function.
|
||||
auto func = op->template getParentOfType<func::FuncOp>();
|
||||
auto funcCoro = coros.find(func);
|
||||
if (funcCoro == coros.end())
|
||||
auto funcCoro = coros_->find(func);
|
||||
if (funcCoro == coros_->end())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "operation is not inside the async coroutine function");
|
||||
|
||||
@ -721,7 +738,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::DenseMap<func::FuncOp, CoroMachinery> &coros;
|
||||
FuncCoroMapPtr coros_;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -730,22 +747,23 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
||||
SymbolTable symbolTable(module);
|
||||
|
||||
// Functions with coroutine CFG setups, which are results of outlining
|
||||
// `async.execute` body regions and converting async.func.
|
||||
llvm::DenseMap<func::FuncOp, CoroMachinery> coros;
|
||||
// `async.execute` body regions
|
||||
FuncCoroMapPtr coros =
|
||||
std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
|
||||
|
||||
module.walk([&](ExecuteOp execute) {
|
||||
coros.insert(outlineExecuteOp(symbolTable, execute));
|
||||
coros->insert(outlineExecuteOp(symbolTable, execute));
|
||||
});
|
||||
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "Outlined " << coros.size()
|
||||
llvm::dbgs() << "Outlined " << coros->size()
|
||||
<< " functions built from async.execute operations\n";
|
||||
});
|
||||
|
||||
// Returns true if operation is inside the coroutine.
|
||||
auto isInCoroutine = [&](Operation *op) -> bool {
|
||||
auto parentFunc = op->getParentOfType<func::FuncOp>();
|
||||
return coros.find(parentFunc) != coros.end();
|
||||
return coros->find(parentFunc) != coros->end();
|
||||
};
|
||||
|
||||
// Lower async operations to async.runtime operations.
|
||||
@ -762,22 +780,18 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
||||
// types for async.runtime operations.
|
||||
asyncPatterns.add<CreateGroupOpLowering, AddToGroupOpLowering>(ctx);
|
||||
|
||||
// Lower async.func to func.func with coroutine cfg.
|
||||
asyncPatterns.add<AsyncCallOpLowering>(ctx);
|
||||
asyncPatterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
|
||||
|
||||
asyncPatterns.add<AwaitTokenOpLowering, AwaitValueOpLowering,
|
||||
AwaitAllOpLowering, YieldOpLowering>(ctx, coros);
|
||||
asyncPatterns
|
||||
.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
|
||||
ctx, coros, /*should_lower_blocking_wait=*/true);
|
||||
|
||||
// Lower assertions to conditional branches into error blocks.
|
||||
asyncPatterns.add<AssertOpLowering>(ctx, coros);
|
||||
asyncPatterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
|
||||
|
||||
// All high level async operations must be lowered to the runtime operations.
|
||||
ConversionTarget runtimeTarget(*ctx);
|
||||
runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
|
||||
runtimeTarget.addIllegalOp<CreateGroupOp, AddToGroupOp>();
|
||||
runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp,
|
||||
async::FuncOp, async::CallOp, async::ReturnOp>();
|
||||
runtimeTarget.addIllegalOp<ExecuteOp, AwaitOp, AwaitAllOp, async::YieldOp>();
|
||||
|
||||
// Decide if structured control flow has to be lowered to branch-based CFG.
|
||||
runtimeTarget.addDynamicallyLegalDialect<scf::SCFDialect>([&](Operation *op) {
|
||||
@ -795,7 +809,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
||||
runtimeTarget.addDynamicallyLegalOp<cf::AssertOp>(
|
||||
[&](cf::AssertOp op) -> bool {
|
||||
auto func = op->getParentOfType<func::FuncOp>();
|
||||
return coros.find(func) == coros.end();
|
||||
return coros->find(func) == coros->end();
|
||||
});
|
||||
|
||||
if (failed(applyPartialConversion(module, runtimeTarget,
|
||||
@ -805,6 +819,59 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
void mlir::populateAsyncFuncToAsyncRuntimeConversionPatterns(
|
||||
RewritePatternSet &patterns, ConversionTarget &target) {
|
||||
// Functions with coroutine CFG setups, which are results of converting
|
||||
// async.func.
|
||||
FuncCoroMapPtr coros =
|
||||
std::make_shared<llvm::DenseMap<func::FuncOp, CoroMachinery>>();
|
||||
MLIRContext *ctx = patterns.getContext();
|
||||
// Lower async.func to func.func with coroutine cfg.
|
||||
patterns.add<AsyncCallOpLowering>(ctx);
|
||||
patterns.add<AsyncFuncOpLowering, AsyncReturnOpLowering>(ctx, coros);
|
||||
|
||||
patterns.add<AwaitTokenOpLowering, AwaitValueOpLowering, AwaitAllOpLowering>(
|
||||
ctx, coros, /*should_lower_blocking_wait=*/false);
|
||||
patterns.add<YieldOpLowering, AssertOpLowering>(ctx, coros);
|
||||
|
||||
target.addDynamicallyLegalOp<AwaitOp, AwaitAllOp, YieldOp, cf::AssertOp>(
|
||||
[coros](Operation *op) {
|
||||
auto func = op->getParentOfType<func::FuncOp>();
|
||||
return coros->find(func) == coros->end();
|
||||
});
|
||||
}
|
||||
|
||||
void AsyncFuncToAsyncRuntimePass::runOnOperation() {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
// Lower async operations to async.runtime operations.
|
||||
MLIRContext *ctx = module->getContext();
|
||||
RewritePatternSet asyncPatterns(ctx);
|
||||
ConversionTarget runtimeTarget(*ctx);
|
||||
|
||||
// Lower async.func to func.func with coroutine cfg.
|
||||
populateAsyncFuncToAsyncRuntimeConversionPatterns(asyncPatterns,
|
||||
runtimeTarget);
|
||||
|
||||
runtimeTarget.addLegalDialect<AsyncDialect, func::FuncDialect>();
|
||||
runtimeTarget.addIllegalOp<async::FuncOp, async::CallOp, async::ReturnOp>();
|
||||
|
||||
runtimeTarget.addLegalOp<arith::XOrIOp, arith::ConstantOp, func::ConstantOp,
|
||||
cf::BranchOp, cf::CondBranchOp>();
|
||||
|
||||
if (failed(applyPartialConversion(module, runtimeTarget,
|
||||
std::move(asyncPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> mlir::createAsyncToAsyncRuntimePass() {
|
||||
return std::make_unique<AsyncToAsyncRuntimePass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::createAsyncFuncToAsyncRuntimePass() {
|
||||
return std::make_unique<AsyncFuncToAsyncRuntimePass>();
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
// RUN: mlir-opt %s -split-input-file -async-to-async-runtime \
|
||||
// RUN: | FileCheck %s --dump-input=always
|
||||
// RUN: mlir-opt %s -split-input-file -async-func-to-async-runtime \
|
||||
// RUN: -async-to-async-runtime | FileCheck %s --dump-input=always
|
||||
|
||||
// CHECK-LABEL: @execute_no_async_args
|
||||
func.func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -pass-pipeline="builtin.module(async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
|
||||
// RUN: mlir-opt %s -pass-pipeline="builtin.module(async-func-to-async-runtime,async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \
|
||||
// RUN: | mlir-cpu-runner \
|
||||
// RUN: -e main -entry-point-result=void -O0 \
|
||||
// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext \
|
||||
|
Loading…
Reference in New Issue
Block a user