[mlir][transform] Add transform.get_operand op (#78397)

Similar to `transform.get_result`, except it returns a handle to the
operand indicated by a positional specification, same as is defined for
the linalg match ops.

Additionally updates `get_result` to take the same positional specification.
This makes the use case of wanting to get all of the results of an
operation easier by no longer requiring the user to reconstruct the list
of results one-by-one.
This commit is contained in:
Quinn Dawkins 2024-01-18 06:33:14 -08:00 committed by GitHub
parent e90e43fb9c
commit 5caab8bbc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 399 additions and 192 deletions

View File

@ -288,7 +288,7 @@ def MatchStructuredDimOp : Op<Transform_Dialect, "match.structured.dim", [
let results = (outs Optional<TransformParamTypeInterface>:$result);
let assemblyFormat =
"$operand_handle `[`"
"custom<StructuredTransformDims>($raw_dim_list, $is_inverted, $is_all)"
"custom<TransformMatchDims>($raw_dim_list, $is_inverted, $is_all)"
"`]` attr-dict `:` "
"custom<SemiFunctionType>(type($operand_handle), type($result))";
@ -347,7 +347,7 @@ class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
(outs Optional<AnyTypeOf<[TransformAnyHandle,Transform_AffineMapParamType]>>:$result);
let assemblyFormat =
"$operand_handle `[`"
"custom<StructuredTransformDims>($raw_position_list, $is_inverted, $is_all)"
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
"`]` attr-dict "
"`:` custom<SemiFunctionType>(type($operand_handle), type($result))";

View File

@ -9,11 +9,12 @@
#ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
#define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
#include <optional>
#include <type_traits>
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/STLExtras.h"
#include <optional>
#include <type_traits>
namespace mlir {
namespace transform {
@ -168,6 +169,52 @@ public:
}
};
//===----------------------------------------------------------------------===//
// Printing/parsing for positional specification matchers
//===----------------------------------------------------------------------===//
/// Parses a positional index specification for transform match operations.
/// The following forms are accepted:
///
/// - `all`: sets `isAll` and returns;
/// - comma-separated-integer-list: populates `rawDimList` with the values;
/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
/// with the values and sets `isInverted`.
ParseResult parseTransformMatchDims(OpAsmParser &parser,
DenseI64ArrayAttr &rawDimList,
UnitAttr &isInverted, UnitAttr &isAll);
/// Prints a positional index specification for transform match operations.
void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
UnitAttr isAll);
//===----------------------------------------------------------------------===//
// Utilities for positional specification matchers
//===----------------------------------------------------------------------===//
/// Checks if the positional specification defined is valid and reports errors
/// otherwise.
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
bool inverted, bool all);
/// Populates `result` with the positional identifiers relative to `maxNumber`.
/// If `isAll` is set, the result will contain all numbers from `0` to
/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
/// values from `rawList` are are interpreted as counting backwards from
/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
/// numbers remain as is. If `isInverted` is set, populates `result` with those
/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
/// `rawList`. If `rawList` contains values that are greater than or equal to
/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
/// given location. `maxNumber` must be positive. If `rawList` contains
/// duplicate numbers or numbers that become duplicate after negative value
/// remapping, emits a silenceable error.
DiagnosedSilenceableFailure
expandTargetSpecification(Location loc, bool isAll, bool isInverted,
ArrayRef<int64_t> rawList, int64_t maxNumber,
SmallVectorImpl<int64_t> &result);
} // namespace transform
} // namespace mlir

View File

