[mlir][Interfaces] LoopLikeOpInterface: Expose tied loop results (#70535)

Expose loop results, which correspond to the region iter_arg values that
are returned from the loop when there are no more iterations. Exposing
loop results is optional because some loops (e.g., `scf.while`) do not
have a 1-to-1 mapping between region iter_args and op results.

Also add additional helper functions to query tied
results/iter_args/inits.
This commit is contained in:
Matthias Springer 2023-11-01 08:34:14 +09:00 committed by GitHub
parent e599978760
commit 98a6edd38f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 142 additions and 46 deletions

View File

@ -269,28 +269,6 @@ def ForOp : SCF_Op<"for",
/// Number of operands controlling the loop: lb, ub, step
unsigned getNumControlOperands() { return 3; }
/// Get the OpResult that corresponds to an OpOperand.
/// Assert that opOperand is an iterArg.
/// This helper prevents internal op implementation detail leakage to
/// clients by hiding the operand / block argument mapping.
OpResult getResultForOpOperand(OpOperand &opOperand) {
assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
"expected an iter args operand");
assert(opOperand.getOwner() == getOperation() &&
"opOperand does not belong to this scf::ForOp operation");
return getOperation()->getResult(
opOperand.getOperandNumber() - getNumControlOperands());
}
/// Get the OpOperand& that corresponds to an OpResultOpOperand.
/// This helper prevents internal op implementation detail leakage to
/// clients by hiding the operand / block argument mapping.
OpOperand &getOpOperandForResult(OpResult opResult) {
assert(opResult.getDefiningOp() == getOperation() &&
"opResult does not belong to the scf::ForOp operation");
return getOperation()->getOpOperand(
getNumControlOperands() + opResult.getResultNumber());
}
/// Returns the step as an `APInt` if it is constant.
std::optional<APInt> getConstantStep();
@ -942,7 +920,7 @@ def WhileOp : SCF_Op<"while",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getLoopResults", "getRegionIterArgs", "getYieldedValuesMutable"]>,
["getRegionIterArgs", "getYieldedValuesMutable"]>,
RecursiveMemoryEffects, SingleBlock]> {
let summary = "a generic 'while' loop";
let description = [{
@ -1156,7 +1134,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
//===----------------------------------------------------------------------===//
def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "IndexSwitchOp",
ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
"ParallelOp", "WhileOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{

View File

@ -33,6 +33,13 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
If one of the respective interface methods is implemented, so must the other
two. The interface verifier ensures that the number of types of the region
iter_args, init values and yielded values match.
Optionally, "loop results" can be exposed through this interface. These are
the values that are returned from the loop op when there are no more
iterations. The number and types of the loop results must match with the
region iter_args. Note: Loop results are optional because some loops
(e.g., `scf.while`) may produce results that do match 1-to-1 with the
region iter_args.
}];
let cppNamespace = "::mlir";
@ -166,6 +173,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return {};
}]
>,
InterfaceMethod<[{
Return the range of results that are return from this loop and
correspond to the "init" operands.
Note: This interface method is optional. If loop results are not
exposed via this interface, "std::nullopt" should be returned.
Otherwise, the number and types of results must match with the
region iter_args, inits and yielded values that are exposed via this
interface. If loop results are exposed but this loop op has no
loop-carried variables, an empty result range (and not "std::nullopt")
should be returned.
}],
/*retTy=*/"::std::optional<::mlir::ResultRange>",
/*methodName=*/"getLoopResults",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return ::std::nullopt;
}]
>,
InterfaceMethod<[{
Append the specified additional "init" operands: replace this loop with
a new loop that has the additional init operands. The loop body of
@ -242,6 +269,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
}
/// Return the region iter_arg that corresponds to the given init operand.
/// Return an "empty" block argument if the given operand is not an init
/// operand of this loop op.
BlockArgument getTiedLoopRegionIterArg(OpOperand *opOperand) {
auto initsMutable = $_op.getInitsMutable();
auto it = llvm::find(initsMutable, *opOperand);
@ -250,7 +279,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return $_op.getRegionIterArgs()[std::distance(initsMutable.begin(), it)];
}
/// Return the region iter_arg that corresponds to the given loop result.
/// Return an "empty" block argument if the given OpResult is not a loop
/// result or if this op does not expose any loop results.
BlockArgument getTiedLoopRegionIterArg(OpResult opResult) {
auto loopResults = $_op.getLoopResults();
if (!loopResults)
return {};
auto it = llvm::find(*loopResults, opResult);
if (it == loopResults->end())
return {};
return $_op.getRegionIterArgs()[std::distance(loopResults->begin(), it)];
}
/// Return the init operand that corresponds to the given region iter_arg.
/// Return "nullptr" if the given block argument is not a region iter_arg
/// of this loop op.
OpOperand *getTiedLoopInit(BlockArgument bbArg) {
auto iterArgs = $_op.getRegionIterArgs();
auto it = llvm::find(iterArgs, bbArg);
@ -259,7 +303,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return &$_op.getInitsMutable()[std::distance(iterArgs.begin(), it)];
}
/// Return the init operand that corresponds to the given loop result.
/// Return "nullptr" if the given OpResult is not a loop result or if this
/// op does not expose any loop results.
OpOperand *getTiedLoopInit(OpResult opResult) {
auto loopResults = $_op.getLoopResults();
if (!loopResults)
return nullptr;
auto it = llvm::find(*loopResults, opResult);
if (it == loopResults->end())
return nullptr;
return &$_op.getInitsMutable()[std::distance(loopResults->begin(), it)];
}
/// Return the yielded value that corresponds to the given region iter_arg.
/// Return "nullptr" if the given block argument is not a region iter_arg
/// of this loop op.
OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
auto iterArgs = $_op.getRegionIterArgs();
auto it = llvm::find(iterArgs, bbArg);
@ -268,6 +327,34 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
return
&$_op.getYieldedValuesMutable()[std::distance(iterArgs.begin(), it)];
}
/// Return the loop result that corresponds to the given init operand.
/// Return an "empty" OpResult if the given operand is not an init operand
/// of this loop op or if this op does not expose any loop results.
OpResult getTiedLoopResult(OpOperand *opOperand) {
auto loopResults = $_op.getLoopResults();
if (!loopResults)
return {};
auto initsMutable = $_op.getInitsMutable();
auto it = llvm::find(initsMutable, *opOperand);
if (it == initsMutable.end())
return {};
return (*loopResults)[std::distance(initsMutable.begin(), it)];
}
/// Return the loop result that corresponds to the given region iter_arg.
/// Return an "empty" OpResult if the given block argument is not a region
/// iter_arg of this loop op or if this op does not expose any loop results.
OpResult getTiedLoopResult(BlockArgument bbArg) {
auto loopResults = $_op.getLoopResults();
if (!loopResults)
return {};
auto iterArgs = $_op.getRegionIterArgs();
auto it = llvm::find(iterArgs, bbArg);
if (it == iterArgs.end())
return {};
return (*loopResults)[std::distance(iterArgs.begin(), it)];
}
}];
let verifyWithRegions = 1;

