[mlir] Fix missing verification after running an OpToOpAdaptorPass

The current decision of when to run the verifier is running on the
assumption that nested passes can't affect the validity of the parent
operation, which isn't true. Parent operations may attach any number
of constraints on nested operations, which may not necessarily be
captured (or shouldn't be captured) at a smaller granularity.

This commit rectifies this by properly running the verifier after an
OpToOpAdaptor pass. To avoid an explosive increase in compile time,
we only run verification on the parent operation itself. To do this, a
flag to mlir::verify is added to avoid recursive verification if it isn't
desired.

Fixes #54288

Differential Revision: https://reviews.llvm.org/D121836
This commit is contained in:
River Riddle 2022-03-16 11:45:14 -07:00
parent 79f661edc1
commit 50f82e6847
8 changed files with 114 additions and 28 deletions

View File

@ -15,8 +15,12 @@ class Operation;
/// Perform (potentially expensive) checks of invariants, used to detect
/// compiler bugs, on this operation and any nested operations. On error, this
/// reports the error through the MLIRContext and returns failure.
LogicalResult verify(Operation *op);
/// reports the error through the MLIRContext and returns failure. If
/// `verifyRecursively` is false, this assumes that nested operations have
/// already been properly verified, and does not recursively invoke the verifier
/// on nested operations.
LogicalResult verify(Operation *op, bool verifyRecursively = true);
} // namespace mlir
#endif

View File

