mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-24 12:34:19 +00:00
[mlir][interfaces] Add insideMutuallyExclusiveRegions helper
Add a helper function to ControlFlowInterfaces for checking if two ops are in mutually exclusive regions according to RegionBranchOpInterface. Utilize this new helper in Linalg ComprehensiveBufferize. This makes the analysis independent of the SCF dialect and generalizes it to other ops that implement RegionBranchOpInterface. Differential Revision: https://reviews.llvm.org/D114220
This commit is contained in:
parent
72e4f4a2a1
commit
a5c2f78287
@ -87,6 +87,10 @@ private:
|
||||
ValueRange inputs;
|
||||
};
|
||||
|
||||
/// Return `true` if `a` and `b` are in mutually exclusive regions as per
|
||||
/// RegionBranchOpInterface.
|
||||
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RegionBranchTerminatorOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -430,9 +430,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
|
||||
aliasInfo))
|
||||
continue;
|
||||
|
||||
// Special rules for branches.
|
||||
// TODO: Use an interface.
|
||||
if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp))
|
||||
// Ops are not conflicting if they are in mutually exclusive regions.
|
||||
if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
|
||||
continue;
|
||||
|
||||
LDBG("WRITE = #" << printValueInfo(lastWrite) << "\n");
|
||||
|
@ -219,6 +219,78 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Return `true` if `a` and `b` are in mutually exclusive regions.
|
||||
///
|
||||
/// 1. Find the first common of `a` and `b` (ancestor) that implements
|
||||
/// RegionBranchOpInterface.
|
||||
/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
|
||||
/// contained.
|
||||
/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
|
||||
/// mutually exclusive if they are not reachable from each other as per
|
||||
/// RegionBranchOpInterface::getSuccessorRegions.
|
||||
bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
|
||||
assert(a && "expected non-empty operation");
|
||||
assert(b && "expected non-empty operation");
|
||||
|
||||
auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
|
||||
while (branchOp) {
|
||||
// Check if b is inside branchOp. (We already know that a is.)
|
||||
if (!branchOp->isProperAncestor(b)) {
|
||||
// Check next enclosing RegionBranchOpInterface.
|
||||
branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
|
||||
continue;
|
||||
}
|
||||
|
||||
// b is contained in branchOp. Retrieve the regions in which `a` and `b`
|
||||
// are contained.
|
||||
Region *regionA = nullptr, *regionB = nullptr;
|
||||
for (Region &r : branchOp->getRegions()) {
|
||||
if (r.findAncestorOpInRegion(*a)) {
|
||||
assert(!regionA && "already found a region for a");
|
||||
regionA = &r;
|
||||
}
|
||||
if (r.findAncestorOpInRegion(*b)) {
|
||||
assert(!regionB && "already found a region for b");
|
||||
regionB = &r;
|
||||
}
|
||||
}
|
||||
assert(regionA && regionB && "could not find region of op");
|
||||
|
||||
// Helper function that checks if region `r` is reachable from region
|
||||
// `begin`.
|
||||
std::function<bool(Region *, Region *)> isRegionReachable =
|
||||
[&](Region *begin, Region *r) {
|
||||
if (begin == r)
|
||||
return true;
|
||||
if (begin == nullptr)
|
||||
return false;
|
||||
// Compute index of region.
|
||||
int64_t beginIndex = -1;
|
||||
for (auto it : llvm::enumerate(branchOp->getRegions()))
|
||||
if (&it.value() == begin)
|
||||
beginIndex = it.index();
|
||||
assert(beginIndex != -1 && "could not find region in op");
|
||||
// Retrieve all successors of the region.
|
||||
SmallVector<RegionSuccessor> successors;
|
||||
branchOp.getSuccessorRegions(beginIndex, successors);
|
||||
// Call function recursively on all successors.
|
||||
for (RegionSuccessor successor : successors)
|
||||
if (isRegionReachable(successor.getSuccessor(), r))
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
// `a` and `b` are in mutually exclusive regions if neither region is
|
||||
// reachable from the other region.
|
||||
return !isRegionReachable(regionA, regionB) &&
|
||||
!isRegionReachable(regionB, regionA);
|
||||
}
|
||||
|
||||
// Could not find a common RegionBranchOpInterface among a's and b's
|
||||
// ancestors.
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RegionBranchTerminatorOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1,10 +1,12 @@
|
||||
add_mlir_unittest(MLIRInterfacesTests
|
||||
ControlFlowInterfacesTest.cpp
|
||||
DataLayoutInterfacesTest.cpp
|
||||
InferTypeOpInterfaceTest.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(MLIRInterfacesTests
|
||||
PRIVATE
|
||||
MLIRControlFlowInterfaces
|
||||
MLIRDataLayoutInterfaces
|
||||
MLIRDLTI
|
||||
MLIRInferTypeOpInterface
|
||||
|
145
mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
Normal file
145
mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
Normal file
@ -0,0 +1,145 @@
|
||||
//===- ControlFlowInterfacesTest.cpp - Unit Tests for Control Flow Interf. ===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Parser.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// A dummy op that is also a terminator.
|
||||
struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
|
||||
using Op::Op;
|
||||
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
||||
|
||||
static StringRef getOperationName() { return "cftest.dummy_op"; }
|
||||
};
|
||||
|
||||
/// All regions of this op are mutually exclusive.
|
||||
struct MutuallyExclusiveRegionsOp
|
||||
: public Op<MutuallyExclusiveRegionsOp, RegionBranchOpInterface::Trait> {
|
||||
using Op::Op;
|
||||
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
||||
|
||||
static StringRef getOperationName() {
|
||||
return "cftest.mutually_exclusive_regions_op";
|
||||
}
|
||||
|
||||
// Regions have no successors.
|
||||
void getSuccessorRegions(Optional<unsigned> index,
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {}
|
||||
};
|
||||
|
||||
/// Regions are executed sequentially.
|
||||
struct SequentialRegionsOp
|
||||
: public Op<SequentialRegionsOp, RegionBranchOpInterface::Trait> {
|
||||
using Op::Op;
|
||||
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
||||
|
||||
static StringRef getOperationName() { return "cftest.sequential_regions_op"; }
|
||||
|
||||
// Region 0 has Region 1 as a successor.
|
||||
void getSuccessorRegions(Optional<unsigned> index,
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
assert(index.hasValue() && "expected index");
|
||||
if (*index == 0) {
|
||||
Operation *thisOp = this->getOperation();
|
||||
regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// A dialect putting all the above together.
|
||||
struct CFTestDialect : Dialect {
|
||||
explicit CFTestDialect(MLIRContext *ctx)
|
||||
: Dialect(getDialectNamespace(), ctx, TypeID::get<CFTestDialect>()) {
|
||||
addOperations<DummyOp, MutuallyExclusiveRegionsOp, SequentialRegionsOp>();
|
||||
}
|
||||
static StringRef getDialectNamespace() { return "cftest"; }
|
||||
};
|
||||
|
||||
TEST(RegionBranchOpInterface, MutuallyExclusiveOps) {
|
||||
const char *ir = R"MLIR(
|
||||
"cftest.mutually_exclusive_regions_op"() (
|
||||
{"cftest.dummy_op"() : () -> ()}, // op1
|
||||
{"cftest.dummy_op"() : () -> ()} // op2
|
||||
) : () -> ()
|
||||
)MLIR";
|
||||
|
||||
DialectRegistry registry;
|
||||
registry.insert<CFTestDialect>();
|
||||
MLIRContext ctx(registry);
|
||||
|
||||
OwningModuleRef module = parseSourceString(ir, &ctx);
|
||||
Operation *testOp = &module->getBody()->getOperations().front();
|
||||
Operation *op1 = &testOp->getRegion(0).front().front();
|
||||
Operation *op2 = &testOp->getRegion(1).front().front();
|
||||
|
||||
EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
|
||||
EXPECT_TRUE(insideMutuallyExclusiveRegions(op2, op1));
|
||||
}
|
||||
|
||||
TEST(RegionBranchOpInterface, NotMutuallyExclusiveOps) {
|
||||
const char *ir = R"MLIR(
|
||||
"cftest.sequential_regions_op"() (
|
||||
{"cftest.dummy_op"() : () -> ()}, // op1
|
||||
{"cftest.dummy_op"() : () -> ()} // op2
|
||||
) : () -> ()
|
||||
)MLIR";
|
||||
|
||||
DialectRegistry registry;
|
||||
registry.insert<CFTestDialect>();
|
||||
MLIRContext ctx(registry);
|
||||
|
||||
OwningModuleRef module = parseSourceString(ir, &ctx);
|
||||
Operation *testOp = &module->getBody()->getOperations().front();
|
||||
Operation *op1 = &testOp->getRegion(0).front().front();
|
||||
Operation *op2 = &testOp->getRegion(1).front().front();
|
||||
|
||||
EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op2));
|
||||
EXPECT_FALSE(insideMutuallyExclusiveRegions(op2, op1));
|
||||
}
|
||||
|
||||
TEST(RegionBranchOpInterface, NestedMutuallyExclusiveOps) {
|
||||
const char *ir = R"MLIR(
|
||||
"cftest.mutually_exclusive_regions_op"() (
|
||||
{
|
||||
"cftest.sequential_regions_op"() (
|
||||
{"cftest.dummy_op"() : () -> ()}, // op1
|
||||
{"cftest.dummy_op"() : () -> ()} // op3
|
||||
) : () -> ()
|
||||
"cftest.dummy_op"() : () -> ()
|
||||
},
|
||||
{"cftest.dummy_op"() : () -> ()} // op2
|
||||
) : () -> ()
|
||||
)MLIR";
|
||||
|
||||
DialectRegistry registry;
|
||||
registry.insert<CFTestDialect>();
|
||||
MLIRContext ctx(registry);
|
||||
|
||||
OwningModuleRef module = parseSourceString(ir, &ctx);
|
||||
Operation *testOp = &module->getBody()->getOperations().front();
|
||||
Operation *op1 =
|
||||
&testOp->getRegion(0).front().front().getRegion(0).front().front();
|
||||
Operation *op2 = &testOp->getRegion(1).front().front();
|
||||
Operation *op3 =
|
||||
&testOp->getRegion(0).front().front().getRegion(1).front().front();
|
||||
|
||||
EXPECT_TRUE(insideMutuallyExclusiveRegions(op1, op2));
|
||||
EXPECT_TRUE(insideMutuallyExclusiveRegions(op3, op2));
|
||||
EXPECT_FALSE(insideMutuallyExclusiveRegions(op1, op3));
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user