[mlir] Add support for operation-produced successor arguments in BranchOpInterface

This patch revamps the BranchOpInterface a bit and allows a proper implementation of what was previously `getMutableSuccessorOperands` for operations, which internally produce arguments to some of the block arguments. A motivating example for this would be an invoke op with a error handling path:
```
invoke %function(%0)
  label ^success ^error(%1 : i32)

^error(%e: !error, %arg0 : i32):
  ...
```
The advantages of this are that any users of `BranchOpInterface` can still argue over remaining block argument operands (such as `%1` in the example above), as well as make use of the modifying capabilities to add more operands, erase an operand etc.

The way this patch implements that functionality is via a new class called `SuccessorOperands`, which is now returned by `getSuccessorOperands`. It basically contains an `unsigned` denoting how many operator produced operands exist, as well as a `MutableOperandRange`, which are the usual forwarded operands we are used to. The produced operands are assumed to the first few block arguments, followed by the forwarded operands afterwards. The role of `SuccessorOperands` is to provide various utility functions to modify and query the successor arguments from a `BranchOpInterface`.

Differential Revision: https://reviews.llvm.org/D123062
This commit is contained in:
Markus Böck 2022-04-08 08:17:36 +02:00
parent 795b07f549
commit 0c789db541
20 changed files with 291 additions and 154 deletions

View File