@ -43,6 +43,11 @@ namespace {
/// This class encapsulates all the state used to verify an operation region.
class OperationVerifier {
public:
/// If `verifyRecursively` is true, then this will also recursively verify
/// nested operations.
explicit OperationVerifier(bool verifyRecursively)
: verifyRecursively(verifyRecursively) {}
/// Verify the given operation.
LogicalResult verifyOpAndDominance(Operation &op);
@ -61,6 +66,10 @@ private:
/// Operation.
LogicalResult verifyDominanceOfContainedRegions(Operation &op,
DominanceInfo &domInfo);
/// A flag indicating if this verifier should recursively verify nested
/// operations.
bool verifyRecursively;
};
} // namespace
@ -81,8 +90,12 @@ LogicalResult OperationVerifier::verifyOpAndDominance(Operation &op) {
return failure();
}
// Check the dominance properties and invariants of any operations in the
// regions contained by the 'opsWithIsolatedRegions' operations.
// If we aren't verifying nested operations, then we're done.
if (!verifyRecursively)
return success();
// Otherwise, check the dominance properties and invariants of any operations
// in the regions contained by the 'opsWithIsolatedRegions' operations.
return failableParallelForEach(
op.getContext(), opsWithIsolatedRegions,
[&](Operation *op) { return verifyOpAndDominance(*op); });
@ -120,21 +133,25 @@ LogicalResult OperationVerifier::verifyBlock(
// Check each operation, and make sure there are no branches out of the
// middle of this block.
for (auto &op : block) {
for (Operation &op : block) {
// Only the last instructions is allowed to have successors.
if (op.getNumSuccessors() != 0 && &op != &block.back())
return op.emitError(
"operation with block successors must terminate its parent block");
// If we aren't verifying recursievly, there is nothing left to check.
if (!verifyRecursively)
continue;
// If this operation has regions and is IsolatedFromAbove, we defer
// checking. This allows us to parallelize verification better.
if (op.getNumRegions() != 0 &&
op.hasTrait<OpTrait::IsIsolatedFromAbove>()) {
opsWithIsolatedRegions.push_back(&op);
} else {
// Otherwise, check the operation inline.
if (failed(verifyOperation(op, opsWithIsolatedRegions)))
return failure();
} else if (failed(verifyOperation(op, opsWithIsolatedRegions))) {
return failure();
}
}
@ -185,8 +202,9 @@ LogicalResult OperationVerifier::verifyOperation(
auto kindInterface = dyn_cast<RegionKindInterface>(op);
// Verify that all child regions are ok.
MutableArrayRef<Region> regions = op.getRegions();
for (unsigned i = 0; i < numRegions; ++i) {
Region &region = op.getRegion(i);
Region &region = regions[i];
RegionKind kind =
kindInterface ? kindInterface.getRegionKind(i) : RegionKind::SSACFG;
// Check that Graph Regions only have a single basic block. This is
@ -210,10 +228,13 @@ LogicalResult OperationVerifier::verifyOperation(
return emitError(op.getLoc(),
"entry block of region may not have predecessors");
// Verify each of the blocks within the region.
for (Block &block : region)
if (failed(verifyBlock(block, opsWithIsolatedRegions)))
return failure();
// Verify each of the blocks within the region if we are verifying
// recursively.
if (verifyRecursively) {
for (Block &block : region)
if (failed(verifyBlock(block, opsWithIsolatedRegions)))
return failure();
}
}
}
@ -330,10 +351,10 @@ OperationVerifier::verifyDominanceOfContainedRegions(Operation &op,
}
}
// Recursively verify dominance within each operation in the
// block, even if the block itself is not reachable, or we are in
// a region which doesn't respect dominance.
if (op.getNumRegions() != 0) {
// Recursively verify dominance within each operation in the block, even
// if the block itself is not reachable, or we are in a region which
// doesn't respect dominance.
if (verifyRecursively && op.getNumRegions() != 0) {
// If this operation is IsolatedFromAbove, then we'll handle it in the
// outer verification loop.
if (op.hasTrait<OpTrait::IsIsolatedFromAbove>())
@ -352,9 +373,7 @@ OperationVerifier::verifyDominanceOfContainedRegions(Operation &op,
// Entrypoint
//===----------------------------------------------------------------------===//
/// Perform (potentially expensive) checks of invariants, used to detect
/// compiler bugs. On error, this reports the error through the MLIRContext and
/// returns failure.
LogicalResult mlir::verify(Operation *op) {
return OperationVerifier().verifyOpAndDominance(*op);
LogicalResult mlir::verify(Operation *op, bool verifyRecursively) {
OperationVerifier verifier(verifyRecursively);
return verifier.verifyOpAndDominance(*op);
}

View File

@ -408,22 +408,24 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
// failed).
if (!passFailed && verifyPasses) {
bool runVerifierNow = true;
// If the pass is an adaptor pass, we don't run the verifier recursively
// because the nested operations should have already been verified after
// nested passes had run.
bool runVerifierRecursively = !isa<OpToOpPassAdaptor>(pass);
// Reduce compile time by avoiding running the verifier if the pass didn't
// change the IR since the last time the verifier was run:
//
// 1) If the pass said that it preserved all analyses then it can't have
// permuted the IR.
// 2) If we just ran an OpToOpPassAdaptor (e.g. to run function passes
// within a module) then each sub-unit will have been verified on the
// subunit (and those passes aren't allowed to modify the parent).
//
// We run these checks in EXPENSIVE_CHECKS mode out of caution.
#ifndef EXPENSIVE_CHECKS
runVerifierNow = !isa<OpToOpPassAdaptor>(pass) &&
!pass->passState->preservedAnalyses.isAll();
runVerifierNow = !pass->passState->preservedAnalyses.isAll();
#endif
if (runVerifierNow)
passFailed = failed(verify(op));
passFailed = failed(verify(op, runVerifierRecursively));
}
// Instrument after the pass has run.

View File

@ -0,0 +1,8 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.func(test-pass-invalid-parent)' -verify-diagnostics
// Test that we properly report errors when the parent becomes invalid after running a pass
// on a child operation.
// expected-error@below {{'some_unknown_func' does not reference a valid function}}
func @TestCreateInvalidCallInPass() {
return
}

View File

@ -358,6 +358,21 @@ void TestDialect::getCanonicalizationPatterns(
results.add(&dialectCanonicalizationPattern);
}
//===----------------------------------------------------------------------===//
// TestCallOp
//===----------------------------------------------------------------------===//
LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Check that the callee attribute was specified.
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
return emitOpError("requires a 'callee' symbol reference attribute");
if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
return emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
return success();
}
//===----------------------------------------------------------------------===//
// TestFoldToCallOp
//===----------------------------------------------------------------------===//

View File

@ -375,6 +375,14 @@ def VariadicNoTerminatorOp : TEST_Op<"variadic_no_terminator_op",
// Test Call Interfaces
//===----------------------------------------------------------------------===//
def TestCallOp : TEST_Op<"call", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
let results = (outs Variadic<AnyType>);
let assemblyFormat = [{
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
}];
}
def ConversionCallOp : TEST_Op<"conversion_call_op",
[CallOpInterface]> {
let arguments = (ins Variadic<AnyType>:$arg_operands, SymbolRefAttr:$callee);

View File

@ -11,4 +11,11 @@ add_mlir_library(MLIRTestPass
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRTestDialect
)
target_include_directories(MLIRTestPass
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test
${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test
)

View File

@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@ -98,6 +99,27 @@ class TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
}
};
/// A test pass that always fails to enable testing the failure recovery
/// mechanisms of the pass manager.
class TestInvalidParentPass
: public PassWrapper<TestInvalidParentPass,
InterfacePass<FunctionOpInterface>> {
StringRef getArgument() const final { return "test-pass-invalid-parent"; }
StringRef getDescription() const final {
return "Test a pass in the pass manager that makes the parent operation "
"invalid";
}
void getDependentDialects(DialectRegistry &registry) const final {
registry.insert<test::TestDialect>();
}
void runOnOperation() final {
FunctionOpInterface op = getOperation();
OpBuilder b(getOperation().getBody());
b.create<test::TestCallOp>(op.getLoc(), TypeRange(), "some_unknown_func",
ValueRange());
}
};
/// A test pass that contains a statistic.
struct TestStatisticPass
: public PassWrapper<TestStatisticPass, OperationPass<>> {
@ -144,6 +166,7 @@ void registerPassManagerTestPass() {
PassRegistration<TestCrashRecoveryPass>();
PassRegistration<TestFailurePass>();
PassRegistration<TestInvalidParentPass>();
PassRegistration<TestStatisticPass>();