mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-04 08:16:49 +00:00
[mlir] Add support for operation-produced successor arguments in BranchOpInterface
This patch revamps the BranchOpInterface a bit and allows a proper implementation of what was previously `getMutableSuccessorOperands` for operations, which internally produce arguments to some of the block arguments. A motivating example for this would be an invoke op with a error handling path: ``` invoke %function(%0) label ^success ^error(%1 : i32) ^error(%e: !error, %arg0 : i32): ... ``` The advantages of this are that any users of `BranchOpInterface` can still argue over remaining block argument operands (such as `%1` in the example above), as well as make use of the modifying capabilities to add more operands, erase an operand etc. The way this patch implements that functionality is via a new class called `SuccessorOperands`, which is now returned by `getSuccessorOperands`. It basically contains an `unsigned` denoting how many operator produced operands exist, as well as a `MutableOperandRange`, which are the usual forwarded operands we are used to. The produced operands are assumed to the first few block arguments, followed by the forwarded operands afterwards. The role of `SuccessorOperands` is to provide various utility functions to modify and query the successor arguments from a `BranchOpInterface`. Differential Revision: https://reviews.llvm.org/D123062
This commit is contained in:
parent
795b07f549
commit
0c789db541
@ -489,16 +489,12 @@ class fir_SwitchTerminatorOp<string mnemonic, list<Trait> traits = []> :
|
||||
llvm::ArrayRef<mlir::Value> operands, unsigned cond);
|
||||
llvm::Optional<mlir::ValueRange> getSuccessorOperands(
|
||||
mlir::ValueRange operands, unsigned cond);
|
||||
using BranchOpInterfaceTrait::getSuccessorOperands;
|
||||
|
||||
// Helper function to deal with Optional operand forms
|
||||
void printSuccessorAtIndex(mlir::OpAsmPrinter &p, unsigned i) {
|
||||
auto *succ = getSuccessor(i);
|
||||
auto ops = getSuccessorOperands(i);
|
||||
if (ops.hasValue())
|
||||
p.printSuccessorAndUseList(succ, ops.getValue());
|
||||
else
|
||||
p.printSuccessor(succ);
|
||||
p.printSuccessorAndUseList(succ, ops.getForwardedOperands());
|
||||
}
|
||||
|
||||
mlir::ArrayAttr getCases() {
|
||||
|
@ -2401,10 +2401,9 @@ fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
|
||||
return {};
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::MutableOperandRange>
|
||||
fir::SelectOp::getMutableSuccessorOperands(unsigned oper) {
|
||||
return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
|
||||
getTargetOffsetAttr());
|
||||
mlir::SuccessorOperands fir::SelectOp::getSuccessorOperands(unsigned oper) {
|
||||
return mlir::SuccessorOperands(::getMutableSuccessorOperands(
|
||||
oper, getTargetArgsMutable(), getTargetOffsetAttr()));
|
||||
}
|
||||
|
||||
llvm::Optional<llvm::ArrayRef<mlir::Value>>
|
||||
@ -2462,10 +2461,9 @@ fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands,
|
||||
return {getSubOperands(cond, getSubOperands(1, operands, segments), a)};
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::MutableOperandRange>
|
||||
fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) {
|
||||
return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
|
||||
getTargetOffsetAttr());
|
||||
mlir::SuccessorOperands fir::SelectCaseOp::getSuccessorOperands(unsigned oper) {
|
||||
return mlir::SuccessorOperands(::getMutableSuccessorOperands(
|
||||
oper, getTargetArgsMutable(), getTargetOffsetAttr()));
|
||||
}
|
||||
|
||||
llvm::Optional<llvm::ArrayRef<mlir::Value>>
|
||||
@ -2734,10 +2732,9 @@ fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
|
||||
return {};
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::MutableOperandRange>
|
||||
fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) {
|
||||
return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
|
||||
getTargetOffsetAttr());
|
||||
mlir::SuccessorOperands fir::SelectRankOp::getSuccessorOperands(unsigned oper) {
|
||||
return mlir::SuccessorOperands(::getMutableSuccessorOperands(
|
||||
oper, getTargetArgsMutable(), getTargetOffsetAttr()));
|
||||
}
|
||||
|
||||
llvm::Optional<llvm::ArrayRef<mlir::Value>>
|
||||
@ -2779,10 +2776,9 @@ fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
|
||||
return {};
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::MutableOperandRange>
|
||||
fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) {
|
||||
return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
|
||||
getTargetOffsetAttr());
|
||||
mlir::SuccessorOperands fir::SelectTypeOp::getSuccessorOperands(unsigned oper) {
|
||||
return mlir::SuccessorOperands(::getMutableSuccessorOperands(
|
||||
oper, getTargetArgsMutable(), getTargetOffsetAttr()));
|
||||
}
|
||||
|
||||
llvm::Optional<llvm::ArrayRef<mlir::Value>>
|
||||
|
@ -907,6 +907,11 @@ public:
|
||||
/// elements attribute, which contains the sizes of the sub ranges.
|
||||
MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
|
||||
|
||||
/// Returns the value at the given index.
|
||||
Value operator[](unsigned index) const {
|
||||
return static_cast<OperandRange>(*this)[index];
|
||||
}
|
||||
|
||||
private:
|
||||
/// Update the length of this range to the one provided.
|
||||
void updateLength(unsigned newLength);
|
||||
|
@ -20,6 +20,106 @@ namespace mlir {
|
||||
class BranchOpInterface;
|
||||
class RegionBranchOpInterface;
|
||||
|
||||
/// This class models how operands are forwarded to block arguments in control
|
||||
/// flow. It consists of a number, denoting how many of the successors block
|
||||
/// arguments are produced by the operation, followed by a range of operands
|
||||
/// that are forwarded. The produced operands are passed to the first few
|
||||
/// block arguments of the successor, followed by the forwarded operands.
|
||||
/// It is unsupported to pass them in a different order.
|
||||
///
|
||||
/// An example operation with both of these concepts would be a branch-on-error
|
||||
/// operation, that internally produces an error object on the error path:
|
||||
///
|
||||
/// invoke %function(%0)
|
||||
/// label ^success ^error(%1 : i32)
|
||||
///
|
||||
/// ^error(%e: !error, %arg0 : i32):
|
||||
/// ...
|
||||
///
|
||||
/// This operation would return an instance of SuccessorOperands with a produced
|
||||
/// operand count of 1 (mapped to %e in the successor) and a forwarded
|
||||
/// operands range consisting of %1 in the example above (mapped to %arg0 in the
|
||||
/// successor).
|
||||
class SuccessorOperands {
|
||||
public:
|
||||
/// Constructs a SuccessorOperands with no produced operands that simply
|
||||
/// forwards operands to the successor.
|
||||
explicit SuccessorOperands(MutableOperandRange forwardedOperands);
|
||||
|
||||
/// Constructs a SuccessorOperands with the given amount of produced operands
|
||||
/// and forwarded operands.
|
||||
SuccessorOperands(unsigned producedOperandCount,
|
||||
MutableOperandRange forwardedOperands);
|
||||
|
||||
/// Returns the amount of operands passed to the successor. This consists both
|
||||
/// of produced operands by the operation as well as forwarded ones.
|
||||
unsigned size() const {
|
||||
return producedOperandCount + forwardedOperands.size();
|
||||
}
|
||||
|
||||
/// Returns true if there are no successor operands.
|
||||
bool empty() const { return size() == 0; }
|
||||
|
||||
/// Returns the amount of operands that are produced internally by the
|
||||
/// operation. These are passed to the first few block arguments.
|
||||
unsigned getProducedOperandCount() const { return producedOperandCount; }
|
||||
|
||||
/// Returns true if the successor operand denoted by `index` is produced by
|
||||
/// the operation.
|
||||
bool isOperandProduced(unsigned index) const {
|
||||
return index < producedOperandCount;
|
||||
}
|
||||
|
||||
/// Returns the Value that is passed to the successors block argument denoted
|
||||
/// by `index`. If it is produced by the operation, no such value exists and
|
||||
/// a null Value is returned.
|
||||
Value operator[](unsigned index) const {
|
||||
if (isOperandProduced(index))
|
||||
return Value();
|
||||
return forwardedOperands[index - producedOperandCount];
|
||||
}
|
||||
|
||||
/// Get the range of operands that are simply forwarded to the successor.
|
||||
OperandRange getForwardedOperands() const { return forwardedOperands; }
|
||||
|
||||
/// Get a slice of the operands forwarded to the successor. The given range
|
||||
/// must not contain any operands produced by the operation.
|
||||
MutableOperandRange slice(unsigned subStart, unsigned subLen) const {
|
||||
assert(!isOperandProduced(subStart) &&
|
||||
"can't slice operands produced by the operation");
|
||||
return forwardedOperands.slice(subStart - producedOperandCount, subLen);
|
||||
}
|
||||
|
||||
/// Erase operands forwarded to the successor. The given range must
|
||||
/// not contain any operands produced by the operation.
|
||||
void erase(unsigned subStart, unsigned subLen = 1) {
|
||||
assert(!isOperandProduced(subStart) &&
|
||||
"can't erase operands produced by the operation");
|
||||
forwardedOperands.erase(subStart - producedOperandCount, subLen);
|
||||
}
|
||||
|
||||
/// Add new operands that are forwarded to the successor.
|
||||
void append(ValueRange valueRange) { forwardedOperands.append(valueRange); }
|
||||
|
||||
/// Gets the index of the forwarded operand within the operation which maps
|
||||
/// to the block argument denoted by `blockArgumentIndex`. The block argument
|
||||
/// must be mapped to a forwarded operand.
|
||||
unsigned getOperandIndex(unsigned blockArgumentIndex) const {
|
||||
assert(!isOperandProduced(blockArgumentIndex) &&
|
||||
"can't map operand produced by the operation");
|
||||
return static_cast<mlir::OperandRange>(forwardedOperands)
|
||||
.getBeginOperandIndex() +
|
||||
(blockArgumentIndex - producedOperandCount);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Amount of operands that are produced internally within the operation and
|
||||
/// passed to the first few block arguments.
|
||||
unsigned producedOperandCount;
|
||||
/// Range of operands that are forwarded to the remaining block arguments.
|
||||
MutableOperandRange forwardedOperands;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -29,12 +129,12 @@ namespace detail {
|
||||
/// successor if `operandIndex` is within the range of `operands`, or None if
|
||||
/// `operandIndex` isn't a successor operand index.
|
||||
Optional<BlockArgument>
|
||||
getBranchSuccessorArgument(Optional<OperandRange> operands,
|
||||
getBranchSuccessorArgument(const SuccessorOperands &operands,
|
||||
unsigned operandIndex, Block *successor);
|
||||
|
||||
/// Verify that the given operands match those of the given successor block.
|
||||
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
|
||||
Optional<OperandRange> operands);
|
||||
const SuccessorOperands &operands);
|
||||
} // namespace detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -36,26 +36,35 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<[{
|
||||
Returns a mutable range of operands that correspond to the arguments of
|
||||
successor at the given index. Returns None if the operands to the
|
||||
successor are non-materialized values, i.e. they are internal to the
|
||||
operation.
|
||||
Returns the operands that correspond to the arguments of the successor
|
||||
at the given index. It consists of a number of operands that are
|
||||
internally produced by the operation, followed by a range of operands
|
||||
that are forwarded. An example operation making use of produced
|
||||
operands would be:
|
||||
|
||||
```mlir
|
||||
invoke %function(%0)
|
||||
label ^success ^error(%1 : i32)
|
||||
|
||||
^error(%e: !error, %arg0: i32):
|
||||
...
|
||||
```
|
||||
|
||||
The operand that would map to the `^error`s `%e` operand is produced
|
||||
by the `invoke` operation, while `%1` is a forwarded operand that maps
|
||||
to `%arg0` in the successor.
|
||||
|
||||
Produced operands always map to the first few block arguments of the
|
||||
successor, followed by the forwarded operands. Mapping them in any
|
||||
other order is not supported by the interface.
|
||||
|
||||
By having the forwarded operands last allows users of the interface
|
||||
to append more forwarded operands to the branch operation without
|
||||
interfering with other successor operands.
|
||||
}],
|
||||
"::mlir::Optional<::mlir::MutableOperandRange>", "getMutableSuccessorOperands",
|
||||
"::mlir::SuccessorOperands", "getSuccessorOperands",
|
||||
(ins "unsigned":$index)
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns a range of operands that correspond to the arguments of
|
||||
successor at the given index. Returns None if the operands to the
|
||||
successor are non-materialized values, i.e. they are internal to the
|
||||
operation.
|
||||
}],
|
||||
"::mlir::Optional<::mlir::OperandRange>", "getSuccessorOperands",
|
||||
(ins "unsigned":$index), [{}], [{
|
||||
auto operands = $_op.getMutableSuccessorOperands(index);
|
||||
return operands ? ::mlir::Optional<::mlir::OperandRange>(*operands) : ::llvm::None;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns the `BlockArgument` corresponding to operand `operandIndex` in
|
||||
some successor, or None if `operandIndex` isn't a successor operand
|
||||
@ -94,7 +103,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
|
||||
let verify = [{
|
||||
auto concreteOp = ::mlir::cast<ConcreteOp>($_op);
|
||||
for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) {
|
||||
::mlir::Optional<OperandRange> operands = concreteOp.getSuccessorOperands(i);
|
||||
::mlir::SuccessorOperands operands = concreteOp.getSuccessorOperands(i);
|
||||
if (::mlir::failed(::mlir::detail::verifyBranchSuccessorOperands($_op, i, operands)))
|
||||
return ::mlir::failure();
|
||||
}
|
||||
|
@ -149,14 +149,13 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
|
||||
|
||||
// Try to get the operand passed for this argument.
|
||||
unsigned index = it.getSuccessorIndex();
|
||||
Optional<OperandRange> operands = branch.getSuccessorOperands(index);
|
||||
if (!operands) {
|
||||
Value operand = branch.getSuccessorOperands(index)[argNumber];
|
||||
if (!operand) {
|
||||
// We can't analyze the control flow, so bail out early.
|
||||
output.push_back(arg);
|
||||
return;
|
||||
}
|
||||
collectUnderlyingAddressValues((*operands)[argNumber], maxDepth, visited,
|
||||
output);
|
||||
collectUnderlyingAddressValues(operand, maxDepth, visited, output);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -70,10 +70,10 @@ void BufferViewFlowAnalysis::build(Operation *op) {
|
||||
// Query the branch op interface to get the successor operands.
|
||||
auto successorOperands =
|
||||
branchInterface.getSuccessorOperands(it.getIndex());
|
||||
if (!successorOperands.hasValue())
|
||||
continue;
|
||||
// Build the actual mapping of values to their immediate dependencies.
|
||||
registerDependencies(successorOperands.getValue(), (*it)->getArguments());
|
||||
registerDependencies(successorOperands.getForwardedOperands(),
|
||||
(*it)->getArguments().drop_front(
|
||||
successorOperands.getProducedOperandCount()));
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -681,10 +681,13 @@ void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
|
||||
// Try to get the operand forwarded by the predecessor. If we can't reason
|
||||
// about the terminator of the predecessor, mark as having reached a
|
||||
// fixpoint.
|
||||
Optional<OperandRange> branchOperands;
|
||||
if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()))
|
||||
branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex());
|
||||
if (!branchOperands) {
|
||||
auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
|
||||
if (!branch) {
|
||||
updatedLattice |= argLattice.markPessimisticFixpoint();
|
||||
break;
|
||||
}
|
||||
Value operand = branch.getSuccessorOperands(it.getSuccessorIndex())[i];
|
||||
if (!operand) {
|
||||
updatedLattice |= argLattice.markPessimisticFixpoint();
|
||||
break;
|
||||
}
|
||||
@ -692,7 +695,7 @@ void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
|
||||
// If the operand hasn't been resolved, it is uninitialized and can merge
|
||||
// with anything.
|
||||
AbstractLatticeElement *operandLattice =
|
||||
analysis.lookupLatticeElement((*branchOperands)[i]);
|
||||
analysis.lookupLatticeElement(operand);
|
||||
if (!operandLattice)
|
||||
continue;
|
||||
|
||||
|
@ -325,25 +325,20 @@ private:
|
||||
// argument.
|
||||
Operation *terminator = (*it)->getTerminator();
|
||||
auto branchInterface = cast<BranchOpInterface>(terminator);
|
||||
SuccessorOperands operands =
|
||||
branchInterface.getSuccessorOperands(it.getSuccessorIndex());
|
||||
|
||||
// Query the associated source value.
|
||||
Value sourceValue =
|
||||
branchInterface.getSuccessorOperands(it.getSuccessorIndex())
|
||||
.getValue()[blockArg.getArgNumber()];
|
||||
// Wire new clone and successor operand.
|
||||
auto mutableOperands =
|
||||
branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex());
|
||||
if (!mutableOperands) {
|
||||
terminator->emitError() << "terminators with immutable successor "
|
||||
"operands are not supported";
|
||||
continue;
|
||||
Value sourceValue = operands[blockArg.getArgNumber()];
|
||||
if (!sourceValue) {
|
||||
return failure();
|
||||
}
|
||||
// Wire new clone and successor operand.
|
||||
// Create a new clone at the current location of the terminator.
|
||||
auto clone = introduceCloneBuffers(sourceValue, terminator);
|
||||
if (failed(clone))
|
||||
return failure();
|
||||
mutableOperands.getValue()
|
||||
.slice(blockArg.getArgNumber(), 1)
|
||||
.assign(*clone);
|
||||
operands.slice(blockArg.getArgNumber(), 1).assign(*clone);
|
||||
}
|
||||
|
||||
// Check whether the block argument has implicitly defined predecessors via
|
||||
|
@ -186,10 +186,9 @@ void BranchOp::setDest(Block *block) { return setSuccessor(block); }
|
||||
|
||||
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
BranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index == 0 && "invalid successor index");
|
||||
return getDestOperandsMutable();
|
||||
return SuccessorOperands(getDestOperandsMutable());
|
||||
}
|
||||
|
||||
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
|
||||
@ -437,11 +436,10 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
CondBranchTruthPropagation>(context);
|
||||
}
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index < getNumSuccessors() && "invalid successor index");
|
||||
return index == trueIndex ? getTrueDestOperandsMutable()
|
||||
: getFalseDestOperandsMutable();
|
||||
return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
|
||||
: getFalseDestOperandsMutable());
|
||||
}
|
||||
|
||||
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
||||
@ -575,11 +573,10 @@ LogicalResult SwitchOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
SwitchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index < getNumSuccessors() && "invalid successor index");
|
||||
return index == 0 ? getDefaultOperandsMutable()
|
||||
: getCaseOperandsMutable(index - 1);
|
||||
return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
|
||||
: getCaseOperandsMutable(index - 1));
|
||||
}
|
||||
|
||||
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
||||
|
@ -67,12 +67,13 @@ public:
|
||||
SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
|
||||
for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
|
||||
succIdx < succEnd; ++succIdx) {
|
||||
auto successorOperands = op.getSuccessorOperands(succIdx);
|
||||
if (!successorOperands || successorOperands->empty())
|
||||
OperandRange forwardedOperands =
|
||||
op.getSuccessorOperands(succIdx).getForwardedOperands();
|
||||
if (forwardedOperands.empty())
|
||||
continue;
|
||||
|
||||
for (int idx = successorOperands->getBeginOperandIndex(),
|
||||
eidx = idx + successorOperands->size();
|
||||
for (int idx = forwardedOperands.getBeginOperandIndex(),
|
||||
eidx = idx + forwardedOperands.size();
|
||||
idx < eidx; ++idx) {
|
||||
if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
|
||||
newOperands[idx] = operands[idx];
|
||||
@ -121,8 +122,8 @@ bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
|
||||
if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
|
||||
for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
|
||||
auto successorOperands = branchOp.getSuccessorOperands(p);
|
||||
if (successorOperands.hasValue() &&
|
||||
!converter.isLegal(successorOperands.getValue().getTypes()))
|
||||
if (!converter.isLegal(
|
||||
successorOperands.getForwardedOperands().getTypes()))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -240,21 +240,19 @@ ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
// LLVM::BrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
BrOp::getMutableSuccessorOperands(unsigned index) {
|
||||
SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index == 0 && "invalid successor index");
|
||||
return getDestOperandsMutable();
|
||||
return SuccessorOperands(getDestOperandsMutable());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LLVM::CondBrOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
CondBrOp::getMutableSuccessorOperands(unsigned index) {
|
||||
SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index < getNumSuccessors() && "invalid successor index");
|
||||
return index == 0 ? getTrueDestOperandsMutable()
|
||||
: getFalseDestOperandsMutable();
|
||||
return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
|
||||
: getFalseDestOperandsMutable());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -356,11 +354,10 @@ LogicalResult SwitchOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
SwitchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index < getNumSuccessors() && "invalid successor index");
|
||||
return index == 0 ? getDefaultOperandsMutable()
|
||||
: getCaseOperandsMutable(index - 1);
|
||||
return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
|
||||
: getCaseOperandsMutable(index - 1));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -735,11 +732,10 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
/// LLVM::InvokeOp
|
||||
///===---------------------------------------------------------------------===//
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
InvokeOp::getMutableSuccessorOperands(unsigned index) {
|
||||
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index < getNumSuccessors() && "invalid successor index");
|
||||
return index == 0 ? getNormalDestOperandsMutable()
|
||||
: getUnwindDestOperandsMutable();
|
||||
return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
|
||||
: getUnwindDestOperandsMutable());
|
||||
}
|
||||
|
||||
LogicalResult InvokeOp::verify() {
|
||||
|
@ -223,12 +223,12 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
|
||||
auto blockOperands =
|
||||
terminator.getSuccessorOperands(pred.getSuccessorIndex());
|
||||
|
||||
if (!blockOperands || blockOperands->empty())
|
||||
if (blockOperands.empty() ||
|
||||
blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
|
||||
continue;
|
||||
|
||||
detensorableBranchOps[terminator].insert(
|
||||
blockOperands->getBeginOperandIndex() +
|
||||
blockArgumentElem.getArgNumber());
|
||||
blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -343,14 +343,15 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
|
||||
auto ownerBlockOperands =
|
||||
predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
|
||||
|
||||
if (!ownerBlockOperands || ownerBlockOperands->empty())
|
||||
if (ownerBlockOperands.empty() ||
|
||||
ownerBlockOperands.isOperandProduced(
|
||||
currentItemBlockArgument.getArgNumber()))
|
||||
continue;
|
||||
|
||||
// For each predecessor, add the value it passes to that argument to
|
||||
// workList to find out how it's computed.
|
||||
workList.push_back(
|
||||
ownerBlockOperands
|
||||
.getValue()[currentItemBlockArgument.getArgNumber()]);
|
||||
ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
|
||||
}
|
||||
|
||||
continue;
|
||||
@ -418,18 +419,16 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
|
||||
auto blockOperands =
|
||||
terminator.getSuccessorOperands(pred.getSuccessorIndex());
|
||||
|
||||
if (!blockOperands || blockOperands->empty())
|
||||
if (blockOperands.empty() ||
|
||||
blockOperands.isOperandProduced(blockArg.getArgNumber()))
|
||||
continue;
|
||||
|
||||
Operation *definingOp =
|
||||
terminator
|
||||
->getOperand(blockOperands->getBeginOperandIndex() +
|
||||
blockArg.getArgNumber())
|
||||
.getDefiningOp();
|
||||
blockOperands[blockArg.getArgNumber()].getDefiningOp();
|
||||
|
||||
// If the operand is defined by a GenericOp that will not be
|
||||
// detensored, then do not detensor the corresponding block argument.
|
||||
if (dyn_cast_or_null<GenericOp>(definingOp) &&
|
||||
if (isa_and_nonnull<GenericOp>(definingOp) &&
|
||||
opsToDetensor.count(definingOp) == 0) {
|
||||
blockArgsToRemove.insert(blockArg);
|
||||
break;
|
||||
|
@ -1515,21 +1515,20 @@ LogicalResult spirv::BitcastOp::verify() {
|
||||
// spv.BranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index == 0 && "invalid successor index");
|
||||
return targetOperandsMutable();
|
||||
return SuccessorOperands(0, targetOperandsMutable());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.BranchConditionalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
|
||||
SuccessorOperands
|
||||
spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index < 2 && "invalid successor index");
|
||||
return index == kTrueIndex ? trueTargetOperandsMutable()
|
||||
: falseTargetOperandsMutable();
|
||||
return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable()
|
||||
: falseTargetOperandsMutable());
|
||||
}
|
||||
|
||||
ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
|
||||
|
@ -18,6 +18,14 @@ using namespace mlir;
|
||||
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
|
||||
|
||||
SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
|
||||
: producedOperandCount(0), forwardedOperands(forwardedOperands) {}
|
||||
|
||||
SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
|
||||
MutableOperandRange forwardedOperands)
|
||||
: producedOperandCount(producedOperandCount),
|
||||
forwardedOperands(std::move(forwardedOperands)) {}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -26,32 +34,31 @@ using namespace mlir;
|
||||
/// successor if 'operandIndex' is within the range of 'operands', or None if
|
||||
/// `operandIndex` isn't a successor operand index.
|
||||
Optional<BlockArgument>
|
||||
detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
|
||||
detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
|
||||
unsigned operandIndex, Block *successor) {
|
||||
OperandRange forwardedOperands = operands.getForwardedOperands();
|
||||
// Check that the operands are valid.
|
||||
if (!operands || operands->empty())
|
||||
if (forwardedOperands.empty())
|
||||
return llvm::None;
|
||||
|
||||
// Check to ensure that this operand is within the range.
|
||||
unsigned operandsStart = operands->getBeginOperandIndex();
|
||||
unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
|
||||
if (operandIndex < operandsStart ||
|
||||
operandIndex >= (operandsStart + operands->size()))
|
||||
operandIndex >= (operandsStart + forwardedOperands.size()))
|
||||
return llvm::None;
|
||||
|
||||
// Index the successor.
|
||||
unsigned argIndex = operandIndex - operandsStart;
|
||||
unsigned argIndex =
|
||||
operands.getProducedOperandCount() + operandIndex - operandsStart;
|
||||
return successor->getArgument(argIndex);
|
||||
}
|
||||
|
||||
/// Verify that the given operands match those of the given successor block.
|
||||
LogicalResult
|
||||
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
|
||||
Optional<OperandRange> operands) {
|
||||
if (!operands)
|
||||
return success();
|
||||
|
||||
const SuccessorOperands &operands) {
|
||||
// Check the count.
|
||||
unsigned operandCount = operands->size();
|
||||
unsigned operandCount = operands.size();
|
||||
Block *destBB = op->getSuccessor(succNo);
|
||||
if (operandCount != destBB->getNumArguments())
|
||||
return op->emitError() << "branch has " << operandCount
|
||||
@ -60,10 +67,10 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
|
||||
<< destBB->getNumArguments();
|
||||
|
||||
// Check the types.
|
||||
auto operandIt = operands->begin();
|
||||
for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
|
||||
for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
|
||||
++i) {
|
||||
if (!cast<BranchOpInterface>(op).areTypesCompatible(
|
||||
(*operandIt).getType(), destBB->getArgument(i).getType()))
|
||||
operands[i].getType(), destBB->getArgument(i).getType()))
|
||||
return op->emitError() << "type mismatch for bb argument #" << i
|
||||
<< " of successor #" << succNo;
|
||||
}
|
||||
|
@ -441,10 +441,9 @@ static Value getPHISourceValue(Block *current, Block *pred,
|
||||
for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
|
||||
Block *successor = terminator.getSuccessor(i);
|
||||
auto branch = cast<BranchOpInterface>(terminator);
|
||||
Optional<OperandRange> successorOperands = branch.getSuccessorOperands(i);
|
||||
SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
|
||||
assert(
|
||||
(!seenSuccessors.contains(successor) ||
|
||||
(successorOperands && successorOperands->empty())) &&
|
||||
(!seenSuccessors.contains(successor) || successorOperands.empty()) &&
|
||||
"successors with arguments in LLVM branches must be different blocks");
|
||||
seenSuccessors.insert(successor);
|
||||
}
|
||||
|
@ -223,12 +223,14 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If we can't reason about the operands to a successor, conservatively mark
|
||||
// all arguments as live.
|
||||
// If we can't reason about the operand to a successor, conservatively mark
|
||||
// it as live.
|
||||
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
|
||||
if (!branchInterface.getMutableSuccessorOperands(i))
|
||||
for (BlockArgument arg : op->getSuccessor(i)->getArguments())
|
||||
liveMap.setProvedLive(arg);
|
||||
SuccessorOperands successorOperands =
|
||||
branchInterface.getSuccessorOperands(i);
|
||||
for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
|
||||
opI != opE; ++opI)
|
||||
liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
|
||||
}
|
||||
}
|
||||
|
||||
@ -291,18 +293,15 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
|
||||
// since it will promote later operands of the terminator being erased
|
||||
// first, reducing the quadratic-ness.
|
||||
unsigned succ = succE - succI - 1;
|
||||
Optional<MutableOperandRange> succOperands =
|
||||
branchOp.getMutableSuccessorOperands(succ);
|
||||
if (!succOperands)
|
||||
continue;
|
||||
SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
|
||||
Block *successor = terminator->getSuccessor(succ);
|
||||
|
||||
for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) {
|
||||
for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
|
||||
// Iterating args in reverse is needed for correctness, to avoid
|
||||
// shifting later args when earlier args are erased.
|
||||
unsigned arg = argE - argI - 1;
|
||||
if (!liveMap.wasProvenLive(successor->getArgument(arg)))
|
||||
succOperands->erase(arg);
|
||||
succOperands.erase(arg);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -570,8 +569,7 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
|
||||
/// their operands updated.
|
||||
static bool ableToUpdatePredOperands(Block *block) {
|
||||
for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
|
||||
auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator());
|
||||
if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex()))
|
||||
if (!isa<BranchOpInterface>((*it)->getTerminator()))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@ -631,7 +629,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
|
||||
predIt != predE; ++predIt) {
|
||||
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
|
||||
unsigned succIndex = predIt.getSuccessorIndex();
|
||||
branch.getMutableSuccessorOperands(succIndex)->append(
|
||||
branch.getSuccessorOperands(succIndex).append(
|
||||
newArguments[clusterIndex]);
|
||||
}
|
||||
};
|
||||
|
@ -198,3 +198,21 @@ func @recheck_executable_edge(%cond0: i1) -> (i1, i1) {
|
||||
// CHECK: return %[[X]], %[[Y]]
|
||||
return %x, %y : i1, i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @simple_produced_operand
|
||||
func @simple_produced_operand() -> (i32, i32) {
|
||||
// CHECK: %[[ONE:.*]] = arith.constant 1
|
||||
%1 = arith.constant 1 : i32
|
||||
"test.internal_br"(%1) [^bb1, ^bb2] {
|
||||
operand_segment_sizes = dense<[0, 1]> : vector<2 x i32>
|
||||
} : (i32) -> ()
|
||||
|
||||
^bb1:
|
||||
cf.br ^bb2(%1, %1 : i32, i32)
|
||||
|
||||
^bb2(%arg1 : i32, %arg2 : i32):
|
||||
// CHECK: ^bb2(%[[ARG:.*]]: i32, %{{.*}}: i32):
|
||||
// CHECK: return %[[ARG]], %[[ONE]] : i32, i32
|
||||
|
||||
return %arg1, %arg2 : i32, i32
|
||||
}
|
||||
|
@ -335,22 +335,31 @@ TestDialect::getOperationPrinter(Operation *op) const {
|
||||
// TestBranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
TestBranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index == 0 && "invalid successor index");
|
||||
return getTargetOperandsMutable();
|
||||
return SuccessorOperands(getTargetOperandsMutable());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestProducingBranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Optional<MutableOperandRange>
|
||||
TestProducingBranchOp::getMutableSuccessorOperands(unsigned index) {
|
||||
SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index <= 1 && "invalid successor index");
|
||||
if (index == 1)
|
||||
return getFirstOperandsMutable();
|
||||
return getSecondOperandsMutable();
|
||||
return SuccessorOperands(getFirstOperandsMutable());
|
||||
return SuccessorOperands(getSecondOperandsMutable());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TestProducingBranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
|
||||
assert(index <= 1 && "invalid successor index");
|
||||
if (index == 0)
|
||||
return SuccessorOperands(0, getSuccessOperandsMutable());
|
||||
return SuccessorOperands(1, getErrorOperandsMutable());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -642,6 +642,17 @@ def TestProducingBranchOp : TEST_Op<"producing_br",
|
||||
let successors = (successor AnySuccessor:$first,AnySuccessor:$second);
|
||||
}
|
||||
|
||||
// Produces an error value on the error path
|
||||
def TestInternalBranchOp : TEST_Op<"internal_br",
|
||||
[DeclareOpInterfaceMethods<BranchOpInterface>, Terminator,
|
||||
AttrSizedOperandSegments]> {
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$successOperands,
|
||||
Variadic<AnyType>:$errorOperands);
|
||||
|
||||
let successors = (successor AnySuccessor:$successPath, AnySuccessor:$errorPath);
|
||||
}
|
||||
|
||||
def AttrSizedOperandOp : TEST_Op<"attr_sized_operands",
|
||||
[AttrSizedOperandSegments]> {
|
||||
let arguments = (ins
|
||||
|
Loading…
x
Reference in New Issue
Block a user