@ -489,16 +489,12 @@ class fir_SwitchTerminatorOp<string mnemonic, list<Trait> traits = []> :
llvm::ArrayRef<mlir::Value> operands, unsigned cond);
llvm::Optional<mlir::ValueRange> getSuccessorOperands(
mlir::ValueRange operands, unsigned cond);
using BranchOpInterfaceTrait::getSuccessorOperands;
// Helper function to deal with Optional operand forms
void printSuccessorAtIndex(mlir::OpAsmPrinter &p, unsigned i) {
auto *succ = getSuccessor(i);
auto ops = getSuccessorOperands(i);
if (ops.hasValue())
p.printSuccessorAndUseList(succ, ops.getValue());
else
p.printSuccessor(succ);
p.printSuccessorAndUseList(succ, ops.getForwardedOperands());
}
mlir::ArrayAttr getCases() {

View File

@ -2401,10 +2401,9 @@ fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
llvm::Optional<mlir::MutableOperandRange>
fir::SelectOp::getMutableSuccessorOperands(unsigned oper) {
return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
getTargetOffsetAttr());
mlir::SuccessorOperands fir::SelectOp::getSuccessorOperands(unsigned oper) {
return mlir::SuccessorOperands(::getMutableSuccessorOperands(
oper, getTargetArgsMutable(), getTargetOffsetAttr()));
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@ -2462,10 +2461,9 @@ fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands,
return {getSubOperands(cond, getSubOperands(1, operands, segments), a)};
}
llvm::Optional<mlir::MutableOperandRange>
fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) {
return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
getTargetOffsetAttr());
mlir::SuccessorOperands fir::SelectCaseOp::getSuccessorOperands(unsigned oper) {
return mlir::SuccessorOperands(::getMutableSuccessorOperands(
oper, getTargetArgsMutable(), getTargetOffsetAttr()));
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@ -2734,10 +2732,9 @@ fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
llvm::Optional<mlir::MutableOperandRange>
fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) {
return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
getTargetOffsetAttr());
mlir::SuccessorOperands fir::SelectRankOp::getSuccessorOperands(unsigned oper) {
return mlir::SuccessorOperands(::getMutableSuccessorOperands(
oper, getTargetArgsMutable(), getTargetOffsetAttr()));
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@ -2779,10 +2776,9 @@ fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
llvm::Optional<mlir::MutableOperandRange>
fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) {
return ::getMutableSuccessorOperands(oper, getTargetArgsMutable(),
getTargetOffsetAttr());
mlir::SuccessorOperands fir::SelectTypeOp::getSuccessorOperands(unsigned oper) {
return mlir::SuccessorOperands(::getMutableSuccessorOperands(
oper, getTargetArgsMutable(), getTargetOffsetAttr()));
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>

View File

@ -907,6 +907,11 @@ public:
/// elements attribute, which contains the sizes of the sub ranges.
MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
/// Returns the value at the given index.
Value operator[](unsigned index) const {
return static_cast<OperandRange>(*this)[index];
}
private:
/// Update the length of this range to the one provided.
void updateLength(unsigned newLength);

View File

@ -20,6 +20,106 @@ namespace mlir {
class BranchOpInterface;
class RegionBranchOpInterface;
/// This class models how operands are forwarded to block arguments in control
/// flow. It consists of a number, denoting how many of the successors block
/// arguments are produced by the operation, followed by a range of operands
/// that are forwarded. The produced operands are passed to the first few
/// block arguments of the successor, followed by the forwarded operands.
/// It is unsupported to pass them in a different order.
///
/// An example operation with both of these concepts would be a branch-on-error
/// operation, that internally produces an error object on the error path:
///
/// invoke %function(%0)
/// label ^success ^error(%1 : i32)
///
/// ^error(%e: !error, %arg0 : i32):
/// ...
///
/// This operation would return an instance of SuccessorOperands with a produced
/// operand count of 1 (mapped to %e in the successor) and a forwarded
/// operands range consisting of %1 in the example above (mapped to %arg0 in the
/// successor).
class SuccessorOperands {
public:
/// Constructs a SuccessorOperands with no produced operands that simply
/// forwards operands to the successor.
explicit SuccessorOperands(MutableOperandRange forwardedOperands);
/// Constructs a SuccessorOperands with the given amount of produced operands
/// and forwarded operands.
SuccessorOperands(unsigned producedOperandCount,
MutableOperandRange forwardedOperands);
/// Returns the amount of operands passed to the successor. This consists both
/// of produced operands by the operation as well as forwarded ones.
unsigned size() const {
return producedOperandCount + forwardedOperands.size();
}
/// Returns true if there are no successor operands.
bool empty() const { return size() == 0; }
/// Returns the amount of operands that are produced internally by the
/// operation. These are passed to the first few block arguments.
unsigned getProducedOperandCount() const { return producedOperandCount; }
/// Returns true if the successor operand denoted by `index` is produced by
/// the operation.
bool isOperandProduced(unsigned index) const {
return index < producedOperandCount;
}
/// Returns the Value that is passed to the successors block argument denoted
/// by `index`. If it is produced by the operation, no such value exists and
/// a null Value is returned.
Value operator[](unsigned index) const {
if (isOperandProduced(index))
return Value();
return forwardedOperands[index - producedOperandCount];
}
/// Get the range of operands that are simply forwarded to the successor.
OperandRange getForwardedOperands() const { return forwardedOperands; }
/// Get a slice of the operands forwarded to the successor. The given range
/// must not contain any operands produced by the operation.
MutableOperandRange slice(unsigned subStart, unsigned subLen) const {
assert(!isOperandProduced(subStart) &&
"can't slice operands produced by the operation");
return forwardedOperands.slice(subStart - producedOperandCount, subLen);
}
/// Erase operands forwarded to the successor. The given range must
/// not contain any operands produced by the operation.
void erase(unsigned subStart, unsigned subLen = 1) {
assert(!isOperandProduced(subStart) &&
"can't erase operands produced by the operation");
forwardedOperands.erase(subStart - producedOperandCount, subLen);
}
/// Add new operands that are forwarded to the successor.
void append(ValueRange valueRange) { forwardedOperands.append(valueRange); }
/// Gets the index of the forwarded operand within the operation which maps
/// to the block argument denoted by `blockArgumentIndex`. The block argument
/// must be mapped to a forwarded operand.
unsigned getOperandIndex(unsigned blockArgumentIndex) const {
assert(!isOperandProduced(blockArgumentIndex) &&
"can't map operand produced by the operation");
return static_cast<mlir::OperandRange>(forwardedOperands)
.getBeginOperandIndex() +
(blockArgumentIndex - producedOperandCount);
}
private:
/// Amount of operands that are produced internally within the operation and
/// passed to the first few block arguments.
unsigned producedOperandCount;
/// Range of operands that are forwarded to the remaining block arguments.
MutableOperandRange forwardedOperands;
};
//===----------------------------------------------------------------------===//
// BranchOpInterface
//===----------------------------------------------------------------------===//
@ -29,12 +129,12 @@ namespace detail {
/// successor if `operandIndex` is within the range of `operands`, or None if
/// `operandIndex` isn't a successor operand index.
Optional<BlockArgument>
getBranchSuccessorArgument(Optional<OperandRange> operands,
getBranchSuccessorArgument(const SuccessorOperands &operands,
unsigned operandIndex, Block *successor);
/// Verify that the given operands match those of the given successor block.
LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
Optional<OperandRange> operands);
const SuccessorOperands &operands);
} // namespace detail
//===----------------------------------------------------------------------===//

View File

@ -36,26 +36,35 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
let methods = [
InterfaceMethod<[{
Returns a mutable range of operands that correspond to the arguments of
successor at the given index. Returns None if the operands to the
successor are non-materialized values, i.e. they are internal to the
operation.
Returns the operands that correspond to the arguments of the successor
at the given index. It consists of a number of operands that are
internally produced by the operation, followed by a range of operands
that are forwarded. An example operation making use of produced
operands would be:
```mlir
invoke %function(%0)
label ^success ^error(%1 : i32)
^error(%e: !error, %arg0: i32):
...
```
The operand that would map to the `^error`s `%e` operand is produced
by the `invoke` operation, while `%1` is a forwarded operand that maps
to `%arg0` in the successor.
Produced operands always map to the first few block arguments of the
successor, followed by the forwarded operands. Mapping them in any
other order is not supported by the interface.
By having the forwarded operands last allows users of the interface
to append more forwarded operands to the branch operation without
interfering with other successor operands.
}],
"::mlir::Optional<::mlir::MutableOperandRange>", "getMutableSuccessorOperands",
"::mlir::SuccessorOperands", "getSuccessorOperands",
(ins "unsigned":$index)
>,
InterfaceMethod<[{
Returns a range of operands that correspond to the arguments of
successor at the given index. Returns None if the operands to the
successor are non-materialized values, i.e. they are internal to the
operation.
}],
"::mlir::Optional<::mlir::OperandRange>", "getSuccessorOperands",
(ins "unsigned":$index), [{}], [{
auto operands = $_op.getMutableSuccessorOperands(index);
return operands ? ::mlir::Optional<::mlir::OperandRange>(*operands) : ::llvm::None;
}]
>,
InterfaceMethod<[{
Returns the `BlockArgument` corresponding to operand `operandIndex` in
some successor, or None if `operandIndex` isn't a successor operand
@ -94,7 +103,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
let verify = [{
auto concreteOp = ::mlir::cast<ConcreteOp>($_op);
for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) {
::mlir::Optional<OperandRange> operands = concreteOp.getSuccessorOperands(i);
::mlir::SuccessorOperands operands = concreteOp.getSuccessorOperands(i);
if (::mlir::failed(::mlir::detail::verifyBranchSuccessorOperands($_op, i, operands)))
return ::mlir::failure();
}

View File

@ -149,14 +149,13 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
// Try to get the operand passed for this argument.
unsigned index = it.getSuccessorIndex();
Optional<OperandRange> operands = branch.getSuccessorOperands(index);
if (!operands) {
Value operand = branch.getSuccessorOperands(index)[argNumber];
if (!operand) {
// We can't analyze the control flow, so bail out early.
output.push_back(arg);
return;
}
collectUnderlyingAddressValues((*operands)[argNumber], maxDepth, visited,
output);
collectUnderlyingAddressValues(operand, maxDepth, visited, output);
}
return;
}

View File

@ -70,10 +70,10 @@ void BufferViewFlowAnalysis::build(Operation *op) {
// Query the branch op interface to get the successor operands.
auto successorOperands =
branchInterface.getSuccessorOperands(it.getIndex());
if (!successorOperands.hasValue())
continue;
// Build the actual mapping of values to their immediate dependencies.
registerDependencies(successorOperands.getValue(), (*it)->getArguments());
registerDependencies(successorOperands.getForwardedOperands(),
(*it)->getArguments().drop_front(
successorOperands.getProducedOperandCount()));
}
});

View File

@ -681,10 +681,13 @@ void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
// Try to get the operand forwarded by the predecessor. If we can't reason
// about the terminator of the predecessor, mark as having reached a
// fixpoint.
Optional<OperandRange> branchOperands;
if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()))
branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex());
if (!branchOperands) {
auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
if (!branch) {
updatedLattice |= argLattice.markPessimisticFixpoint();
break;
}
Value operand = branch.getSuccessorOperands(it.getSuccessorIndex())[i];
if (!operand) {
updatedLattice |= argLattice.markPessimisticFixpoint();
break;
}
@ -692,7 +695,7 @@ void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
// If the operand hasn't been resolved, it is uninitialized and can merge
// with anything.
AbstractLatticeElement *operandLattice =
analysis.lookupLatticeElement((*branchOperands)[i]);
analysis.lookupLatticeElement(operand);
if (!operandLattice)
continue;

View File

@ -325,25 +325,20 @@ private:
// argument.
Operation *terminator = (*it)->getTerminator();
auto branchInterface = cast<BranchOpInterface>(terminator);
SuccessorOperands operands =
branchInterface.getSuccessorOperands(it.getSuccessorIndex());
// Query the associated source value.
Value sourceValue =
branchInterface.getSuccessorOperands(it.getSuccessorIndex())
.getValue()[blockArg.getArgNumber()];
// Wire new clone and successor operand.
auto mutableOperands =
branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex());
if (!mutableOperands) {
terminator->emitError() << "terminators with immutable successor "
"operands are not supported";
continue;
Value sourceValue = operands[blockArg.getArgNumber()];
if (!sourceValue) {
return failure();
}
// Wire new clone and successor operand.
// Create a new clone at the current location of the terminator.
auto clone = introduceCloneBuffers(sourceValue, terminator);
if (failed(clone))
return failure();
mutableOperands.getValue()
.slice(blockArg.getArgNumber(), 1)
.assign(*clone);
operands.slice(blockArg.getArgNumber(), 1).assign(*clone);
}
// Check whether the block argument has implicitly defined predecessors via

