[mlir] Simplify BranchOpInterface by using MutableOperandRange

This range allows for performing many different operations on successor operands, including erasing/adding/setting. This removes the need for the explicit canEraseSuccessorOperand and eraseSuccessorOperand methods.

Differential Revision: https://reviews.llvm.org/D79077
This commit is contained in:
River Riddle 2020-04-29 16:09:43 -07:00
parent 91dae57087
commit 0752d98ccf
14 changed files with 102 additions and 123 deletions

View File

@ -585,6 +585,7 @@ class fir_SwitchTerminatorOp<string mnemonic, list<OpTrait> traits = []> :
llvm::Optional<llvm::ArrayRef<mlir::Value>> getSuccessorOperands(
llvm::ArrayRef<mlir::Value> operands, unsigned cond);
using BranchOpInterfaceTrait::getSuccessorOperands;
// Helper function to deal with Optional operand forms
void printSuccessorAtIndex(mlir::OpAsmPrinter &p, unsigned i) {

View File

@ -997,14 +997,26 @@ static constexpr llvm::StringRef getTargetOffsetAttr() {
return "target_operand_offsets";
}
template <typename A>
template <typename A, typename... AdditionalArgs>
static A getSubOperands(unsigned pos, A allArgs,
mlir::DenseIntElementsAttr ranges) {
mlir::DenseIntElementsAttr ranges,
AdditionalArgs &&... additionalArgs) {
unsigned start = 0;
for (unsigned i = 0; i < pos; ++i)
start += (*(ranges.begin() + i)).getZExtValue();
unsigned end = start + (*(ranges.begin() + pos)).getZExtValue();
return {std::next(allArgs.begin(), start), std::next(allArgs.begin(), end)};
return allArgs.slice(start, (*(ranges.begin() + pos)).getZExtValue(),
std::forward<AdditionalArgs>(additionalArgs)...);
}
static mlir::MutableOperandRange
getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands,
StringRef offsetAttr) {
Operation *owner = operands.getOwner();
NamedAttribute targetOffsetAttr =
*owner->getMutableAttrDict().getNamed(offsetAttr);
return getSubOperands(
pos, operands, targetOffsetAttr.second.cast<DenseIntElementsAttr>(),
mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr));
}
static unsigned denseElementsSize(mlir::DenseIntElementsAttr attr) {
@ -1020,10 +1032,10 @@ fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
llvm::Optional<mlir::OperandRange>
fir::SelectOp::getSuccessorOperands(unsigned oper) {
auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
return {getSubOperands(oper, targetArgs(), a)};
llvm::Optional<mlir::MutableOperandRange>
fir::SelectOp::getMutableSuccessorOperands(unsigned oper) {
return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
getTargetOffsetAttr());
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@ -1035,8 +1047,6 @@ fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
bool fir::SelectOp::canEraseSuccessorOperand() { return true; }
unsigned fir::SelectOp::targetOffsetSize() {
return denseElementsSize(
getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()));
@ -1061,10 +1071,10 @@ fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(cond, getSubOperands(1, operands, segments), a)};
}
llvm::Optional<mlir::OperandRange>
fir::SelectCaseOp::getSuccessorOperands(unsigned oper) {
auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
return {getSubOperands(oper, targetArgs(), a)};
llvm::Optional<mlir::MutableOperandRange>
fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) {
return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
getTargetOffsetAttr());
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@ -1076,8 +1086,6 @@ fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
bool fir::SelectCaseOp::canEraseSuccessorOperand() { return true; }
// parser for fir.select_case Op
static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
@ -1254,10 +1262,10 @@ fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
llvm::Optional<mlir::OperandRange>
fir::SelectRankOp::getSuccessorOperands(unsigned oper) {
auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
return {getSubOperands(oper, targetArgs(), a)};
llvm::Optional<mlir::MutableOperandRange>
fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) {
return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
getTargetOffsetAttr());
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@ -1269,8 +1277,6 @@ fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
bool fir::SelectRankOp::canEraseSuccessorOperand() { return true; }
unsigned fir::SelectRankOp::targetOffsetSize() {
return denseElementsSize(
getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr()));
@ -1290,10 +1296,10 @@ fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) {
return {};
}
llvm::Optional<mlir::OperandRange>
fir::SelectTypeOp::getSuccessorOperands(unsigned oper) {
auto a = getAttrOfType<mlir::DenseIntElementsAttr>(getTargetOffsetAttr());
return {getSubOperands(oper, targetArgs(), a)};
llvm::Optional<mlir::MutableOperandRange>
fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) {
return ::getMutableSuccessorOperands(oper, targetArgsMutable(),
getTargetOffsetAttr());
}
llvm::Optional<llvm::ArrayRef<mlir::Value>>
@ -1305,8 +1311,6 @@ fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
}
bool fir::SelectTypeOp::canEraseSuccessorOperand() { return true; }
static ParseResult parseSelectType(OpAsmParser &parser,
OperationState &result) {
mlir::OpAsmParser::OperandType selector;

View File

@ -1074,7 +1074,7 @@ def CondBranchOp : Std_Op<"cond_br",
/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index) {
eraseSuccessorOperand(trueIndex, index);
trueDestOperandsMutable().erase(index);
}
// Accessors for operands to the 'false' destination.
@ -1093,7 +1093,7 @@ def CondBranchOp : Std_Op<"cond_br",
/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index) {
eraseSuccessorOperand(falseIndex, index);
falseDestOperandsMutable().erase(index);
}
private:

View File

@ -678,6 +678,10 @@ public:
ArrayRef<OperandSegment> operandSegments = llvm::None);
MutableOperandRange(Operation *owner);
/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange slice(unsigned subStart, unsigned subLen,
Optional<OperandSegment> segment = llvm::None);
/// Append the given values to the range.
void append(ValueRange values);
@ -699,6 +703,9 @@ public:
/// Allow implicit conversion to an OperandRange.
operator OperandRange() const;
/// Returns the owning operation.
Operation *getOwner() const { return owner; }
private:
/// Update the length of this range to the one provided.
void updateLength(unsigned newLength);

View File

@ -24,11 +24,6 @@ class BranchOpInterface;
//===----------------------------------------------------------------------===//
namespace detail {
/// Erase an operand from a branch operation that is used as a successor
/// operand. `operandIndex` is the operand within `operands` to be erased.
void eraseBranchSuccessorOperand(OperandRange operands, unsigned operandIndex,
Operation *op);
/// Return the `BlockArgument` corresponding to operand `operandIndex` in some
/// successor if `operandIndex` is within the range of `operands`, or None if
/// `operandIndex` isn't a successor operand index.

View File

@ -27,29 +27,25 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
}];
let methods = [
InterfaceMethod<[{
Returns a set of values that correspond to the arguments to the
Returns a mutable range of operands that correspond to the arguments of
successor at the given index. Returns None if the operands to the
successor are non-materialized values, i.e. they are internal to the
operation.
}],
"Optional<OperandRange>", "getSuccessorOperands", (ins "unsigned":$index)
"Optional<MutableOperandRange>", "getMutableSuccessorOperands",
(ins "unsigned":$index)
>,
InterfaceMethod<[{
Return true if this operation can erase an operand to a successor block.
Returns a range of operands that correspond to the arguments of
successor at the given index. Returns None if the operands to the
successor are non-materialized values, i.e. they are internal to the
operation.
}],
"bool", "canEraseSuccessorOperand"
>,
InterfaceMethod<[{
Erase the operand at `operandIndex` from the `index`-th successor. This
should only be called if `canEraseSuccessorOperand` returns true.
}],
"void", "eraseSuccessorOperand",
(ins "unsigned":$index, "unsigned":$operandIndex), [{}],
/*defaultImplementation=*/[{
"Optional<OperandRange>", "getSuccessorOperands",
(ins "unsigned":$index), [{}], [{
ConcreteOp *op = static_cast<ConcreteOp *>(this);
Optional<OperandRange> operands = op->getSuccessorOperands(index);
assert(operands && "unable to query operands for successor");
detail::eraseBranchSuccessorOperand(*operands, operandIndex, *op);
auto operands = op->getMutableSuccessorOperands(index);
return operands ? Optional<OperandRange>(*operands) : llvm::None;
}]
>,
InterfaceMethod<[{

View File

@ -160,24 +160,22 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
// LLVM::BrOp
//===----------------------------------------------------------------------===//
Optional<OperandRange> BrOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
BrOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getOperands();
return destOperandsMutable();
}
bool BrOp::canEraseSuccessorOperand() { return true; }
//===----------------------------------------------------------------------===//
// LLVM::CondBrOp
//===----------------------------------------------------------------------===//
Optional<OperandRange> CondBrOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
CondBrOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? trueDestOperands() : falseDestOperands();
return index == 0 ? trueDestOperandsMutable() : falseDestOperandsMutable();
}
bool CondBrOp::canEraseSuccessorOperand() { return true; }
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
@ -257,13 +255,12 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
/// LLVM::InvokeOp
///===---------------------------------------------------------------------===//
Optional<OperandRange> InvokeOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
InvokeOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? normalDestOperands() : unwindDestOperands();
return index == 0 ? normalDestOperandsMutable() : unwindDestOperandsMutable();
}
bool InvokeOp::canEraseSuccessorOperand() { return true; }
static LogicalResult verify(InvokeOp op) {
if (op.getNumResults() > 1)
return op.emitOpError("must have 0 or 1 result");

View File

@ -987,26 +987,23 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
// spv.BranchOp
//===----------------------------------------------------------------------===//
Optional<OperandRange> spirv::BranchOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getOperands();
return targetOperandsMutable();
}
bool spirv::BranchOp::canEraseSuccessorOperand() { return true; }
//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//
Optional<OperandRange>
spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
assert(index < 2 && "invalid successor index");
return index == kTrueIndex ? getTrueBlockArguments()
: getFalseBlockArguments();
return index == kTrueIndex ? trueTargetOperandsMutable()
: falseTargetOperandsMutable();
}
bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; }
static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
OperationState &state) {
auto &builder = parser.getBuilder();

View File

@ -677,13 +677,12 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
context);
}
Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
BranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getOperands();
return destOperandsMutable();
}
bool BranchOp::canEraseSuccessorOperand() { return true; }
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
//===----------------------------------------------------------------------===//
@ -1021,13 +1020,13 @@ void CondBranchOp::getCanonicalizationPatterns(
SimplifyCondBranchIdenticalSuccessors>(context);
}
Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
CondBranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == trueIndex ? getTrueOperands() : getFalseOperands();
return index == trueIndex ? trueDestOperandsMutable()
: falseDestOperandsMutable();
}
bool CondBranchOp::canEraseSuccessorOperand() { return true; }
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
if (BoolAttr condAttr = operands.front().dyn_cast_or_null<BoolAttr>())
return condAttr.getValue() ? trueDest() : falseDest();