View File

@ -810,7 +810,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();
auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
.getDefiningOp<tensor::ExtractSliceOp>();
if (!yieldingExtractSliceOp)

View File

@ -390,6 +390,8 @@ std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
return OpFoldResult(getUpperBound());
}
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {

View File

@ -614,7 +614,7 @@ struct ForOpInterface
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
auto forOp = cast<scf::ForOp>(op);
OpResult opResult = forOp.getResultForOpOperand(opOperand);
OpResult opResult = forOp.getTiedLoopResult(&opOperand);
BufferRelation relation = bufferRelation(op, opResult, state);
return {{opResult, relation,
/*isDefinite=*/relation == BufferRelation::Equivalent}};
@ -625,10 +625,9 @@ struct ForOpInterface
// ForOp results are equivalent to their corresponding init_args if the
// corresponding iter_args and yield values are equivalent.
auto forOp = cast<scf::ForOp>(op);
OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
auto bbArg = forOp.getTiedLoopRegionIterArg(&forOperand);
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
bool equivalentYield = state.areEquivalentBufferizedValues(
bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
return equivalentYield ? BufferRelation::Equivalent
: BufferRelation::Unknown;
}
@ -703,16 +702,13 @@ struct ForOpInterface
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(
&forOp.getOpOperandForResult(opResult));
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
return bufferization::getBufferType(bbArg, options, invocationStack);
}
// Compute result/argument number.
BlockArgument bbArg = cast<BlockArgument>(value);
unsigned resultNum =
forOp.getResultForOpOperand(*forOp.getTiedLoopInit(bbArg))
.getResultNumber();
unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
// Compute the bufferized type.
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());