View File

@ -186,10 +186,9 @@ void BranchOp::setDest(Block *block) { return setSuccessor(block); }
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
Optional<MutableOperandRange>
BranchOp::getMutableSuccessorOperands(unsigned index) {
SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getDestOperandsMutable();
return SuccessorOperands(getDestOperandsMutable());
}
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
@ -437,11 +436,10 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
CondBranchTruthPropagation>(context);
}
Optional<MutableOperandRange>
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == trueIndex ? getTrueDestOperandsMutable()
: getFalseDestOperandsMutable();
return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
: getFalseDestOperandsMutable());
}
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
@ -575,11 +573,10 @@ LogicalResult SwitchOp::verify() {
return success();
}
Optional<MutableOperandRange>
SwitchOp::getMutableSuccessorOperands(unsigned index) {
SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? getDefaultOperandsMutable()
: getCaseOperandsMutable(index - 1);
return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
: getCaseOperandsMutable(index - 1));
}
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {

View File

@ -67,12 +67,13 @@ public:
SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
succIdx < succEnd; ++succIdx) {
auto successorOperands = op.getSuccessorOperands(succIdx);
if (!successorOperands || successorOperands->empty())
OperandRange forwardedOperands =
op.getSuccessorOperands(succIdx).getForwardedOperands();
if (forwardedOperands.empty())
continue;
for (int idx = successorOperands->getBeginOperandIndex(),
eidx = idx + successorOperands->size();
for (int idx = forwardedOperands.getBeginOperandIndex(),
eidx = idx + forwardedOperands.size();
idx < eidx; ++idx) {
if (!shouldConvertBranchOperand || shouldConvertBranchOperand(op, idx))
newOperands[idx] = operands[idx];
@ -121,8 +122,8 @@ bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
auto successorOperands = branchOp.getSuccessorOperands(p);
if (successorOperands.hasValue() &&
!converter.isLegal(successorOperands.getValue().getTypes()))
if (!converter.isLegal(
successorOperands.getForwardedOperands().getTypes()))
return false;
}
return true;

View File

@ -240,21 +240,19 @@ ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
// LLVM::BrOp
//===----------------------------------------------------------------------===//
Optional<MutableOperandRange>
BrOp::getMutableSuccessorOperands(unsigned index) {
SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getDestOperandsMutable();
return SuccessorOperands(getDestOperandsMutable());
}
//===----------------------------------------------------------------------===//
// LLVM::CondBrOp
//===----------------------------------------------------------------------===//
Optional<MutableOperandRange>
CondBrOp::getMutableSuccessorOperands(unsigned index) {
SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? getTrueDestOperandsMutable()
: getFalseDestOperandsMutable();
return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
: getFalseDestOperandsMutable());
}
//===----------------------------------------------------------------------===//
@ -356,11 +354,10 @@ LogicalResult SwitchOp::verify() {
return success();
}
Optional<MutableOperandRange>
SwitchOp::getMutableSuccessorOperands(unsigned index) {
SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? getDefaultOperandsMutable()
: getCaseOperandsMutable(index - 1);
return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
: getCaseOperandsMutable(index - 1));
}
//===----------------------------------------------------------------------===//
@ -735,11 +732,10 @@ ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
/// LLVM::InvokeOp
///===---------------------------------------------------------------------===//
Optional<MutableOperandRange>
InvokeOp::getMutableSuccessorOperands(unsigned index) {
SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? getNormalDestOperandsMutable()
: getUnwindDestOperandsMutable();
return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
: getUnwindDestOperandsMutable());
}
LogicalResult InvokeOp::verify() {

View File

@ -223,12 +223,12 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
auto blockOperands =
terminator.getSuccessorOperands(pred.getSuccessorIndex());
if (!blockOperands || blockOperands->empty())
if (blockOperands.empty() ||
blockOperands.isOperandProduced(blockArgumentElem.getArgNumber()))
continue;
detensorableBranchOps[terminator].insert(
blockOperands->getBeginOperandIndex() +
blockArgumentElem.getArgNumber());
blockOperands.getOperandIndex(blockArgumentElem.getArgNumber()));
}
}
@ -343,14 +343,15 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
auto ownerBlockOperands =
predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
if (!ownerBlockOperands || ownerBlockOperands->empty())
if (ownerBlockOperands.empty() ||
ownerBlockOperands.isOperandProduced(
currentItemBlockArgument.getArgNumber()))
continue;
// For each predecessor, add the value it passes to that argument to
// workList to find out how it's computed.
workList.push_back(
ownerBlockOperands
.getValue()[currentItemBlockArgument.getArgNumber()]);
ownerBlockOperands[currentItemBlockArgument.getArgNumber()]);
}
continue;
@ -418,18 +419,16 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
auto blockOperands =
terminator.getSuccessorOperands(pred.getSuccessorIndex());
if (!blockOperands || blockOperands->empty())
if (blockOperands.empty() ||
blockOperands.isOperandProduced(blockArg.getArgNumber()))
continue;
Operation *definingOp =
terminator
->getOperand(blockOperands->getBeginOperandIndex() +
blockArg.getArgNumber())
.getDefiningOp();
blockOperands[blockArg.getArgNumber()].getDefiningOp();
// If the operand is defined by a GenericOp that will not be
// detensored, then do not detensor the corresponding block argument.
if (dyn_cast_or_null<GenericOp>(definingOp) &&
if (isa_and_nonnull<GenericOp>(definingOp) &&
opsToDetensor.count(definingOp) == 0) {
blockArgsToRemove.insert(blockArg);
break;

View File

@ -1515,21 +1515,20 @@ LogicalResult spirv::BitcastOp::verify() {
// spv.BranchOp
//===----------------------------------------------------------------------===//
Optional<MutableOperandRange>
spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return targetOperandsMutable();
return SuccessorOperands(0, targetOperandsMutable());
}
//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//
Optional<MutableOperandRange>
spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
SuccessorOperands
spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
assert(index < 2 && "invalid successor index");
return index == kTrueIndex ? trueTargetOperandsMutable()
: falseTargetOperandsMutable();
return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable()
: falseTargetOperandsMutable());
}
ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,

