[mlir][NFC] Add helper for common pattern of replaceAllUsesExcept

This covers the extremely common case of replacing all uses of a Value
with a new op that is itself a user of the original Value.

This should also be a little bit more efficient than the
`SmallPtrSet<Operation *, 1>{op}` idiom that was being used before.

Differential Revision: https://reviews.llvm.org/D102373
This commit is contained in:
Sean Silva 2021-05-12 14:59:12 -07:00
parent b42fb6811e
commit 12874e93a1
6 changed files with 22 additions and 10 deletions

View File

@ -166,6 +166,11 @@ public:
replaceAllUsesExcept(Value newValue,
const SmallPtrSetImpl<Operation *> &exceptions) const;
/// Replace all uses of 'this' value with 'newValue', updating anything in the
/// IR that uses 'this' to use the other value instead except if the user is
/// 'exceptedUser'.
void replaceAllUsesExcept(Value newValue, Operation *exceptedUser) const;
/// Replace all uses of 'this' value with 'newValue' if the given callback
/// returns true.
void replaceUsesWithIf(Value newValue,

View File

@ -72,7 +72,7 @@ void mlir::normalizeAffineParallel(AffineParallelOp op) {
applyOperands.push_back(iv);
applyOperands.append(symbolOperands.begin(), symbolOperands.end());
auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands);
iv.replaceAllUsesExcept(apply, SmallPtrSet<Operation *, 1>{apply});
iv.replaceAllUsesExcept(apply, apply);
}
SmallVector<int64_t, 8> newSteps(op.getNumDims(), 1);
@ -181,8 +181,7 @@ static void normalizeAffineFor(AffineForOp op) {
AffineMap ivMap = AffineMap::get(origLbMap.getNumDims() + 1,
origLbMap.getNumSymbols(), newIVExpr);
Operation *newIV = opBuilder.create<AffineApplyOp>(loc, ivMap, lbOperands);
op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0),
SmallPtrSet<Operation *, 1>{newIV});
op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV);
}
namespace {

View File

@ -191,8 +191,7 @@ static LinalgOp fuse(OpBuilder &builder, LinalgOp producer,
AffineApplyOp applyOp = builder.create<AffineApplyOp>(
indexOp.getLoc(), index + offset,
ValueRange{indexOp.getResult(), loopRanges[indexOp.dim()].offset});
indexOp.getResult().replaceAllUsesExcept(
applyOp, SmallPtrSet<Operation *, 1>{applyOp});
indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
}
}

View File

@ -155,8 +155,7 @@ transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
AffineApplyOp applyOp = b.create<AffineApplyOp>(
indexOp.getLoc(), index + iv,
ValueRange{indexOp.getResult(), ivs[rangeIndex->second]});
indexOp.getResult().replaceAllUsesExcept(
applyOp.getResult(), SmallPtrSet<Operation *, 1>{applyOp});
indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
}
}

View File

@ -121,8 +121,7 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes) {
Value inner_index = std::get<0>(ivs);
AddIOp newIndex =
b.create<AddIOp>(op.getLoc(), std::get<0>(ivs), std::get<1>(ivs));
inner_index.replaceAllUsesExcept(
newIndex, SmallPtrSet<Operation *, 1>{newIndex.getOperation()});
inner_index.replaceAllUsesExcept(newIndex, newIndex);
}
op.erase();

View File

@ -63,12 +63,23 @@ void Value::replaceAllUsesWith(Value newValue) const {
/// listed in 'exceptions' .
void Value::replaceAllUsesExcept(
Value newValue, const SmallPtrSetImpl<Operation *> &exceptions) const {
for (auto &use : llvm::make_early_inc_range(getUses())) {
for (OpOperand &use : llvm::make_early_inc_range(getUses())) {
if (exceptions.count(use.getOwner()) == 0)
use.set(newValue);
}
}
/// Replace all uses of 'this' value with 'newValue', updating anything in the
/// IR that uses 'this' to use the other value instead except if the user is
/// 'exceptedUser'.
void Value::replaceAllUsesExcept(Value newValue,
Operation *exceptedUser) const {
for (OpOperand &use : llvm::make_early_inc_range(getUses())) {
if (use.getOwner() != exceptedUser)
use.set(newValue);
}
}
/// Replace all uses of 'this' value with 'newValue' if the given callback
/// returns true.
void Value::replaceUsesWithIf(Value newValue,