[mlir][transform] Check for invalidated iterators on payload IR mappings (#66369)

Add extra error checking (in debug mode) to detect cases where an
iterator on "direct" payload IR mappings is invalidated (due to elements
being removed). Such errors are hard to debug: they are often
non-deterministic; sometimes the program crashes, sometimes it produces
wrong results. Even when it crashes, the stack trace often points to
completely unrelated code locations.

Store a timestamp with each "direct" mapping. The timestamp is increased
whenever an operation is performed that invaldiates an iterator on that
mapping. A debug iterator is added that checks the timestamp as payload
IR is enumerated.
This commit is contained in:
Matthias Springer 2023-09-14 16:34:32 +02:00 committed by GitHub
parent 66aa9a2517
commit aca9019be0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 2 deletions

View File

@ -170,6 +170,12 @@ private:
/// should be emitted when the value is used.
using InvalidatedHandleMap = DenseMap<Value, std::function<void(Location)>>;
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
/// Debug only: A timestamp is associated with each transform IR value, so
/// that invalid iterator usage can be detected more reliably.
using TransformIRTimestampMapping = DenseMap<Value, int64_t>;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
/// The bidirectional mappings between transform IR values and payload IR
/// operations, and the mapping between transform IR values and parameters.
struct Mappings {
@ -178,6 +184,11 @@ private:
ParamMapping params;
ValueMapping values;
ValueMapping reverseValues;
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
TransformIRTimestampMapping timestamps;
void incrementTimestamp(Value value) { ++timestamps[value]; }
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};
friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
@ -207,10 +218,26 @@ public:
/// not enumerated. This function is helpful for transformations that apply to
/// a particular handle.
auto getPayloadOps(Value value) const {
ArrayRef<Operation *> view = getPayloadOpsView(value);
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Memorize the current timestamp and make sure that it has not changed
// when incrementing or dereferencing the iterator returned by this
// function. The timestamp is incremented when the "direct" mapping is
// resized; this would invalidate the iterator returned by this function.
int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
// When ops are replaced/erased, they are replaced with nullptr (until
// the data structure is compacted). Do not enumerate these ops.
return llvm::make_filter_range(getPayloadOpsView(value),
[](Operation *op) { return op != nullptr; });
return llvm::make_filter_range(view, [=](Operation *op) {
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
bool sameTimestamp =
currentTimestamp == this->getMapping(value).timestamps.lookup(value);
assert(sameTimestamp && "iterator was invalidated during iteration");
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
return op != nullptr;
});
}
/// Returns the list of parameters that the given transform IR value

View File

@ -310,6 +310,11 @@ void transform::TransformState::forgetMapping(Value opHandle,
for (Operation *op : mappings.direct[opHandle])
dropMappingEntry(mappings.reverse, op, opHandle);
mappings.direct.erase(opHandle);
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(opHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
for (Value opResult : origOpFlatResults) {
SmallVector<Value> resultHandles;
@ -336,6 +341,12 @@ void transform::TransformState::forgetValueMapping(
Mappings &localMappings = getMapping(opHandle);
dropMappingEntry(localMappings.direct, opHandle, payloadOp);
dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
localMappings.incrementTimestamp(opHandle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
}
}
@ -774,6 +785,13 @@ checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
void transform::TransformState::compactOpHandles() {
for (Value handle : opHandlesToCompact) {
Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
#ifndef LLVM_ENABLE_ABI_BREAKING_CHECKS
if (llvm::find(mappings.direct[handle], nullptr) !=
mappings.direct[handle].end())
// Payload IR is removed from the mapping. This invalidates the respective
// iterators.
mappings.incrementTimestamp(handle);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
llvm::erase_value(mappings.direct[handle], nullptr);
}
opHandlesToCompact.clear();