View File

@ -18,6 +18,14 @@ using namespace mlir;
#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
: producedOperandCount(0), forwardedOperands(forwardedOperands) {}
SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
MutableOperandRange forwardedOperands)
: producedOperandCount(producedOperandCount),
forwardedOperands(std::move(forwardedOperands)) {}
//===----------------------------------------------------------------------===//
// BranchOpInterface
//===----------------------------------------------------------------------===//
@ -26,32 +34,31 @@ using namespace mlir;
/// successor if 'operandIndex' is within the range of 'operands', or None if
/// `operandIndex` isn't a successor operand index.
Optional<BlockArgument>
detail::getBranchSuccessorArgument(Optional<OperandRange> operands,
detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
unsigned operandIndex, Block *successor) {
OperandRange forwardedOperands = operands.getForwardedOperands();
// Check that the operands are valid.
if (!operands || operands->empty())
if (forwardedOperands.empty())
return llvm::None;
// Check to ensure that this operand is within the range.
unsigned operandsStart = operands->getBeginOperandIndex();
unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
if (operandIndex < operandsStart ||
operandIndex >= (operandsStart + operands->size()))
operandIndex >= (operandsStart + forwardedOperands.size()))
return llvm::None;
// Index the successor.
unsigned argIndex = operandIndex - operandsStart;
unsigned argIndex =
operands.getProducedOperandCount() + operandIndex - operandsStart;
return successor->getArgument(argIndex);
}
/// Verify that the given operands match those of the given successor block.
LogicalResult
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
Optional<OperandRange> operands) {
if (!operands)
return success();
const SuccessorOperands &operands) {
// Check the count.
unsigned operandCount = operands->size();
unsigned operandCount = operands.size();
Block *destBB = op->getSuccessor(succNo);
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
@ -60,10 +67,10 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
<< destBB->getNumArguments();
// Check the types.
auto operandIt = operands->begin();
for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
++i) {
if (!cast<BranchOpInterface>(op).areTypesCompatible(
(*operandIt).getType(), destBB->getArgument(i).getType()))
operands[i].getType(), destBB->getArgument(i).getType()))
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}