View File

@ -287,6 +287,18 @@ MutableOperandRange::MutableOperandRange(
MutableOperandRange::MutableOperandRange(Operation *owner)
: MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange
MutableOperandRange::slice(unsigned subStart, unsigned subLen,
Optional<OperandSegment> segment) {
assert((subStart + subLen) <= length && "invalid sub-range");
MutableOperandRange subSlice(owner, start + subStart, subLen,
operandSegments);
if (segment)
subSlice.operandSegments.push_back(*segment);
return subSlice;
}
/// Append the given values to the range.
void MutableOperandRange::append(ValueRange values) {
if (values.empty())

View File

@ -21,39 +21,6 @@ using namespace mlir;
// BranchOpInterface
//===----------------------------------------------------------------------===//
/// Erase an operand from a branch operation that is used as a successor
/// operand. 'operandIndex' is the operand within 'operands' to be erased.
void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands,
unsigned operandIndex,
Operation *op) {
assert(operandIndex < operands.size() &&
"invalid index for successor operands");
// Erase the operand from the operation.
size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex;
op->eraseOperand(fullOperandIndex);
// If this operation has an OperandSegmentSizeAttr, keep it up to date.
auto operandSegmentAttr =
op->getAttrOfType<DenseElementsAttr>("operand_segment_sizes");
if (!operandSegmentAttr)
return;
// Find the segment containing the full operand index and decrement it.
// TODO: This seems like a general utility that could be added somewhere.
SmallVector<int32_t, 4> values(operandSegmentAttr.getValues<int32_t>());
unsigned currentSize = 0;
for (unsigned i = 0, e = values.size(); i != e; ++i) {
currentSize += values[i];
if (fullOperandIndex < currentSize) {
--values[i];
break;
}
}
op->setAttr("operand_segment_sizes",
DenseIntElementsAttr::get(operandSegmentAttr.getType(), values));
}
/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
/// successor if 'operandIndex' is within the range of 'operands', or None if
/// `operandIndex` isn't a successor operand index.

View File

@ -209,7 +209,7 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
// Check to see if we can reason about the successor operands and mutate them.
BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
if (!branchInterface || !branchInterface.canEraseSuccessorOperand()) {
if (!branchInterface) {
for (Block *successor : op->getSuccessors())
for (BlockArgument arg : successor->getArguments())
liveMap.setProvedLive(arg);
@ -219,7 +219,7 @@ static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
// If we can't reason about the operands to a successor, conservatively mark
// all arguments as live.
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
if (!branchInterface.getSuccessorOperands(i))
if (!branchInterface.getMutableSuccessorOperands(i))
for (BlockArgument arg : op->getSuccessor(i)->getArguments())
liveMap.setProvedLive(arg);
}
@ -278,7 +278,8 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
// since it will promote later operands of the terminator being erased
// first, reducing the quadratic-ness.
unsigned succ = succE - succI - 1;
Optional<OperandRange> succOperands = branchOp.getSuccessorOperands(succ);
Optional<MutableOperandRange> succOperands =
branchOp.getMutableSuccessorOperands(succ);
if (!succOperands)
continue;
Block *successor = terminator->getSuccessor(succ);
@ -288,7 +289,7 @@ static void eraseTerminatorSuccessorOperands(Operation *terminator,
// shifting later args when earlier args are erased.
unsigned arg = argE - argI - 1;
if (!liveMap.wasProvenLive(successor->getArgument(arg)))
branchOp.eraseSuccessorOperand(succ, arg);
succOperands->erase(arg);
}
}
}

