[mlir][IR] Change MutableArrayRange to enumerate OpOperand & (#66622)

In line with #66515, change `MutableArrayRange::begin`/`end` to
enumerate `OpOperand &` instead of `Value`. Also remove
`ForOp::getIterOpOperands`/`setIterArg`, which are now redundant.

Note: `MutableOperandRange` cannot be made a derived class of
`indexed_accessor_range_base` (like `OperandRange`), because
`MutableOperandRange::assign` can change the number of operands in the
range.
This commit is contained in:
Matthias Springer 2023-09-19 09:09:21 +02:00 committed by GitHub
parent 45bb45f2ae
commit 6923a31542
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 53 additions and 42 deletions

View File

@ -250,17 +250,10 @@ def ForOp : SCF_Op<"for",
"expected an index less than the number of region iter args");
return getBody()->getArguments().drop_front(getNumInductionVars())[index];
}
MutableArrayRef<OpOperand> getIterOpOperands() {
return
getOperation()->getOpOperands().drop_front(getNumControlOperands());
}
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
void setStep(Value step) { getOperation()->setOperand(2, step); }
void setIterArg(unsigned iterArgNum, Value iterArgValue) {
getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
}
/// Number of induction variables, always 1 for scf::ForOp.
unsigned getNumInductionVars() { return 1; }

View File

@ -165,13 +165,9 @@ public:
/// Returns the OpOperand at the given index.
OpOperand &operator[](unsigned index) const;
OperandRange::iterator begin() const {
return static_cast<OperandRange>(*this).begin();
}
OperandRange::iterator end() const {
return static_cast<OperandRange>(*this).end();
}
/// Iterators enumerate OpOperands.
MutableArrayRef<OpOperand>::iterator begin() const;
MutableArrayRef<OpOperand>::iterator end() const;
private:
/// Update the length of this range to the one provided.

View File

