[mlir][transform] TrackingListener: Improve dead handles detection (#74290)

The tracking listener should not report op replacement errors for
payload ops that are not mapped to any live handles. The handle liveless
analysis did not work properly with transform IR that has named
sequences.

A handle is live if it has a user after the transform op that is
currently being applied. With named sequences, we need to maintain a
stack of currently applied transform ops. That stack already exists
(`regionStack`), the only thing that's missing is the current transform
op for each stack frame.

This commit fixes #72931.
This commit is contained in:
Matthias Springer 2023-12-06 16:32:22 +09:00 committed by GitHub
parent c630f95f33
commit e8ae0e72b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 112 additions and 55 deletions

View File

@ -310,10 +310,8 @@ public:
/// with the type of the handle value.
LogicalResult mapBlockArguments(BlockArgument argument,
ArrayRef<Operation *> operations) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
assert(argument.getParentRegion() == regionStack.back() &&
assert(argument.getParentRegion() == regionStack.back()->region &&
"mapping block arguments from a region other than the active one");
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
return setPayloadOps(argument, operations);
}
LogicalResult mapBlockArgument(BlockArgument argument,
@ -350,9 +348,7 @@ public:
std::make_pair(&region, std::make_unique<Mappings>()));
assert(res.second && "the region scope is already present");
(void)res;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
state.regionStack.push_back(&region);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
state.regionStack.push_back(this);
}
/// Back-reference to the transform state.
@ -361,7 +357,10 @@ public:
/// The region this scope is associated with.
Region *region;
friend RegionScope TransformState::make_region_scope(Region &);
/// The transform op within this region that is currently being applied.
TransformOpInterface currentTransform;
friend class transform::TransformState;
};
friend class RegionScope;
@ -784,12 +783,14 @@ private:
/// location.
InvalidatedHandleMap invalidatedHandles;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// A stack of nested regions that are being processed in the transform IR.
/// Each region must be an ancestor of the following regions in this list.
/// These are also the keys for "mappings".
SmallVector<Region *> regionStack;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
SmallVector<RegionScope *> regionStack;
/// The top-level region scope. The first (bottom) element of `regionStack`
/// is the top-level region scope object.
std::unique_ptr<RegionScope> topLevelRegionScope;
};
/// Local mapping between values defined by a specific op implementing the
@ -926,8 +927,14 @@ TransformState::RegionScope TransformState::make_region_scope(Region &region) {
class TrackingListener : public RewriterBase::Listener,
public TransformState::Extension {
public:
/// A function that returns "true" for handles that do not have to be updated.
using SkipHandleFn = std::function<bool(Value)>;
/// Create a new TrackingListener for usage in the specified transform op.
TrackingListener(TransformState &state, TransformOpInterface op);
/// Optionally, a function can be specified to identify handles that should
/// do not have to be updated.
TrackingListener(TransformState &state, TransformOpInterface op,
SkipHandleFn skipHandleFn = nullptr);
protected:
/// Return a replacement payload op for the given op, which is going to be
@ -1015,6 +1022,10 @@ private:
/// The handles that are consumed by the transform op.
DenseSet<Value> consumedHandles;
/// Handles for which this function evaluates to "true" do not have to be
/// updated. These are typically dead or consumed handles.
SkipHandleFn skipHandleFn;
};
/// A specialized listener that keeps track of cases in which no replacement

View File

@ -30,6 +30,23 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
/// properly dominates `b` and `b` is not inside `a`.
static bool happensBefore(Operation *a, Operation *b) {
do {
if (a->isProperAncestor(b))
return false;
if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
return a->isBeforeInBlock(bAncestor);
}
} while ((a = a->getParentOp()));
return false;
}
//===----------------------------------------------------------------------===//
// TransformState
//===----------------------------------------------------------------------===//
@ -44,14 +61,10 @@ transform::TransformState::TransformState(
topLevelMappedValues.reserve(extraMappings.size());
for (ArrayRef<MappedValue> mapping : extraMappings)
topLevelMappedValues.push_back(mapping);
auto result =
mappings.insert(std::make_pair(region, std::make_unique<Mappings>()));
assert(result.second && "the region scope is already present");
(void)result;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
regionStack.push_back(region);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
if (region) {
RegionScope *scope = new RegionScope(*this, *region);
topLevelRegionScope.reset(scope);
}
}
Operation *transform::TransformState::getTopLevel() const { return topLevel; }
@ -811,6 +824,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
});
// Set current transform op.
regionStack.back()->currentTransform = transform;
// Expensive checks to detect invalid transform IR.
if (options.getExpensiveChecksEnabled()) {
FULL_LDBG("ExpensiveChecksEnabled\n");
if (failed(checkAndRecordHandleInvalidation(transform)))
@ -899,7 +917,24 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
// Prepare rewriter and listener.
transform::ErrorCheckingTrackingListener trackingListener(*this, transform);
TrackingListener::SkipHandleFn skipHandleFn = [&](Value handle) {
// Skip handle if it is dead.
auto scopeIt =
llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
return handle.getParentRegion() == scope->region;
});
assert(scopeIt != regionStack.rend() &&
"could not find region scope for handle");
RegionScope *scope = *scopeIt;
for (Operation *user : handle.getUsers()) {
if (user != scope->currentTransform &&
!happensBefore(user, scope->currentTransform))
return false;
}
return true;
};
transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
skipHandleFn);
transform::TransformRewriter rewriter(transform->getContext(),
&trackingListener);
@ -1040,10 +1075,7 @@ transform::TransformState::RegionScope::~RegionScope() {
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
state.mappings.erase(region);
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
state.regionStack.pop_back();
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
//===----------------------------------------------------------------------===//
@ -1150,8 +1182,10 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const {
//===----------------------------------------------------------------------===//
transform::TrackingListener::TrackingListener(TransformState &state,
TransformOpInterface op)
: TransformState::Extension(state), transformOp(op) {
TransformOpInterface op,
SkipHandleFn skipHandleFn)
: TransformState::Extension(state), transformOp(op),
skipHandleFn(skipHandleFn) {
if (op) {
for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
consumedHandles.insert(opOperand->get());
@ -1251,19 +1285,6 @@ void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
});
}
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
/// properly dominates `b` and `b` is not inside `a`.
static bool happensBefore(Operation *a, Operation *b) {
do {
if (a->isProperAncestor(b))
return false;
if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
return a->isBeforeInBlock(bAncestor);
}
} while ((a = a->getParentOp()));
return false;
}
void transform::TrackingListener::notifyOperationReplaced(
Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
@ -1295,18 +1316,17 @@ void transform::TrackingListener::notifyOperationReplaced(
[&](Value h) { return consumedHandles.contains(h); });
};
// Helper function to check if the handle is alive.
auto firstAliveUser = [&]() -> std::optional<OpOperand *> {
for (Value v : opHandles) {
for (OpOperand &use : v.getUses())
if (use.getOwner() != transformOp &&
!happensBefore(use.getOwner(), transformOp))
return &use;
}
return std::nullopt;
}();
if (!firstAliveUser.has_value() || handleWasConsumed()) {
// Check if there are any handles that must be updated.
Value aliveHandle;
if (skipHandleFn) {
auto it =
llvm::find_if(opHandles, [&](Value v) { return !skipHandleFn(v); });
if (it != opHandles.end())
aliveHandle = *it;
} else if (!opHandles.empty()) {
aliveHandle = opHandles.front();
}
if (!aliveHandle || handleWasConsumed()) {
// The op is tracked but the corresponding handles are dead or were
// consumed. Drop the op form the mapping.
(void)replacePayloadOp(op, nullptr);
@ -1319,10 +1339,8 @@ void transform::TrackingListener::notifyOperationReplaced(
// If the op is tracked but no replacement op was found, send a
// notification.
if (!diag.succeeded()) {
diag.attachNote((*firstAliveUser)->getOwner()->getLoc())
<< "replacement is required because alive handle(s) exist "
<< "(first use in this op as operand number "
<< (*firstAliveUser)->getOperandNumber() << ")";
diag.attachNote(aliveHandle.getLoc())
<< "replacement is required because this handle must be updated";
notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
(void)replacePayloadOp(op, nullptr);
return;

View File

@ -36,6 +36,7 @@ func.func @replacement_op_not_found() {
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-note @below {{replacement is required because this handle must be updated}}
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{tracking listener failed to find replacement op during application of this transform op}}
// expected-note @below {{ran out of suitable replacement values}}
@ -44,7 +45,6 @@ transform.sequence failures(propagate) {
} : !transform.any_op
// %1 must be used in some way. If no replacement payload op could be found,
// an error is thrown only if the handle is not dead.
// expected-note @below {{replacement is required because alive handle(s) exist (first use in this op as operand number 0)}}
transform.annotate %1 "annotated" : !transform.any_op
}
@ -363,3 +363,31 @@ transform.sequence failures(propagate) {
legal_ops = ["func.func", "func.return", "test.new_op"]}
: !transform.any_op
}
// -----
module attributes { transform.with_named_sequence } {
func.func @replacement_op_not_found() {
// No op replacement can be found, but there are no handles that must be
// updated. No error should be reported.
"test.container"() ({
%0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32)
}) : () -> ()
return
}
transform.named_sequence @patterns(%container: !transform.any_op {transform.readonly}) {
transform.apply_patterns to %container {
transform.apply_patterns.transform.test_patterns
} : !transform.any_op
transform.yield
}
transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.annotate %1 "annotated" : !transform.any_op
transform.include @patterns failures(propagate) (%0) : (!transform.any_op) -> ()
}
}