View File

@ -609,8 +609,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
if (destinationInitArg &&
(*destinationInitArg)->getOwner() == outerMostLoop) {
unsigned iterArgNumber =
outerMostLoop.getResultForOpOperand(**destinationInitArg)
.getResultNumber();
outerMostLoop.getTiedLoopResult(*destinationInitArg).getResultNumber();
int64_t resultNumber = fusableProducer.getResultNumber();
if (auto dstOp =
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {

View File

@ -58,7 +58,7 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
// but the LoopLikeOpInterface provides better error messages.
auto loopLikeOp = cast<LoopLikeOpInterface>(op);
// Verify number of inits/iter_args/yielded values.
// Verify number of inits/iter_args/yielded values/loop results.
if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
return op->emitOpError("different number of inits and region iter_args: ")
<< loopLikeOp.getInits().size()
@ -69,21 +69,43 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
"different number of region iter_args and yielded values: ")
<< loopLikeOp.getRegionIterArgs().size()
<< " != " << loopLikeOp.getYieldedValues().size();
if (loopLikeOp.getLoopResults() && loopLikeOp.getLoopResults()->size() !=
loopLikeOp.getRegionIterArgs().size())
return op->emitOpError(
"different number of loop results and region iter_args: ")
<< loopLikeOp.getLoopResults()->size()
<< " != " << loopLikeOp.getRegionIterArgs().size();
// Verify types of inits/iter_args/yielded values.
// Verify types of inits/iter_args/yielded values/loop results.
int64_t i = 0;
for (const auto it :
llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
loopLikeOp.getYieldedValues())) {
if (std::get<0>(it).getType() != std::get<1>(it).getType())
op->emitOpError(std::to_string(i))
<< "-th init and " << i << "-th region iter_arg have different type: "
<< std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
return op->emitOpError(std::to_string(i))
<< "-th init and " << i
<< "-th region iter_arg have different type: "
<< std::get<0>(it).getType()
<< " != " << std::get<1>(it).getType();
if (std::get<1>(it).getType() != std::get<2>(it).getType())
op->emitOpError(std::to_string(i))
<< "-th region iter_arg and " << i
<< "-th yielded value have different type: "
<< std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
return op->emitOpError(std::to_string(i))
<< "-th region iter_arg and " << i
<< "-th yielded value have different type: "
<< std::get<1>(it).getType()
<< " != " << std::get<2>(it).getType();
++i;
}
i = 0;
if (loopLikeOp.getLoopResults()) {
for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
*loopLikeOp.getLoopResults())) {
if (std::get<0>(it).getType() != std::get<1>(it).getType())
return op->emitOpError(std::to_string(i))
<< "-th region iter_arg and " << i
<< "-th loop result have different type: "
<< std::get<0>(it).getType()
<< " != " << std::get<1>(it).getType();
}
++i;
}

View File

@ -96,6 +96,19 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
// -----
func.func @scf_for_incorrect_result_type(%arg0: index, %init: f32) {
// expected-error @below{{0-th region iter_arg and 0-th loop result have different type: 'f32' != 'f64'}}
"scf.for"(%arg0, %arg0, %arg0, %init) (
{
^bb0(%i0 : index, %iter: f32):
scf.yield %iter : f32
}
) : (index, index, index, f32) -> (f64)
return
}
// -----
func.func @too_many_iter_args(%arg0: index, %init: f32) {
// expected-error @below{{different number of inits and region iter_args: 1 != 2}}
%x = "scf.for"(%arg0, %arg0, %arg0, %init) (
@ -449,7 +462,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
%s0 = arith.constant 0.0 : f32
%t0 = arith.constant 1.0 : f32
// expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}}
// expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
%result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
%sn = arith.addf %si, %si : f32