[mlir] introduce transform.collect_matching (#76724)

Introduce a new match combinator into the transform dialect. This
operation collects all operations that are yielded by a satisfactory
match into its results. This is a simpler version of `foreach_match`
that can be inserted directly into existing transform scripts.
This commit is contained in:
Oleksandr "Alex" Zinenko 2024-01-09 13:18:57 +01:00 committed by GitHub
parent 4f7c402d9f
commit 633d9184f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 279 additions and 18 deletions

View File

@ -460,6 +460,39 @@ def NumAssociationsOp : TransformDialectOp<"num_associations",
let hasVerifier = 1;
}
def CollectMatchingOp : TransformDialectOp<"collect_matching", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Collects all payload ops that match the given named matcher";
let description = [{
Collects operations or other payload IR objects nested under `root`
(inclusive) that match the given matcher expressed as a named sequence. The
matcher sequence must accept exactly one argument that it is not allowed to
modify. It must yield as many values as this op has results. Each of the
yielded values must be associated with exactly one payload object. If any
operation in the matcher sequence produces a silenceable failure, the
matcher advances to the next payload operation in the walk order without
finishing the sequence.
The i-th result of this operation is constructed by concatenating the i-th
yielded payload IR objects of all successful matcher sequence applications.
All results are guaranteed to be mapped to the same number of payload IR
objects.
The operation succeeds unless the matcher sequence produced a definite
failure for any invocation.
}];
let arguments = (ins TransformHandleTypeInterface:$root,
SymbolRefAttr:$matcher);
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
let assemblyFormat = [{
$matcher `in` $root attr-dict `:` functional-type($root, $results)
}];
}
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
@ -674,7 +707,7 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
let summary = "Get handle to the producer of this operation's operand number";
let description = [{
The handle defined by this Transform op corresponds to operation that

View File

@ -22,6 +22,7 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
}
//===----------------------------------------------------------------------===//
// ForeachMatchOp
// CollectMatchingOp
//===----------------------------------------------------------------------===//
/// Applies matcher operations from the given `block` assigning `op` as the
@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
return DiagnosedSilenceableFailure::success();
}
/// Returns `true` if both types implement one of the interfaces provided as
/// template parameters.
template <typename... Tys>
static bool implementSameInterface(Type t1, Type t2) {
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}
/// Returns `true` if both types implement one of the transform dialect
/// interfaces.
static bool implementSameTransformInterface(Type t1, Type t2) {
return implementSameInterface<transform::TransformHandleTypeInterface,
transform::TransformParamTypeInterface,
transform::TransformValueHandleTypeInterface>(
t1, t2);
}
//===----------------------------------------------------------------------===//
// CollectMatchingOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
getOperation(), getMatcher());
if (matcher.isExternal()) {
return emitDefiniteFailure()
<< "unresolved external symbol " << getMatcher();
}
SmallVector<SmallVector<MappedValue>, 2> rawResults;
rawResults.resize(getOperation()->getNumResults());
std::optional<DiagnosedSilenceableFailure> maybeFailure;
for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
DEBUG_MATCHER({
DBGS_MATCHER() << "matching ";
op->print(llvm::dbgs(),
OpPrintingFlags().assumeVerified().skipRegions());
llvm::dbgs() << " @" << op << "\n";
});
// Try matching.
SmallVector<SmallVector<MappedValue>> mappings;
DiagnosedSilenceableFailure diag =
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
<< " failed: " << diag.getMessage());
return WalkResult::advance();
}
// If succeeded, collect results.
for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
if (mapping.size() != 1) {
maybeFailure.emplace(emitSilenceableError()
<< "result #" << i << ", associated with "
<< mapping.size()
<< " payload objects, expected 1");
return WalkResult::interrupt();
}
rawResults[i].push_back(mapping[0]);
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return std::move(*maybeFailure);
assert(!maybeFailure && "failure set but the walk was not interrupted");
for (auto &&[opResult, rawResult] :
llvm::zip_equal(getOperation()->getResults(), rawResults)) {
results.setMappedValues(opResult, rawResult);
}
}
return DiagnosedSilenceableFailure::success();
}
void transform::CollectMatchingOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getRoot(), effects);
producesHandle(getResults(), effects);
onlyReadsPayload(effects);
}
LogicalResult transform::CollectMatchingOp::verifySymbolUses(
SymbolTableCollection &symbolTable) {
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
if (!matcherSymbol ||
!isa<TransformOpInterface>(matcherSymbol.getOperation()))
return emitError() << "unresolved matcher symbol " << getMatcher();
ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
if (argumentTypes.size() != 1 ||
!isa<TransformHandleTypeInterface>(argumentTypes[0])) {
return emitError()
<< "expected the matcher to take one operation handle argument";
}
if (!matcherSymbol.getArgAttr(
0, transform::TransformDialect::kArgReadOnlyAttrName)) {
return emitError() << "expected the matcher argument to be marked readonly";
}
ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
if (resultTypes.size() != getOperation()->getNumResults()) {
return emitError()
<< "expected the matcher to yield as many values as op has results ("
<< getOperation()->getNumResults() << "), got "
<< resultTypes.size();
}
for (auto &&[i, matcherType, resultType] :
llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
if (implementSameTransformInterface(matcherType, resultType))
continue;
return emitError()
<< "mismatching type interfaces for matcher result and op result #"
<< i;
}
return success();
}
//===----------------------------------------------------------------------===//
// ForeachMatchOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() {
return success();
}
/// Returns `true` if both types implement one of the interfaces provided as
/// template parameters.
template <typename... Tys>
static bool implementSameInterface(Type t1, Type t2) {
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
}
/// Returns `true` if both types implement one of the transform dialect
/// interfaces.
static bool implementSameTransformInterface(Type t1, Type t2) {
return implementSameInterface<transform::TransformHandleTypeInterface,
transform::TransformParamTypeInterface,
transform::TransformValueHandleTypeInterface>(
t1, t2);
}
/// Checks that the attributes of the function-like operation have correct
/// consumption effect annotations. If `alsoVerifyInternal`, checks for
/// annotations being present even if they can be inferred from the body.

View File

@ -704,3 +704,71 @@ transform.sequence failures(propagate) {
// expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
}
// -----
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{unresolved matcher symbol @missing_symbol}}
transform.collect_matching @missing_symbol in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{expected the matcher to take one operation handle argument}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
transform.named_sequence @matcher() {
transform.yield
}
}
// -----
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{expected the matcher argument to be marked readonly}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
transform.named_sequence @matcher(%arg0: !transform.any_op) {
transform.yield
}
}
// -----
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{expected the matcher to yield as many values as op has results (1), got 0}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) {
transform.yield
}
}
// -----
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{mismatching type interfaces for matcher result and op result #0}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_value
transform.yield
}
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
transform.yield %arg0 : !transform.any_op
}
}

View File

@ -2380,3 +2380,47 @@ module @named_inclusion attributes { transform.with_named_sequence } {
transform.yield
}
}
// -----
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{result #0, associated with 2 payload objects, expected 1}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
%0 = transform.merge_handles %arg0, %arg0 : !transform.any_op
transform.yield %0 : !transform.any_op
}
}
// -----
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-error @below {{unresolved external symbol @matcher}}
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
}
// -----
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
// expected-remark @below {{matched}}
%0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
// expected-remark @below {{matched}}
transform.test_print_remark_at_operand %0, "matched" : !transform.any_op
transform.yield
}
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
transform.match.operation_name %arg0 ["transform.test_print_remark_at_operand", "transform.collect_matching"] : !transform.any_op
transform.yield %arg0 : !transform.any_op
}
}