[mlir] Support getSuccessorInputs from parent op

Ops that implement `RegionBranchOpInterface` are allowed to indicate that they can branch back to themselves in `getSuccessorRegions`, but there is no API that allows them to specify the forwarded operands. This patch enables that by changing `getSuccessorEntryOperands` to accept `None`.

Fixes #54928

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D127239
This commit is contained in:
Mogball 2022-06-13 22:02:02 +00:00
parent 68df5c5c13
commit 537f220891
13 changed files with 80 additions and 50 deletions

View File

@ -309,7 +309,7 @@ def ForOp : SCF_Op<"for",
/// correspond to the loop iterator operands, i.e., those exclusing the
/// induction variable. LoopOp only has one region, so 0 is the only valid
/// value for `index`.
OperandRange getSuccessorEntryOperands(unsigned index);
OperandRange getSuccessorEntryOperands(Optional<unsigned> index);
}];
let hasCanonicalizer = 1;
@ -955,7 +955,7 @@ def WhileOp : SCF_Op<"while",
let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
let extraClassDeclaration = [{
OperandRange getSuccessorEntryOperands(unsigned index);
OperandRange getSuccessorEntryOperands(Optional<unsigned> index);
ConditionOp getConditionOp();
YieldOp getYieldOp();
Block::BlockArgListType getBeforeArguments();

View File

@ -134,12 +134,14 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
InterfaceMethod<[{
Returns the operands of this operation used as the entry arguments when
entering the region at `index`, which was specified as a successor of
this operation by `getSuccessorRegions`. These operands should
correspond 1-1 with the successor inputs specified in
this operation by `getSuccessorRegions`, or the operands forwarded to
the operation's results when it branches back to itself. These operands
should correspond 1-1 with the successor inputs specified in
`getSuccessorRegions`.
}],
"::mlir::OperandRange", "getSuccessorEntryOperands",
(ins "unsigned":$index), [{}], /*defaultImplementation=*/[{
(ins "::llvm::Optional<unsigned>":$index), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
}]

View File

@ -78,12 +78,12 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
if (region) {
// Determine the actual region number from the passed region.
regionIndex = region->getRegionNumber();
if (Optional<unsigned> operandIndex =
getOperandIndexIfPred(/*predIndex=*/llvm::None)) {
collectUnderlyingAddressValues(
branch.getSuccessorEntryOperands(*regionIndex)[*operandIndex],
maxDepth, visited, output);
}
}
if (Optional<unsigned> operandIndex =
getOperandIndexIfPred(/*predIndex=*/llvm::None)) {
collectUnderlyingAddressValues(
branch.getSuccessorEntryOperands(regionIndex)[*operandIndex], maxDepth,
visited, output);
}
// Check branches from each child region.
Operation *op = branch.getOperation();

View File

@ -470,11 +470,10 @@ void ForwardDataFlowSolver::visitRegionBranchOperation(
// also allow for the parent operation to have itself as a region successor.
if (successors.empty())
return markAllPessimisticFixpoint(branch, branch->getResults());
return visitRegionSuccessors(
branch, successors, operandLattices, [&](Optional<unsigned> index) {
assert(index && "expected valid region index");
return branch.getSuccessorEntryOperands(*index);
});
return visitRegionSuccessors(branch, successors, operandLattices,
[&](Optional<unsigned> index) {
return branch.getSuccessorEntryOperands(index);
});
}
void ForwardDataFlowSolver::visitRegionSuccessors(

View File

@ -1731,11 +1731,11 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable. AffineForOp only has one region, so zero is the only
/// valid value for `index`.
OperandRange AffineForOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0 && "invalid region index");
OperandRange AffineForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
assert(!index || *index == 0 && "invalid region index");
// The initial operands map to the loop arguments after the induction
// variable.
// variable or are forwarded to the results when the trip count is zero.
return getIterOperands();
}

View File

@ -59,8 +59,8 @@ YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0 && "invalid region index");
OperandRange ExecuteOp::getSuccessorEntryOperands(Optional<unsigned> index) {
assert(index && *index == 0 && "invalid region index");
return operands();
}

View File

