mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-04-01 12:43:47 +00:00
[mlir][IR] Change MutableArrayRange
to enumerate OpOperand &
(#66622)
In line with #66515, change `MutableArrayRange::begin`/`end` to enumerate `OpOperand &` instead of `Value`. Also remove `ForOp::getIterOpOperands`/`setIterArg`, which are now redundant. Note: `MutableOperandRange` cannot be made a derived class of `indexed_accessor_range_base` (like `OperandRange`), because `MutableOperandRange::assign` can change the number of operands in the range.
This commit is contained in:
parent
45bb45f2ae
commit
6923a31542
@ -250,17 +250,10 @@ def ForOp : SCF_Op<"for",
|
||||
"expected an index less than the number of region iter args");
|
||||
return getBody()->getArguments().drop_front(getNumInductionVars())[index];
|
||||
}
|
||||
MutableArrayRef<OpOperand> getIterOpOperands() {
|
||||
return
|
||||
getOperation()->getOpOperands().drop_front(getNumControlOperands());
|
||||
}
|
||||
|
||||
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
|
||||
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
|
||||
void setStep(Value step) { getOperation()->setOperand(2, step); }
|
||||
void setIterArg(unsigned iterArgNum, Value iterArgValue) {
|
||||
getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
|
||||
}
|
||||
|
||||
/// Number of induction variables, always 1 for scf::ForOp.
|
||||
unsigned getNumInductionVars() { return 1; }
|
||||
|
@ -165,13 +165,9 @@ public:
|
||||
/// Returns the OpOperand at the given index.
|
||||
OpOperand &operator[](unsigned index) const;
|
||||
|
||||
OperandRange::iterator begin() const {
|
||||
return static_cast<OperandRange>(*this).begin();
|
||||
}
|
||||
|
||||
OperandRange::iterator end() const {
|
||||
return static_cast<OperandRange>(*this).end();
|
||||
}
|
||||
/// Iterators enumerate OpOperands.
|
||||
MutableArrayRef<OpOperand>::iterator begin() const;
|
||||
MutableArrayRef<OpOperand>::iterator end() const;
|
||||
|
||||
private:
|
||||
/// Update the length of this range to the one provided.
|
||||
|
@ -47,6 +47,10 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
|
||||
|
||||
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
|
||||
|
||||
static bool isMemrefOperand(OpOperand &operand) {
|
||||
return isMemref(operand.get());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Backedges analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -937,7 +941,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
|
||||
|
||||
// Add an additional operand for every MemRef for the ownership indicator.
|
||||
if (!funcWithoutDynamicOwnership) {
|
||||
unsigned numMemRefs = llvm::count_if(operands, isMemref);
|
||||
unsigned numMemRefs = llvm::count_if(operands, isMemrefOperand);
|
||||
SmallVector<Value> newOperands{OperandRange(operands)};
|
||||
auto ownershipValues =
|
||||
deallocOp.getUpdatedConditions().take_front(numMemRefs);
|
||||
|
@ -96,12 +96,12 @@ struct CondBranchOpInterface
|
||||
mapping[retained] = ownership;
|
||||
}
|
||||
SmallVector<Value> replacements, ownerships;
|
||||
for (Value operand : destOperands) {
|
||||
replacements.push_back(operand);
|
||||
if (isMemref(operand)) {
|
||||
assert(mapping.contains(operand) &&
|
||||
for (OpOperand &operand : destOperands) {
|
||||
replacements.push_back(operand.get());
|
||||
if (isMemref(operand.get())) {
|
||||
assert(mapping.contains(operand.get()) &&
|
||||
"Should be contained at this point");
|
||||
ownerships.push_back(mapping[operand]);
|
||||
ownerships.push_back(mapping[operand.get()]);
|
||||
}
|
||||
}
|
||||
replacements.append(ownerships);
|
||||
|
@ -932,7 +932,7 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
|
||||
assert(operand.get().getType() != replacement.getType() &&
|
||||
"Expected a different type");
|
||||
SmallVector<Value> newIterOperands;
|
||||
for (OpOperand &opOperand : forOp.getIterOpOperands()) {
|
||||
for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
|
||||
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
|
||||
newIterOperands.push_back(replacement);
|
||||
continue;
|
||||
@ -1015,7 +1015,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
|
||||
|
||||
LogicalResult matchAndRewrite(ForOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
|
||||
for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
|
||||
OpOperand &iterOpOperand = std::get<0>(it);
|
||||
auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
|
||||
if (!incomingCast ||
|
||||
|
@ -325,7 +325,7 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
|
||||
/// Helper function for loop bufferization. Return the bufferized values of the
|
||||
/// given OpOperands. If an operand is not a tensor, return the original value.
|
||||
static FailureOr<SmallVector<Value>>
|
||||
getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
|
||||
getBuffers(RewriterBase &rewriter, MutableOperandRange operands,
|
||||
const BufferizationOptions &options) {
|
||||
SmallVector<Value> result;
|
||||
for (OpOperand &opOperand : operands) {
|
||||
@ -598,7 +598,7 @@ struct ForOpInterface
|
||||
|
||||
// The new memref init_args of the loop.
|
||||
FailureOr<SmallVector<Value>> maybeInitArgs =
|
||||
getBuffers(rewriter, forOp.getIterOpOperands(), options);
|
||||
getBuffers(rewriter, forOp.getInitArgsMutable(), options);
|
||||
if (failed(maybeInitArgs))
|
||||
return failure();
|
||||
SmallVector<Value> initArgs = *maybeInitArgs;
|
||||
@ -816,7 +816,7 @@ struct WhileOpInterface
|
||||
|
||||
// The new memref init_args of the loop.
|
||||
FailureOr<SmallVector<Value>> maybeInitArgs =
|
||||
getBuffers(rewriter, whileOp->getOpOperands(), options);
|
||||
getBuffers(rewriter, whileOp.getInitsMutable(), options);
|
||||
if (failed(maybeInitArgs))
|
||||
return failure();
|
||||
SmallVector<Value> initArgs = *maybeInitArgs;
|
||||
|
@ -500,7 +500,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
|
||||
MutableArrayRef<scf::ForOp> loops) {
|
||||
// 1. Get the producer of the source (potentially walking through
|
||||
// `iter_args` of nested `scf.for`)
|
||||
auto [fusableProducer, destinationIterArg] =
|
||||
auto [fusableProducer, destinationInitArg] =
|
||||
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
|
||||
loops);
|
||||
if (!fusableProducer)
|
||||
@ -567,17 +567,15 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
|
||||
// TODO: This can be modeled better if the `DestinationStyleOpInterface`.
|
||||
// Update to use that when it does become available.
|
||||
scf::ForOp outerMostLoop = loops.front();
|
||||
std::optional<unsigned> iterArgNumber;
|
||||
if (destinationIterArg) {
|
||||
iterArgNumber =
|
||||
outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
|
||||
}
|
||||
if (iterArgNumber) {
|
||||
if (destinationInitArg &&
|
||||
(*destinationInitArg)->getOwner() == outerMostLoop) {
|
||||
std::optional<unsigned> iterArgNumber =
|
||||
outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
|
||||
int64_t resultNumber = fusableProducer.getResultNumber();
|
||||
if (auto dstOp =
|
||||
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
|
||||
outerMostLoop.setIterArg(iterArgNumber.value(),
|
||||
dstOp.getTiedOpOperand(fusableProducer)->get());
|
||||
(*destinationInitArg)
|
||||
->set(dstOp.getTiedOpOperand(fusableProducer)->get());
|
||||
}
|
||||
for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
|
||||
auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
|
||||
|
@ -522,6 +522,14 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
|
||||
return owner->getOpOperand(start + index);
|
||||
}
|
||||
|
||||
MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
|
||||
return owner->getOpOperands().slice(start, length).begin();
|
||||
}
|
||||
|
||||
MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
|
||||
return owner->getOpOperands().slice(start, length).end();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MutableOperandRangeRange
|
||||
|
||||
|
@ -137,6 +137,13 @@ getMutableSuccessorOperands(Block *block, unsigned successorIndex) {
|
||||
return succOps.getMutableForwardedOperands();
|
||||
}
|
||||
|
||||
/// Return the operand range used to transfer operands from `block` to its
|
||||
/// successor with the given index.
|
||||
static OperandRange getSuccessorOperands(Block *block,
|
||||
unsigned successorIndex) {
|
||||
return getMutableSuccessorOperands(block, successorIndex);
|
||||
}
|
||||
|
||||
/// Appends all the block arguments from `other` to the block arguments of
|
||||
/// `block`, copying their types and locations.
|
||||
static void addBlockArgumentsFromOther(Block *block, Block *other) {
|
||||
@ -175,8 +182,14 @@ public:
|
||||
|
||||
/// Returns the arguments of this edge that are passed to the block arguments
|
||||
/// of the successor.
|
||||
MutableOperandRange getSuccessorOperands() const {
|
||||
return getMutableSuccessorOperands(fromBlock, successorIndex);
|
||||
MutableOperandRange getMutableSuccessorOperands() const {
|
||||
return ::getMutableSuccessorOperands(fromBlock, successorIndex);
|
||||
}
|
||||
|
||||
/// Returns the arguments of this edge that are passed to the block arguments
|
||||
/// of the successor.
|
||||
OperandRange getSuccessorOperands() const {
|
||||
return ::getSuccessorOperands(fromBlock, successorIndex);
|
||||
}
|
||||
};
|
||||
|
||||
@ -262,7 +275,7 @@ public:
|
||||
assert(result != blockArgMapping.end() &&
|
||||
"Edge was not originally passed to `create` method.");
|
||||
|
||||
MutableOperandRange successorOperands = edge.getSuccessorOperands();
|
||||
MutableOperandRange successorOperands = edge.getMutableSuccessorOperands();
|
||||
|
||||
// Extra arguments are always appended at the end of the block arguments.
|
||||
unsigned extraArgsBeginIndex =
|
||||
@ -666,7 +679,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
|
||||
// invalidated when mutating the operands through a different
|
||||
// `MutableOperandRange` of the same operation.
|
||||
SmallVector<Value> loopHeaderSuccessorOperands =
|
||||
llvm::to_vector(getMutableSuccessorOperands(latch, loopHeaderIndex));
|
||||
llvm::to_vector(getSuccessorOperands(latch, loopHeaderIndex));
|
||||
|
||||
// Add all values used in the next iteration to the exit block. Replace
|
||||
// any uses that are outside the loop with the newly created exit block.
|
||||
@ -742,7 +755,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
|
||||
|
||||
loopHeaderSuccessorOperands.push_back(argument);
|
||||
for (Edge edge : successorEdges(latch))
|
||||
edge.getSuccessorOperands().append(argument);
|
||||
edge.getMutableSuccessorOperands().append(argument);
|
||||
}
|
||||
|
||||
use.set(blockArgument);
|
||||
@ -939,9 +952,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
|
||||
if (regionEntry->getNumSuccessors() == 1) {
|
||||
// Single successor we can just splice together.
|
||||
Block *successor = regionEntry->getSuccessor(0);
|
||||
for (auto &&[oldValue, newValue] :
|
||||
llvm::zip(successor->getArguments(),
|
||||
getMutableSuccessorOperands(regionEntry, 0)))
|
||||
for (auto &&[oldValue, newValue] : llvm::zip(
|
||||
successor->getArguments(), getSuccessorOperands(regionEntry, 0)))
|
||||
oldValue.replaceAllUsesWith(newValue);
|
||||
regionEntry->getTerminator()->erase();
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user