mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-24 06:10:12 +00:00
[mlir][Interfaces] LoopLikeOpInterface
: Expose tied loop results (#70535)
Expose loop results, which correspond to the region iter_arg values that are returned from the loop when there are no more iterations. Exposing loop results is optional because some loops (e.g., `scf.while`) do not have a 1-to-1 mapping between region iter_args and op results. Also add additional helper functions to query tied results/iter_args/inits.
This commit is contained in:
parent
e599978760
commit
98a6edd38f
@ -269,28 +269,6 @@ def ForOp : SCF_Op<"for",
|
||||
/// Number of operands controlling the loop: lb, ub, step
|
||||
unsigned getNumControlOperands() { return 3; }
|
||||
|
||||
/// Get the OpResult that corresponds to an OpOperand.
|
||||
/// Assert that opOperand is an iterArg.
|
||||
/// This helper prevents internal op implementation detail leakage to
|
||||
/// clients by hiding the operand / block argument mapping.
|
||||
OpResult getResultForOpOperand(OpOperand &opOperand) {
|
||||
assert(opOperand.getOperandNumber() >= getNumControlOperands() &&
|
||||
"expected an iter args operand");
|
||||
assert(opOperand.getOwner() == getOperation() &&
|
||||
"opOperand does not belong to this scf::ForOp operation");
|
||||
return getOperation()->getResult(
|
||||
opOperand.getOperandNumber() - getNumControlOperands());
|
||||
}
|
||||
/// Get the OpOperand& that corresponds to an OpResultOpOperand.
|
||||
/// This helper prevents internal op implementation detail leakage to
|
||||
/// clients by hiding the operand / block argument mapping.
|
||||
OpOperand &getOpOperandForResult(OpResult opResult) {
|
||||
assert(opResult.getDefiningOp() == getOperation() &&
|
||||
"opResult does not belong to the scf::ForOp operation");
|
||||
return getOperation()->getOpOperand(
|
||||
getNumControlOperands() + opResult.getResultNumber());
|
||||
}
|
||||
|
||||
/// Returns the step as an `APInt` if it is constant.
|
||||
std::optional<APInt> getConstantStep();
|
||||
|
||||
@ -942,7 +920,7 @@ def WhileOp : SCF_Op<"while",
|
||||
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
|
||||
["getEntrySuccessorOperands"]>,
|
||||
DeclareOpInterfaceMethods<LoopLikeOpInterface,
|
||||
["getLoopResults", "getRegionIterArgs", "getYieldedValuesMutable"]>,
|
||||
["getRegionIterArgs", "getYieldedValuesMutable"]>,
|
||||
RecursiveMemoryEffects, SingleBlock]> {
|
||||
let summary = "a generic 'while' loop";
|
||||
let description = [{
|
||||
@ -1156,7 +1134,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def YieldOp : SCF_Op<"yield", [Pure, ReturnLike, Terminator,
|
||||
ParentOneOf<["ExecuteRegionOp, ForOp", "IfOp", "IndexSwitchOp",
|
||||
ParentOneOf<["ExecuteRegionOp", "ForOp", "IfOp", "IndexSwitchOp",
|
||||
"ParallelOp", "WhileOp"]>]> {
|
||||
let summary = "loop yield and termination operation";
|
||||
let description = [{
|
||||
|
@ -33,6 +33,13 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
|
||||
If one of the respective interface methods is implemented, so must the other
|
||||
two. The interface verifier ensures that the number of types of the region
|
||||
iter_args, init values and yielded values match.
|
||||
|
||||
Optionally, "loop results" can be exposed through this interface. These are
|
||||
the values that are returned from the loop op when there are no more
|
||||
iterations. The number and types of the loop results must match with the
|
||||
region iter_args. Note: Loop results are optional because some loops
|
||||
(e.g., `scf.while`) may produce results that do match 1-to-1 with the
|
||||
region iter_args.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
|
||||
@ -166,6 +173,26 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
|
||||
return {};
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Return the range of results that are return from this loop and
|
||||
correspond to the "init" operands.
|
||||
|
||||
Note: This interface method is optional. If loop results are not
|
||||
exposed via this interface, "std::nullopt" should be returned.
|
||||
Otherwise, the number and types of results must match with the
|
||||
region iter_args, inits and yielded values that are exposed via this
|
||||
interface. If loop results are exposed but this loop op has no
|
||||
loop-carried variables, an empty result range (and not "std::nullopt")
|
||||
should be returned.
|
||||
}],
|
||||
/*retTy=*/"::std::optional<::mlir::ResultRange>",
|
||||
/*methodName=*/"getLoopResults",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return ::std::nullopt;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Append the specified additional "init" operands: replace this loop with
|
||||
a new loop that has the additional init operands. The loop body of
|
||||
@ -242,6 +269,8 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
|
||||
}
|
||||
|
||||
/// Return the region iter_arg that corresponds to the given init operand.
|
||||
/// Return an "empty" block argument if the given operand is not an init
|
||||
/// operand of this loop op.
|
||||
BlockArgument getTiedLoopRegionIterArg(OpOperand *opOperand) {
|
||||
auto initsMutable = $_op.getInitsMutable();
|
||||
auto it = llvm::find(initsMutable, *opOperand);
|
||||
@ -250,7 +279,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
|
||||
return $_op.getRegionIterArgs()[std::distance(initsMutable.begin(), it)];
|
||||
}
|
||||
|
||||
/// Return the region iter_arg that corresponds to the given loop result.
|
||||
/// Return an "empty" block argument if the given OpResult is not a loop
|
||||
/// result or if this op does not expose any loop results.
|
||||
BlockArgument getTiedLoopRegionIterArg(OpResult opResult) {
|
||||
auto loopResults = $_op.getLoopResults();
|
||||
if (!loopResults)
|
||||
return {};
|
||||
auto it = llvm::find(*loopResults, opResult);
|
||||
if (it == loopResults->end())
|
||||
return {};
|
||||
return $_op.getRegionIterArgs()[std::distance(loopResults->begin(), it)];
|
||||
}
|
||||
|
||||
/// Return the init operand that corresponds to the given region iter_arg.
|
||||
/// Return "nullptr" if the given block argument is not a region iter_arg
|
||||
/// of this loop op.
|
||||
OpOperand *getTiedLoopInit(BlockArgument bbArg) {
|
||||
auto iterArgs = $_op.getRegionIterArgs();
|
||||
auto it = llvm::find(iterArgs, bbArg);
|
||||
@ -259,7 +303,22 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
|
||||
return &$_op.getInitsMutable()[std::distance(iterArgs.begin(), it)];
|
||||
}
|
||||
|
||||
/// Return the init operand that corresponds to the given loop result.
|
||||
/// Return "nullptr" if the given OpResult is not a loop result or if this
|
||||
/// op does not expose any loop results.
|
||||
OpOperand *getTiedLoopInit(OpResult opResult) {
|
||||
auto loopResults = $_op.getLoopResults();
|
||||
if (!loopResults)
|
||||
return nullptr;
|
||||
auto it = llvm::find(*loopResults, opResult);
|
||||
if (it == loopResults->end())
|
||||
return nullptr;
|
||||
return &$_op.getInitsMutable()[std::distance(loopResults->begin(), it)];
|
||||
}
|
||||
|
||||
/// Return the yielded value that corresponds to the given region iter_arg.
|
||||
/// Return "nullptr" if the given block argument is not a region iter_arg
|
||||
/// of this loop op.
|
||||
OpOperand *getTiedLoopYieldedValue(BlockArgument bbArg) {
|
||||
auto iterArgs = $_op.getRegionIterArgs();
|
||||
auto it = llvm::find(iterArgs, bbArg);
|
||||
@ -268,6 +327,34 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
|
||||
return
|
||||
&$_op.getYieldedValuesMutable()[std::distance(iterArgs.begin(), it)];
|
||||
}
|
||||
|
||||
/// Return the loop result that corresponds to the given init operand.
|
||||
/// Return an "empty" OpResult if the given operand is not an init operand
|
||||
/// of this loop op or if this op does not expose any loop results.
|
||||
OpResult getTiedLoopResult(OpOperand *opOperand) {
|
||||
auto loopResults = $_op.getLoopResults();
|
||||
if (!loopResults)
|
||||
return {};
|
||||
auto initsMutable = $_op.getInitsMutable();
|
||||
auto it = llvm::find(initsMutable, *opOperand);
|
||||
if (it == initsMutable.end())
|
||||
return {};
|
||||
return (*loopResults)[std::distance(initsMutable.begin(), it)];
|
||||
}
|
||||
|
||||
/// Return the loop result that corresponds to the given region iter_arg.
|
||||
/// Return an "empty" OpResult if the given block argument is not a region
|
||||
/// iter_arg of this loop op or if this op does not expose any loop results.
|
||||
OpResult getTiedLoopResult(BlockArgument bbArg) {
|
||||
auto loopResults = $_op.getLoopResults();
|
||||
if (!loopResults)
|
||||
return {};
|
||||
auto iterArgs = $_op.getRegionIterArgs();
|
||||
auto it = llvm::find(iterArgs, bbArg);
|
||||
if (it == iterArgs.end())
|
||||
return {};
|
||||
return (*loopResults)[std::distance(iterArgs.begin(), it)];
|
||||
}
|
||||
}];
|
||||
|
||||
let verifyWithRegions = 1;
|
||||
|
@ -810,7 +810,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
|
||||
|
||||
unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
|
||||
unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber();
|
||||
auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
|
||||
.getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!yieldingExtractSliceOp)
|
||||
|
@ -390,6 +390,8 @@ std::optional<OpFoldResult> ForOp::getSingleUpperBound() {
|
||||
return OpFoldResult(getUpperBound());
|
||||
}
|
||||
|
||||
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
|
||||
|
||||
/// Promotes the loop body of a forOp to its containing block if the forOp
|
||||
/// it can be determined that the loop has a single iteration.
|
||||
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
|
||||
|
@ -614,7 +614,7 @@ struct ForOpInterface
|
||||
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
OpResult opResult = forOp.getResultForOpOperand(opOperand);
|
||||
OpResult opResult = forOp.getTiedLoopResult(&opOperand);
|
||||
BufferRelation relation = bufferRelation(op, opResult, state);
|
||||
return {{opResult, relation,
|
||||
/*isDefinite=*/relation == BufferRelation::Equivalent}};
|
||||
@ -625,10 +625,9 @@ struct ForOpInterface
|
||||
// ForOp results are equivalent to their corresponding init_args if the
|
||||
// corresponding iter_args and yield values are equivalent.
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
|
||||
auto bbArg = forOp.getTiedLoopRegionIterArg(&forOperand);
|
||||
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
|
||||
bool equivalentYield = state.areEquivalentBufferizedValues(
|
||||
bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
|
||||
bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get());
|
||||
return equivalentYield ? BufferRelation::Equivalent
|
||||
: BufferRelation::Unknown;
|
||||
}
|
||||
@ -703,16 +702,13 @@ struct ForOpInterface
|
||||
|
||||
if (auto opResult = dyn_cast<OpResult>(value)) {
|
||||
// The type of an OpResult must match the corresponding iter_arg type.
|
||||
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(
|
||||
&forOp.getOpOperandForResult(opResult));
|
||||
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
|
||||
return bufferization::getBufferType(bbArg, options, invocationStack);
|
||||
}
|
||||
|
||||
// Compute result/argument number.
|
||||
BlockArgument bbArg = cast<BlockArgument>(value);
|
||||
unsigned resultNum =
|
||||
forOp.getResultForOpOperand(*forOp.getTiedLoopInit(bbArg))
|
||||
.getResultNumber();
|
||||
unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber();
|
||||
|
||||
// Compute the bufferized type.
|
||||
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
|
@ -609,8 +609,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
|
||||
if (destinationInitArg &&
|
||||
(*destinationInitArg)->getOwner() == outerMostLoop) {
|
||||
unsigned iterArgNumber =
|
||||
outerMostLoop.getResultForOpOperand(**destinationInitArg)
|
||||
.getResultNumber();
|
||||
outerMostLoop.getTiedLoopResult(*destinationInitArg).getResultNumber();
|
||||
int64_t resultNumber = fusableProducer.getResultNumber();
|
||||
if (auto dstOp =
|
||||
dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
|
||||
|
@ -58,7 +58,7 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
|
||||
// but the LoopLikeOpInterface provides better error messages.
|
||||
auto loopLikeOp = cast<LoopLikeOpInterface>(op);
|
||||
|
||||
// Verify number of inits/iter_args/yielded values.
|
||||
// Verify number of inits/iter_args/yielded values/loop results.
|
||||
if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
|
||||
return op->emitOpError("different number of inits and region iter_args: ")
|
||||
<< loopLikeOp.getInits().size()
|
||||
@ -69,21 +69,43 @@ LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
|
||||
"different number of region iter_args and yielded values: ")
|
||||
<< loopLikeOp.getRegionIterArgs().size()
|
||||
<< " != " << loopLikeOp.getYieldedValues().size();
|
||||
if (loopLikeOp.getLoopResults() && loopLikeOp.getLoopResults()->size() !=
|
||||
loopLikeOp.getRegionIterArgs().size())
|
||||
return op->emitOpError(
|
||||
"different number of loop results and region iter_args: ")
|
||||
<< loopLikeOp.getLoopResults()->size()
|
||||
<< " != " << loopLikeOp.getRegionIterArgs().size();
|
||||
|
||||
// Verify types of inits/iter_args/yielded values.
|
||||
// Verify types of inits/iter_args/yielded values/loop results.
|
||||
int64_t i = 0;
|
||||
for (const auto it :
|
||||
llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
|
||||
loopLikeOp.getYieldedValues())) {
|
||||
if (std::get<0>(it).getType() != std::get<1>(it).getType())
|
||||
op->emitOpError(std::to_string(i))
|
||||
<< "-th init and " << i << "-th region iter_arg have different type: "
|
||||
<< std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
|
||||
return op->emitOpError(std::to_string(i))
|
||||
<< "-th init and " << i
|
||||
<< "-th region iter_arg have different type: "
|
||||
<< std::get<0>(it).getType()
|
||||
<< " != " << std::get<1>(it).getType();
|
||||
if (std::get<1>(it).getType() != std::get<2>(it).getType())
|
||||
op->emitOpError(std::to_string(i))
|
||||
<< "-th region iter_arg and " << i
|
||||
<< "-th yielded value have different type: "
|
||||
<< std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
|
||||
return op->emitOpError(std::to_string(i))
|
||||
<< "-th region iter_arg and " << i
|
||||
<< "-th yielded value have different type: "
|
||||
<< std::get<1>(it).getType()
|
||||
<< " != " << std::get<2>(it).getType();
|
||||
++i;
|
||||
}
|
||||
i = 0;
|
||||
if (loopLikeOp.getLoopResults()) {
|
||||
for (const auto it : llvm::zip_equal(loopLikeOp.getRegionIterArgs(),
|
||||
*loopLikeOp.getLoopResults())) {
|
||||
if (std::get<0>(it).getType() != std::get<1>(it).getType())
|
||||
return op->emitOpError(std::to_string(i))
|
||||
<< "-th region iter_arg and " << i
|
||||
<< "-th loop result have different type: "
|
||||
<< std::get<0>(it).getType()
|
||||
<< " != " << std::get<1>(it).getType();
|
||||
}
|
||||
++i;
|
||||
}
|
||||
|
||||
|
@ -96,6 +96,19 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @scf_for_incorrect_result_type(%arg0: index, %init: f32) {
|
||||
// expected-error @below{{0-th region iter_arg and 0-th loop result have different type: 'f32' != 'f64'}}
|
||||
"scf.for"(%arg0, %arg0, %arg0, %init) (
|
||||
{
|
||||
^bb0(%i0 : index, %iter: f32):
|
||||
scf.yield %iter : f32
|
||||
}
|
||||
) : (index, index, index, f32) -> (f64)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @too_many_iter_args(%arg0: index, %init: f32) {
|
||||
// expected-error @below{{different number of inits and region iter_args: 1 != 2}}
|
||||
%x = "scf.for"(%arg0, %arg0, %arg0, %init) (
|
||||
@ -449,7 +462,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
|
||||
%s0 = arith.constant 0.0 : f32
|
||||
%t0 = arith.constant 1.0 : f32
|
||||
// expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}}
|
||||
// expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
|
||||
%result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
|
||||
iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
|
||||
%sn = arith.addf %si, %si : f32
|
||||
|
Loading…
Reference in New Issue
Block a user