mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-26 23:21:11 +00:00
[MLIR] Support interrupting AffineExpr walks (#74792)
Support WalkResult for AffineExpr walk and support interrupting walks along the lines of Operation::walk. This allows interrupted walks when a condition is met. Also, switch from std::function to llvm::function_ref for the walk function.
This commit is contained in:
parent
5fd18bdef9
commit
c1eef483b2
@ -14,6 +14,7 @@
|
||||
#ifndef MLIR_IR_AFFINEEXPR_H
|
||||
#define MLIR_IR_AFFINEEXPR_H
|
||||
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
@ -123,8 +124,13 @@ public:
|
||||
/// Return true if the affine expression involves AffineSymbolExpr `position`.
|
||||
bool isFunctionOfSymbol(unsigned position) const;
|
||||
|
||||
/// Walk all of the AffineExpr's in this expression in postorder.
|
||||
void walk(std::function<void(AffineExpr)> callback) const;
|
||||
/// Walk all of the AffineExpr's in this expression in postorder. This allows
|
||||
/// a lambda walk function that can either return `void` or a WalkResult. With
|
||||
/// a WalkResult, interrupting is supported.
|
||||
template <typename FnT, typename RetT = detail::walkResultType<FnT>>
|
||||
RetT walk(FnT &&callback) const {
|
||||
return walk<RetT>(*this, callback);
|
||||
}
|
||||
|
||||
/// This method substitutes any uses of dimensions and symbols (e.g.
|
||||
/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
|
||||
@ -202,6 +208,15 @@ public:
|
||||
|
||||
protected:
|
||||
ImplType *expr{nullptr};
|
||||
|
||||
private:
|
||||
/// A trampoline for the templated non-static AffineExpr::walk method to
|
||||
/// dispatch lambda `callback`'s of either a void result type or a
|
||||
/// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
|
||||
/// should use the regular (non-static) `walk` method.
|
||||
template <typename WalkRetTy>
|
||||
static WalkRetTy walk(AffineExpr e,
|
||||
function_ref<WalkRetTy(AffineExpr)> callback);
|
||||
};
|
||||
|
||||
/// Affine binary operation expression. An affine binary operation could be an
|
||||
|
@ -30,6 +30,9 @@ namespace mlir {
|
||||
/// functions in your class. This class is defined in terms of statically
|
||||
/// resolved overloading, not virtual functions.
|
||||
///
|
||||
/// The visitor is templated on its return type (`RetTy`). With a WalkResult
|
||||
/// return type, the visitor supports interrupting walks.
|
||||
///
|
||||
/// For example, here is a visitor that counts the number of for AffineDimExprs
|
||||
/// in an AffineExpr.
|
||||
///
|
||||
@ -65,7 +68,6 @@ namespace mlir {
|
||||
/// virtual function call overhead. Defining and using a AffineExprVisitor is
|
||||
/// just as efficient as having your own switch instruction over the instruction
|
||||
/// opcode.
|
||||
|
||||
template <typename SubClass, typename RetTy>
|
||||
class AffineExprVisitorBase {
|
||||
public:
|
||||
@ -136,6 +138,8 @@ public:
|
||||
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
|
||||
};
|
||||
|
||||
/// See documentation for AffineExprVisitorBase. This visitor supports
|
||||
/// interrupting walks when a `WalkResult` is used for `RetTy`.
|
||||
template <typename SubClass, typename RetTy = void>
|
||||
class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
|
||||
//===--------------------------------------------------------------------===//
|
||||
@ -150,27 +154,52 @@ public:
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::Add: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
if constexpr (std::is_same<RetTy, WalkResult>::value) {
|
||||
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
} else {
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
}
|
||||
return self->visitAddExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Mul: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
if constexpr (std::is_same<RetTy, WalkResult>::value) {
|
||||
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
} else {
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
}
|
||||
return self->visitMulExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Mod: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
if constexpr (std::is_same<RetTy, WalkResult>::value) {
|
||||
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
} else {
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
}
|
||||
return self->visitModExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::FloorDiv: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
if constexpr (std::is_same<RetTy, WalkResult>::value) {
|
||||
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
} else {
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
}
|
||||
return self->visitFloorDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::CeilDiv: {
|
||||
auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
if constexpr (std::is_same<RetTy, WalkResult>::value) {
|
||||
if (walkOperandsPostOrder(binOpExpr).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
} else {
|
||||
walkOperandsPostOrder(binOpExpr);
|
||||
}
|
||||
return self->visitCeilDivExpr(binOpExpr);
|
||||
}
|
||||
case AffineExprKind::Constant:
|
||||
@ -186,8 +215,19 @@ public:
|
||||
private:
|
||||
// Walk the operands - each operand is itself walked in post order.
|
||||
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
|
||||
walkPostOrder(expr.getLHS());
|
||||
walkPostOrder(expr.getRHS());
|
||||
if constexpr (std::is_same<RetTy, WalkResult>::value) {
|
||||
if (walkPostOrder(expr.getLHS()).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
} else {
|
||||
walkPostOrder(expr.getLHS());
|
||||
}
|
||||
if constexpr (std::is_same<RetTy, WalkResult>::value) {
|
||||
if (walkPostOrder(expr.getLHS()).wasInterrupted())
|
||||
return WalkResult::interrupt();
|
||||
return WalkResult::advance();
|
||||
} else {
|
||||
return walkPostOrder(expr.getRHS());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1561,22 +1561,21 @@ static LogicalResult getTileSizePos(
|
||||
/// memref<4x?xf32, #map0> ==> memref<4x?x?xf32>
|
||||
static bool
|
||||
isNormalizedMemRefDynamicDim(unsigned dim, AffineMap layoutMap,
|
||||
SmallVectorImpl<unsigned> &inMemrefTypeDynDims,
|
||||
MLIRContext *context) {
|
||||
bool isDynamicDim = false;
|
||||
SmallVectorImpl<unsigned> &inMemrefTypeDynDims) {
|
||||
AffineExpr expr = layoutMap.getResults()[dim];
|
||||
// Check if affine expr of the dimension includes dynamic dimension of input
|
||||
// memrefType.
|
||||
expr.walk([&inMemrefTypeDynDims, &isDynamicDim, &context](AffineExpr e) {
|
||||
if (isa<AffineDimExpr>(e)) {
|
||||
for (unsigned dm : inMemrefTypeDynDims) {
|
||||
if (e == getAffineDimExpr(dm, context)) {
|
||||
isDynamicDim = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
return isDynamicDim;
|
||||
MLIRContext *context = layoutMap.getContext();
|
||||
return expr
|
||||
.walk([&](AffineExpr e) {
|
||||
if (isa<AffineDimExpr>(e) &&
|
||||
llvm::any_of(inMemrefTypeDynDims, [&](unsigned dim) {
|
||||
return e == getAffineDimExpr(dim, context);
|
||||
}))
|
||||
return WalkResult::interrupt();
|
||||
return WalkResult::advance();
|
||||
})
|
||||
.wasInterrupted();
|
||||
}
|
||||
|
||||
/// Create affine expr to calculate dimension size for a tiled-layout map.
|
||||
@ -1792,29 +1791,28 @@ MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) {
|
||||
MLIRContext *context = memrefType.getContext();
|
||||
for (unsigned d = 0; d < newRank; ++d) {
|
||||
// Check if this dimension is dynamic.
|
||||
bool isDynDim =
|
||||
isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context);
|
||||
if (isDynDim) {
|
||||
if (bool isDynDim =
|
||||
isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims)) {
|
||||
newShape[d] = ShapedType::kDynamic;
|
||||
} else {
|
||||
// The lower bound for the shape is always zero.
|
||||
std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
|
||||
// For a static memref and an affine map with no symbols, this is
|
||||
// always bounded. However, when we have symbols, we may not be able to
|
||||
// obtain a constant upper bound. Also, mapping to a negative space is
|
||||
// invalid for normalization.
|
||||
if (!ubConst.has_value() || *ubConst < 0) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "can't normalize map due to unknown/invalid upper bound");
|
||||
return memrefType;
|
||||
}
|
||||
// If dimension of new memrefType is dynamic, the value is -1.
|
||||
newShape[d] = *ubConst + 1;
|
||||
continue;
|
||||
}
|
||||
// The lower bound for the shape is always zero.
|
||||
std::optional<int64_t> ubConst = fac.getConstantBound64(BoundType::UB, d);
|
||||
// For a static memref and an affine map with no symbols, this is
|
||||
// always bounded. However, when we have symbols, we may not be able to
|
||||
// obtain a constant upper bound. Also, mapping to a negative space is
|
||||
// invalid for normalization.
|
||||
if (!ubConst.has_value() || *ubConst < 0) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "can't normalize map due to unknown/invalid upper bound");
|
||||
return memrefType;
|
||||
}
|
||||
// If dimension of new memrefType is dynamic, the value is -1.
|
||||
newShape[d] = *ubConst + 1;
|
||||
}
|
||||
|
||||
// Create the new memref type after trivializing the old layout map.
|
||||
MemRefType newMemRefType =
|
||||
auto newMemRefType =
|
||||
MemRefType::Builder(memrefType)
|
||||
.setShape(newShape)
|
||||
.setLayout(AffineMapAttr::get(
|
||||
|
@ -26,22 +26,37 @@ MLIRContext *AffineExpr::getContext() const { return expr->context; }
|
||||
|
||||
AffineExprKind AffineExpr::getKind() const { return expr->kind; }
|
||||
|
||||
/// Walk all of the AffineExprs in this subgraph in postorder.
|
||||
void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
|
||||
struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
|
||||
std::function<void(AffineExpr)> callback;
|
||||
/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
|
||||
/// method to help handle lambda walk functions. Users should use the regular
|
||||
/// (non-static) `walk` method.
|
||||
template <typename WalkRetTy>
|
||||
WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
|
||||
function_ref<WalkRetTy(AffineExpr)> callback) {
|
||||
struct AffineExprWalker
|
||||
: public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
|
||||
function_ref<WalkRetTy(AffineExpr)> callback;
|
||||
|
||||
AffineExprWalker(std::function<void(AffineExpr)> callback)
|
||||
: callback(std::move(callback)) {}
|
||||
AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
|
||||
: callback(callback) {}
|
||||
|
||||
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
|
||||
void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
|
||||
void visitDimExpr(AffineDimExpr expr) { callback(expr); }
|
||||
void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
|
||||
WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
|
||||
return callback(expr);
|
||||
}
|
||||
WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
|
||||
return callback(expr);
|
||||
}
|
||||
WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
|
||||
WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
|
||||
};
|
||||
|
||||
AffineExprWalker(std::move(callback)).walkPostOrder(*this);
|
||||
return AffineExprWalker(callback).walkPostOrder(e);
|
||||
}
|
||||
// Explicitly instantiate for the two supported return types.
|
||||
template void mlir::AffineExpr::walk(AffineExpr e,
|
||||
function_ref<void(AffineExpr)> callback);
|
||||
template WalkResult
|
||||
mlir::AffineExpr::walk(AffineExpr e,
|
||||
function_ref<WalkResult(AffineExpr)> callback);
|
||||
|
||||
// Dispatch affine expression construction based on kind.
|
||||
AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
|
||||
|
9
mlir/test/IR/affine-walk.mlir
Normal file
9
mlir/test/IR/affine-walk.mlir
Normal file
@ -0,0 +1,9 @@
|
||||
// RUN: mlir-opt -test-affine-walk -verify-diagnostics %s
|
||||
|
||||
// Test affine walk interrupt. A remark should be printed only for the first mod
|
||||
// expression encountered in post order.
|
||||
|
||||
#map = affine_map<(i, j) -> ((i mod 4) mod 2, j)>
|
||||
|
||||
"test.check_first_mod"() {"map" = #map} : () -> ()
|
||||
// expected-remark@-1 {{mod expression}}
|
@ -1,5 +1,6 @@
|
||||
# Exclude tests from libMLIR.so
|
||||
add_mlir_library(MLIRTestIR
|
||||
TestAffineWalk.cpp
|
||||
TestBytecodeRoundtrip.cpp
|
||||
TestBuiltinAttributeInterfaces.cpp
|
||||
TestBuiltinDistinctAttributes.cpp
|
||||
|
57
mlir/test/lib/IR/TestAffineWalk.cpp
Normal file
57
mlir/test/lib/IR/TestAffineWalk.cpp
Normal file
@ -0,0 +1,57 @@
|
||||
//===- TestAffineWalk.cpp - Pass to test affine walks
|
||||
//----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// A test pass for verifying walk interrupts.
|
||||
struct TestAffineWalk
|
||||
: public PassWrapper<TestAffineWalk, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineWalk)
|
||||
|
||||
void runOnOperation() override;
|
||||
StringRef getArgument() const final { return "test-affine-walk"; }
|
||||
StringRef getDescription() const final { return "Test affine walk method."; }
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Emits a remark for the first `map`'s result expression that contains a
|
||||
/// mod expression.
|
||||
static void checkMod(AffineMap map, Location loc) {
|
||||
for (AffineExpr e : map.getResults()) {
|
||||
e.walk([&](AffineExpr s) {
|
||||
if (s.getKind() == mlir::AffineExprKind::Mod) {
|
||||
emitRemark(loc, "mod expression: ");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void TestAffineWalk::runOnOperation() {
|
||||
auto m = getOperation();
|
||||
// Test whether the walk is being correctly interrupted.
|
||||
m.walk([](Operation *op) {
|
||||
for (NamedAttribute attr : op->getAttrs()) {
|
||||
auto mapAttr = attr.getValue().dyn_cast<AffineMapAttr>();
|
||||
if (!mapAttr)
|
||||
return;
|
||||
checkMod(mapAttr.getAffineMap(), op->getLoc());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
void registerTestAffineWalk() { PassRegistration<TestAffineWalk>(); }
|
||||
} // namespace mlir
|
@ -44,11 +44,12 @@ void registerSymbolTestPasses();
|
||||
void registerRegionTestPasses();
|
||||
void registerTestAffineDataCopyPass();
|
||||
void registerTestAffineReifyValueBoundsPass();
|
||||
void registerTestAffineLoopUnswitchingPass();
|
||||
void registerTestAffineWalk();
|
||||
void registerTestBytecodeRoundtripPasses();
|
||||
void registerTestDecomposeAffineOpPass();
|
||||
void registerTestAffineLoopUnswitchingPass();
|
||||
void registerTestGpuLoweringPasses();
|
||||
void registerTestFunc();
|
||||
void registerTestGpuLoweringPasses();
|
||||
void registerTestGpuMemoryPromotionPass();
|
||||
void registerTestLoopPermutationPass();
|
||||
void registerTestMatchers();
|
||||
@ -167,12 +168,13 @@ void registerTestPasses() {
|
||||
registerSymbolTestPasses();
|
||||
registerRegionTestPasses();
|
||||
registerTestAffineDataCopyPass();
|
||||
registerTestAffineReifyValueBoundsPass();
|
||||
registerTestDecomposeAffineOpPass();
|
||||
registerTestAffineLoopUnswitchingPass();
|
||||
registerTestGpuLoweringPasses();
|
||||
registerTestAffineReifyValueBoundsPass();
|
||||
registerTestAffineWalk();
|
||||
registerTestBytecodeRoundtripPasses();
|
||||
registerTestDecomposeAffineOpPass();
|
||||
registerTestFunc();
|
||||
registerTestGpuLoweringPasses();
|
||||
registerTestGpuMemoryPromotionPass();
|
||||
registerTestLoopPermutationPass();
|
||||
registerTestMatchers();
|
||||
|
Loading…
Reference in New Issue
Block a user