@ -728,24 +728,75 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
"functional-type(operands, results)";
}
def GetResultOp : TransformDialectOp<"get_result",
def GetOperandOp : TransformDialectOp<"get_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
let summary = "Get handle to the a result of the targeted op";
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
let summary = "Get a handle to the operand(s) of the targeted op";
let description = [{
The handle defined by this Transform op corresponds to the OpResult with
`result_number` that is defined by the given `target` operation.
The handle defined by this Transform op corresponds to the operands of the
given `target` operation specified by the given set of positions. There are
three possible modes:
This transform produces a silenceable failure if the targeted operation
does not have enough results. It reads the target handle and produces the
result handle.
- Position list directly, i.e. `%target[0, 1, 2]`. This will return the
operands at the specified positions.
- Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
all operands except those at the given positions.
- All, i.e. `%target[all]`. This will return all operands of the operation.
This transform produces a silenceable failure if any of the operand indices
exceeds the number of operands in the target. It reads the target handle and
produces the result handle.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
I64Attr:$result_number);
DenseI64ArrayAttr:$raw_position_list,
UnitAttr:$is_inverted,
UnitAttr:$is_all);
let results = (outs TransformValueHandleTypeInterface:$result);
let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
"functional-type(operands, results)";
let assemblyFormat =
"$target `[`"
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
"`]` attr-dict `:` functional-type(operands, results)";
let hasVerifier = 1;
}
def GetResultOp : TransformDialectOp<"get_result",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
let summary = "Get a handle to the result(s) of the targeted op";
let description = [{
The handle defined by this Transform op correspond to the OpResults of the
given `target` operation. Optionally `result_number` can be specified to
select a specific result.
This transform fails silently if the targeted operation does not have enough
results. It reads the target handle and produces the result handle.
The handle defined by this Transform op corresponds to the results of the
given `target` operation specified by the given set of positions. There are
three possible modes:
- Position list directly, i.e. `%target[0, 1, 2]`. This will return the
results at the specified positions.
- Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
all results except those at the given positions.
- All, i.e. `%target[all]`. This will return all results of the operation.
This transform produces a silenceable failure if any of the result indices
exceeds the number of results returned by the target. It reads the target
handle and produces the result handle.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
DenseI64ArrayAttr:$raw_position_list,
UnitAttr:$is_inverted,
UnitAttr:$is_all);
let results = (outs TransformValueHandleTypeInterface:$result);
let assemblyFormat =
"$target `[`"
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
"`]` attr-dict `:` functional-type(operands, results)";
let hasVerifier = 1;
}
def GetTypeOp : TransformDialectOp<"get_type",

View File