View File

@ -441,10 +441,9 @@ static Value getPHISourceValue(Block *current, Block *pred,
for (unsigned i = 0, e = terminator.getNumSuccessors(); i < e; ++i) {
Block *successor = terminator.getSuccessor(i);
auto branch = cast<BranchOpInterface>(terminator);
Optional<OperandRange> successorOperands = branch.getSuccessorOperands(i);
SuccessorOperands successorOperands = branch.getSuccessorOperands(i);
assert(
(!seenSuccessors.contains(successor) ||
(successorOperands && successorOperands->empty())) &&
(!seenSuccessors.contains(successor) || successorOperands.empty()) &&
"successors with arguments in LLVM branches must be different blocks");
seenSuccessors.insert(successor);
}

View File

@ -223,12 +223,14 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
return;
}
// If we can't reason about the operands to a successor, conservatively mark
// all arguments as live.
// If we can't reason about the operand to a successor, conservatively mark
// it as live.
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
if (!branchInterface.getMutableSuccessorOperands(i))
for (BlockArgument arg : op->getSuccessor(i)->getArguments())
liveMap.setProvedLive(arg);
SuccessorOperands successorOperands =
branchInterface.getSuccessorOperands(i);
for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
opI != opE; ++opI)
liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
}
}
@ -291,18 +293,15 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
// since it will promote later operands of the terminator being erased
// first, reducing the quadratic-ness.
unsigned succ = succE - succI - 1;
Optional<MutableOperandRange> succOperands =
branchOp.getMutableSuccessorOperands(succ);
if (!succOperands)
continue;
SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
Block *successor = terminator->getSuccessor(succ);
for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) {
for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
// Iterating args in reverse is needed for correctness, to avoid
// shifting later args when earlier args are erased.
unsigned arg = argE - argI - 1;
if (!liveMap.wasProvenLive(successor->getArgument(arg)))
succOperands->erase(arg);
succOperands.erase(arg);
}
}
}
@ -570,8 +569,7 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
/// their operands updated.
static bool ableToUpdatePredOperands(Block *block) {
for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator());
if (!branch || !branch.getMutableSuccessorOperands(it.getSuccessorIndex()))
if (!isa<BranchOpInterface>((*it)->getTerminator()))
return false;
}
return true;
@ -631,7 +629,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
predIt != predE; ++predIt) {
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
unsigned succIndex = predIt.getSuccessorIndex();
branch.getMutableSuccessorOperands(succIndex)->append(
branch.getSuccessorOperands(succIndex).append(
newArguments[clusterIndex]);
}
};

