mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-04-12 11:23:33 +00:00
[mlir] Fix missing verification after running an OpToOpAdaptorPass
The current decision of when to run the verifier is running on the assumption that nested passes can't affect the validity of the parent operation, which isn't true. Parent operations may attach any number of constraints on nested operations, which may not necessarily be captured (or shouldn't be captured) at a smaller granularity. This commit rectifies this by properly running the verifier after an OpToOpAdaptor pass. To avoid an explosive increase in compile time, we only run verification on the parent operation itself. To do this, a flag to mlir::verify is added to avoid recursive verification if it isn't desired. Fixes #54288 Differential Revision: https://reviews.llvm.org/D121836
This commit is contained in:
parent
79f661edc1
commit
50f82e6847
@ -15,8 +15,12 @@ class Operation;
|
|||||||
|
|
||||||
/// Perform (potentially expensive) checks of invariants, used to detect
|
/// Perform (potentially expensive) checks of invariants, used to detect
|
||||||
/// compiler bugs, on this operation and any nested operations. On error, this
|
/// compiler bugs, on this operation and any nested operations. On error, this
|
||||||
/// reports the error through the MLIRContext and returns failure.
|
/// reports the error through the MLIRContext and returns failure. If
|
||||||
LogicalResult verify(Operation *op);
|
/// `verifyRecursively` is false, this assumes that nested operations have
|
||||||
|
/// already been properly verified, and does not recursively invoke the verifier
|
||||||
|
/// on nested operations.
|
||||||
|
LogicalResult verify(Operation *op, bool verifyRecursively = true);
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -43,6 +43,11 @@ namespace {
|
|||||||
/// This class encapsulates all the state used to verify an operation region.
|
/// This class encapsulates all the state used to verify an operation region.
|
||||||
class OperationVerifier {
|
class OperationVerifier {
|
||||||
public:
|
public:
|
||||||
|
/// If `verifyRecursively` is true, then this will also recursively verify
|
||||||
|
/// nested operations.
|
||||||
|
explicit OperationVerifier(bool verifyRecursively)
|
||||||
|
: verifyRecursively(verifyRecursively) {}
|
||||||
|
|
||||||
/// Verify the given operation.
|
/// Verify the given operation.
|
||||||
LogicalResult verifyOpAndDominance(Operation &op);
|
LogicalResult verifyOpAndDominance(Operation &op);
|
||||||
|
|
||||||
@ -61,6 +66,10 @@ private:
|
|||||||
/// Operation.
|
/// Operation.
|
||||||
LogicalResult verifyDominanceOfContainedRegions(Operation &op,
|
LogicalResult verifyDominanceOfContainedRegions(Operation &op,
|
||||||
DominanceInfo &domInfo);
|
DominanceInfo &domInfo);
|
||||||
|
|
||||||
|
/// A flag indicating if this verifier should recursively verify nested
|
||||||
|
/// operations.
|
||||||
|
bool verifyRecursively;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -81,8 +90,12 @@ LogicalResult OperationVerifier::verifyOpAndDominance(Operation &op) {
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the dominance properties and invariants of any operations in the
|
// If we aren't verifying nested operations, then we're done.
|
||||||
// regions contained by the 'opsWithIsolatedRegions' operations.
|
if (!verifyRecursively)
|
||||||
|
return success();
|
||||||
|
|
||||||
|
// Otherwise, check the dominance properties and invariants of any operations
|
||||||
|
// in the regions contained by the 'opsWithIsolatedRegions' operations.
|
||||||
return failableParallelForEach(
|
return failableParallelForEach(
|
||||||
op.getContext(), opsWithIsolatedRegions,
|
op.getContext(), opsWithIsolatedRegions,
|
||||||
[&](Operation *op) { return verifyOpAndDominance(*op); });
|
[&](Operation *op) { return verifyOpAndDominance(*op); });
|
||||||
@ -120,21 +133,25 @@ LogicalResult OperationVerifier::verifyBlock(
|
|||||||
|
|
||||||
// Check each operation, and make sure there are no branches out of the
|
// Check each operation, and make sure there are no branches out of the
|
||||||
// middle of this block.
|
// middle of this block.
|
||||||
for (auto &op : block) {
|
for (Operation &op : block) {
|
||||||
// Only the last instructions is allowed to have successors.
|
// Only the last instructions is allowed to have successors.
|
||||||
if (op.getNumSuccessors() != 0 && &op != &block.back())
|
if (op.getNumSuccessors() != 0 && &op != &block.back())
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
"operation with block successors must terminate its parent block");
|
"operation with block successors must terminate its parent block");
|
||||||
|
|
||||||
|
// If we aren't verifying recursievly, there is nothing left to check.
|
||||||
|
if (!verifyRecursively)
|
||||||
|
continue;
|
||||||
|
|
||||||
// If this operation has regions and is IsolatedFromAbove, we defer
|
// If this operation has regions and is IsolatedFromAbove, we defer
|
||||||
// checking. This allows us to parallelize verification better.
|
// checking. This allows us to parallelize verification better.
|
||||||
if (op.getNumRegions() != 0 &&
|
if (op.getNumRegions() != 0 &&
|
||||||
op.hasTrait<OpTrait::IsIsolatedFromAbove>()) {
|
op.hasTrait<OpTrait::IsIsolatedFromAbove>()) {
|
||||||
opsWithIsolatedRegions.push_back(&op);
|
opsWithIsolatedRegions.push_back(&op);
|
||||||
} else {
|
|
||||||
// Otherwise, check the operation inline.
|
// Otherwise, check the operation inline.
|
||||||
if (failed(verifyOperation(op, opsWithIsolatedRegions)))
|
} else if (failed(verifyOperation(op, opsWithIsolatedRegions))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,8 +202,9 @@ LogicalResult OperationVerifier::verifyOperation(
|
|||||||
auto kindInterface = dyn_cast<RegionKindInterface>(op);
|
auto kindInterface = dyn_cast<RegionKindInterface>(op);
|
||||||
|
|
||||||
// Verify that all child regions are ok.
|
// Verify that all child regions are ok.
|
||||||
|
MutableArrayRef<Region> regions = op.getRegions();
|
||||||
for (unsigned i = 0; i < numRegions; ++i) {
|
for (unsigned i = 0; i < numRegions; ++i) {
|
||||||
Region ®ion = op.getRegion(i);
|
Region ®ion = regions[i];
|
||||||
RegionKind kind =
|
RegionKind kind =
|
||||||
kindInterface ? kindInterface.getRegionKind(i) : RegionKind::SSACFG;
|
kindInterface ? kindInterface.getRegionKind(i) : RegionKind::SSACFG;
|
||||||
// Check that Graph Regions only have a single basic block. This is
|
// Check that Graph Regions only have a single basic block. This is
|
||||||
@ -210,10 +228,13 @@ LogicalResult OperationVerifier::verifyOperation(
|
|||||||
return emitError(op.getLoc(),
|
return emitError(op.getLoc(),
|
||||||
"entry block of region may not have predecessors");
|
"entry block of region may not have predecessors");
|
||||||
|
|
||||||
// Verify each of the blocks within the region.
|
// Verify each of the blocks within the region if we are verifying
|
||||||
for (Block &block : region)
|
// recursively.
|
||||||
if (failed(verifyBlock(block, opsWithIsolatedRegions)))
|
if (verifyRecursively) {
|
||||||
return failure();
|
for (Block &block : region)
|
||||||
|
if (failed(verifyBlock(block, opsWithIsolatedRegions)))
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -330,10 +351,10 @@ OperationVerifier::verifyDominanceOfContainedRegions(Operation &op,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recursively verify dominance within each operation in the
|
// Recursively verify dominance within each operation in the block, even
|
||||||
// block, even if the block itself is not reachable, or we are in
|
// if the block itself is not reachable, or we are in a region which
|
||||||
// a region which doesn't respect dominance.
|
// doesn't respect dominance.
|
||||||
if (op.getNumRegions() != 0) {
|
if (verifyRecursively && op.getNumRegions() != 0) {
|
||||||
// If this operation is IsolatedFromAbove, then we'll handle it in the
|
// If this operation is IsolatedFromAbove, then we'll handle it in the
|
||||||
// outer verification loop.
|
// outer verification loop.
|
||||||
if (op.hasTrait<OpTrait::IsIsolatedFromAbove>())
|
if (op.hasTrait<OpTrait::IsIsolatedFromAbove>())
|
||||||
@ -352,9 +373,7 @@ OperationVerifier::verifyDominanceOfContainedRegions(Operation &op,
|
|||||||
// Entrypoint
|
// Entrypoint
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Perform (potentially expensive) checks of invariants, used to detect
|
LogicalResult mlir::verify(Operation *op, bool verifyRecursively) {
|
||||||
/// compiler bugs. On error, this reports the error through the MLIRContext and
|
OperationVerifier verifier(verifyRecursively);
|
||||||
/// returns failure.
|
return verifier.verifyOpAndDominance(*op);
|
||||||
LogicalResult mlir::verify(Operation *op) {
|
|
||||||
return OperationVerifier().verifyOpAndDominance(*op);
|
|
||||||
}
|
}
|
||||||
|
@ -408,22 +408,24 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
|
|||||||
// failed).
|
// failed).
|
||||||
if (!passFailed && verifyPasses) {
|
if (!passFailed && verifyPasses) {
|
||||||
bool runVerifierNow = true;
|
bool runVerifierNow = true;
|
||||||
|
|
||||||
|
// If the pass is an adaptor pass, we don't run the verifier recursively
|
||||||
|
// because the nested operations should have already been verified after
|
||||||
|
// nested passes had run.
|
||||||
|
bool runVerifierRecursively = !isa<OpToOpPassAdaptor>(pass);
|
||||||
|
|
||||||
// Reduce compile time by avoiding running the verifier if the pass didn't
|
// Reduce compile time by avoiding running the verifier if the pass didn't
|
||||||
// change the IR since the last time the verifier was run:
|
// change the IR since the last time the verifier was run:
|
||||||
//
|
//
|
||||||
// 1) If the pass said that it preserved all analyses then it can't have
|
// 1) If the pass said that it preserved all analyses then it can't have
|
||||||
// permuted the IR.
|
// permuted the IR.
|
||||||
// 2) If we just ran an OpToOpPassAdaptor (e.g. to run function passes
|
|
||||||
// within a module) then each sub-unit will have been verified on the
|
|
||||||
// subunit (and those passes aren't allowed to modify the parent).
|
|
||||||
//
|
//
|
||||||
// We run these checks in EXPENSIVE_CHECKS mode out of caution.
|
// We run these checks in EXPENSIVE_CHECKS mode out of caution.
|
||||||
#ifndef EXPENSIVE_CHECKS
|
#ifndef EXPENSIVE_CHECKS
|
||||||
runVerifierNow = !isa<OpToOpPassAdaptor>(pass) &&
|
runVerifierNow = !pass->passState->preservedAnalyses.isAll();
|
||||||
!pass->passState->preservedAnalyses.isAll();
|
|
||||||
#endif
|
#endif
|
||||||
if (runVerifierNow)
|
if (runVerifierNow)
|
||||||
passFailed = failed(verify(op));
|
passFailed = failed(verify(op, runVerifierRecursively));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instrument after the pass has run.
|
// Instrument after the pass has run.
|
||||||
|
8
mlir/test/Pass/invalid-parent.mlir
Normal file
8
mlir/test/Pass/invalid-parent.mlir
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
// RUN: mlir-opt %s -pass-pipeline='builtin.func(test-pass-invalid-parent)' -verify-diagnostics
|
||||||
|
|
||||||
|
// Test that we properly report errors when the parent becomes invalid after running a pass
|
||||||
|
// on a child operation.
|
||||||
|
// expected-error@below {{'some_unknown_func' does not reference a valid function}}
|
||||||
|
func @TestCreateInvalidCallInPass() {
|
||||||
|
return
|
||||||
|
}
|
@ -358,6 +358,21 @@ void TestDialect::getCanonicalizationPatterns(
|
|||||||
results.add(&dialectCanonicalizationPattern);
|
results.add(&dialectCanonicalizationPattern);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TestCallOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||||
|
// Check that the callee attribute was specified.
|
||||||
|
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
|
||||||
|
if (!fnAttr)
|
||||||
|
return emitOpError("requires a 'callee' symbol reference attribute");
|
||||||
|
if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
|
||||||
|
return emitOpError() << "'" << fnAttr.getValue()
|
||||||
|
<< "' does not reference a valid function";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TestFoldToCallOp
|
// TestFoldToCallOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -375,6 +375,14 @@ def VariadicNoTerminatorOp : TEST_Op<"variadic_no_terminator_op",
|
|||||||
// Test Call Interfaces
|
// Test Call Interfaces
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def TestCallOp : TEST_Op<"call", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
|
||||||
|
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
|
||||||
|
let results = (outs Variadic<AnyType>);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def ConversionCallOp : TEST_Op<"conversion_call_op",
|
def ConversionCallOp : TEST_Op<"conversion_call_op",
|
||||||
[CallOpInterface]> {
|
[CallOpInterface]> {
|
||||||
let arguments = (ins Variadic<AnyType>:$arg_operands, SymbolRefAttr:$callee);
|
let arguments = (ins Variadic<AnyType>:$arg_operands, SymbolRefAttr:$callee);
|
||||||
|
@ -11,4 +11,11 @@ add_mlir_library(MLIRTestPass
|
|||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRPass
|
MLIRPass
|
||||||
|
MLIRTestDialect
|
||||||
|
)
|
||||||
|
|
||||||
|
target_include_directories(MLIRTestPass
|
||||||
|
PRIVATE
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test
|
||||||
)
|
)
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "TestDialect.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
@ -98,6 +99,27 @@ class TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// A test pass that always fails to enable testing the failure recovery
|
||||||
|
/// mechanisms of the pass manager.
|
||||||
|
class TestInvalidParentPass
|
||||||
|
: public PassWrapper<TestInvalidParentPass,
|
||||||
|
InterfacePass<FunctionOpInterface>> {
|
||||||
|
StringRef getArgument() const final { return "test-pass-invalid-parent"; }
|
||||||
|
StringRef getDescription() const final {
|
||||||
|
return "Test a pass in the pass manager that makes the parent operation "
|
||||||
|
"invalid";
|
||||||
|
}
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const final {
|
||||||
|
registry.insert<test::TestDialect>();
|
||||||
|
}
|
||||||
|
void runOnOperation() final {
|
||||||
|
FunctionOpInterface op = getOperation();
|
||||||
|
OpBuilder b(getOperation().getBody());
|
||||||
|
b.create<test::TestCallOp>(op.getLoc(), TypeRange(), "some_unknown_func",
|
||||||
|
ValueRange());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// A test pass that contains a statistic.
|
/// A test pass that contains a statistic.
|
||||||
struct TestStatisticPass
|
struct TestStatisticPass
|
||||||
: public PassWrapper<TestStatisticPass, OperationPass<>> {
|
: public PassWrapper<TestStatisticPass, OperationPass<>> {
|
||||||
@ -144,6 +166,7 @@ void registerPassManagerTestPass() {
|
|||||||
|
|
||||||
PassRegistration<TestCrashRecoveryPass>();
|
PassRegistration<TestCrashRecoveryPass>();
|
||||||
PassRegistration<TestFailurePass>();
|
PassRegistration<TestFailurePass>();
|
||||||
|
PassRegistration<TestInvalidParentPass>();
|
||||||
|
|
||||||
PassRegistration<TestStatisticPass>();
|
PassRegistration<TestStatisticPass>();
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user