@ -330,91 +330,6 @@ static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
return DiagnosedSilenceableFailure::success();
}
/// Populates `result` with the positional identifiers relative to `maxNumber`.
/// If `isAll` is set, the result will contain all numbers from `0` to
/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
/// values from `rawList` are are interpreted as counting backwards from
/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
/// numbers remain as is. If `isInverted` is set, populates `result` with those
/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
/// `rawList`. If `rawList` contains values that are greater than or equal to
/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
/// given location. `maxNumber` must be positive. If `rawList` contains
/// duplicate numbers or numbers that become duplicate after negative value
/// remapping, emits a silenceable error.
static DiagnosedSilenceableFailure
expandTargetSpecification(Location loc, bool isAll, bool isInverted,
ArrayRef<int64_t> rawList, int64_t maxNumber,
SmallVectorImpl<int64_t> &result) {
assert(maxNumber > 0 && "expected size to be positive");
assert(!(isAll && isInverted) && "cannot invert all");
if (isAll) {
result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
return DiagnosedSilenceableFailure::success();
}
SmallVector<int64_t> expanded;
llvm::SmallDenseSet<int64_t> visited;
expanded.reserve(rawList.size());
SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
for (int64_t raw : rawList) {
int64_t updated = raw < 0 ? maxNumber + raw : raw;
if (updated >= maxNumber) {
return emitSilenceableFailure(loc)
<< "position overflow " << updated << " (updated from " << raw
<< ") for maximum " << maxNumber;
}
if (updated < 0) {
return emitSilenceableFailure(loc) << "position underflow " << updated
<< " (updated from " << raw << ")";
}
if (!visited.insert(updated).second) {
return emitSilenceableFailure(loc) << "repeated position " << updated
<< " (updated from " << raw << ")";
}
target.push_back(updated);
}
if (!isInverted)
return DiagnosedSilenceableFailure::success();
result.reserve(result.size() + (maxNumber - expanded.size()));
for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
if (llvm::is_contained(expanded, candidate))
continue;
result.push_back(candidate);
}
return DiagnosedSilenceableFailure::success();
}
/// Checks if the positional specification defined is valid and reports errors
/// otherwise.
LogicalResult verifyStructuredTransformDimsOp(Operation *op,
ArrayRef<int64_t> raw,
bool inverted, bool all) {
if (all) {
if (inverted) {
return op->emitOpError()
<< "cannot request both 'all' and 'inverted' values in the list";
}
if (!raw.empty()) {
return op->emitOpError()
<< "cannot both request 'all' and specific values in the list";
}
}
if (!all && raw.empty()) {
return op->emitOpError() << "must request specific values in the list if "
"'all' is not specified";
}
SmallVector<int64_t> rawVector = llvm::to_vector(raw);
auto *it = std::unique(rawVector.begin(), rawVector.end());
if (it != rawVector.end())
return op->emitOpError() << "expected the listed values to be unique";
return success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredDimOp
//===----------------------------------------------------------------------===//
@ -475,8 +390,8 @@ LogicalResult transform::MatchStructuredDimOp::verify() {
return emitOpError() << "cannot request the same dimension to be both "
"parallel and reduction";
}
return verifyStructuredTransformDimsOp(getOperation(), getRawDimList(),
getIsInverted(), getIsAll());
return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
@ -592,8 +507,8 @@ LogicalResult verifyStructuredOperandOp(OpTy op) {
LogicalResult transform::MatchStructuredInputOp::verify() {
if (failed(verifyStructuredOperandOp(*this)))
return failure();
return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
@ -665,8 +580,8 @@ DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
LogicalResult transform::MatchStructuredInitOp::verify() {
if (failed(verifyStructuredOperandOp(*this)))
return failure();
return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
@ -793,78 +708,5 @@ void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
build(builder, state, ValueRange());
}
//===----------------------------------------------------------------------===//
// Printing and parsing for structured match ops.
//===----------------------------------------------------------------------===//
/// Keyword syntax for positional specification inversion.
constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
/// Keyword syntax for full inclusion in positional specification.
constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
/// Parses a positional specification for structured transform operations. The
/// following forms are accepted:
///
/// - `all`: sets `isAll` and returns;
/// - comma-separated-integer-list: populates `rawDimList` with the values;
/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
/// with the values and sets `isInverted`.
static ParseResult parseStructuredTransformDims(OpAsmParser &parser,
DenseI64ArrayAttr &rawDimList,
UnitAttr &isInverted,
UnitAttr &isAll) {
Builder &builder = parser.getBuilder();
if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
rawDimList = builder.getDenseI64ArrayAttr({});
isInverted = nullptr;
isAll = builder.getUnitAttr();
return success();
}
isAll = nullptr;
isInverted = nullptr;
if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
isInverted = builder.getUnitAttr();
}
if (isInverted) {
if (parser.parseLParen().failed())
return failure();
}
SmallVector<int64_t> values;
ParseResult listResult = parser.parseCommaSeparatedList(
[&]() { return parser.parseInteger(values.emplace_back()); });
if (listResult.failed())
return failure();
rawDimList = builder.getDenseI64ArrayAttr(values);
if (isInverted) {
if (parser.parseRParen().failed())
return failure();
}
return success();
}
/// Prints a positional specification for structured transform operations.
static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op,
DenseI64ArrayAttr rawDimList,
UnitAttr isInverted, UnitAttr isAll) {
if (isAll) {
printer << kDimAllKeyword;
return;
}
if (isInverted) {
printer << kDimExceptKeyword << "(";
}
llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
[&](int64_t value) { printer << value; });
if (isInverted) {
printer << ")";
}
}
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"

View File

