mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-26 23:21:11 +00:00
[mlir][transform] Add transform.get_operand op (#78397)
Similar to `transform.get_result`, except it returns a handle to the operand indicated by a positional specification, same as is defined for the linalg match ops. Additionally updates `get_result` to take the same positional specification. This makes the use case of wanting to get all of the results of an operation easier by no longer requiring the user to reconstruct the list of results one-by-one.
This commit is contained in:
parent
e90e43fb9c
commit
5caab8bbc0
@ -288,7 +288,7 @@ def MatchStructuredDimOp : Op<Transform_Dialect, "match.structured.dim", [
|
||||
let results = (outs Optional<TransformParamTypeInterface>:$result);
|
||||
let assemblyFormat =
|
||||
"$operand_handle `[`"
|
||||
"custom<StructuredTransformDims>($raw_dim_list, $is_inverted, $is_all)"
|
||||
"custom<TransformMatchDims>($raw_dim_list, $is_inverted, $is_all)"
|
||||
"`]` attr-dict `:` "
|
||||
"custom<SemiFunctionType>(type($operand_handle), type($result))";
|
||||
|
||||
@ -347,7 +347,7 @@ class MatchStructuredOperandOp<string opname> : Op<Transform_Dialect, opname, [
|
||||
(outs Optional<AnyTypeOf<[TransformAnyHandle,Transform_AffineMapParamType]>>:$result);
|
||||
let assemblyFormat =
|
||||
"$operand_handle `[`"
|
||||
"custom<StructuredTransformDims>($raw_position_list, $is_inverted, $is_all)"
|
||||
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
|
||||
"`]` attr-dict "
|
||||
"`:` custom<SemiFunctionType>(type($operand_handle), type($result))";
|
||||
|
||||
|
@ -9,11 +9,12 @@
|
||||
#ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
|
||||
#define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
|
||||
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace mlir {
|
||||
namespace transform {
|
||||
@ -168,6 +169,52 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for positional specification matchers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Parses a positional index specification for transform match operations.
|
||||
/// The following forms are accepted:
|
||||
///
|
||||
/// - `all`: sets `isAll` and returns;
|
||||
/// - comma-separated-integer-list: populates `rawDimList` with the values;
|
||||
/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
|
||||
/// with the values and sets `isInverted`.
|
||||
ParseResult parseTransformMatchDims(OpAsmParser &parser,
|
||||
DenseI64ArrayAttr &rawDimList,
|
||||
UnitAttr &isInverted, UnitAttr &isAll);
|
||||
|
||||
/// Prints a positional index specification for transform match operations.
|
||||
void printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
|
||||
DenseI64ArrayAttr rawDimList, UnitAttr isInverted,
|
||||
UnitAttr isAll);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities for positional specification matchers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks if the positional specification defined is valid and reports errors
|
||||
/// otherwise.
|
||||
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef<int64_t> raw,
|
||||
bool inverted, bool all);
|
||||
|
||||
/// Populates `result` with the positional identifiers relative to `maxNumber`.
|
||||
/// If `isAll` is set, the result will contain all numbers from `0` to
|
||||
/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
|
||||
/// values from `rawList` are are interpreted as counting backwards from
|
||||
/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
|
||||
/// numbers remain as is. If `isInverted` is set, populates `result` with those
|
||||
/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
|
||||
/// `rawList`. If `rawList` contains values that are greater than or equal to
|
||||
/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
|
||||
/// given location. `maxNumber` must be positive. If `rawList` contains
|
||||
/// duplicate numbers or numbers that become duplicate after negative value
|
||||
/// remapping, emits a silenceable error.
|
||||
DiagnosedSilenceableFailure
|
||||
expandTargetSpecification(Location loc, bool isAll, bool isInverted,
|
||||
ArrayRef<int64_t> rawList, int64_t maxNumber,
|
||||
SmallVectorImpl<int64_t> &result);
|
||||
|
||||
} // namespace transform
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -728,24 +728,75 @@ def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
|
||||
"functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def GetResultOp : TransformDialectOp<"get_result",
|
||||
def GetOperandOp : TransformDialectOp<"get_operand",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
|
||||
let summary = "Get handle to the a result of the targeted op";
|
||||
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
|
||||
let summary = "Get a handle to the operand(s) of the targeted op";
|
||||
let description = [{
|
||||
The handle defined by this Transform op corresponds to the OpResult with
|
||||
`result_number` that is defined by the given `target` operation.
|
||||
The handle defined by this Transform op corresponds to the operands of the
|
||||
given `target` operation specified by the given set of positions. There are
|
||||
three possible modes:
|
||||
|
||||
This transform produces a silenceable failure if the targeted operation
|
||||
does not have enough results. It reads the target handle and produces the
|
||||
result handle.
|
||||
- Position list directly, i.e. `%target[0, 1, 2]`. This will return the
|
||||
operands at the specified positions.
|
||||
- Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
|
||||
all operands except those at the given positions.
|
||||
- All, i.e. `%target[all]`. This will return all operands of the operation.
|
||||
|
||||
This transform produces a silenceable failure if any of the operand indices
|
||||
exceeds the number of operands in the target. It reads the target handle and
|
||||
produces the result handle.
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target,
|
||||
I64Attr:$result_number);
|
||||
DenseI64ArrayAttr:$raw_position_list,
|
||||
UnitAttr:$is_inverted,
|
||||
UnitAttr:$is_all);
|
||||
let results = (outs TransformValueHandleTypeInterface:$result);
|
||||
let assemblyFormat = "$target `[` $result_number `]` attr-dict `:` "
|
||||
"functional-type(operands, results)";
|
||||
let assemblyFormat =
|
||||
"$target `[`"
|
||||
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
|
||||
"`]` attr-dict `:` functional-type(operands, results)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def GetResultOp : TransformDialectOp<"get_result",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
|
||||
let summary = "Get a handle to the result(s) of the targeted op";
|
||||
let description = [{
|
||||
The handle defined by this Transform op correspond to the OpResults of the
|
||||
given `target` operation. Optionally `result_number` can be specified to
|
||||
select a specific result.
|
||||
|
||||
This transform fails silently if the targeted operation does not have enough
|
||||
results. It reads the target handle and produces the result handle.
|
||||
|
||||
The handle defined by this Transform op corresponds to the results of the
|
||||
given `target` operation specified by the given set of positions. There are
|
||||
three possible modes:
|
||||
|
||||
- Position list directly, i.e. `%target[0, 1, 2]`. This will return the
|
||||
results at the specified positions.
|
||||
- Inverted position list, i.e. `%target[except(0, 1, 2)]`. This will return
|
||||
all results except those at the given positions.
|
||||
- All, i.e. `%target[all]`. This will return all results of the operation.
|
||||
|
||||
This transform produces a silenceable failure if any of the result indices
|
||||
exceeds the number of results returned by the target. It reads the target
|
||||
handle and produces the result handle.
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$target,
|
||||
DenseI64ArrayAttr:$raw_position_list,
|
||||
UnitAttr:$is_inverted,
|
||||
UnitAttr:$is_all);
|
||||
let results = (outs TransformValueHandleTypeInterface:$result);
|
||||
let assemblyFormat =
|
||||
"$target `[`"
|
||||
"custom<TransformMatchDims>($raw_position_list, $is_inverted, $is_all)"
|
||||
"`]` attr-dict `:` functional-type(operands, results)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def GetTypeOp : TransformDialectOp<"get_type",
|
||||
|
@ -330,91 +330,6 @@ static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
/// Populates `result` with the positional identifiers relative to `maxNumber`.
|
||||
/// If `isAll` is set, the result will contain all numbers from `0` to
|
||||
/// `maxNumber - 1` inclusive regardless of `rawList`. Otherwise, negative
|
||||
/// values from `rawList` are are interpreted as counting backwards from
|
||||
/// `maxNumber`, i.e., `-1` is interpreted a `maxNumber - 1`, while positive
|
||||
/// numbers remain as is. If `isInverted` is set, populates `result` with those
|
||||
/// values from the `0` to `maxNumber - 1` inclusive range that don't appear in
|
||||
/// `rawList`. If `rawList` contains values that are greater than or equal to
|
||||
/// `maxNumber` or less than `-maxNumber`, produces a silenceable error at the
|
||||
/// given location. `maxNumber` must be positive. If `rawList` contains
|
||||
/// duplicate numbers or numbers that become duplicate after negative value
|
||||
/// remapping, emits a silenceable error.
|
||||
static DiagnosedSilenceableFailure
|
||||
expandTargetSpecification(Location loc, bool isAll, bool isInverted,
|
||||
ArrayRef<int64_t> rawList, int64_t maxNumber,
|
||||
SmallVectorImpl<int64_t> &result) {
|
||||
assert(maxNumber > 0 && "expected size to be positive");
|
||||
assert(!(isAll && isInverted) && "cannot invert all");
|
||||
if (isAll) {
|
||||
result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> expanded;
|
||||
llvm::SmallDenseSet<int64_t> visited;
|
||||
expanded.reserve(rawList.size());
|
||||
SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
|
||||
for (int64_t raw : rawList) {
|
||||
int64_t updated = raw < 0 ? maxNumber + raw : raw;
|
||||
if (updated >= maxNumber) {
|
||||
return emitSilenceableFailure(loc)
|
||||
<< "position overflow " << updated << " (updated from " << raw
|
||||
<< ") for maximum " << maxNumber;
|
||||
}
|
||||
if (updated < 0) {
|
||||
return emitSilenceableFailure(loc) << "position underflow " << updated
|
||||
<< " (updated from " << raw << ")";
|
||||
}
|
||||
if (!visited.insert(updated).second) {
|
||||
return emitSilenceableFailure(loc) << "repeated position " << updated
|
||||
<< " (updated from " << raw << ")";
|
||||
}
|
||||
target.push_back(updated);
|
||||
}
|
||||
|
||||
if (!isInverted)
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
|
||||
result.reserve(result.size() + (maxNumber - expanded.size()));
|
||||
for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
|
||||
if (llvm::is_contained(expanded, candidate))
|
||||
continue;
|
||||
result.push_back(candidate);
|
||||
}
|
||||
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
/// Checks if the positional specification defined is valid and reports errors
|
||||
/// otherwise.
|
||||
LogicalResult verifyStructuredTransformDimsOp(Operation *op,
|
||||
ArrayRef<int64_t> raw,
|
||||
bool inverted, bool all) {
|
||||
if (all) {
|
||||
if (inverted) {
|
||||
return op->emitOpError()
|
||||
<< "cannot request both 'all' and 'inverted' values in the list";
|
||||
}
|
||||
if (!raw.empty()) {
|
||||
return op->emitOpError()
|
||||
<< "cannot both request 'all' and specific values in the list";
|
||||
}
|
||||
}
|
||||
if (!all && raw.empty()) {
|
||||
return op->emitOpError() << "must request specific values in the list if "
|
||||
"'all' is not specified";
|
||||
}
|
||||
SmallVector<int64_t> rawVector = llvm::to_vector(raw);
|
||||
auto *it = std::unique(rawVector.begin(), rawVector.end());
|
||||
if (it != rawVector.end())
|
||||
return op->emitOpError() << "expected the listed values to be unique";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatchStructuredDimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -475,8 +390,8 @@ LogicalResult transform::MatchStructuredDimOp::verify() {
|
||||
return emitOpError() << "cannot request the same dimension to be both "
|
||||
"parallel and reduction";
|
||||
}
|
||||
return verifyStructuredTransformDimsOp(getOperation(), getRawDimList(),
|
||||
getIsInverted(), getIsAll());
|
||||
return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
|
||||
getIsInverted(), getIsAll());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -592,8 +507,8 @@ LogicalResult verifyStructuredOperandOp(OpTy op) {
|
||||
LogicalResult transform::MatchStructuredInputOp::verify() {
|
||||
if (failed(verifyStructuredOperandOp(*this)))
|
||||
return failure();
|
||||
return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
|
||||
getIsInverted(), getIsAll());
|
||||
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
|
||||
getIsInverted(), getIsAll());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -665,8 +580,8 @@ DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
|
||||
LogicalResult transform::MatchStructuredInitOp::verify() {
|
||||
if (failed(verifyStructuredOperandOp(*this)))
|
||||
return failure();
|
||||
return verifyStructuredTransformDimsOp(getOperation(), getRawPositionList(),
|
||||
getIsInverted(), getIsAll());
|
||||
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
|
||||
getIsInverted(), getIsAll());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -793,78 +708,5 @@ void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
|
||||
build(builder, state, ValueRange());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing and parsing for structured match ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Keyword syntax for positional specification inversion.
|
||||
constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
|
||||
|
||||
/// Keyword syntax for full inclusion in positional specification.
|
||||
constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
|
||||
|
||||
/// Parses a positional specification for structured transform operations. The
|
||||
/// following forms are accepted:
|
||||
///
|
||||
/// - `all`: sets `isAll` and returns;
|
||||
/// - comma-separated-integer-list: populates `rawDimList` with the values;
|
||||
/// - `except` `(` comma-separated-integer-list `)`: populates `rawDimList`
|
||||
/// with the values and sets `isInverted`.
|
||||
static ParseResult parseStructuredTransformDims(OpAsmParser &parser,
|
||||
DenseI64ArrayAttr &rawDimList,
|
||||
UnitAttr &isInverted,
|
||||
UnitAttr &isAll) {
|
||||
Builder &builder = parser.getBuilder();
|
||||
if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
|
||||
rawDimList = builder.getDenseI64ArrayAttr({});
|
||||
isInverted = nullptr;
|
||||
isAll = builder.getUnitAttr();
|
||||
return success();
|
||||
}
|
||||
|
||||
isAll = nullptr;
|
||||
isInverted = nullptr;
|
||||
if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
|
||||
isInverted = builder.getUnitAttr();
|
||||
}
|
||||
|
||||
if (isInverted) {
|
||||
if (parser.parseLParen().failed())
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> values;
|
||||
ParseResult listResult = parser.parseCommaSeparatedList(
|
||||
[&]() { return parser.parseInteger(values.emplace_back()); });
|
||||
if (listResult.failed())
|
||||
return failure();
|
||||
|
||||
rawDimList = builder.getDenseI64ArrayAttr(values);
|
||||
|
||||
if (isInverted) {
|
||||
if (parser.parseRParen().failed())
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Prints a positional specification for structured transform operations.
|
||||
static void printStructuredTransformDims(OpAsmPrinter &printer, Operation *op,
|
||||
DenseI64ArrayAttr rawDimList,
|
||||
UnitAttr isInverted, UnitAttr isAll) {
|
||||
if (isAll) {
|
||||
printer << kDimAllKeyword;
|
||||
return;
|
||||
}
|
||||
if (isInverted) {
|
||||
printer << kDimExceptKeyword << "(";
|
||||
}
|
||||
llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
|
||||
[&](int64_t value) { printer << value; });
|
||||
if (isInverted) {
|
||||
printer << ")";
|
||||
}
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
|
||||
|
@ -10,6 +10,141 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing and parsing for match ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Keyword syntax for positional specification inversion.
|
||||
constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
|
||||
|
||||
/// Keyword syntax for full inclusion in positional specification.
|
||||
constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
|
||||
|
||||
ParseResult transform::parseTransformMatchDims(OpAsmParser &parser,
|
||||
DenseI64ArrayAttr &rawDimList,
|
||||
UnitAttr &isInverted,
|
||||
UnitAttr &isAll) {
|
||||
Builder &builder = parser.getBuilder();
|
||||
if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
|
||||
rawDimList = builder.getDenseI64ArrayAttr({});
|
||||
isInverted = nullptr;
|
||||
isAll = builder.getUnitAttr();
|
||||
return success();
|
||||
}
|
||||
|
||||
isAll = nullptr;
|
||||
isInverted = nullptr;
|
||||
if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
|
||||
isInverted = builder.getUnitAttr();
|
||||
}
|
||||
|
||||
if (isInverted) {
|
||||
if (parser.parseLParen().failed())
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> values;
|
||||
ParseResult listResult = parser.parseCommaSeparatedList(
|
||||
[&]() { return parser.parseInteger(values.emplace_back()); });
|
||||
if (listResult.failed())
|
||||
return failure();
|
||||
|
||||
rawDimList = builder.getDenseI64ArrayAttr(values);
|
||||
|
||||
if (isInverted) {
|
||||
if (parser.parseRParen().failed())
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
|
||||
DenseI64ArrayAttr rawDimList,
|
||||
UnitAttr isInverted, UnitAttr isAll) {
|
||||
if (isAll) {
|
||||
printer << kDimAllKeyword;
|
||||
return;
|
||||
}
|
||||
if (isInverted) {
|
||||
printer << kDimExceptKeyword << "(";
|
||||
}
|
||||
llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
|
||||
[&](int64_t value) { printer << value; });
|
||||
if (isInverted) {
|
||||
printer << ")";
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult transform::verifyTransformMatchDimsOp(Operation *op,
|
||||
ArrayRef<int64_t> raw,
|
||||
bool inverted, bool all) {
|
||||
if (all) {
|
||||
if (inverted) {
|
||||
return op->emitOpError()
|
||||
<< "cannot request both 'all' and 'inverted' values in the list";
|
||||
}
|
||||
if (!raw.empty()) {
|
||||
return op->emitOpError()
|
||||
<< "cannot both request 'all' and specific values in the list";
|
||||
}
|
||||
}
|
||||
if (!all && raw.empty()) {
|
||||
return op->emitOpError() << "must request specific values in the list if "
|
||||
"'all' is not specified";
|
||||
}
|
||||
SmallVector<int64_t> rawVector = llvm::to_vector(raw);
|
||||
auto *it = std::unique(rawVector.begin(), rawVector.end());
|
||||
if (it != rawVector.end())
|
||||
return op->emitOpError() << "expected the listed values to be unique";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure transform::expandTargetSpecification(
|
||||
Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList,
|
||||
int64_t maxNumber, SmallVectorImpl<int64_t> &result) {
|
||||
assert(maxNumber > 0 && "expected size to be positive");
|
||||
assert(!(isAll && isInverted) && "cannot invert all");
|
||||
if (isAll) {
|
||||
result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> expanded;
|
||||
llvm::SmallDenseSet<int64_t> visited;
|
||||
expanded.reserve(rawList.size());
|
||||
SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
|
||||
for (int64_t raw : rawList) {
|
||||
int64_t updated = raw < 0 ? maxNumber + raw : raw;
|
||||
if (updated >= maxNumber) {
|
||||
return emitSilenceableFailure(loc)
|
||||
<< "position overflow " << updated << " (updated from " << raw
|
||||
<< ") for maximum " << maxNumber;
|
||||
}
|
||||
if (updated < 0) {
|
||||
return emitSilenceableFailure(loc) << "position underflow " << updated
|
||||
<< " (updated from " << raw << ")";
|
||||
}
|
||||
if (!visited.insert(updated).second) {
|
||||
return emitSilenceableFailure(loc) << "repeated position " << updated
|
||||
<< " (updated from " << raw << ")";
|
||||
}
|
||||
target.push_back(updated);
|
||||
}
|
||||
|
||||
if (!isInverted)
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
|
||||
result.reserve(result.size() + (maxNumber - expanded.size()));
|
||||
for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
|
||||
if (llvm::is_contained(expanded, candidate))
|
||||
continue;
|
||||
result.push_back(candidate);
|
||||
}
|
||||
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Generated interface implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1464,6 +1464,39 @@ transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetOperandOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
SmallVector<Value> operands;
|
||||
for (Operation *target : state.getPayloadOps(getTarget())) {
|
||||
SmallVector<int64_t> operandPositions;
|
||||
DiagnosedSilenceableFailure diag = expandTargetSpecification(
|
||||
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
|
||||
target->getNumOperands(), operandPositions);
|
||||
if (diag.isSilenceableFailure()) {
|
||||
diag.attachNote(target->getLoc())
|
||||
<< "while considering positions of this payload operation";
|
||||
return diag;
|
||||
}
|
||||
llvm::append_range(operands,
|
||||
llvm::map_range(operandPositions, [&](int64_t pos) {
|
||||
return target->getOperand(pos);
|
||||
}));
|
||||
}
|
||||
results.setValues(cast<OpResult>(getResult()), operands);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
LogicalResult transform::GetOperandOp::verify() {
|
||||
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
|
||||
getIsInverted(), getIsAll());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetResultOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1472,21 +1505,31 @@ DiagnosedSilenceableFailure
|
||||
transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
int64_t resultNumber = getResultNumber();
|
||||
SmallVector<Value> opResults;
|
||||
for (Operation *target : state.getPayloadOps(getTarget())) {
|
||||
if (resultNumber >= target->getNumResults()) {
|
||||
DiagnosedSilenceableFailure diag =
|
||||
emitSilenceableError() << "targeted op does not have enough results";
|
||||
diag.attachNote(target->getLoc()) << "target op";
|
||||
SmallVector<int64_t> resultPositions;
|
||||
DiagnosedSilenceableFailure diag = expandTargetSpecification(
|
||||
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
|
||||
target->getNumResults(), resultPositions);
|
||||
if (diag.isSilenceableFailure()) {
|
||||
diag.attachNote(target->getLoc())
|
||||
<< "while considering positions of this payload operation";
|
||||
return diag;
|
||||
}
|
||||
opResults.push_back(target->getOpResult(resultNumber));
|
||||
llvm::append_range(opResults,
|
||||
llvm::map_range(resultPositions, [&](int64_t pos) {
|
||||
return target->getResult(pos);
|
||||
}));
|
||||
}
|
||||
results.setValues(llvm::cast<OpResult>(getResult()), opResults);
|
||||
results.setValues(cast<OpResult>(getResult()), opResults);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
LogicalResult transform::GetResultOp::verify() {
|
||||
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
|
||||
getIsInverted(), getIsAll());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetTypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -43,7 +43,6 @@ class Handle(ir.Value):
|
||||
self.parent = parent
|
||||
self.children = children if children is not None else []
|
||||
|
||||
|
||||
@ir.register_value_caster(AnyOpType.get_static_typeid())
|
||||
@ir.register_value_caster(OperationType.get_static_typeid())
|
||||
class OpHandle(Handle):
|
||||
@ -61,16 +60,16 @@ class OpHandle(Handle):
|
||||
):
|
||||
super().__init__(v, parent=parent, children=children)
|
||||
|
||||
def get_result(self, idx: int = 0) -> "ValueHandle":
|
||||
def get_result(self, indices: Sequence[int] = [0]) -> "ValueHandle":
|
||||
"""
|
||||
Emits a `transform.GetResultOp`.
|
||||
Returns a handle to the result of the payload operation at the given
|
||||
index.
|
||||
indices.
|
||||
"""
|
||||
get_result_op = transform.GetResultOp(
|
||||
AnyValueType.get(),
|
||||
self,
|
||||
idx,
|
||||
indices,
|
||||
)
|
||||
return get_result_op.result
|
||||
|
||||
|
@ -1483,6 +1483,78 @@ module attributes {transform.with_named_sequence} {
|
||||
|
||||
// -----
|
||||
|
||||
// expected-remark @below {{addi operand}}
|
||||
// expected-note @below {{value handle points to a block argument #0}}
|
||||
func.func @get_operand_of_op(%arg0: index, %arg1: index) -> index {
|
||||
%r = arith.addi %arg0, %arg1 : index
|
||||
return %r : index
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
||||
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%operand = transform.get_operand %addi[0] : (!transform.any_op) -> !transform.any_value
|
||||
transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @get_out_of_bounds_operand_of_op(%arg0: index, %arg1: index) -> index {
|
||||
// expected-note @below {{while considering positions of this payload operation}}
|
||||
%r = arith.addi %arg0, %arg1 : index
|
||||
return %r : index
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
||||
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error @below {{position overflow 2 (updated from 2) for maximum 2}}
|
||||
%operand = transform.get_operand %addi[2] : (!transform.any_op) -> !transform.any_value
|
||||
transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-remark @below {{addi operand}}
|
||||
// expected-note @below {{value handle points to a block argument #1}}
|
||||
func.func @get_inverted_operand_of_op(%arg0: index, %arg1: index) -> index {
|
||||
%r = arith.addi %arg0, %arg1 : index
|
||||
return %r : index
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
||||
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%operand = transform.get_operand %addi[except(0)] : (!transform.any_op) -> !transform.any_value
|
||||
transform.debug.emit_remark_at %operand, "addi operand" : !transform.any_value
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @get_multiple_operands_of_op(%arg0: index, %arg1: index) -> index {
|
||||
%r = arith.addi %arg0, %arg1 : index
|
||||
return %r : index
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
||||
%addui = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%operands = transform.get_operand %addui[all] : (!transform.any_op) -> !transform.any_value
|
||||
%p = transform.num_associations %operands : (!transform.any_value) -> !transform.param<i64>
|
||||
// expected-remark @below {{2}}
|
||||
transform.debug.emit_param_as_remark %p : !transform.param<i64>
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @get_result_of_op(%arg0: index, %arg1: index) -> index {
|
||||
// expected-remark @below {{addi result}}
|
||||
// expected-note @below {{value handle points to an op result #0}}
|
||||
@ -1502,7 +1574,7 @@ module attributes {transform.with_named_sequence} {
|
||||
// -----
|
||||
|
||||
func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
|
||||
// expected-note @below {{target op}}
|
||||
// expected-note @below {{while considering positions of this payload operation}}
|
||||
%r = arith.addi %arg0, %arg1 : index
|
||||
return %r : index
|
||||
}
|
||||
@ -1510,7 +1582,7 @@ func.func @get_out_of_bounds_result_of_op(%arg0: index, %arg1: index) -> index {
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
||||
%addi = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
// expected-error @below {{targeted op does not have enough results}}
|
||||
// expected-error @below {{position overflow 1 (updated from 1) for maximum 1}}
|
||||
%result = transform.get_result %addi[1] : (!transform.any_op) -> !transform.any_value
|
||||
transform.debug.emit_remark_at %result, "addi result" : !transform.any_value
|
||||
transform.yield
|
||||
@ -1537,6 +1609,24 @@ module attributes {transform.with_named_sequence} {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @get_multiple_result_of_op(%arg0: index, %arg1: index) -> (index, i1) {
|
||||
%r, %b = arith.addui_extended %arg0, %arg1 : index, i1
|
||||
return %r, %b : index, i1
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
||||
%addui = transform.structured.match ops{["arith.addui_extended"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%results = transform.get_result %addui[all] : (!transform.any_op) -> !transform.any_value
|
||||
%p = transform.num_associations %results : (!transform.any_value) -> !transform.param<i64>
|
||||
// expected-remark @below {{2}}
|
||||
transform.debug.emit_param_as_remark %p : !transform.param<i64>
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-note @below {{target value}}
|
||||
func.func @get_result_of_op_bbarg(%arg0: index, %arg1: index) -> index {
|
||||
%r = arith.addi %arg0, %arg1 : index
|
||||
|
Loading…
Reference in New Issue
Block a user