@ -473,8 +473,8 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
/// correspond to the loop iterator operands, i.e., those excluding the
/// induction variable. LoopOp only has one region, so 0 is the only valid value
/// for `index`.
OperandRange ForOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0 && "invalid region index");
OperandRange ForOp::getSuccessorEntryOperands(Optional<unsigned> index) {
assert(index && *index == 0 && "invalid region index");
// The initial operands map to the loop arguments after the induction
// variable.
@ -2605,8 +2605,8 @@ LogicalResult ReduceReturnOp::verify() {
// WhileOp
//===----------------------------------------------------------------------===//
OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0 &&
OperandRange WhileOp::getSuccessorEntryOperands(Optional<unsigned> index) {
assert(index && *index == 0 &&
"WhileOp is expected to branch only to the first region");
return getInits();

View File

@ -312,8 +312,9 @@ void transform::SequenceOp::getEffects(
}
}
OperandRange transform::SequenceOp::getSuccessorEntryOperands(unsigned index) {
assert(index == 0 && "unexpected region index");
OperandRange
transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
assert(index && *index == 0 && "unexpected region index");
if (getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),

View File

@ -8,6 +8,7 @@
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"
@ -151,16 +152,7 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
if (regionNo.hasValue()) {
return regionInterface.getSuccessorEntryOperands(regionNo.getValue())
.getTypes();
}
// If the successor of a parent op is the parent itself
// RegionBranchOpInterface does not have an API to query what the entry
// operands will be in that case. Vend out the result types of the op in
// that case so that type checking succeeds for this case.
return op->getResultTypes();
return regionInterface.getSuccessorEntryOperands(regionNo).getTypes();
};
// Verify types along control flow edges originating from the parent.

View File

@ -191,6 +191,31 @@ func.func @region_loop_control_flow(%arg: memref<2xf32>, %loopI0 : index,
// -----
// CHECK-LABEL: Testing : "region_loop_zero_trip_count"
// CHECK-DAG: alloca_1#0 <-> alloca_2#0: NoAlias
// CHECK-DAG: alloca_1#0 <-> for_alloca#0: MustAlias
// CHECK-DAG: alloca_1#0 <-> for_alloca.region0#0: MayAlias
// CHECK-DAG: alloca_1#0 <-> for_alloca.region0#1: MayAlias
// CHECK-DAG: alloca_2#0 <-> for_alloca#0: NoAlias
// CHECK-DAG: alloca_2#0 <-> for_alloca.region0#0: MayAlias
// CHECK-DAG: alloca_2#0 <-> for_alloca.region0#1: MayAlias
// CHECK-DAG: for_alloca#0 <-> for_alloca.region0#0: MayAlias
// CHECK-DAG: for_alloca#0 <-> for_alloca.region0#1: MayAlias
// CHECK-DAG: for_alloca.region0#0 <-> for_alloca.region0#1: MayAlias
func.func @region_loop_zero_trip_count() attributes {test.ptr = "func"} {
%0 = memref.alloca() {test.ptr = "alloca_1"} : memref<i32>
%1 = memref.alloca() {test.ptr = "alloca_2"} : memref<i32>
%result = affine.for %i = 0 to 0 iter_args(%si = %0) -> (memref<i32>) {
affine.yield %si : memref<i32>
} {test.ptr = "for_alloca"}
return
}
// -----
// CHECK-LABEL: Testing : "view_like"
// CHECK-DAG: alloc_1#0 <-> view#0: NoAlias

View File

@ -154,7 +154,7 @@ func.func @loop_region_branch_terminator_op(%arg1 : i32) {
/// interface as well.
// CHECK-LABEL: func @affine_loop_one_iter(
func.func @affine_loop_one_iter(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 {
func.func @affine_loop_one_iter() -> i32 {
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
%s0 = arith.constant 0 : i32
%s1 = arith.constant 1 : i32
@ -167,17 +167,27 @@ func.func @affine_loop_one_iter(%arg0 : index, %arg1 : index, %arg2 : index) ->
}
// CHECK-LABEL: func @affine_loop_zero_iter(
func.func @affine_loop_zero_iter(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 {
// This exposes a crash in sccp/forward data flow analysis: https://github.com/llvm/llvm-project/issues/54928
func.func @affine_loop_zero_iter() -> i32 {
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
%s1 = arith.constant 1 : i32
%result = affine.for %i = 0 to 0 iter_args(%si = %s1) -> (i32) {
%sn = arith.addi %si, %si : i32
affine.yield %sn : i32
}
// CHECK: return %[[C1]] : i32
return %result : i32
}
// CHECK-LABEL: func @affine_loop_unknown_trip_count(
func.func @affine_loop_unknown_trip_count(%ub: index) -> i32 {
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
%s0 = arith.constant 0 : i32
// %result = affine.for %i = 0 to 0 iter_args(%si = %s0) -> (i32) {
// %sn = arith.addi %si, %si : i32
// affine.yield %sn : i32
// }
// return %result : i32
%result = affine.for %i = 0 to %ub iter_args(%si = %s0) -> (i32) {
%sn = arith.addi %si, %si : i32
affine.yield %sn : i32
}
// CHECK: return %[[C0]] : i32
return %s0 : i32
return %result : i32
}
// CHECK-LABEL: func @while_loop_different_arg_count

View File

@ -1301,8 +1301,8 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getCurrentLocation(), result.operands);
}
OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
assert(index < 2 && "invalid region index");
OperandRange RegionIfOp::getSuccessorEntryOperands(Optional<unsigned> index) {
assert(index && *index < 2 && "invalid region index");
return getOperands();
}
@ -1339,7 +1339,7 @@ void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
SmallVectorImpl<RegionSuccessor> &regions) {
// The parent op branches into the only region, and the region branches back
// to the parent op.
if (index)
if (!index)
regions.emplace_back(&getRegion());
else
regions.emplace_back(getResults());

View File

@ -2549,7 +2549,8 @@ def RegionIfOp : TEST_Op<"region_if",
::mlir::Block::BlockArgListType getJoinArgs() {
return getBody(2)->getArguments();
}
::mlir::OperandRange getSuccessorEntryOperands(unsigned index);
::mlir::OperandRange getSuccessorEntryOperands(
::llvm::Optional<unsigned> index);
}];
let hasCustomAssemblyFormat = 1;
}