View File

@ -198,3 +198,21 @@ func @recheck_executable_edge(%cond0: i1) -> (i1, i1) {
// CHECK: return %[[X]], %[[Y]]
return %x, %y : i1, i1
}
// CHECK-LABEL: func @simple_produced_operand
func @simple_produced_operand() -> (i32, i32) {
// CHECK: %[[ONE:.*]] = arith.constant 1
%1 = arith.constant 1 : i32
"test.internal_br"(%1) [^bb1, ^bb2] {
operand_segment_sizes = dense<[0, 1]> : vector<2 x i32>
} : (i32) -> ()
^bb1:
cf.br ^bb2(%1, %1 : i32, i32)
^bb2(%arg1 : i32, %arg2 : i32):
// CHECK: ^bb2(%[[ARG:.*]]: i32, %{{.*}}: i32):
// CHECK: return %[[ARG]], %[[ONE]] : i32, i32
return %arg1, %arg2 : i32, i32
}

View File

@ -335,22 +335,31 @@ TestDialect::getOperationPrinter(Operation *op) const {
// TestBranchOp
//===----------------------------------------------------------------------===//
Optional<MutableOperandRange>
TestBranchOp::getMutableSuccessorOperands(unsigned index) {
SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getTargetOperandsMutable();
return SuccessorOperands(getTargetOperandsMutable());
}
//===----------------------------------------------------------------------===//
// TestProducingBranchOp
//===----------------------------------------------------------------------===//
Optional<MutableOperandRange>
TestProducingBranchOp::getMutableSuccessorOperands(unsigned index) {
SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
assert(index <= 1 && "invalid successor index");
if (index == 1)
return getFirstOperandsMutable();
return getSecondOperandsMutable();
return SuccessorOperands(getFirstOperandsMutable());
return SuccessorOperands(getSecondOperandsMutable());
}
//===----------------------------------------------------------------------===//
// TestProducingBranchOp
//===----------------------------------------------------------------------===//
SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
assert(index <= 1 && "invalid successor index");
if (index == 0)
return SuccessorOperands(0, getSuccessOperandsMutable());
return SuccessorOperands(1, getErrorOperandsMutable());
}
//===----------------------------------------------------------------------===//

View File

@ -642,6 +642,17 @@ def TestProducingBranchOp : TEST_Op<"producing_br",
let successors = (successor AnySuccessor:$first,AnySuccessor:$second);
}
// Produces an error value on the error path
def TestInternalBranchOp : TEST_Op<"internal_br",
[DeclareOpInterfaceMethods<BranchOpInterface>, Terminator,
AttrSizedOperandSegments]> {
let arguments = (ins Variadic<AnyType>:$successOperands,
Variadic<AnyType>:$errorOperands);
let successors = (successor AnySuccessor:$successPath, AnySuccessor:$errorPath);
}
def AttrSizedOperandOp : TEST_Op<"attr_sized_operands",
[AttrSizedOperandSegments]> {
let arguments = (ins