@ -47,6 +47,10 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
static bool isMemrefOperand(OpOperand &operand) {
return isMemref(operand.get());
}
//===----------------------------------------------------------------------===//
// Backedges analysis
//===----------------------------------------------------------------------===//
@ -937,7 +941,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
// Add an additional operand for every MemRef for the ownership indicator.
if (!funcWithoutDynamicOwnership) {
unsigned numMemRefs = llvm::count_if(operands, isMemref);
unsigned numMemRefs = llvm::count_if(operands, isMemrefOperand);
SmallVector<Value> newOperands{OperandRange(operands)};
auto ownershipValues =
deallocOp.getUpdatedConditions().take_front(numMemRefs);

View File

@ -96,12 +96,12 @@ struct CondBranchOpInterface
mapping[retained] = ownership;
}
SmallVector<Value> replacements, ownerships;
for (Value operand : destOperands) {
replacements.push_back(operand);
if (isMemref(operand)) {
assert(mapping.contains(operand) &&
for (OpOperand &operand : destOperands) {
replacements.push_back(operand.get());
if (isMemref(operand.get())) {
assert(mapping.contains(operand.get()) &&
"Should be contained at this point");
ownerships.push_back(mapping[operand]);
ownerships.push_back(mapping[operand.get()]);
}
}
replacements.append(ownerships);

View File

@ -932,7 +932,7 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
assert(operand.get().getType() != replacement.getType() &&
"Expected a different type");
SmallVector<Value> newIterOperands;
for (OpOperand &opOperand : forOp.getIterOpOperands()) {
for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
newIterOperands.push_back(replacement);
continue;
@ -1015,7 +1015,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
OpOperand &iterOpOperand = std::get<0>(it);
auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
if (!incomingCast ||

View File

@ -325,7 +325,7 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
/// Helper function for loop bufferization. Return the bufferized values of the
/// given OpOperands. If an operand is not a tensor, return the original value.
static FailureOr<SmallVector<Value>>
getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
getBuffers(RewriterBase &rewriter, MutableOperandRange operands,
const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
@ -598,7 +598,7 @@ struct ForOpInterface
// The new memref init_args of the loop.
FailureOr<SmallVector<Value>> maybeInitArgs =
getBuffers(rewriter, forOp.getIterOpOperands(), options);
getBuffers(rewriter, forOp.getInitArgsMutable(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
@ -816,7 +816,7 @@ struct WhileOpInterface
// The new memref init_args of the loop.
FailureOr<SmallVector<Value>> maybeInitArgs =
getBuffers(rewriter, whileOp->getOpOperands(), options);
getBuffers(rewriter, whileOp.getInitsMutable(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;

View File

@ -500,7 +500,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
MutableArrayRef<scf::ForOp> loops) {
// 1. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
auto [fusableProducer, destinationIterArg] =
auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
loops);
if (!fusableProducer)
@ -567,17 +567,15 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
// TODO: This can be modeled better if the `DestinationStyleOpInterface`.
// Update to use that when it does become available.
scf::ForOp outerMostLoop = loops.front();
std::optional<unsigned> iterArgNumber;
if (destinationIterArg) {
iterArgNumber =
outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
}
if (iterArgNumber) {
if (destinationInitArg &&
(*destinationInitArg)->getOwner() == outerMostLoop) {
std::optional<unsigned> iterArgNumber =
outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
int64_t resultNumber = fusableProducer.getResultNumber();
if (auto dstOp =
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
outerMostLoop.setIterArg(iterArgNumber.value(),
dstOp.getTiedOpOperand(fusableProducer)->get());
(*destinationInitArg)
->set(dstOp.getTiedOpOperand(fusableProducer)->get());
}
for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);

View File

@ -522,6 +522,14 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
return owner->getOpOperand(start + index);
}
MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
return owner->getOpOperands().slice(start, length).begin();
}
MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
return owner->getOpOperands().slice(start, length).end();
}
//===----------------------------------------------------------------------===//
// MutableOperandRangeRange

View File

@ -137,6 +137,13 @@ getMutableSuccessorOperands(Block *block, unsigned successorIndex) {
return succOps.getMutableForwardedOperands();
}
/// Return the operand range used to transfer operands from `block` to its
/// successor with the given index.
static OperandRange getSuccessorOperands(Block *block,
unsigned successorIndex) {
return getMutableSuccessorOperands(block, successorIndex);
}
/// Appends all the block arguments from `other` to the block arguments of
/// `block`, copying their types and locations.
static void addBlockArgumentsFromOther(Block *block, Block *other) {
@ -175,8 +182,14 @@ public:
/// Returns the arguments of this edge that are passed to the block arguments
/// of the successor.
MutableOperandRange getSuccessorOperands() const {
return getMutableSuccessorOperands(fromBlock, successorIndex);
MutableOperandRange getMutableSuccessorOperands() const {
return ::getMutableSuccessorOperands(fromBlock, successorIndex);
}
/// Returns the arguments of this edge that are passed to the block arguments
/// of the successor.
OperandRange getSuccessorOperands() const {
return ::getSuccessorOperands(fromBlock, successorIndex);
}
};
@ -262,7 +275,7 @@ public:
assert(result != blockArgMapping.end() &&
"Edge was not originally passed to `create` method.");
MutableOperandRange successorOperands = edge.getSuccessorOperands();
MutableOperandRange successorOperands = edge.getMutableSuccessorOperands();
// Extra arguments are always appended at the end of the block arguments.
unsigned extraArgsBeginIndex =
@ -666,7 +679,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
// invalidated when mutating the operands through a different
// `MutableOperandRange` of the same operation.
SmallVector<Value> loopHeaderSuccessorOperands =
llvm::to_vector(getMutableSuccessorOperands(latch, loopHeaderIndex));
llvm::to_vector(getSuccessorOperands(latch, loopHeaderIndex));
// Add all values used in the next iteration to the exit block. Replace
// any uses that are outside the loop with the newly created exit block.
@ -742,7 +755,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
loopHeaderSuccessorOperands.push_back(argument);
for (Edge edge : successorEdges(latch))
edge.getSuccessorOperands().append(argument);
edge.getMutableSuccessorOperands().append(argument);
}
use.set(blockArgument);
@ -939,9 +952,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
if (regionEntry->getNumSuccessors() == 1) {
// Single successor we can just splice together.
Block *successor = regionEntry->getSuccessor(0);
for (auto &&[oldValue, newValue] :
llvm::zip(successor->getArguments(),
getMutableSuccessorOperands(regionEntry, 0)))
for (auto &&[oldValue, newValue] : llvm::zip(
successor->getArguments(), getSuccessorOperands(regionEntry, 0)))
oldValue.replaceAllUsesWith(newValue);
regionEntry->getTerminator()->erase();