@ -10,6 +10,141 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
// Printing and parsing for match ops.
//===----------------------------------------------------------------------===//
/// Keyword syntax for positional specification inversion.
constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
/// Keyword syntax for full inclusion in positional specification.
constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
ParseResult transform::parseTransformMatchDims(OpAsmParser &parser,
DenseI64ArrayAttr &rawDimList,
UnitAttr &isInverted,
UnitAttr &isAll) {
Builder &builder = parser.getBuilder();
if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
rawDimList = builder.getDenseI64ArrayAttr({});
isInverted = nullptr;
isAll = builder.getUnitAttr();
return success();
}
isAll = nullptr;
isInverted = nullptr;
if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
isInverted = builder.getUnitAttr();
}
if (isInverted) {
if (parser.parseLParen().failed())
return failure();
}
SmallVector<int64_t> values;
ParseResult listResult = parser.parseCommaSeparatedList(
[&]() { return parser.parseInteger(values.emplace_back()); });
if (listResult.failed())
return failure();
rawDimList = builder.getDenseI64ArrayAttr(values);
if (isInverted) {
if (parser.parseRParen().failed())
return failure();
}
return success();
}
void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
DenseI64ArrayAttr rawDimList,
UnitAttr isInverted, UnitAttr isAll) {
if (isAll) {
printer << kDimAllKeyword;
return;
}
if (isInverted) {
printer << kDimExceptKeyword << "(";
}
llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
[&](int64_t value) { printer << value; });
if (isInverted) {
printer << ")";
}
}
LogicalResult transform::verifyTransformMatchDimsOp(Operation *op,
ArrayRef<int64_t> raw,
bool inverted, bool all) {
if (all) {
if (inverted) {
return op->emitOpError()
<< "cannot request both 'all' and 'inverted' values in the list";
}
if (!raw.empty()) {
return op->emitOpError()
<< "cannot both request 'all' and specific values in the list";
}
}
if (!all && raw.empty()) {
return op->emitOpError() << "must request specific values in the list if "
"'all' is not specified";
}
SmallVector<int64_t> rawVector = llvm::to_vector(raw);
auto *it = std::unique(rawVector.begin(), rawVector.end());
if (it != rawVector.end())
return op->emitOpError() << "expected the listed values to be unique";
return success();
}
DiagnosedSilenceableFailure transform::expandTargetSpecification(
Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList,
int64_t maxNumber, SmallVectorImpl<int64_t> &result) {
assert(maxNumber > 0 && "expected size to be positive");
assert(!(isAll && isInverted) && "cannot invert all");
if (isAll) {
result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
return DiagnosedSilenceableFailure::success();
}
SmallVector<int64_t> expanded;
llvm::SmallDenseSet<int64_t> visited;
expanded.reserve(rawList.size());
SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
for (int64_t raw : rawList) {
int64_t updated = raw < 0 ? maxNumber + raw : raw;
if (updated >= maxNumber) {
return emitSilenceableFailure(loc)
<< "position overflow " << updated << " (updated from " << raw
<< ") for maximum " << maxNumber;
}
if (updated < 0) {
return emitSilenceableFailure(loc) << "position underflow " << updated
<< " (updated from " << raw << ")";
}
if (!visited.insert(updated).second) {
return emitSilenceableFailure(loc) << "repeated position " << updated
<< " (updated from " << raw << ")";
}
target.push_back(updated);
}
if (!isInverted)
return DiagnosedSilenceableFailure::success();
result.reserve(result.size() + (maxNumber - expanded.size()));
for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
if (llvm::is_contained(expanded, candidate))
continue;
result.push_back(candidate);
}
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// Generated interface implementation.
//===----------------------------------------------------------------------===//

View File

@ -1464,6 +1464,39 @@ transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// GetOperandOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Value> operands;
for (Operation *target : state.getPayloadOps(getTarget())) {
SmallVector<int64_t> operandPositions;
DiagnosedSilenceableFailure diag = expandTargetSpecification(
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
target->getNumOperands(), operandPositions);
if (diag.isSilenceableFailure()) {
diag.attachNote(target->getLoc())
<< "while considering positions of this payload operation";
return diag;
}
llvm::append_range(operands,
llvm::map_range(operandPositions, [&](int64_t pos) {
return target->getOperand(pos);
}));
}
results.setValues(cast<OpResult>(getResult()), operands);
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::GetOperandOp::verify() {
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// GetResultOp
//===----------------------------------------------------------------------===//
@ -1472,21 +1505,31 @@ DiagnosedSilenceableFailure
transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
int64_t resultNumber = getResultNumber();
SmallVector<Value> opResults;
for (Operation *target : state.getPayloadOps(getTarget())) {
if (resultNumber >= target->getNumResults()) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "targeted op does not have enough results";
diag.attachNote(target->getLoc()) << "target op";
SmallVector<int64_t> resultPositions;
DiagnosedSilenceableFailure diag = expandTargetSpecification(
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
target->getNumResults(), resultPositions);
if (diag.isSilenceableFailure()) {
diag.attachNote(target->getLoc())
<< "while considering positions of this payload operation";
return diag;
}
opResults.push_back(target->getOpResult(resultNumber));
llvm::append_range(opResults,
llvm::map_range(resultPositions, [&](int64_t pos) {
return target->getResult(pos);
}));
}
results.setValues(llvm::cast<OpResult>(getResult()), opResults);
results.setValues(cast<OpResult>(getResult()), opResults);
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::GetResultOp::verify() {
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// GetTypeOp
//===----------------------------------------------------------------------===//

View File

@ -43,7 +43,6 @@ class Handle(ir.Value):
self.parent = parent
self.children = children if children is not None else []
@ir.register_value_caster(AnyOpType.get_static_typeid())
@ir.register_value_caster(OperationType.get_static_typeid())
class OpHandle(Handle):
@ -61,16 +60,16 @@ class OpHandle(Handle):
):
super().__init__(v, parent=parent, children=children)
def get_result(self, idx: int = 0) -> "ValueHandle":
def get_result(self, indices: Sequence[int] = [0]) -> "ValueHandle":
"""
Emits a `transform.GetResultOp`.
Returns a handle to the result of the payload operation at the given
index.
indices.
"""
get_result_op = transform.GetResultOp(
AnyValueType.get(),
self,
idx,
indices,
)
return get_result_op.result

View File

@ -1483,6 +1483,78 @@ module attributes {transform.with_named_sequence} {
// -----
// expected-remark @below {{addi operand}}
// expected-note @below {{value handle points to a block argument #0}}
func.func @get_operand_of_op(%arg0: index, %arg1: index) -> index {
%r = arith.addi %arg0, %arg1 : index
return %r : index
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%operand = transform.get_operand %addi[0] : (!transform.any_op) -> !transform.any_value
transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
transform.yield
}
}
// -----
func.func @get_out_of_bounds_operand_of_op(%arg0: index, %arg1: index) -> index {
// expected-note @below {{while considering positions of this payload operation}}
%r = arith.addi %arg0, %arg1 : index
return %r : index
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{position overflow 2 (updated from 2) for maximum 2}}
%operand = transform.get_operand %addi[2] : (!transform.any_op) -> !transform.any_value
transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
transform.yield
}
}
// -----
// expected-remark @below {{addi operand}}
// expected-note @below {{value handle points to a block argument #1}}
func.func @get_inverted_operand_of_op(%arg0: index, %arg1: index) -> index {
%r = arith.addi %arg0, %arg1 : index
return %r : index
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%operand = transform.get_operand %addi[except(0)] : (!transform.any_op) -> !transform.any_value
transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
transform.yield
}
}
// -----
func.func @get_multiple_operands_of_op(%arg0: index, %arg1: index) -> index {
%r = arith.addi %arg0, %arg1 : index
return %r : index
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%addui = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%operands = transform.get_operand %addui[all] : (!transform.any_op) -> !transform.any_value
%p = transform.num_associations %operands : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{2}}
transform.debug.emit_param_as_remark %p : !transform.param<i64>
transform.yield
}
}
// -----
func.func @get_result_of_op(%arg0: index, %arg1: index) -> index {
// expected-remark @below {{addi result}}
// expected-note @below {{value handle points to an op result #0}}
@ -1502,7 +1574,7 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
// expected-note @below {{target op}}
// expected-note @below {{while considering positions of this payload operation}}
%r = arith.addi %arg0, %arg1 : index
return %r : index
}
@ -1510,7 +1582,7 @@ func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{targeted op does not have enough results}}
// expected-error @below {{position overflow 1 (updated from 1) for maximum 1}}
%result = transform.get_result %addi[1] : (!transform.any_op) -> !transform.any_value
transform.debug.emit_remark_at %result, "addi result" : !transform.any_value
transform.yield
@ -1537,6 +1609,24 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @get_multiple_result_of_op(%arg0: index, %arg1: index) -> (index, i1) {
%r, %b = arith.addui_extended %arg0, %arg1 : index, i1
return %r, %b : index, i1
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%addui = transform.structured.match ops{["arith.addui_extended"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%results = transform.get_result %addui[all] : (!transform.any_op) -> !transform.any_value
%p = transform.num_associations %results : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{2}}
transform.debug.emit_param_as_remark %p : !transform.param<i64>
transform.yield
}
}
// -----
// expected-note @below {{target value}}
func.func @get_result_of_op_bbarg(%arg0: index, %arg1: index) -> index {
%r = arith.addi %arg0, %arg1 : index