View File

@ -167,13 +167,12 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
// TestBranchOp
//===----------------------------------------------------------------------===//
Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) {
Optional<MutableOperandRange>
TestBranchOp::getMutableSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getOperands();
return targetOperandsMutable();
}
bool TestBranchOp::canEraseSuccessorOperand() { return true; }
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//

View File

@ -146,7 +146,7 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
StringRef interfaceName,
StringRef interfaceTraitsName) {
os << " template <typename ConcreteOp>\n "
<< llvm::formatv("struct Trait : public OpInterface<{0},"
<< llvm::formatv("struct {0}Trait : public OpInterface<{0},"
" detail::{1}>::Trait<ConcreteOp> {{\n",
interfaceName, interfaceTraitsName);
@ -171,13 +171,17 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
tblgen::FmtContext traitCtx;
traitCtx.withOp("op");
if (auto verify = interface.getVerify()) {
os << " static LogicalResult verifyTrait(Operation* op) {\n"
os << " static LogicalResult verifyTrait(Operation* op) {\n"
<< std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n";
}
if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
os << extraTraitDecls << "\n";
os << " };\n";
// Emit a utility using directive for the trait class.
os << " template <typename ConcreteOp>\n "
<< llvm::formatv("using Trait = {0}Trait<ConcreteOp>;\n", interfaceName);
}
static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {