mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-12 04:43:48 +00:00
[mlir][IR] Rename "update root" to "modify op" in rewriter API (#78260)
This commit renames 4 pattern rewriter API functions: * `updateRootInPlace` -> `modifyOpInPlace` * `startRootUpdate` -> `startOpModification` * `finalizeRootUpdate` -> `finalizeOpModification` * `cancelRootUpdate` -> `cancelOpModification` The term "root" is a misnomer. The root is the op that a rewrite pattern matches against (https://mlir.llvm.org/docs/PatternRewriter/#root-operation-name-optional). A rewriter must be notified of all in-place op modifications, not just in-place modifications of the root (https://mlir.llvm.org/docs/PatternRewriter/#pattern-rewriter). The old function names were confusing and have contributed to various broken rewrite patterns. Note: The new function names use the term "modify" instead of "update" for consistency with the `RewriterBase::Listener` terminology (`notifyOperationModified`).
This commit is contained in:
parent
57b50ef017
commit
5fcf907b34
@ -215,14 +215,14 @@ public:
|
||||
rewriter.replaceOpWithNewOp<ConvertOp>(
|
||||
addr, typeConverter.convertType(addr.getType()), addr.getVal());
|
||||
} else if (typeConverter.needsConversion(resTy)) {
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
op->getResult(0).setType(typeConverter.convertType(resTy));
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
}
|
||||
} else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
|
||||
mlir::FunctionType ty = func.getFunctionType();
|
||||
if (typeConverter.needsConversion(ty)) {
|
||||
rewriter.startRootUpdate(func);
|
||||
rewriter.startOpModification(func);
|
||||
auto toTy =
|
||||
typeConverter.convertType(ty).cast<mlir::FunctionType>();
|
||||
if (!func.empty())
|
||||
@ -235,7 +235,7 @@ public:
|
||||
block.eraseArgument(i + 1);
|
||||
}
|
||||
func.setType(toTy);
|
||||
rewriter.finalizeRootUpdate(func);
|
||||
rewriter.finalizeOpModification(func);
|
||||
}
|
||||
} else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
|
||||
// Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
|
||||
@ -273,10 +273,10 @@ public:
|
||||
} else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
|
||||
auto ty = global.getType();
|
||||
if (typeConverter.needsConversion(ty)) {
|
||||
rewriter.startRootUpdate(global);
|
||||
rewriter.startOpModification(global);
|
||||
auto toTy = typeConverter.convertType(ty);
|
||||
global.setType(toTy);
|
||||
rewriter.finalizeRootUpdate(global);
|
||||
rewriter.finalizeOpModification(global);
|
||||
}
|
||||
} else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
|
||||
auto ty = mem.getType();
|
||||
@ -339,17 +339,17 @@ public:
|
||||
mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
|
||||
}
|
||||
} else if (op->getDialect() == firDialect) {
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
for (auto i : llvm::enumerate(op->getResultTypes()))
|
||||
if (typeConverter.needsConversion(i.value())) {
|
||||
auto toTy = typeConverter.convertType(i.value());
|
||||
op->getResult(i.index()).setType(toTy);
|
||||
}
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
}
|
||||
// Ensure block arguments are updated if needed.
|
||||
if (op->getNumRegions() != 0) {
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
for (mlir::Region ®ion : op->getRegions())
|
||||
for (mlir::Block &block : region.getBlocks())
|
||||
for (mlir::BlockArgument blockArg : block.getArguments())
|
||||
@ -358,7 +358,7 @@ public:
|
||||
typeConverter.convertType(blockArg.getType());
|
||||
blockArg.setType(toTy);
|
||||
}
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -3763,13 +3763,13 @@ public:
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::LLVM::CallOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
auto callee = op.getCallee();
|
||||
if (callee)
|
||||
if (callee->equals("hypotf"))
|
||||
op.setCalleeAttr(mlir::SymbolRefAttr::get(op.getContext(), "_hypotf"));
|
||||
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
@ -3782,10 +3782,10 @@ public:
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::LLVM::LLVMFuncOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
if (op.getSymName().equals("hypotf"))
|
||||
op.setSymNameAttr(rewriter.getStringAttr("_hypotf"));
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
@ -256,9 +256,9 @@ struct AssignOpConversion : public mlir::OpConversionPattern<hlfir::AssignOp> {
|
||||
llvm::SmallVector<mlir::Value> newOperands;
|
||||
for (mlir::Value operand : adaptor.getOperands())
|
||||
newOperands.push_back(getBufferizedExprStorage(operand));
|
||||
rewriter.startRootUpdate(assign);
|
||||
rewriter.startOpModification(assign);
|
||||
assign->setOperands(newOperands);
|
||||
rewriter.finalizeRootUpdate(assign);
|
||||
rewriter.finalizeOpModification(assign);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
@ -834,9 +834,9 @@ struct ElementalOpConversion
|
||||
// Explicitly delete the body of the elemental to get rid
|
||||
// of any users of hlfir.expr values inside the body as early
|
||||
// as possible.
|
||||
rewriter.startRootUpdate(elemental);
|
||||
rewriter.startOpModification(elemental);
|
||||
rewriter.eraseBlock(elemental.getBody());
|
||||
rewriter.finalizeRootUpdate(elemental);
|
||||
rewriter.finalizeOpModification(elemental);
|
||||
rewriter.replaceOp(elemental, bufferizedExpr);
|
||||
return mlir::success();
|
||||
}
|
||||
|
@ -114,9 +114,9 @@ public:
|
||||
op.getValue());
|
||||
return success();
|
||||
}
|
||||
rewriter.startRootUpdate(op->getParentOp());
|
||||
rewriter.startOpModification(op->getParentOp());
|
||||
op.getResult().replaceAllUsesWith(op.getValue());
|
||||
rewriter.finalizeRootUpdate(op->getParentOp());
|
||||
rewriter.finalizeOpModification(op->getParentOp());
|
||||
rewriter.eraseOp(op);
|
||||
}
|
||||
return success();
|
||||
|
@ -464,15 +464,15 @@ public:
|
||||
auto affineFor = loopAndIndex.first;
|
||||
auto inductionVar = loopAndIndex.second;
|
||||
|
||||
rewriter.startRootUpdate(affineFor.getOperation());
|
||||
rewriter.startOpModification(affineFor.getOperation());
|
||||
affineFor.getBody()->getOperations().splice(
|
||||
std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
|
||||
std::prev(loopOps.end()));
|
||||
rewriter.finalizeRootUpdate(affineFor.getOperation());
|
||||
rewriter.finalizeOpModification(affineFor.getOperation());
|
||||
|
||||
rewriter.startRootUpdate(loop.getOperation());
|
||||
rewriter.startOpModification(loop.getOperation());
|
||||
loop.getInductionVar().replaceAllUsesWith(inductionVar);
|
||||
rewriter.finalizeRootUpdate(loop.getOperation());
|
||||
rewriter.finalizeOpModification(loop.getOperation());
|
||||
|
||||
rewriteMemoryOps(affineFor.getBody(), rewriter);
|
||||
|
||||
@ -561,7 +561,7 @@ public:
|
||||
auto affineIf = rewriter.create<affine::AffineIfOp>(
|
||||
op.getLoc(), affineCondition.getIntegerSet(),
|
||||
affineCondition.getAffineArgs(), !op.getElseRegion().empty());
|
||||
rewriter.startRootUpdate(affineIf);
|
||||
rewriter.startOpModification(affineIf);
|
||||
affineIf.getThenBlock()->getOperations().splice(
|
||||
std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
|
||||
std::prev(ifOps.end()));
|
||||
@ -571,7 +571,7 @@ public:
|
||||
std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
|
||||
std::prev(otherOps.end()));
|
||||
}
|
||||
rewriter.finalizeRootUpdate(affineIf);
|
||||
rewriter.finalizeOpModification(affineIf);
|
||||
rewriteMemoryOps(affineIf.getBody(), rewriter);
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
|
||||
|
@ -76,7 +76,7 @@ public:
|
||||
matchAndRewrite(mlir::func::FuncOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::LogicalResult ret = success();
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
llvm::StringRef oldName = op.getSymName();
|
||||
auto result = fir::NameUniquer::deconstruct(oldName);
|
||||
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
|
||||
@ -95,7 +95,7 @@ public:
|
||||
}
|
||||
|
||||
updateEarlyOutliningParentName(op, appendUnderscore);
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -114,7 +114,7 @@ public:
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(fir::GlobalOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
auto result = fir::NameUniquer::deconstruct(
|
||||
op.getSymref().getRootReference().getValue());
|
||||
if (fir::NameUniquer::isExternalFacingUniquedName(result)) {
|
||||
@ -122,7 +122,7 @@ public:
|
||||
op.setSymrefAttr(mlir::SymbolRefAttr::get(op.getContext(), newName));
|
||||
SymbolTable::setSymbolName(op, newName);
|
||||
}
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -213,15 +213,15 @@ user is determined by the specific pattern driver.
|
||||
This method replaces an operation's results with a set of provided values, and
|
||||
erases the operation.
|
||||
|
||||
* Update an Operation in-place : `(start|cancel|finalize)RootUpdate`
|
||||
* Update an Operation in-place : `(start|cancel|finalize)OpModification`
|
||||
|
||||
This is a collection of methods that provide a transaction-like API for updating
|
||||
the attributes, location, operands, or successors of an operation in-place
|
||||
within a pattern. An in-place update transaction is started with
|
||||
`startRootUpdate`, and may either be canceled or finalized with
|
||||
`cancelRootUpdate` and `finalizeRootUpdate` respectively. A convenience wrapper,
|
||||
`updateRootInPlace`, is provided that wraps a `start` and `finalize` around a
|
||||
callback.
|
||||
`startOpModification`, and may either be canceled or finalized with
|
||||
`cancelOpModification` and `finalizeOpModification` respectively. A convenience
|
||||
wrapper, `modifyOpInPlace`, is provided that wraps a `start` and `finalize`
|
||||
around a callback.
|
||||
|
||||
* OpBuilder API
|
||||
|
||||
|
@ -24,7 +24,7 @@ public:
|
||||
LogicalResult matchAndRewrite(func::FuncOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
if (op.getSymName() == "bar") {
|
||||
rewriter.updateRootInPlace(op, [&op]() { op.setSymName("foo"); });
|
||||
rewriter.modifyOpInPlace(op, [&op]() { op.setSymName("foo"); });
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
|
@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// We don't lower "toy.print" in this pass, but we need to update its
|
||||
// operands.
|
||||
rewriter.updateRootInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
rewriter.modifyOpInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// We don't lower "toy.print" in this pass, but we need to update its
|
||||
// operands.
|
||||
rewriter.updateRootInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
rewriter.modifyOpInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -260,8 +260,8 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> {
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// We don't lower "toy.print" in this pass, but we need to update its
|
||||
// operands.
|
||||
rewriter.updateRootInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
rewriter.modifyOpInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -585,28 +585,30 @@ public:
|
||||
|
||||
/// This method is used to notify the rewriter that an in-place operation
|
||||
/// modification is about to happen. A call to this function *must* be
|
||||
/// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
|
||||
/// This is a minor efficiency win (it avoids creating a new operation and
|
||||
/// removing the old one) but also often allows simpler code in the client.
|
||||
virtual void startRootUpdate(Operation *op) {}
|
||||
/// followed by a call to either `finalizeOpModification` or
|
||||
/// `cancelOpModification`. This is a minor efficiency win (it avoids creating
|
||||
/// a new operation and removing the old one) but also often allows simpler
|
||||
/// code in the client.
|
||||
virtual void startOpModification(Operation *op) {}
|
||||
|
||||
/// This method is used to signal the end of a root update on the given
|
||||
/// operation. This can only be called on operations that were provided to a
|
||||
/// call to `startRootUpdate`.
|
||||
virtual void finalizeRootUpdate(Operation *op);
|
||||
/// This method is used to signal the end of an in-place modification of the
|
||||
/// given operation. This can only be called on operations that were provided
|
||||
/// to a call to `startOpModification`.
|
||||
virtual void finalizeOpModification(Operation *op);
|
||||
|
||||
/// This method cancels a pending root update. This can only be called on
|
||||
/// operations that were provided to a call to `startRootUpdate`.
|
||||
virtual void cancelRootUpdate(Operation *op) {}
|
||||
/// This method cancels a pending in-place modification. This can only be
|
||||
/// called on operations that were provided to a call to
|
||||
/// `startOpModification`.
|
||||
virtual void cancelOpModification(Operation *op) {}
|
||||
|
||||
/// This method is a utility wrapper around a root update of an operation. It
|
||||
/// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
|
||||
/// callable.
|
||||
/// This method is a utility wrapper around an in-place modification of an
|
||||
/// operation. It wraps calls to `startOpModification` and
|
||||
/// `finalizeOpModification` around the given callable.
|
||||
template <typename CallableT>
|
||||
void updateRootInPlace(Operation *root, CallableT &&callable) {
|
||||
startRootUpdate(root);
|
||||
void modifyOpInPlace(Operation *root, CallableT &&callable) {
|
||||
startOpModification(root);
|
||||
callable();
|
||||
finalizeRootUpdate(root);
|
||||
finalizeOpModification(root);
|
||||
}
|
||||
|
||||
/// Find uses of `from` and replace them with `to`. It also marks every
|
||||
@ -619,7 +621,7 @@ public:
|
||||
void replaceAllUsesWith(IRObjectWithUseList<OperandType> *from, ValueT &&to) {
|
||||
for (OperandType &operand : llvm::make_early_inc_range(from->getUses())) {
|
||||
Operation *op = operand.getOwner();
|
||||
updateRootInPlace(op, [&]() { operand.set(to); });
|
||||
modifyOpInPlace(op, [&]() { operand.set(to); });
|
||||
}
|
||||
}
|
||||
void replaceAllUsesWith(ValueRange from, ValueRange to) {
|
||||
|
@ -739,17 +739,17 @@ public:
|
||||
/// PatternRewriter hook for inserting a new operation.
|
||||
void notifyOperationInserted(Operation *op) override;
|
||||
|
||||
/// PatternRewriter hook for updating the root operation in-place.
|
||||
/// Note: These methods only track updates to the top-level operation itself,
|
||||
/// PatternRewriter hook for updating the given operation in-place.
|
||||
/// Note: These methods only track updates to the given operation itself,
|
||||
/// and not nested regions. Updates to regions will still require notification
|
||||
/// through other more specific hooks above.
|
||||
void startRootUpdate(Operation *op) override;
|
||||
void startOpModification(Operation *op) override;
|
||||
|
||||
/// PatternRewriter hook for updating the root operation in-place.
|
||||
void finalizeRootUpdate(Operation *op) override;
|
||||
/// PatternRewriter hook for updating the given operation in-place.
|
||||
void finalizeOpModification(Operation *op) override;
|
||||
|
||||
/// PatternRewriter hook for updating the root operation in-place.
|
||||
void cancelRootUpdate(Operation *op) override;
|
||||
/// PatternRewriter hook for updating the given operation in-place.
|
||||
void cancelOpModification(Operation *op) override;
|
||||
|
||||
/// PatternRewriter hook for notifying match failure reasons.
|
||||
LogicalResult
|
||||
|
@ -255,7 +255,7 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
|
||||
// Step 2. Assign the op a real tile ID.
|
||||
// For simplicity, we always use tile 0 (which always exists).
|
||||
auto zeroTileId = rewriter.getI32IntegerAttr(0);
|
||||
rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
|
||||
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
|
||||
|
||||
VectorType tileVectorType = tileOp.getTileType();
|
||||
auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
|
||||
|
@ -918,8 +918,7 @@ LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
|
||||
for (auto stream : streams)
|
||||
streamDestroyCallBuilder.create(loc, rewriter, {stream});
|
||||
|
||||
rewriter.updateRootInPlace(yieldOp,
|
||||
[&] { yieldOp->setOperands(newOperands); });
|
||||
rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperands(newOperands); });
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -43,14 +43,13 @@ class ExpandIfCondition : public OpRewritePattern<OpTy> {
|
||||
if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) {
|
||||
auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
|
||||
op.getIfCond(), false);
|
||||
rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
|
||||
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
|
||||
auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener());
|
||||
thenBodyBuilder.clone(*op.getOperation());
|
||||
rewriter.eraseOp(op);
|
||||
} else {
|
||||
if (constAttr.getInt())
|
||||
rewriter.updateRootInPlace(op,
|
||||
[&]() { op.getIfCondMutable().erase(0); });
|
||||
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
|
||||
else
|
||||
rewriter.eraseOp(op);
|
||||
}
|
||||
|
@ -645,13 +645,13 @@ struct PrepareTransferWriteConversion
|
||||
rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
|
||||
buffers.dataBuffer);
|
||||
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
|
||||
rewriter.updateRootInPlace(xferOp, [&]() {
|
||||
rewriter.modifyOpInPlace(xferOp, [&]() {
|
||||
xferOp.getVectorMutable().assign(loadedVec);
|
||||
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
|
||||
});
|
||||
|
||||
if (xferOp.getMask()) {
|
||||
rewriter.updateRootInPlace(xferOp, [&]() {
|
||||
rewriter.modifyOpInPlace(xferOp, [&]() {
|
||||
xferOp.getMaskMutable().assign(buffers.maskBuffer);
|
||||
});
|
||||
}
|
||||
@ -966,7 +966,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
|
||||
loadIndices, iv);
|
||||
auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
|
||||
loadIndices);
|
||||
rewriter.updateRootInPlace(newXfer, [&]() {
|
||||
rewriter.modifyOpInPlace(newXfer, [&]() {
|
||||
newXfer.getMaskMutable().assign(mask);
|
||||
});
|
||||
}
|
||||
|
@ -2493,7 +2493,7 @@ FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
|
||||
newYieldValuesFn(rewriter, getLoc(), newIterArgs);
|
||||
assert(newInitOperands.size() == newYieldedValues.size() &&
|
||||
"expected as many new yield values as new iter operands");
|
||||
rewriter.updateRootInPlace(yieldOp, [&]() {
|
||||
rewriter.modifyOpInPlace(yieldOp, [&]() {
|
||||
yieldOp.getOperandsMutable().append(newYieldedValues);
|
||||
});
|
||||
}
|
||||
@ -2686,9 +2686,9 @@ struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
|
||||
!llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
|
||||
return failure();
|
||||
|
||||
rewriter.startRootUpdate(ifOp);
|
||||
rewriter.startOpModification(ifOp);
|
||||
rewriter.eraseBlock(ifOp.getElseBlock());
|
||||
rewriter.finalizeRootUpdate(ifOp);
|
||||
rewriter.finalizeOpModification(ifOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -71,10 +71,10 @@ void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter,
|
||||
op->getContext());
|
||||
canonicalizeMapAndOperands(&map, &operands);
|
||||
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
op.setMap(map);
|
||||
op->setOperands(operands);
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
}
|
||||
|
||||
/// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine
|
||||
|
@ -218,7 +218,7 @@ struct AssignTileIDsPattern
|
||||
return defaultVal;
|
||||
};
|
||||
auto setDiscardableIntAttr = [&](StringRef name, auto value) {
|
||||
rewriter.updateRootInPlace(tileOp, [&] {
|
||||
rewriter.modifyOpInPlace(tileOp, [&] {
|
||||
func->setDiscardableAttr(name,
|
||||
rewriter.getI32IntegerAttr((unsigned)value));
|
||||
});
|
||||
@ -274,10 +274,10 @@ struct AssignTileIDsPattern
|
||||
setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
|
||||
else
|
||||
setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
|
||||
rewriter.updateRootInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
|
||||
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
|
||||
for (auto *op : dependantOps) {
|
||||
if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
|
||||
}
|
||||
}
|
||||
|
@ -30,8 +30,8 @@ class ForwardOperands : public OpConversionPattern<OpTy> {
|
||||
if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
|
||||
return rewriter.notifyMatchFailure(op, "operand types already match");
|
||||
|
||||
rewriter.updateRootInPlace(
|
||||
op, [&]() { op->setOperands(adaptor.getOperands()); });
|
||||
rewriter.modifyOpInPlace(op,
|
||||
[&]() { op->setOperands(adaptor.getOperands()); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -106,8 +106,8 @@ struct RelaxScalableVectorAllocaAlignment
|
||||
|
||||
// Set alignment based on the defaults for SVE vectors and predicates.
|
||||
unsigned aligment = vectorType.getElementType().isInteger(1) ? 2 : 16;
|
||||
rewriter.updateRootInPlace(allocaOp,
|
||||
[&] { allocaOp.setAlignment(aligment); });
|
||||
rewriter.modifyOpInPlace(allocaOp,
|
||||
[&] { allocaOp.setAlignment(aligment); });
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -253,7 +253,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
|
||||
copiedOpOperands.contains(opOperand));
|
||||
if (failed(copy))
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
|
||||
rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
|
||||
}
|
||||
|
||||
// Insert copies of Values.
|
||||
@ -274,7 +274,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
|
||||
// dynamic extents. Do not update these either.
|
||||
if (isa<tensor::DimOp>(use->getOwner()))
|
||||
continue;
|
||||
rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
|
||||
rewriter.modifyOpInPlace(use->getOwner(), [&]() { use->set(*copy); });
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -895,7 +895,7 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
|
||||
deallocOp.getConditions() == conditions)
|
||||
return failure();
|
||||
|
||||
rewriter.updateRootInPlace(deallocOp, [&]() {
|
||||
rewriter.modifyOpInPlace(deallocOp, [&]() {
|
||||
deallocOp.getMemrefsMutable().assign(memrefs);
|
||||
deallocOp.getConditionsMutable().assign(conditions);
|
||||
});
|
||||
|
@ -42,7 +42,7 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
|
||||
deallocOp.getConditions() == conditions)
|
||||
return failure();
|
||||
|
||||
rewriter.updateRootInPlace(deallocOp, [&]() {
|
||||
rewriter.modifyOpInPlace(deallocOp, [&]() {
|
||||
deallocOp.getMemrefsMutable().assign(memrefs);
|
||||
deallocOp.getConditionsMutable().assign(conditions);
|
||||
});
|
||||
|
@ -403,8 +403,8 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
|
||||
constantTrue = rewriter.create<arith::ConstantOp>(
|
||||
condbr.getLoc(), ty, rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.updateRootInPlace(use.getOwner(),
|
||||
[&] { use.set(constantTrue); });
|
||||
rewriter.modifyOpInPlace(use.getOwner(),
|
||||
[&] { use.set(constantTrue); });
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -418,8 +418,8 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
|
||||
constantFalse = rewriter.create<arith::ConstantOp>(
|
||||
condbr.getLoc(), ty, rewriter.getBoolAttr(false));
|
||||
|
||||
rewriter.updateRootInPlace(use.getOwner(),
|
||||
[&] { use.set(constantFalse); });
|
||||
rewriter.modifyOpInPlace(use.getOwner(),
|
||||
[&] { use.set(constantFalse); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -86,7 +86,7 @@ struct DecomposeCallGraphTypesForFuncArgs
|
||||
if (failed(typeConverter->convertTypes(functionType.getResults(),
|
||||
newResultTypes)))
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(op, [&] {
|
||||
rewriter.modifyOpInPlace(op, [&] {
|
||||
op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
|
||||
newResultTypes));
|
||||
});
|
||||
|
@ -84,7 +84,7 @@ public:
|
||||
newOperands[idx] = operands[idx];
|
||||
}
|
||||
}
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [newOperands, op]() { op->setOperands(newOperands); });
|
||||
return success();
|
||||
}
|
||||
@ -107,8 +107,8 @@ public:
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// For a return, all operands go to the results of the parent, so
|
||||
// rewrite them all.
|
||||
rewriter.updateRootInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
rewriter.modifyOpInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -80,7 +80,7 @@ public:
|
||||
auto newType = FunctionType::get(rewriter.getContext(),
|
||||
argumentMapping.getConvertedTypes(),
|
||||
funcResultMapping.getConvertedTypes());
|
||||
rewriter.updateRootInPlace(op, [&] { op.setType(newType); });
|
||||
rewriter.modifyOpInPlace(op, [&] { op.setType(newType); });
|
||||
|
||||
// Update block signatures.
|
||||
if (!op.isExternal()) {
|
||||
@ -105,7 +105,7 @@ public:
|
||||
return failure();
|
||||
|
||||
// Convert operands.
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&] { op->setOperands(adaptor.getFlatOperands()); });
|
||||
|
||||
return success();
|
||||
|
@ -2030,7 +2030,7 @@ public:
|
||||
continue;
|
||||
validOperands.push_back(operand);
|
||||
}
|
||||
rewriter.updateRootInPlace(op, [&]() { op->setOperands(validOperands); });
|
||||
rewriter.modifyOpInPlace(op, [&]() { op->setOperands(validOperands); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -301,7 +301,7 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses(
|
||||
// the variable has been optimized out.
|
||||
auto undef =
|
||||
rewriter.create<UndefOp>(getValue().getLoc(), getValue().getType());
|
||||
rewriter.updateRootInPlace(*this, [&] { getValueMutable().assign(undef); });
|
||||
rewriter.modifyOpInPlace(*this, [&] { getValueMutable().assign(undef); });
|
||||
return DeletionKind::Keep;
|
||||
}
|
||||
|
||||
@ -394,7 +394,7 @@ DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
|
||||
return DeletionKind::Delete;
|
||||
}
|
||||
|
||||
rewriter.updateRootInPlace(*this, [&]() {
|
||||
rewriter.modifyOpInPlace(*this, [&]() {
|
||||
// Rewire the indices by popping off the second index.
|
||||
// Start with a single zero, then add the indices beyond the second.
|
||||
SmallVector<int32_t> newIndices(1);
|
||||
|
@ -83,8 +83,8 @@ static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter,
|
||||
op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType,
|
||||
op.getAddr(), firstTypeIndices);
|
||||
|
||||
rewriter.updateRootInPlace(op,
|
||||
[&]() { op.getAddrMutable().assign(properPtr); });
|
||||
rewriter.modifyOpInPlace(op,
|
||||
[&]() { op.getAddrMutable().assign(properPtr); });
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -111,8 +111,8 @@ LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
|
||||
rewriter.setInsertionPointAfterValue(load.getResult());
|
||||
BitcastOp bitcast = rewriter.create<BitcastOp>(
|
||||
load->getLoc(), load.getResult().getType(), load.getResult());
|
||||
rewriter.updateRootInPlace(load,
|
||||
[&]() { load.getResult().setType(firstType); });
|
||||
rewriter.modifyOpInPlace(load,
|
||||
[&]() { load.getResult().setType(firstType); });
|
||||
rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(),
|
||||
bitcast);
|
||||
}
|
||||
@ -141,7 +141,7 @@ LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
|
||||
|
||||
insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
|
||||
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
store, [&]() { store.getValueMutable().assign(store.getValue()); });
|
||||
|
||||
return success();
|
||||
@ -630,8 +630,8 @@ LogicalResult BitcastStores::matchAndRewrite(StoreOp store,
|
||||
|
||||
auto bitcastOp =
|
||||
rewriter.create<BitcastOp>(store.getLoc(), typeHint, store.getValue());
|
||||
rewriter.updateRootInPlace(
|
||||
store, [&] { store.getValueMutable().assign(bitcastOp); });
|
||||
rewriter.modifyOpInPlace(store,
|
||||
[&] { store.getValueMutable().assign(bitcastOp); });
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -785,7 +785,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
|
||||
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
|
||||
|
||||
// Replace the use in containingOp.
|
||||
rewriter.updateRootInPlace(containingOp, [&]() {
|
||||
rewriter.modifyOpInPlace(containingOp, [&]() {
|
||||
containingOp->setOperand(pUse->getOperandNumber(),
|
||||
destinationTensors.front());
|
||||
});
|
||||
@ -835,7 +835,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(use->getOwner());
|
||||
fusedOp = rewriter.clone(*producerOp);
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
|
||||
|
||||
return fusedOp;
|
||||
|
@ -311,7 +311,7 @@ Value linalg::bufferizeToAllocation(
|
||||
auto toTensorOp =
|
||||
resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
|
||||
assert(toTensorOp && "expected to_tensor op");
|
||||
rewriter.updateRootInPlace(toTensorOp, [&]() {
|
||||
rewriter.modifyOpInPlace(toTensorOp, [&]() {
|
||||
toTensorOp.setRestrict(true);
|
||||
toTensorOp.setWritable(true);
|
||||
});
|
||||
@ -559,11 +559,11 @@ Value linalg::bufferizeToAllocation(
|
||||
// tensor is uninitialized.
|
||||
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
|
||||
}
|
||||
rewriter.updateRootInPlace(op, [&]() {
|
||||
rewriter.modifyOpInPlace(op, [&]() {
|
||||
auto toTensorOp = rewriter.create<ToTensorOp>(op->getLoc(), alloc);
|
||||
operand->set(toTensorOp);
|
||||
if (options.bufferizeDestinationOnly) {
|
||||
rewriter.updateRootInPlace(toTensorOp, [&]() {
|
||||
rewriter.modifyOpInPlace(toTensorOp, [&]() {
|
||||
toTensorOp.setRestrict(true);
|
||||
toTensorOp.setWritable(true);
|
||||
});
|
||||
@ -584,7 +584,7 @@ Value linalg::bufferizeToAllocation(
|
||||
for (OpOperand *resultUse : resultUses) {
|
||||
auto toTensorOp = resultUse->get().getDefiningOp<ToTensorOp>();
|
||||
assert(toTensorOp && "expected to_tensor op");
|
||||
rewriter.updateRootInPlace(toTensorOp, [&]() {
|
||||
rewriter.modifyOpInPlace(toTensorOp, [&]() {
|
||||
toTensorOp.setRestrict(true);
|
||||
toTensorOp.setWritable(true);
|
||||
});
|
||||
|
@ -104,7 +104,7 @@ struct FunctionNonEntryBlockConversion
|
||||
LogicalResult
|
||||
matchAndRewrite(FunctionOpInterface op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
Region ®ion = op.getFunctionBody();
|
||||
SmallVector<TypeConverter::SignatureConversion, 2> conversions;
|
||||
|
||||
@ -125,11 +125,11 @@ struct FunctionNonEntryBlockConversion
|
||||
|
||||
if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter,
|
||||
conversions))) {
|
||||
rewriter.cancelRootUpdate(op);
|
||||
rewriter.cancelOpModification(op);
|
||||
return failure();
|
||||
}
|
||||
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1816,7 +1816,7 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
bool modifiedOutput = false;
|
||||
Location loc = op.getLoc();
|
||||
for (OpOperand &opOperand : op.getDpsInitsMutable()) {
|
||||
@ -1843,10 +1843,10 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
|
||||
}
|
||||
}
|
||||
if (!modifiedOutput) {
|
||||
rewriter.cancelRootUpdate(op);
|
||||
rewriter.cancelOpModification(op);
|
||||
return failure();
|
||||
}
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -87,7 +87,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
|
||||
}
|
||||
|
||||
// Turn the "in" into an "out".
|
||||
rewriter.updateRootInPlace(op, [&]() {
|
||||
rewriter.modifyOpInPlace(op, [&]() {
|
||||
out->set(in->get());
|
||||
// The original "in" could be removed entirely here (because it will no
|
||||
// longer have any uses in the payload), but we delegate this to
|
||||
|
@ -354,7 +354,7 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
|
||||
// Directly replace the cycle with the blockArg such that
|
||||
// Deduplicate pattern can eliminate it along with unused yield.
|
||||
rewriter.replaceOp(cycleOp, outputArg);
|
||||
rewriter.updateRootInPlace(genericOp, [] {});
|
||||
rewriter.modifyOpInPlace(genericOp, [] {});
|
||||
hasRemovedCycles = true;
|
||||
}
|
||||
|
||||
@ -404,7 +404,7 @@ struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
|
||||
return failure();
|
||||
|
||||
// Rewrite the op.
|
||||
rewriter.updateRootInPlace(genericOp, [&]() {
|
||||
rewriter.modifyOpInPlace(genericOp, [&]() {
|
||||
for (auto [before, after] : replacements) {
|
||||
BlockArgument bbArg = genericOp.getBody()->getArgument(before);
|
||||
BlockArgument replacement = genericOp.getBody()->getArgument(after);
|
||||
|
@ -854,10 +854,10 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
|
||||
LLVM_DEBUG(DBGS() << "with result #"
|
||||
<< numOriginalForOpResults + iterArgNumber
|
||||
<< " of forOp, giving us: " << extracted << "\n");
|
||||
rewriter.startRootUpdate(extracted);
|
||||
rewriter.startOpModification(extracted);
|
||||
extracted.getSourceMutable().assign(
|
||||
newForOp.getResult(numOriginalForOpResults + iterArgNumber));
|
||||
rewriter.finalizeRootUpdate(extracted);
|
||||
rewriter.finalizeOpModification(extracted);
|
||||
|
||||
LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
|
||||
<< "\n");
|
||||
|
@ -60,9 +60,9 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
|
||||
assert(permutationMap && "unexpected null map");
|
||||
|
||||
// Start a guarded inplace update.
|
||||
rewriter.startRootUpdate(genericOp);
|
||||
auto guard =
|
||||
llvm::make_scope_exit([&]() { rewriter.finalizeRootUpdate(genericOp); });
|
||||
rewriter.startOpModification(genericOp);
|
||||
auto guard = llvm::make_scope_exit(
|
||||
[&]() { rewriter.finalizeOpModification(genericOp); });
|
||||
|
||||
// 2. Compute the interchanged indexing maps.
|
||||
SmallVector<AffineMap> newIndexingMaps;
|
||||
|
@ -113,7 +113,7 @@ linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension,
|
||||
// Need to pretend that the original op now takes as operands firstResults,
|
||||
// otherwise tiling interface implementation will take the wrong value to
|
||||
// produce data tiles.
|
||||
rewriter.updateRootInPlace(op, [&]() {
|
||||
rewriter.modifyOpInPlace(op, [&]() {
|
||||
unsigned numTotalOperands = op->getNumOperands();
|
||||
unsigned numOutputOperands = firstResults.size();
|
||||
op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands,
|
||||
|
@ -722,7 +722,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
|
||||
// We cannot use a IRMapping here because it can replace
|
||||
// different OpOperands with the same value.
|
||||
Operation *clonedOp = b.clone(*op.getOperation());
|
||||
b.updateRootInPlace(clonedOp, [&]() {
|
||||
b.modifyOpInPlace(clonedOp, [&]() {
|
||||
for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
|
||||
cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(),
|
||||
tiledDpsInitOperands)) {
|
||||
|
@ -1952,7 +1952,7 @@ struct PadOpVectorizationWithTransferReadPattern
|
||||
if (xferOp.hasOutOfBoundsDim() || xferOp.getMask())
|
||||
return failure();
|
||||
|
||||
rewriter.updateRootInPlace(xferOp, [&]() {
|
||||
rewriter.modifyOpInPlace(xferOp, [&]() {
|
||||
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
|
||||
xferOp->setAttr(xferOp.getInBoundsAttrName(),
|
||||
rewriter.getBoolArrayAttr(inBounds));
|
||||
|
@ -227,7 +227,7 @@ DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot,
|
||||
Attribute index = getAttributeIndexFromIndexOperands(
|
||||
getContext(), getIndices(), getMemRefType());
|
||||
const MemorySlot &memorySlot = subslots.at(index);
|
||||
rewriter.updateRootInPlace(*this, [&]() {
|
||||
rewriter.modifyOpInPlace(*this, [&]() {
|
||||
setMemRef(memorySlot.ptr);
|
||||
getIndicesMutable().clear();
|
||||
});
|
||||
@ -280,7 +280,7 @@ DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot,
|
||||
Attribute index = getAttributeIndexFromIndexOperands(
|
||||
getContext(), getIndices(), getMemRefType());
|
||||
const MemorySlot &memorySlot = subslots.at(index);
|
||||
rewriter.updateRootInPlace(*this, [&]() {
|
||||
rewriter.modifyOpInPlace(*this, [&]() {
|
||||
setMemRef(memorySlot.ptr);
|
||||
getIndicesMutable().clear();
|
||||
});
|
||||
|
@ -792,7 +792,7 @@ struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
|
||||
if (fromType && toType) {
|
||||
if (fromType.getShape() == toType.getShape() &&
|
||||
fromType.getElementType() == toType.getElementType()) {
|
||||
rewriter.updateRootInPlace(copyOp, [&] {
|
||||
rewriter.modifyOpInPlace(copyOp, [&] {
|
||||
copyOp.getSourceMutable().assign(castOp.getSource());
|
||||
});
|
||||
modified = true;
|
||||
@ -808,7 +808,7 @@ struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
|
||||
if (fromType && toType) {
|
||||
if (fromType.getShape() == toType.getShape() &&
|
||||
fromType.getElementType() == toType.getElementType()) {
|
||||
rewriter.updateRootInPlace(copyOp, [&] {
|
||||
rewriter.modifyOpInPlace(copyOp, [&] {
|
||||
copyOp.getTargetMutable().assign(castOp.getSource());
|
||||
});
|
||||
modified = true;
|
||||
@ -1366,7 +1366,7 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
|
||||
loc, llvm::cast<IntegerAttr>(maybeConstant.template get<Attribute>())
|
||||
.getInt());
|
||||
for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
|
||||
// updateRootInplace: lambda cannot capture structured bindings in C++17
|
||||
// modifyOpInPlace: lambda cannot capture structured bindings in C++17
|
||||
// yet.
|
||||
op->replaceUsesOfWith(result, constantVal);
|
||||
atLeastOneReplacement = true;
|
||||
@ -2436,7 +2436,7 @@ public:
|
||||
op.getReassociationIndices());
|
||||
|
||||
if (newResultType == op.getResultType()) {
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
|
||||
} else {
|
||||
Value newOp = rewriter.create<CollapseShapeOp>(
|
||||
|
@ -797,7 +797,7 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
|
||||
extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
|
||||
if (!viewLikeOp)
|
||||
return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
|
||||
rewriter.updateRootInPlace(extractOp, [&]() {
|
||||
rewriter.modifyOpInPlace(extractOp, [&]() {
|
||||
extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
|
||||
});
|
||||
return success();
|
||||
|
@ -154,7 +154,7 @@ static void replaceAndPropagateMemRefType(RewriterBase &rewriter,
|
||||
for (OpOperand &operand : user->getOpOperands()) {
|
||||
if ([[maybe_unused]] auto castOp =
|
||||
operand.get().getDefiningOp<UnrealizedConversionCastOp>()) {
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
user, [&]() { operand.set(conversion->getOperand(0)); });
|
||||
}
|
||||
}
|
||||
|
@ -79,9 +79,9 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter,
|
||||
// TODO: can we use an early_inc iterator?
|
||||
for (OpOperand *operand : operandsToReplace) {
|
||||
Operation *op = operand->getOwner();
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
operand->set(val);
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
}
|
||||
|
||||
// Perform late op erasure.
|
||||
|
@ -54,7 +54,7 @@ struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
|
||||
"for nvgpu.mma.sync on f32 datatype");
|
||||
|
||||
if (precision == MmaSyncF32Lowering::TF32) {
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&]() { op.setTf32EnabledAttr(rewriter.getUnitAttr()); });
|
||||
}
|
||||
|
||||
|
@ -359,7 +359,7 @@ struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
|
||||
if (!matchPattern(ifCond, m_Constant(&constAttr)))
|
||||
return failure();
|
||||
if (constAttr.getInt())
|
||||
rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
|
||||
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
|
||||
else
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
@ -398,7 +398,7 @@ struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
|
||||
if (!matchPattern(ifCond, m_Constant(&constAttr)))
|
||||
return failure();
|
||||
if (constAttr.getInt())
|
||||
rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
|
||||
rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
|
||||
else
|
||||
replaceOpWithRegion(rewriter, op, op.getRegion());
|
||||
|
||||
|
@ -552,7 +552,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
|
||||
newYieldValuesFn(rewriter, getLoc(), newIterArgs);
|
||||
assert(newInitOperands.size() == newYieldedValues.size() &&
|
||||
"expected as many new yield values as new iter operands");
|
||||
rewriter.updateRootInPlace(yieldOp, [&]() {
|
||||
rewriter.modifyOpInPlace(yieldOp, [&]() {
|
||||
yieldOp.getResultsMutable().append(newYieldedValues);
|
||||
});
|
||||
}
|
||||
@ -1444,7 +1444,7 @@ struct DimOfForallOp : public OpRewritePattern<tensor::DimOp> {
|
||||
Value sharedOut =
|
||||
forallOp.getTiedOpOperand(llvm::cast<OpResult>(dimOp.getSource()))
|
||||
->get();
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
dimOp, [&]() { dimOp.getSourceMutable().assign(sharedOut); });
|
||||
return success();
|
||||
}
|
||||
@ -1464,7 +1464,7 @@ public:
|
||||
failed(foldDynamicIndexList(mixedStep)))
|
||||
return failure();
|
||||
|
||||
rewriter.updateRootInPlace(op, [&]() {
|
||||
rewriter.modifyOpInPlace(op, [&]() {
|
||||
SmallVector<Value> dynamicLowerBound, dynamicUpperBound, dynamicStep;
|
||||
SmallVector<int64_t> staticLowerBound, staticUpperBound, staticStep;
|
||||
dispatchIndexOpFoldResults(mixedLowerBound, dynamicLowerBound,
|
||||
@ -1556,7 +1556,7 @@ struct ForallOpSingleOrZeroIterationDimsFolder
|
||||
for (const auto &namedAttr : op->getAttrs()) {
|
||||
if (llvm::is_contained(elidedAttrs, namedAttr.getName()))
|
||||
continue;
|
||||
rewriter.updateRootInPlace(newOp, [&]() {
|
||||
rewriter.modifyOpInPlace(newOp, [&]() {
|
||||
newOp->setAttr(namedAttr.getName(), namedAttr.getValue());
|
||||
});
|
||||
}
|
||||
@ -2023,8 +2023,8 @@ struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
|
||||
[&](OpResult result) {
|
||||
return yieldOp.getOperand(result.getResultNumber());
|
||||
});
|
||||
rewriter.updateRootInPlace(yieldOp,
|
||||
[&]() { yieldOp->setOperands(usedOperands); });
|
||||
rewriter.modifyOpInPlace(yieldOp,
|
||||
[&]() { yieldOp->setOperands(usedOperands); });
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(IfOp op,
|
||||
@ -2189,8 +2189,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
|
||||
constantTrue = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
|
||||
|
||||
rewriter.updateRootInPlace(use.getOwner(),
|
||||
[&]() { use.set(constantTrue); });
|
||||
rewriter.modifyOpInPlace(use.getOwner(),
|
||||
[&]() { use.set(constantTrue); });
|
||||
} else if (op.getElseRegion().isAncestor(
|
||||
use.getOwner()->getParentRegion())) {
|
||||
changed = true;
|
||||
@ -2199,8 +2199,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> {
|
||||
constantFalse = rewriter.create<arith::ConstantOp>(
|
||||
op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
|
||||
|
||||
rewriter.updateRootInPlace(use.getOwner(),
|
||||
[&]() { use.set(constantFalse); });
|
||||
rewriter.modifyOpInPlace(use.getOwner(),
|
||||
[&]() { use.set(constantFalse); });
|
||||
}
|
||||
}
|
||||
|
||||
@ -2383,14 +2383,14 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
|
||||
llvm::make_early_inc_range(std::get<0>(it).getUses())) {
|
||||
if (nextThen && nextThen->getParent()->isAncestor(
|
||||
use.getOwner()->getParentRegion())) {
|
||||
rewriter.startRootUpdate(use.getOwner());
|
||||
rewriter.startOpModification(use.getOwner());
|
||||
use.set(std::get<1>(it));
|
||||
rewriter.finalizeRootUpdate(use.getOwner());
|
||||
rewriter.finalizeOpModification(use.getOwner());
|
||||
} else if (nextElse && nextElse->getParent()->isAncestor(
|
||||
use.getOwner()->getParentRegion())) {
|
||||
rewriter.startRootUpdate(use.getOwner());
|
||||
rewriter.startOpModification(use.getOwner());
|
||||
use.set(std::get<2>(it));
|
||||
rewriter.finalizeRootUpdate(use.getOwner());
|
||||
rewriter.finalizeOpModification(use.getOwner());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -688,7 +688,7 @@ struct ForOpInterface
|
||||
yieldValues.push_back(*alloc);
|
||||
}
|
||||
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); });
|
||||
return success();
|
||||
}
|
||||
@ -928,7 +928,7 @@ struct WhileOpInterface
|
||||
return failure();
|
||||
beforeYieldValues.push_back(*alloc);
|
||||
}
|
||||
rewriter.updateRootInPlace(conditionOp, [&]() {
|
||||
rewriter.modifyOpInPlace(conditionOp, [&]() {
|
||||
conditionOp.getArgsMutable().assign(beforeYieldValues);
|
||||
});
|
||||
|
||||
|
@ -89,8 +89,8 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
|
||||
for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
|
||||
SmallVector<Value> yieldOperands = yieldOp.getOperands();
|
||||
yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
|
||||
rewriter.updateRootInPlace(
|
||||
yieldOp, [&]() { yieldOp->setOperands(yieldOperands); });
|
||||
rewriter.modifyOpInPlace(yieldOp,
|
||||
[&]() { yieldOp->setOperands(yieldOperands); });
|
||||
}
|
||||
|
||||
// We cannot do a direct replacement of the forOp since the while op returns
|
||||
|
@ -99,7 +99,7 @@ struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
|
||||
return failure();
|
||||
|
||||
Value initArg = forOp.getTiedLoopInit(blockArg)->get();
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
|
||||
|
||||
return success();
|
||||
@ -141,7 +141,7 @@ struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
|
||||
unsigned resultNumber = opResult.getResultNumber();
|
||||
if (!isShapePreserving(forOp, resultNumber))
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(dimOp, [&]() {
|
||||
rewriter.modifyOpInPlace(dimOp, [&]() {
|
||||
dimOp.getSourceMutable().assign(forOp.getInitArgs()[resultNumber]);
|
||||
});
|
||||
return success();
|
||||
|
@ -160,8 +160,8 @@ static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
|
||||
partialIteration.getInitArgsMutable().assign(forOp->getResults());
|
||||
|
||||
// Set new upper loop bound.
|
||||
b.updateRootInPlace(
|
||||
forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); });
|
||||
b.modifyOpInPlace(forOp,
|
||||
[&]() { forOp.getUpperBoundMutable().assign(splitBound); });
|
||||
|
||||
return success();
|
||||
}
|
||||
@ -239,7 +239,7 @@ LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
|
||||
firstIteration = cast<ForOp>(b.clone(*forOp.getOperation(), map));
|
||||
|
||||
// Update main loop with new lower bound.
|
||||
b.updateRootInPlace(forOp, [&]() {
|
||||
b.modifyOpInPlace(forOp, [&]() {
|
||||
forOp.getInitArgsMutable().assign(firstIteration->getResults());
|
||||
forOp.getLowerBoundMutable().assign(splitBound);
|
||||
});
|
||||
@ -286,11 +286,11 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
|
||||
}
|
||||
|
||||
// Apply label, so that the same loop is not rewritten a second time.
|
||||
rewriter.updateRootInPlace(partialIteration, [&]() {
|
||||
rewriter.modifyOpInPlace(partialIteration, [&]() {
|
||||
partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
|
||||
partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr());
|
||||
});
|
||||
rewriter.updateRootInPlace(forOp, [&]() {
|
||||
rewriter.modifyOpInPlace(forOp, [&]() {
|
||||
forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr());
|
||||
});
|
||||
return success();
|
||||
|
@ -111,7 +111,7 @@ public:
|
||||
return failure();
|
||||
|
||||
// Convert operands.
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&] { op->setOperands(adaptor.getFlatOperands()); });
|
||||
|
||||
return success();
|
||||
@ -131,7 +131,7 @@ public:
|
||||
return failure();
|
||||
|
||||
// Convert operands.
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&] { op->setOperands(adaptor.getFlatOperands()); });
|
||||
|
||||
return success();
|
||||
|
@ -241,7 +241,7 @@ public:
|
||||
for (Value operand : adaptor.getOperands())
|
||||
unpackUnrealizedConversionCast(operand, unpackedYield);
|
||||
|
||||
rewriter.updateRootInPlace(op, [&]() { op->setOperands(unpackedYield); });
|
||||
rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -692,7 +692,7 @@ void mlir::scf::yieldReplacementForFusedProducer(
|
||||
sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
|
||||
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
|
||||
unsigned resultNumber = fusableProducer.getResultNumber();
|
||||
rewriter.updateRootInPlace(tiledDestStyleOp, [&]() {
|
||||
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
|
||||
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
|
||||
});
|
||||
}
|
||||
|
@ -91,8 +91,8 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(OpT op, typename OpT::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.updateRootInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
rewriter.modifyOpInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -261,7 +261,7 @@ LogicalResult ProcessInterfaceVarABI::matchAndRewrite(
|
||||
return failure();
|
||||
|
||||
// Creates a new function with the update signature.
|
||||
rewriter.updateRootInPlace(funcOp, [&] {
|
||||
rewriter.modifyOpInPlace(funcOp, [&] {
|
||||
funcOp.setType(rewriter.getFunctionType(
|
||||
signatureConverter.getConvertedTypes(), std::nullopt));
|
||||
});
|
||||
|
@ -29,7 +29,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
|
||||
|
||||
// Clones the original operation but changing the output to an unordered COO.
|
||||
Operation *cloned = rewriter.clone(*op.getOperation());
|
||||
rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() {
|
||||
rewriter.modifyOpInPlace(cloned, [cloned, srcCOOTp]() {
|
||||
cloned->getOpResult(0).setType(srcCOOTp);
|
||||
});
|
||||
Value srcCOO = cloned->getOpResult(0);
|
||||
|
@ -389,14 +389,14 @@ public:
|
||||
auto stt = tryGetSparseTensorType(res);
|
||||
auto [idxMap, itTp] = *transMap;
|
||||
|
||||
rewriter.startRootUpdate(linalgOp);
|
||||
rewriter.startOpModification(linalgOp);
|
||||
linalgOp.setIndexingMapsAttr(idxMap);
|
||||
linalgOp.setIteratorTypesAttr(itTp);
|
||||
// Use demapped arguments.
|
||||
linalgOp.getInputsMutable().assign(adaptor.getInputs());
|
||||
linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
|
||||
res.setType(adaptor.getOutputs()[0].getType());
|
||||
rewriter.finalizeRootUpdate(linalgOp);
|
||||
rewriter.finalizeOpModification(linalgOp);
|
||||
|
||||
rewriter.setInsertionPointAfter(linalgOp);
|
||||
if (stt && stt->hasEncoding()) {
|
||||
@ -458,7 +458,7 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
|
||||
}
|
||||
|
||||
// Marks the GenericOp to avoid recursive matching.
|
||||
rewriter.updateRootInPlace(linalgOp, [&]() {
|
||||
rewriter.modifyOpInPlace(linalgOp, [&]() {
|
||||
linalgOp->setAttr(sorted, rewriter.getBoolAttr(true));
|
||||
});
|
||||
|
||||
@ -482,10 +482,10 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
|
||||
for (AffineMap &idxMap : idxMaps)
|
||||
idxMap = idxMap.compose(order); // sorted loop -> lvl map
|
||||
|
||||
rewriter.startRootUpdate(linalgOp);
|
||||
rewriter.startOpModification(linalgOp);
|
||||
linalgOp.setIndexingMapsAttr(rewriter.getAffineMapArrayAttr(idxMaps));
|
||||
linalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(curItTypes));
|
||||
rewriter.finalizeRootUpdate(linalgOp);
|
||||
rewriter.finalizeOpModification(linalgOp);
|
||||
|
||||
return success();
|
||||
}
|
||||
@ -570,7 +570,7 @@ private:
|
||||
rewriter.setInsertionPoint(linalgOp);
|
||||
RankedTensorType dstTp = stt.withDimToLvl(dimToLvl).getRankedTensorType();
|
||||
Value dst = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval);
|
||||
rewriter.updateRootInPlace(linalgOp, [&]() {
|
||||
rewriter.modifyOpInPlace(linalgOp, [&]() {
|
||||
linalgOp->setOperand(t->getOperandNumber(), dst);
|
||||
});
|
||||
return success();
|
||||
@ -623,10 +623,10 @@ struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
|
||||
}
|
||||
|
||||
assert(dynSz.empty()); // should have consumed all.
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
op->setOperands(dynLvlSzs);
|
||||
op.getResult().setType(stt.getDemappedType());
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
|
||||
Value t = genRemap(rewriter, stt.getEncoding(), op.getResult());
|
||||
@ -676,7 +676,7 @@ struct ForeachOpDemapper
|
||||
auto srcStt = getSparseTensorType(op.getTensor());
|
||||
SmallVector<Type> prevRetTps(op.getResultTypes());
|
||||
|
||||
rewriter.startRootUpdate(op);
|
||||
rewriter.startOpModification(op);
|
||||
op.getTensorMutable().assign(adaptor.getTensor());
|
||||
op.getInitArgsMutable().assign(adaptor.getInitArgs());
|
||||
// Update results' types.
|
||||
@ -731,7 +731,7 @@ struct ForeachOpDemapper
|
||||
rewriter.eraseOp(yield);
|
||||
}
|
||||
}
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
rewriter.finalizeOpModification(op);
|
||||
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
SmallVector<Value> outs =
|
||||
|
@ -329,7 +329,7 @@ public:
|
||||
.getCopy();
|
||||
AllocTensorOp a =
|
||||
op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
|
||||
rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); });
|
||||
rewriter.modifyOpInPlace(a, [&]() { a.getCopyMutable().assign(init); });
|
||||
}
|
||||
// Replace consumer with fused operation. Old producer
|
||||
// and consumer ops will be removed by DCE.
|
||||
@ -366,7 +366,7 @@ public:
|
||||
if (tensor::isSameTypeWithoutEncoding(srcType, dstType)) {
|
||||
if (Operation *def = op.getSource().getDefiningOp()) {
|
||||
if (def->hasOneUse() && isa<tensor::ExtractSliceOp>(def)) {
|
||||
rewriter.updateRootInPlace(def, [&]() {
|
||||
rewriter.modifyOpInPlace(def, [&]() {
|
||||
def->getResult(0).setType(op->getResultTypes()[0]);
|
||||
});
|
||||
rewriter.replaceOp(op, def->getResult(0));
|
||||
@ -804,7 +804,7 @@ public:
|
||||
auto denseTp =
|
||||
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
|
||||
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
|
||||
rewriter.updateRootInPlace(op, [&]() { op->setOperand(0, convert); });
|
||||
rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, convert); });
|
||||
return success();
|
||||
}
|
||||
if (encDst) {
|
||||
|
@ -545,7 +545,7 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
|
||||
forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
|
||||
rewriter.setInsertionPointToStart(forOpNew.getBody());
|
||||
} else {
|
||||
rewriter.updateRootInPlace(forOp, [&]() { forOp.setStep(step); });
|
||||
rewriter.modifyOpInPlace(forOp, [&]() { forOp.setStep(step); });
|
||||
rewriter.setInsertionPoint(yield);
|
||||
}
|
||||
vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
|
||||
|
@ -583,7 +583,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
|
||||
if (def->getBlock() == block) {
|
||||
rewriter.setInsertionPoint(def);
|
||||
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
|
||||
rewriter.updateRootInPlace(def, [&]() {
|
||||
rewriter.modifyOpInPlace(def, [&]() {
|
||||
def->setOperand(
|
||||
i, relinkBranch(env, rewriter, block, def->getOperand(i)));
|
||||
});
|
||||
|
@ -1416,7 +1416,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
|
||||
Operation *newRed = rewriter.clone(*redExp);
|
||||
// Replaces arguments of the reduction expression by using the block
|
||||
// arguments from scf.reduce.
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
newRed, [&]() { newRed->setOperands(redBlock->getArguments()); });
|
||||
// Erases the out-dated reduction expression.
|
||||
rewriter.eraseOp(redExp);
|
||||
|
@ -819,7 +819,7 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
|
||||
auto resultIndex = source.cast<OpResult>().getResultNumber();
|
||||
auto initOperand = destOp.getDpsInitOperand(resultIndex);
|
||||
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
|
||||
return success();
|
||||
}
|
||||
@ -1752,7 +1752,7 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
|
||||
srcType, collapseShapeOp.getReassociationMaps());
|
||||
|
||||
if (newResultType == collapseShapeOp.getResultType()) {
|
||||
rewriter.updateRootInPlace(collapseShapeOp, [&]() {
|
||||
rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
|
||||
collapseShapeOp.getSrcMutable().assign(castOp.getSource());
|
||||
});
|
||||
} else {
|
||||
@ -2930,7 +2930,7 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
|
||||
padTensorOp.getResultType().getShape());
|
||||
|
||||
if (newResultType == padTensorOp.getResultType()) {
|
||||
rewriter.updateRootInPlace(padTensorOp, [&]() {
|
||||
rewriter.modifyOpInPlace(padTensorOp, [&]() {
|
||||
padTensorOp.getSourceMutable().assign(castOp.getSource());
|
||||
});
|
||||
} else {
|
||||
@ -3994,9 +3994,9 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
|
||||
|
||||
// Fold optional PaddingValue operand away if padding is not needed.
|
||||
if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
|
||||
rewriter.startRootUpdate(packOp);
|
||||
rewriter.startOpModification(packOp);
|
||||
packOp.getPaddingValueMutable().clear();
|
||||
rewriter.finalizeRootUpdate(packOp);
|
||||
rewriter.finalizeOpModification(packOp);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
@ -4166,8 +4166,8 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
|
||||
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
|
||||
auto destValue = unPackOp.getDest().cast<OpResult>();
|
||||
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
|
||||
rewriter.updateRootInPlace(
|
||||
unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); });
|
||||
rewriter.modifyOpInPlace(unPackOp,
|
||||
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
|
@ -68,7 +68,7 @@ LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
|
||||
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
|
||||
if (!notOp)
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(op, [&]() {
|
||||
rewriter.modifyOpInPlace(op, [&]() {
|
||||
op.getOperation()->setOperands(
|
||||
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
|
||||
});
|
||||
|
@ -4416,7 +4416,7 @@ public:
|
||||
writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
|
||||
while (defWrite) {
|
||||
if (checkSameValueWAW(writeOp, defWrite)) {
|
||||
rewriter.updateRootInPlace(writeToModify, [&]() {
|
||||
rewriter.modifyOpInPlace(writeToModify, [&]() {
|
||||
writeToModify.getSourceMutable().assign(defWrite.getSource());
|
||||
});
|
||||
return success();
|
||||
@ -4533,7 +4533,7 @@ public:
|
||||
transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
|
||||
transferOp.getIndices(), transferOp.getPermutationMapAttr(),
|
||||
rewriter.getBoolArrayAttr(newInBounds));
|
||||
rewriter.updateRootInPlace(insertOp, [&]() {
|
||||
rewriter.modifyOpInPlace(insertOp, [&]() {
|
||||
insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
|
||||
});
|
||||
return success();
|
||||
|
@ -225,7 +225,7 @@ struct MaskOpInterface
|
||||
newReturnValues[it.index()] = it.value();
|
||||
}
|
||||
}
|
||||
rewriter.updateRootInPlace(yieldOp, [&]() {
|
||||
rewriter.modifyOpInPlace(yieldOp, [&]() {
|
||||
yieldOp.getOperandsMutable().assign(newYieldedValues);
|
||||
});
|
||||
|
||||
|
@ -182,7 +182,7 @@ static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
|
||||
auto yield =
|
||||
cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
|
||||
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
|
||||
return newWarpOp;
|
||||
}
|
||||
@ -724,7 +724,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
return failure();
|
||||
// Notify the rewriter that the warp op is changing (see the comment on
|
||||
// the WarpOpTransferRead pattern).
|
||||
rewriter.startRootUpdate(warpOp);
|
||||
rewriter.startOpModification(warpOp);
|
||||
unsigned operandIndex = yieldOperand->getOperandNumber();
|
||||
Attribute scalarAttr = dense.getSplatValue<Attribute>();
|
||||
auto newAttr = DenseElementsAttr::get(
|
||||
@ -733,7 +733,7 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
rewriter.setInsertionPointAfter(warpOp);
|
||||
Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
|
||||
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), distConstant);
|
||||
rewriter.finalizeRootUpdate(warpOp);
|
||||
rewriter.finalizeOpModification(warpOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -1017,9 +1017,9 @@ struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
return failure();
|
||||
// Notify the rewriter that the warp op is changing (see the comment on
|
||||
// the WarpOpTransferRead pattern).
|
||||
rewriter.startRootUpdate(warpOp);
|
||||
rewriter.startOpModification(warpOp);
|
||||
rewriter.replaceAllUsesWith(warpOp.getResult(resultIndex), valForwarded);
|
||||
rewriter.finalizeRootUpdate(warpOp);
|
||||
rewriter.finalizeOpModification(warpOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -1159,7 +1159,7 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
|
||||
// Notify the rewriter that the warp op is changing (see the comment on
|
||||
// the WarpOpTransferRead pattern).
|
||||
rewriter.startRootUpdate(warpOp);
|
||||
rewriter.startOpModification(warpOp);
|
||||
|
||||
AffineExpr s0, s1;
|
||||
bindSymbols(rewriter.getContext(), s0, s1);
|
||||
@ -1179,7 +1179,7 @@ struct WarpOpCreateMask : public OpRewritePattern<WarpExecuteOnLane0Op> {
|
||||
auto newMask =
|
||||
rewriter.create<vector::CreateMaskOp>(loc, distType, newOperands);
|
||||
rewriter.replaceAllUsesWith(warpOp.getResult(operandIndex), newMask);
|
||||
rewriter.finalizeRootUpdate(warpOp);
|
||||
rewriter.finalizeOpModification(warpOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -525,7 +525,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
||||
SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
|
||||
auto inBoundsAttr = b.getBoolArrayAttr(bools);
|
||||
if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
|
||||
b.updateRootInPlace(xferOp, [&]() {
|
||||
b.modifyOpInPlace(xferOp, [&]() {
|
||||
xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
|
||||
});
|
||||
return success();
|
||||
@ -598,7 +598,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
||||
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
|
||||
xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
|
||||
|
||||
b.updateRootInPlace(xferOp, [&]() {
|
||||
b.modifyOpInPlace(xferOp, [&]() {
|
||||
xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
|
||||
});
|
||||
|
||||
|
@ -1050,7 +1050,7 @@ public:
|
||||
mask = rewriter.create<arith::AndIOp>(loc, mask, xferOp.getMask());
|
||||
}
|
||||
|
||||
rewriter.updateRootInPlace(xferOp, [&]() {
|
||||
rewriter.modifyOpInPlace(xferOp, [&]() {
|
||||
xferOp.getMaskMutable().assign(mask);
|
||||
xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
|
||||
});
|
||||
|
@ -263,7 +263,7 @@ void RewriterBase::eraseBlock(Block *block) {
|
||||
block->erase();
|
||||
}
|
||||
|
||||
void RewriterBase::finalizeRootUpdate(Operation *op) {
|
||||
void RewriterBase::finalizeOpModification(Operation *op) {
|
||||
// Notify the listener that the operation was modified.
|
||||
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
|
||||
rewriteListener->notifyOperationModified(op);
|
||||
@ -276,7 +276,7 @@ void RewriterBase::replaceUsesWithIf(Value from, Value to,
|
||||
function_ref<bool(OpOperand &)> functor) {
|
||||
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
|
||||
if (functor(operand))
|
||||
updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); });
|
||||
modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -506,7 +506,7 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
|
||||
if (info.mergePoints.contains(blockOperand.get())) {
|
||||
if (!job.reachingDef)
|
||||
job.reachingDef = getLazyDefaultValue();
|
||||
rewriter.updateRootInPlace(terminator, [&]() {
|
||||
rewriter.modifyOpInPlace(terminator, [&]() {
|
||||
terminator.getSuccessorOperands(blockOperand.getOperandNumber())
|
||||
.append(job.reachingDef);
|
||||
});
|
||||
@ -596,7 +596,7 @@ void MemorySlotPromoter::promoteSlot() {
|
||||
assert(succOperands.size() == mergePoint->getNumArguments() ||
|
||||
succOperands.size() + 1 == mergePoint->getNumArguments());
|
||||
if (succOperands.size() + 1 == mergePoint->getNumArguments())
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
user, [&]() { succOperands.append(getLazyDefaultValue()); });
|
||||
}
|
||||
}
|
||||
|
@ -304,7 +304,7 @@ public:
|
||||
sortedOperands.push_back(commOperand->operand);
|
||||
if (sortedOperands == operands)
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); });
|
||||
rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -1614,15 +1614,15 @@ void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
|
||||
impl->createdOps.push_back(op);
|
||||
}
|
||||
|
||||
void ConversionPatternRewriter::startRootUpdate(Operation *op) {
|
||||
void ConversionPatternRewriter::startOpModification(Operation *op) {
|
||||
#ifndef NDEBUG
|
||||
impl->pendingRootUpdates.insert(op);
|
||||
#endif
|
||||
impl->rootUpdates.emplace_back(op);
|
||||
}
|
||||
|
||||
void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
|
||||
PatternRewriter::finalizeRootUpdate(op);
|
||||
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
|
||||
PatternRewriter::finalizeOpModification(op);
|
||||
// There is nothing to do here, we only need to track the operation at the
|
||||
// start of the update.
|
||||
#ifndef NDEBUG
|
||||
@ -1631,7 +1631,7 @@ void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
|
||||
#endif
|
||||
}
|
||||
|
||||
void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
|
||||
void ConversionPatternRewriter::cancelOpModification(Operation *op) {
|
||||
#ifndef NDEBUG
|
||||
assert(impl->pendingRootUpdates.erase(op) &&
|
||||
"operation did not have a pending in-place update");
|
||||
@ -3115,7 +3115,7 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
|
||||
auto newType = FunctionType::get(rewriter.getContext(),
|
||||
result.getConvertedTypes(), newResults);
|
||||
|
||||
rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); });
|
||||
rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); });
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -160,7 +160,7 @@ struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
|
||||
int64_t val = intAttr.getInt();
|
||||
if (val >= MaxVal)
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); });
|
||||
return success();
|
||||
}
|
||||
@ -175,7 +175,7 @@ struct MakeOpEligible : public RewritePattern {
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->hasAttr("eligible"))
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
|
||||
return success();
|
||||
}
|
||||
@ -195,7 +195,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
|
||||
return failure();
|
||||
// Hoisting means removing an op from the enclosing op. I.e., the enclosing
|
||||
// op is modified.
|
||||
rewriter.updateRootInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
|
||||
rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -327,7 +327,7 @@ private:
|
||||
Operation *newOp =
|
||||
rewriter.create(op->getLoc(), op->getName().getIdentifier(),
|
||||
op->getOperands(), op->getResultTypes());
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); });
|
||||
newOp->setAttr("skip", rewriter.getBoolAttr(true));
|
||||
|
||||
@ -415,8 +415,8 @@ private:
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock())
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(
|
||||
op, [&]() { op->setSuccessor(op->getBlock(), 0); });
|
||||
rewriter.modifyOpInPlace(op,
|
||||
[&]() { op->setSuccessor(op->getBlock(), 0); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -650,7 +650,7 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
|
||||
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
|
||||
rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0),
|
||||
illegalOp->getResult(0));
|
||||
rewriter.updateRootInPlace(op, [] {});
|
||||
rewriter.modifyOpInPlace(op, [] {});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -667,7 +667,7 @@ struct TestUndoBlockErase : public ConversionPattern {
|
||||
rewriter.setInsertionPointToStart(secondBlock);
|
||||
rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
|
||||
rewriter.eraseBlock(secondBlock);
|
||||
rewriter.updateRootInPlace(op, [] {});
|
||||
rewriter.modifyOpInPlace(op, [] {});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -827,7 +827,7 @@ struct TestBoundedRecursiveRewrite
|
||||
LogicalResult matchAndRewrite(TestRecursiveRewriteOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
// Decrement the depth of the op in-place.
|
||||
rewriter.updateRootInPlace(op, [&] {
|
||||
rewriter.modifyOpInPlace(op, [&] {
|
||||
op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1));
|
||||
});
|
||||
return success();
|
||||
@ -1333,7 +1333,7 @@ struct TestTestSignatureConversionNoConverter
|
||||
if (failed(
|
||||
converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&] { rewriter.applySignatureConversion(®ion, result); });
|
||||
return success();
|
||||
}
|
||||
@ -1350,8 +1350,8 @@ struct TestTypeConsumerForward
|
||||
LogicalResult
|
||||
matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
rewriter.updateRootInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
rewriter.modifyOpInPlace(op,
|
||||
[&] { op->setOperands(adaptor.getOperands()); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -1567,7 +1567,7 @@ struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
|
||||
SmallVector<Value, 2> replacements(succOperands);
|
||||
rewriter.eraseOp(branchOp);
|
||||
rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
|
||||
rewriter.updateRootInPlace(op, [] {});
|
||||
rewriter.modifyOpInPlace(op, [] {});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -1588,7 +1588,7 @@ struct TestUndoBlocksMerge : public ConversionPattern {
|
||||
SmallVector<Value, 2> replacements(succOperands);
|
||||
rewriter.eraseOp(branchOp);
|
||||
rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
|
||||
rewriter.updateRootInPlace(op, [] {});
|
||||
rewriter.modifyOpInPlace(op, [] {});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -1613,7 +1613,7 @@ struct TestMergeSingleBlockOps
|
||||
rewriter.inlineBlockBefore(&innerBlock, op);
|
||||
rewriter.eraseOp(innerTerminator);
|
||||
rewriter.eraseOp(op);
|
||||
rewriter.updateRootInPlace(op, [] {});
|
||||
rewriter.modifyOpInPlace(op, [] {});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user