mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-11-30 09:01:19 +00:00
[mlir] introduce transform.collect_matching (#76724)
Introduce a new match combinator into the transform dialect. This operation collects all operations that are yielded by a satisfactory match into its results. This is a simpler version of `foreach_match` that can be inserted directly into existing transform scripts.
This commit is contained in:
parent
4f7c402d9f
commit
633d9184f5
@ -460,6 +460,39 @@ def NumAssociationsOp : TransformDialectOp<"num_associations",
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def CollectMatchingOp : TransformDialectOp<"collect_matching", [
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>]> {
|
||||
let summary = "Collects all payload ops that match the given named matcher";
|
||||
let description = [{
|
||||
Collects operations or other payload IR objects nested under `root`
|
||||
(inclusive) that match the given matcher expressed as a named sequence. The
|
||||
matcher sequence must accept exactly one argument that it is not allowed to
|
||||
modify. It must yield as many values as this op has results. Each of the
|
||||
yielded values must be associated with exactly one payload object. If any
|
||||
operation in the matcher sequence produces a silenceable failure, the
|
||||
matcher advances to the next payload operation in the walk order without
|
||||
finishing the sequence.
|
||||
|
||||
The i-th result of this operation is constructed by concatenating the i-th
|
||||
yielded payload IR objects of all successful matcher sequence applications.
|
||||
All results are guaranteed to be mapped to the same number of payload IR
|
||||
objects.
|
||||
|
||||
The operation succeeds unless the matcher sequence produced a definite
|
||||
failure for any invocation.
|
||||
}];
|
||||
|
||||
let arguments = (ins TransformHandleTypeInterface:$root,
|
||||
SymbolRefAttr:$matcher);
|
||||
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$matcher `in` $root attr-dict `:` functional-type($root, $results)
|
||||
}];
|
||||
}
|
||||
|
||||
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
|
||||
@ -674,7 +707,7 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
|
||||
|
||||
def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
|
||||
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
|
||||
let summary = "Get handle to the producer of this operation's operand number";
|
||||
let description = [{
|
||||
The handle defined by this Transform op corresponds to operation that
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/FunctionImplementation.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ForeachMatchOp
|
||||
// CollectMatchingOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Applies matcher operations from the given `block` assigning `op` as the
|
||||
@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
/// Returns `true` if both types implement one of the interfaces provided as
|
||||
/// template parameters.
|
||||
template <typename... Tys>
|
||||
static bool implementSameInterface(Type t1, Type t2) {
|
||||
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
|
||||
}
|
||||
|
||||
/// Returns `true` if both types implement one of the transform dialect
|
||||
/// interfaces.
|
||||
static bool implementSameTransformInterface(Type t1, Type t2) {
|
||||
return implementSameInterface<transform::TransformHandleTypeInterface,
|
||||
transform::TransformParamTypeInterface,
|
||||
transform::TransformValueHandleTypeInterface>(
|
||||
t1, t2);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CollectMatchingOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
|
||||
getOperation(), getMatcher());
|
||||
if (matcher.isExternal()) {
|
||||
return emitDefiniteFailure()
|
||||
<< "unresolved external symbol " << getMatcher();
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<MappedValue>, 2> rawResults;
|
||||
rawResults.resize(getOperation()->getNumResults());
|
||||
std::optional<DiagnosedSilenceableFailure> maybeFailure;
|
||||
for (Operation *root : state.getPayloadOps(getRoot())) {
|
||||
WalkResult walkResult = root->walk([&](Operation *op) {
|
||||
DEBUG_MATCHER({
|
||||
DBGS_MATCHER() << "matching ";
|
||||
op->print(llvm::dbgs(),
|
||||
OpPrintingFlags().assumeVerified().skipRegions());
|
||||
llvm::dbgs() << " @" << op << "\n";
|
||||
});
|
||||
|
||||
// Try matching.
|
||||
SmallVector<SmallVector<MappedValue>> mappings;
|
||||
DiagnosedSilenceableFailure diag =
|
||||
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
|
||||
if (diag.isDefiniteFailure())
|
||||
return WalkResult::interrupt();
|
||||
if (diag.isSilenceableFailure()) {
|
||||
DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
|
||||
<< " failed: " << diag.getMessage());
|
||||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
// If succeeded, collect results.
|
||||
for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
|
||||
if (mapping.size() != 1) {
|
||||
maybeFailure.emplace(emitSilenceableError()
|
||||
<< "result #" << i << ", associated with "
|
||||
<< mapping.size()
|
||||
<< " payload objects, expected 1");
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
rawResults[i].push_back(mapping[0]);
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walkResult.wasInterrupted())
|
||||
return std::move(*maybeFailure);
|
||||
assert(!maybeFailure && "failure set but the walk was not interrupted");
|
||||
|
||||
for (auto &&[opResult, rawResult] :
|
||||
llvm::zip_equal(getOperation()->getResults(), rawResults)) {
|
||||
results.setMappedValues(opResult, rawResult);
|
||||
}
|
||||
}
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void transform::CollectMatchingOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
onlyReadsHandle(getRoot(), effects);
|
||||
producesHandle(getResults(), effects);
|
||||
onlyReadsPayload(effects);
|
||||
}
|
||||
|
||||
LogicalResult transform::CollectMatchingOp::verifySymbolUses(
|
||||
SymbolTableCollection &symbolTable) {
|
||||
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
|
||||
symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
|
||||
if (!matcherSymbol ||
|
||||
!isa<TransformOpInterface>(matcherSymbol.getOperation()))
|
||||
return emitError() << "unresolved matcher symbol " << getMatcher();
|
||||
|
||||
ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
|
||||
if (argumentTypes.size() != 1 ||
|
||||
!isa<TransformHandleTypeInterface>(argumentTypes[0])) {
|
||||
return emitError()
|
||||
<< "expected the matcher to take one operation handle argument";
|
||||
}
|
||||
if (!matcherSymbol.getArgAttr(
|
||||
0, transform::TransformDialect::kArgReadOnlyAttrName)) {
|
||||
return emitError() << "expected the matcher argument to be marked readonly";
|
||||
}
|
||||
|
||||
ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
|
||||
if (resultTypes.size() != getOperation()->getNumResults()) {
|
||||
return emitError()
|
||||
<< "expected the matcher to yield as many values as op has results ("
|
||||
<< getOperation()->getNumResults() << "), got "
|
||||
<< resultTypes.size();
|
||||
}
|
||||
|
||||
for (auto &&[i, matcherType, resultType] :
|
||||
llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
|
||||
if (implementSameTransformInterface(matcherType, resultType))
|
||||
continue;
|
||||
|
||||
return emitError()
|
||||
<< "mismatching type interfaces for matcher result and op result #"
|
||||
<< i;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ForeachMatchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Returns `true` if both types implement one of the interfaces provided as
|
||||
/// template parameters.
|
||||
template <typename... Tys>
|
||||
static bool implementSameInterface(Type t1, Type t2) {
|
||||
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
|
||||
}
|
||||
|
||||
/// Returns `true` if both types implement one of the transform dialect
|
||||
/// interfaces.
|
||||
static bool implementSameTransformInterface(Type t1, Type t2) {
|
||||
return implementSameInterface<transform::TransformHandleTypeInterface,
|
||||
transform::TransformParamTypeInterface,
|
||||
transform::TransformValueHandleTypeInterface>(
|
||||
t1, t2);
|
||||
}
|
||||
|
||||
/// Checks that the attributes of the function-like operation have correct
|
||||
/// consumption effect annotations. If `alsoVerifyInternal`, checks for
|
||||
/// annotations being present even if they can be inferred from the body.
|
||||
|
@ -704,3 +704,71 @@ transform.sequence failures(propagate) {
|
||||
// expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
|
||||
transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes { transform.with_named_sequence } {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
// expected-error @below {{unresolved matcher symbol @missing_symbol}}
|
||||
transform.collect_matching @missing_symbol in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes { transform.with_named_sequence } {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
// expected-error @below {{expected the matcher to take one operation handle argument}}
|
||||
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
|
||||
transform.named_sequence @matcher() {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
module attributes { transform.with_named_sequence } {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
// expected-error @below {{expected the matcher argument to be marked readonly}}
|
||||
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
|
||||
transform.named_sequence @matcher(%arg0: !transform.any_op) {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
module attributes { transform.with_named_sequence } {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
// expected-error @below {{expected the matcher to yield as many values as op has results (1), got 0}}
|
||||
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
|
||||
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes { transform.with_named_sequence } {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
// expected-error @below {{mismatching type interfaces for matcher result and op result #0}}
|
||||
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_value
|
||||
transform.yield
|
||||
}
|
||||
|
||||
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
|
||||
transform.yield %arg0 : !transform.any_op
|
||||
}
|
||||
}
|
||||
|
@ -2380,3 +2380,47 @@ module @named_inclusion attributes { transform.with_named_sequence } {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes { transform.with_named_sequence } {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
// expected-error @below {{result #0, associated with 2 payload objects, expected 1}}
|
||||
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
|
||||
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
|
||||
%0 = transform.merge_handles %arg0, %arg0 : !transform.any_op
|
||||
transform.yield %0 : !transform.any_op
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes { transform.with_named_sequence } {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
// expected-error @below {{unresolved external symbol @matcher}}
|
||||
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
|
||||
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes { transform.with_named_sequence } {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
|
||||
// expected-remark @below {{matched}}
|
||||
%0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
// expected-remark @below {{matched}}
|
||||
transform.test_print_remark_at_operand %0, "matched" : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
|
||||
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
|
||||
transform.match.operation_name %arg0 ["transform.test_print_remark_at_operand", "transform.collect_matching"] : !transform.any_op
|
||||
transform.yield %arg0 : !transform.any_op
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user