[mlir] move PDL-related transform ops into an extension

The initial bring-up of the Transform dialect relied on PDL to provide
the default handle type (`!pdl.operation`) and the matching capability.
Both are now provided natively by the Transform dialect removing the
reason to have a hard dependency on the PDL dialect and its interpreter.
Move PDL-related transform operations into a separate extension.

This requires us to introduce a dialect state extension mechanism into
the Transform dialect so it no longer needs to know about PDL constraint
functions that may be injected by extensions similarly to operations and
types. This mechanism will be reused to connect pattern application
drivers and the Transform dialect.

This completes the restructuring of the Transform dialect to remove
overrilance on PDL.

Note to downstreams: flow that are using `!pdl.operation` with Transform
dialect operations will now require `transform::PDLExtension` to be
applied to the transform dialect in order to provide the transform
handle type interface for `!pdl.operation`.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D151104
This commit is contained in:
Alex Zinenko 2023-05-22 14:36:58 +00:00
parent 3590945a11
commit 94d608d410
33 changed files with 929 additions and 517 deletions

View File

@ -1,2 +1,3 @@
add_subdirectory(IR)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)

View File

@ -12,12 +12,52 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringMap.h"
#include <optional>
namespace mlir {
namespace transform {
namespace detail {
/// Concrete base class for CRTP TransformDialectDataBase. Must not be used
/// directly.
class TransformDialectDataBase {
public:
virtual ~TransformDialectDataBase() = default;
/// Returns the dynamic type ID of the subclass.
TypeID getTypeID() const { return typeID; }
protected:
/// Must be called by the subclass with the appropriate type ID.
explicit TransformDialectDataBase(TypeID typeID) : typeID(typeID) {}
private:
/// The type ID of the subclass.
const TypeID typeID;
};
} // namespace detail
/// Base class for additional data owned by the Transform dialect. Extensions
/// may communicate with each other using this data. The data object is
/// identified by the TypeID of the specific data subclass, querying the data of
/// the same subclass returns a reference to the same object. When a Transform
/// dialect extension is initialized, it can populate the data in the specific
/// subclass. When a Transform op is applied, it can read (but not mutate) the
/// data in the specific subclass, including the data provided by other
/// extensions.
///
/// This follows CRTP: derived classes must list themselves as template
/// argument.
template <typename DerivedTy>
class TransformDialectData : public detail::TransformDialectDataBase {
protected:
/// Forward the TypeID of the derived class to the base.
TransformDialectData() : TransformDialectDataBase(TypeID::get<DerivedTy>()) {}
};
#ifndef NDEBUG
namespace detail {
/// Asserts that the operations provided as template arguments implement the
@ -85,9 +125,8 @@ public:
for (const DialectLoader &loader : generatedDialectLoaders)
loader(context);
for (const Initializer &init : opInitializers)
for (const Initializer &init : initializers)
init(transformDialect);
transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns));
}
protected:
@ -100,6 +139,41 @@ protected:
static_cast<DerivedTy *>(this)->init();
}
/// Registers a custom initialization step to be performed when the extension
/// is applied to the dialect while loading. This is discouraged in favor of
/// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer`
/// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It
/// will be called during the extension initialization and given the current
/// MLIR context. This may be used to attach additional interfaces that cannot
/// be attached elsewhere.
template <typename Func>
void addCustomInitializationStep(Func &&func) {
std::function<void(MLIRContext *)> initializer = func;
dialectLoaders.push_back(
[init = std::move(initializer)](MLIRContext *ctx) { init(ctx); });
}
/// Registers the given function as one of the initializers for the
/// dialect-owned data of the kind specified as template argument. The
/// function must be convertible to the `void (DataTy &)` form. It will be
/// called during the extension initialization and will be given a mutable
/// reference to `DataTy`. The callback is expected to append data to the
/// given storage, and is not allowed to remove or destructively mutate the
/// existing data. The order in which callbacks from different extensions are
/// executed is unspecified so the callbacks may not rely on data being
/// already present. `DataTy` must be a class deriving `TransformDialectData`.
template <typename DataTy, typename Func>
void addDialectDataInitializer(Func &&func) {
static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>,
"only classes deriving TransformDialectData are accepted");
std::function<void(DataTy &)> initializer = func;
initializers.push_back(
[init = std::move(initializer)](TransformDialect *transformDialect) {
init(transformDialect->getOrCreateExtraData<DataTy>());
});
}
/// Hook for derived classes to inject constructor behavior.
void init() {}
@ -108,7 +182,7 @@ protected:
/// implementations must be already available when the operation is injected.
template <typename... OpTys>
void registerTransformOps() {
opInitializers.push_back([](TransformDialect *transformDialect) {
initializers.push_back([](TransformDialect *transformDialect) {
transformDialect->addOperationsChecked<OpTys...>();
});
}
@ -120,7 +194,7 @@ protected:
/// `StringRef` that is unique across all injected types.
template <typename... TypeTys>
void registerTypes() {
opInitializers.push_back([](TransformDialect *transformDialect) {
initializers.push_back([](TransformDialect *transformDialect) {
transformDialect->addTypesChecked<TypeTys...>();
});
}
@ -151,22 +225,10 @@ protected:
[](MLIRContext *context) { context->loadDialect<DialectTy>(); });
}
/// Injects the named constraint to make it available for use with the
/// PDLMatchOp in the transform dialect.
void registerPDLMatchConstraintFn(StringRef name,
PDLConstraintFunction &&fn) {
pdlMatchConstraintFns.try_emplace(name,
std::forward<PDLConstraintFunction>(fn));
}
template <typename ConstraintFnTy>
void registerPDLMatchConstraintFn(StringRef name, ConstraintFnTy &&fn) {
pdlMatchConstraintFns.try_emplace(
name, ::mlir::detail::pdl_function_builder::buildConstraintFn(
std::forward<ConstraintFnTy>(fn)));
}
private:
SmallVector<Initializer> opInitializers;
/// Callbacks performing extension initialization, e.g., registering ops,
/// types and defining the additional data.
SmallVector<Initializer> initializers;
/// Callbacks loading the dependent dialects, i.e. the dialect needed for the
/// extension ops.
@ -176,13 +238,6 @@ private:
/// applying the transformations.
SmallVector<DialectLoader> generatedDialectLoaders;
/// A list of constraints that should be made available to PDL patterns
/// processed by PDLMatchOp in the Transform dialect.
///
/// Declared as mutable so its contents can be moved in the `apply` const
/// method, which is only called once.
mutable llvm::StringMap<PDLConstraintFunction> pdlMatchConstraintFns;
/// Indicates that the extension is in build-only mode.
bool buildOnly;
};
@ -232,6 +287,17 @@ void TransformDialect::addTypeIfNotRegistered() {
#endif // NDEBUG
}
template <typename DataTy>
DataTy &TransformDialect::getOrCreateExtraData() {
TypeID typeID = TypeID::get<DataTy>();
auto it = extraData.find(typeID);
if (it != extraData.end())
return static_cast<DataTy &>(*it->getSecond());
auto emplaced = extraData.try_emplace(typeID, std::make_unique<DataTy>());
return static_cast<DataTy &>(*emplaced.first->getSecond());
}
/// A wrapper for transform dialect extensions that forces them to be
/// constructed in the build-only mode.
template <typename DerivedTy>

View File

@ -18,36 +18,31 @@ def Transform_Dialect : Dialect {
let name = "transform";
let cppNamespace = "::mlir::transform";
let dependentDialects = [
"::mlir::pdl::PDLDialect",
"::mlir::pdl_interp::PDLInterpDialect",
];
let hasOperationAttrVerify = 1;
let usePropertiesForAttributes = 1;
let extraClassDeclaration = [{
/// Name of the attribute attachable to the symbol table operation
/// containing named sequences. This is used to trigger verification.
constexpr const static llvm::StringLiteral
constexpr const static ::llvm::StringLiteral
kWithNamedSequenceAttrName = "transform.with_named_sequence";
/// Names of the attribute attachable to an operation so it can be
/// identified as root by the default interpreter pass.
constexpr const static llvm::StringLiteral
constexpr const static ::llvm::StringLiteral
kTargetTagAttrName = "transform.target_tag";
/// Names of the attributes indicating whether an argument of an external
/// transform dialect symbol is consumed or only read.
constexpr const static llvm::StringLiteral
constexpr const static ::llvm::StringLiteral
kArgConsumedAttrName = "transform.consumed";
constexpr const static llvm::StringLiteral
constexpr const static ::llvm::StringLiteral
kArgReadOnlyAttrName = "transform.readonly";
/// Returns the named PDL constraint functions available in the dialect
/// as a map from their name to the function.
const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
getPDLConstraintHooks() const;
template <typename DataTy>
const DataTy &getExtraData() const {
return *static_cast<const DataTy *>(extraData.at(::mlir::TypeID::get<DataTy>()).get());
}
/// Parses a type registered by this dialect or one of its extensions.
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
@ -92,23 +87,27 @@ def Transform_Dialect : Dialect {
/// mnemonic.
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);
/// Registers dialect types with the context.
void initializeTypes();
// Give extensions access to injection functions.
template <typename, typename...>
friend class TransformDialectExtension;
/// Takes ownership of the named PDL constraint function from the given
/// map and makes them available for use by the operations in the dialect.
void mergeInPDLMatchHooks(
::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns);
/// Gets a mutable reference to extra data of the kind specified as
/// template argument. Allocates the data on the first call.
template <typename DataTy>
DataTy &getOrCreateExtraData();
//===----------------------------------------------------------------===//
// Data fields
//===----------------------------------------------------------------===//
/// A container for PDL constraint function that can be used by
/// operations in this dialect.
::mlir::PDLPatternModule pdlMatchHooks;
/// Additional data associated with and owned by the dialect. Accessible
/// to extensions.
::llvm::DenseMap<::mlir::TypeID, std::unique_ptr<
::mlir::transform::detail::TransformDialectDataBase>>
extraData;
/// A map from type mnemonic to its parsing function for the remainder of
/// the syntax. The parser has access to the mnemonic, so it is used for

View File

@ -38,6 +38,14 @@ mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
/// Verification hook for PossibleTopLevelTransformOpTrait.
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
/// Populates `effects` with side effects implied by
/// PossibleTopLevelTransformOpTrait for the given operation. The operation may
/// have an optional `root` operand, indicating it is not in fact top-level. It
/// is also expected to have a single-block body.
void getPotentialTopLevelEffects(
Operation *operation, Value root, Block &body,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
/// Verification hook for TransformOpInterface.
LogicalResult verifyTransformOpInterface(Operation *op);
@ -753,15 +761,16 @@ TransformState::make_isolated_region_scope(Region &region) {
/// can be standalone top-level transforms. Such operations typically contain
/// other Transform dialect operations that can be executed following some
/// control flow logic specific to the current operation. The operations with
/// this trait are expected to have at least one single-block region with one
/// argument of PDL Operation type. The operations are also expected to be valid
/// without operands, in which case they are considered top-level, and with one
/// or more arguments, in which case they are considered nested. Top-level
/// operations have the block argument of the entry block in the Transform IR
/// correspond to the root operation of Payload IR. Nested operations have the
/// block argument of the entry block in the Transform IR correspond to a list
/// of Payload IR operations mapped to the first operand of the Transform IR
/// operation. The operation must implement TransformOpInterface.
/// this trait are expected to have at least one single-block region with at
/// least one argument of type implementing TransformHandleTypeInterface. The
/// operations are also expected to be valid without operands, in which case
/// they are considered top-level, and with one or more arguments, in which case
/// they are considered nested. Top-level operations have the block argument of
/// the entry block in the Transform IR correspond to the root operation of
/// Payload IR. Nested operations have the block argument of the entry block in
/// the Transform IR correspond to a list of Payload IR operations mapped to the
/// first operand of the Transform IR operation. The operation must implement
/// TransformOpInterface.
template <typename OpTy>
class PossibleTopLevelTransformOpTrait
: public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
@ -777,6 +786,14 @@ public:
return &this->getOperation()->getRegion(region).front();
}
/// Populates `effects` with side effects implied by this trait.
void getPotentialTopLevelEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
detail::getPotentialTopLevelEffects(
this->getOperation(), cast<OpTy>(this->getOperation()).getRoot(),
*getBodyBlock(), effects);
}
/// Sets up the mapping between the entry block of the given region of this op
/// and the relevant list of Payload IR operations in the given state. The
/// state is expected to be already scoped at the region of this operation.

View File

@ -9,7 +9,6 @@
#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"

View File

@ -575,37 +575,6 @@ def ParamConstantOp : Op<Transform_Dialect, "param.constant", [
let assemblyFormat = "$value attr-dict `->` type($param)";
}
def PDLMatchOp : TransformDialectOp<"pdl_match",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Finds ops that match the named PDL pattern";
let description = [{
Find Payload IR ops nested within the Payload IR op associated with the
operand that match the PDL pattern identified by its name. The pattern is
expected to be defined in the closest surrounding `WithPDLPatternsOp`.
Produces a Transform IR value associated with the list of Payload IR ops
that matched the pattern. The order of results in the list is that of the
Operation::walk, clients are advised not to rely on a specific order though.
If the operand is associated with multiple Payload IR ops, finds matching
ops nested within each of those and produces a single list containing all
of the matched ops.
The transformation is considered successful regardless of whether some
Payload IR ops actually matched the pattern and only fails if the pattern
could not be looked up or compiled.
}];
let arguments = (ins
Arg<TransformHandleTypeInterface, "Payload IR scope to match within">:$root,
SymbolRefAttr:$pattern_name);
let results = (outs
Res<TransformHandleTypeInterface, "Handle to the matched Payload IR ops">:$matched);
let assemblyFormat = "$pattern_name `in` $root attr-dict `:` "
"functional-type(operands, results)";
}
def PrintOp : TransformDialectOp<"print",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
@ -753,61 +722,6 @@ def SequenceOp : TransformDialectOp<"sequence",
let hasVerifier = 1;
}
def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
[DeclareOpInterfaceMethods<TransformOpInterface>, NoTerminator,
OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SymbolTable]> {
let summary = "Contains PDL patterns available for use in transforms";
let description = [{
This op contains a set of named PDL patterns that are available for the
Transform dialect operations to be used for pattern matching. For example,
PDLMatchOp can be used to produce a Transform IR value associated with all
Payload IR operations that match the pattern as follows:
```mlir
transform.with_pdl_patterns {
^bb0(%arg0: !transform.any_op):
pdl.pattern @my_pattern : benefit(1) {
%0 = pdl.operation //...
// Regular PDL goes here.
pdl.rewrite %0 with "transform.dialect"
}
sequence %arg0 failures(propagate) {
^bb0(%arg1: !transform.any_op):
%1 = pdl_match @my_pattern in %arg1
// Use %1 as handle
}
}
```
Note that the pattern is expected to finish with a `pdl.rewrite` terminator
that points to the custom rewriter named "transform.dialect". The rewriter
actually does nothing, but the transform application will keep track of the
operations that matched the pattern.
This op is expected to contain `pdl.pattern` operations and exactly one
another Transform dialect operation that gets executed with all patterns
available. This op is a possible top-level Transform IR op, the argument of
its entry block corresponds to either the root op of the payload IR or the
ops associated with its operand when provided.
}];
let arguments = (ins
Arg<Optional<TransformHandleTypeInterface>, "Root operation of the Payload IR"
>:$root);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions";
let hasVerifier = 1;
let extraClassDeclaration = [{
/// Allow the dialect prefix to be omitted.
static StringRef getDefaultDialect() { return "transform"; }
}];
}
def YieldOp : TransformDialectOp<"yield",
[Terminator, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Yields operation handles from a transform IR region";

View File

@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS PDLExtensionOps.td)
mlir_tablegen(PDLExtensionOps.h.inc -gen-op-decls)
mlir_tablegen(PDLExtensionOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTransformDialectPDLExtensionOpsIncGen)
add_mlir_doc(PDLExtensionOps PDLExtensionOps Dialects/ -gen-op-doc)

View File

@ -0,0 +1,16 @@
//===- PDLExtension.h - PDL extension for Transform dialect -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
namespace mlir {
class DialectRegistry;
namespace transform {
/// Registers the PDL extension of the Transform dialect in the given registry.
void registerPDLExtension(DialectRegistry &dialectRegistry);
} // namespace transform
} // namespace mlir

View File

@ -0,0 +1,49 @@
//===- PDLExtensionOps.h - PDL extension for Transform dialect --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H
#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h.inc"
namespace mlir {
namespace transform {
/// PDL constraint callbacks that can be used by the PDL extension of the
/// Transform dialect. These are owned by the Transform dialect and can be
/// populated by extensions.
class PDLMatchHooks : public TransformDialectData<PDLMatchHooks> {
public:
/// Takes ownership of the named PDL constraint function from the given
/// map and makes them available for use by the operations in the dialect.
void
mergeInPDLMatchHooks(llvm::StringMap<PDLConstraintFunction> &&constraintFns);
/// Returns the named PDL constraint functions available in the dialect
/// as a map from their name to the function.
const llvm::StringMap<::mlir::PDLConstraintFunction> &
getPDLConstraintHooks() const;
private:
/// A container for PDL constraint function that can be used by
/// operations in this dialect.
PDLPatternModule pdlMatchHooks;
};
} // namespace transform
} // namespace mlir
MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks)
#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H

View File

@ -0,0 +1,104 @@
//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS
#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SymbolInterfaces.td"
def PDLMatchOp : TransformDialectOp<"pdl_match",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Finds ops that match the named PDL pattern";
let description = [{
Find Payload IR ops nested within the Payload IR op associated with the
operand that match the PDL pattern identified by its name. The pattern is
expected to be defined in the closest surrounding `WithPDLPatternsOp`.
Produces a Transform IR value associated with the list of Payload IR ops
that matched the pattern. The order of results in the list is that of the
Operation::walk, clients are advised not to rely on a specific order though.
If the operand is associated with multiple Payload IR ops, finds matching
ops nested within each of those and produces a single list containing all
of the matched ops.
The transformation is considered successful regardless of whether some
Payload IR ops actually matched the pattern and only fails if the pattern
could not be looked up or compiled.
}];
let arguments = (ins
Arg<TransformHandleTypeInterface, "Payload IR scope to match within">:$root,
SymbolRefAttr:$pattern_name);
let results = (outs
Res<TransformHandleTypeInterface, "Handle to the matched Payload IR ops">:$matched);
let assemblyFormat = "$pattern_name `in` $root attr-dict `:` "
"functional-type(operands, results)";
}
def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
[DeclareOpInterfaceMethods<TransformOpInterface>, NoTerminator,
OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SymbolTable]> {
let summary = "Contains PDL patterns available for use in transforms";
let description = [{
This op contains a set of named PDL patterns that are available for the
Transform dialect operations to be used for pattern matching. For example,
PDLMatchOp can be used to produce a Transform IR value associated with all
Payload IR operations that match the pattern as follows:
```mlir
transform.with_pdl_patterns {
^bb0(%arg0: !transform.any_op):
pdl.pattern @my_pattern : benefit(1) {
%0 = pdl.operation //...
// Regular PDL goes here.
pdl.rewrite %0 with "transform.dialect"
}
sequence %arg0 failures(propagate) {
^bb0(%arg1: !transform.any_op):
%1 = pdl_match @my_pattern in %arg1
// Use %1 as handle
}
}
```
Note that the pattern is expected to finish with a `pdl.rewrite` terminator
that points to the custom rewriter named "transform.dialect". The rewriter
actually does nothing, but the transform application will keep track of the
operations that matched the pattern.
This op is expected to contain `pdl.pattern` operations and exactly one
another Transform dialect operation that gets executed with all patterns
available. This op is a possible top-level Transform IR op, the argument of
its entry block corresponds to either the root op of the payload IR or the
ops associated with its operand when provided.
}];
let arguments = (ins
Arg<Optional<TransformHandleTypeInterface>, "Root operation of the Payload IR"
>:$root);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions";
let hasVerifier = 1;
let extraClassDeclaration = [{
/// Allow the dialect prefix to be omitted.
static StringRef getDefaultDialect() { return "transform"; }
}];
}
#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS

View File

@ -76,6 +76,7 @@
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
@ -135,6 +136,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
memref::registerTransformDialectExtension(registry);
scf::registerTransformDialectExtension(registry);
tensor::registerTransformDialectExtension(registry);
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);
// Register all external models.

View File

@ -1,3 +1,4 @@
add_subdirectory(IR)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)
add_subdirectory(Utils)

View File

@ -14,8 +14,6 @@ add_mlir_dialect_library(MLIRTransformDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRParser
MLIRPDLDialect
MLIRPDLInterpDialect
MLIRRewrite
MLIRSideEffectInterfaces
MLIRTransforms

View File

@ -8,8 +8,6 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Analysis/CallGraph.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
@ -51,18 +49,6 @@ void transform::detail::checkImplementsTransformHandleTypeInterface(
}
#endif // NDEBUG
namespace {
struct PDLOperationTypeTransformHandleTypeInterfaceImpl
: public transform::TransformHandleTypeInterface::ExternalModel<
PDLOperationTypeTransformHandleTypeInterfaceImpl,
pdl::OperationType> {
DiagnosedSilenceableFailure
checkPayload(Type type, Location loc, ArrayRef<Operation *> payload) const {
return DiagnosedSilenceableFailure::success();
}
};
} // namespace
void transform::TransformDialect::initialize() {
// Using the checked versions to enable the same assertions as for the ops
// from extensions.
@ -71,21 +57,6 @@ void transform::TransformDialect::initialize() {
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
initializeTypes();
pdl::OperationType::attachInterface<
PDLOperationTypeTransformHandleTypeInterfaceImpl>(*getContext());
}
void transform::TransformDialect::mergeInPDLMatchHooks(
llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
// Steal the constraint functions from the given map.
for (auto &it : constraintFns)
pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
}
const llvm::StringMap<PDLConstraintFunction> &
transform::TransformDialect::getPDLConstraintHooks() const {
return pdlMatchHooks.getConstraintFunctions();
}
Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {

View File

@ -1242,6 +1242,61 @@ void transform::detail::forwardTerminatorOperands(
// Utilities for PossibleTopLevelTransformOpTrait.
//===----------------------------------------------------------------------===//
/// Appends to `effects` the memory effect instances on `target` with the same
/// resource and effect as the ones the operation `iface` having on `source`.
static void
remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
iface.getEffectsOnValue(source, nestedEffects);
for (const auto &effect : nestedEffects)
effects.emplace_back(effect.getEffect(), target, effect.getResource());
}
/// Appends to `effects` the same effects as the operations of `block` have on
/// block arguments but associated with `operands.`
static void
remapArgumentEffects(Block &block, ValueRange operands,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (Operation &op : block) {
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
if (!iface)
continue;
for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
remapEffects(iface, source, target, effects);
}
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
iface.getEffectsOnResource(transform::PayloadIRResource::get(),
nestedEffects);
llvm::append_range(effects, nestedEffects);
}
}
void transform::detail::getPotentialTopLevelEffects(
Operation *operation, Value root, Block &body,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(operation->getOperands(), effects);
transform::producesHandle(operation->getResults(), effects);
if (!root) {
for (Operation &op : body) {
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
if (!iface)
continue;
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
iface.getEffects(effects);
}
return;
}
// Carry over all effects on arguments of the entry block as those on the
// operands, this is the same value just remapped.
remapArgumentEffects(body, operation->getOperands(), effects);
}
LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
TransformState &state, Operation *op, Region &region) {
SmallVector<Operation *> targets;

View File

@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
@ -17,8 +16,6 @@
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
@ -52,99 +49,6 @@ static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
//===----------------------------------------------------------------------===//
// PatternApplicatorExtension
//===----------------------------------------------------------------------===//
namespace {
/// A TransformState extension that keeps track of compiled PDL pattern sets.
/// This is intended to be used along the WithPDLPatterns op. The extension
/// can be constructed given an operation that has a SymbolTable trait and
/// contains pdl::PatternOp instances. The patterns are compiled lazily and one
/// by one when requested; this behavior is subject to change.
class PatternApplicatorExtension : public transform::TransformState::Extension {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
/// Creates the extension for patterns contained in `patternContainer`.
explicit PatternApplicatorExtension(transform::TransformState &state,
Operation *patternContainer)
: Extension(state), patterns(patternContainer) {}
/// Appends to `results` the operations contained in `root` that matched the
/// PDL pattern with the given name. Note that `root` may or may not be the
/// operation that contains PDL patterns. Reports an error if the pattern
/// cannot be found. Note that when no operations are matched, this still
/// succeeds as long as the pattern exists.
LogicalResult findAllMatches(StringRef patternName, Operation *root,
SmallVectorImpl<Operation *> &results);
private:
/// Map from the pattern name to a singleton set of rewrite patterns that only
/// contains the pattern with this name. Populated when the pattern is first
/// requested.
// TODO: reconsider the efficiency of this storage when more usage data is
// available. Storing individual patterns in a set and triggering compilation
// for each of them has overhead. So does compiling a large set of patterns
// only to apply a handlful of them.
llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
/// A symbol table operation containing the relevant PDL patterns.
SymbolTable patterns;
};
LogicalResult PatternApplicatorExtension::findAllMatches(
StringRef patternName, Operation *root,
SmallVectorImpl<Operation *> &results) {
auto it = compiledPatterns.find(patternName);
if (it == compiledPatterns.end()) {
auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
if (!patternOp)
return failure();
// Copy the pattern operation into a new module that is compiled and
// consumed by the PDL interpreter.
OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
builder.clone(*patternOp);
PDLPatternModule patternModule(std::move(pdlModuleOp));
// Merge in the hooks owned by the dialect. Make a copy as they may be
// also used by the following operations.
auto *dialect =
root->getContext()->getLoadedDialect<transform::TransformDialect>();
for (const auto &[name, constraintFn] : dialect->getPDLConstraintHooks())
patternModule.registerConstraintFunction(name, constraintFn);
// Register a noop rewriter because PDL requires patterns to end with some
// rewrite call.
patternModule.registerRewriteFunction(
"transform.dialect", [](PatternRewriter &, Operation *) {});
it = compiledPatterns
.try_emplace(patternOp.getName(), std::move(patternModule))
.first;
}
PatternApplicator applicator(it->second);
// We want to discourage direct use of PatternRewriter in APIs but In this
// very specific case, an IRRewriter is not enough.
struct TrivialPatternRewriter : public PatternRewriter {
public:
explicit TrivialPatternRewriter(MLIRContext *context)
: PatternRewriter(context) {}
};
TrivialPatternRewriter rewriter(root->getContext());
applicator.applyDefaultCostModel();
root->walk([&](Operation *op) {
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
results.push_back(op);
});
return success();
}
} // namespace
//===----------------------------------------------------------------------===//
// TrackingListener
//===----------------------------------------------------------------------===//
@ -420,10 +324,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
assert(outputs.size() == 1 && "expected one output");
return llvm::all_of(
std::initializer_list<Type>{inputs.front(), outputs.front()},
[](Type ty) {
return llvm::isa<pdl::OperationType,
transform::TransformHandleTypeInterface>(ty);
});
[](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); });
}
//===----------------------------------------------------------------------===//
@ -1031,38 +932,6 @@ transform::IncludeOp::apply(transform::TransformResults &results,
return result;
}
/// Appends to `effects` the memory effect instances on `target` with the same
/// resource and effect as the ones the operation `iface` having on `source`.
static void
remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
iface.getEffectsOnValue(source, nestedEffects);
for (const auto &effect : nestedEffects)
effects.emplace_back(effect.getEffect(), target, effect.getResource());
}
/// Appends to `effects` the same effects as the operations of `block` have on
/// block arguments but associated with `operands.`
static void
remapArgumentEffects(Block &block, ValueRange operands,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (Operation &op : block) {
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
if (!iface)
continue;
for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
remapEffects(iface, source, target, effects);
}
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
iface.getEffectsOnResource(transform::PayloadIRResource::get(),
nestedEffects);
llvm::append_range(effects, nestedEffects);
}
}
static DiagnosedSilenceableFailure
verifyNamedSequenceOp(transform::NamedSequenceOp op);
@ -1474,8 +1343,7 @@ LogicalResult transform::NamedSequenceOp::verify() {
void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
Value target, int64_t numResultHandles) {
result.addOperands(target);
auto pdlOpType = pdl::OperationType::get(builder.getContext());
result.addTypes(SmallVector<pdl::OperationType>(numResultHandles, pdlOpType));
result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
}
DiagnosedSilenceableFailure
@ -1535,35 +1403,6 @@ LogicalResult transform::SplitHandleOp::verify() {
return success();
}
//===----------------------------------------------------------------------===//
// PDLMatchOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::PDLMatchOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
auto *extension = state.getExtension<PatternApplicatorExtension>();
assert(extension &&
"expected PatternApplicatorExtension to be attached by the parent op");
SmallVector<Operation *> targets;
for (Operation *root : state.getPayloadOps(getRoot())) {
if (failed(extension->findAllMatches(
getPatternName().getLeafReference().getValue(), root, targets))) {
emitDefiniteFailure()
<< "could not find pattern '" << getPatternName() << "'";
}
}
results.set(llvm::cast<OpResult>(getResult()), targets);
return DiagnosedSilenceableFailure::success();
}
void transform::PDLMatchOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getRoot(), effects);
producesHandle(getMatched(), effects);
onlyReadsPayload(effects);
}
//===----------------------------------------------------------------------===//
// ReplicateOp
//===----------------------------------------------------------------------===//
@ -1776,37 +1615,9 @@ LogicalResult transform::SequenceOp::verify() {
return success();
}
/// Populate `effects` with transform dialect memory effects for the potential
/// top-level operation. Such operations have recursive effects from nested
/// operations. When they have an operand, we can additionally remap effects on
/// the block argument to be effects on the operand.
template <typename OpTy>
static void getPotentialTopLevelEffects(
OpTy operation, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(operation->getOperands(), effects);
transform::producesHandle(operation->getResults(), effects);
if (!operation.getRoot()) {
for (Operation &op : *operation.getBodyBlock()) {
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
if (!iface)
continue;
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
iface.getEffects(effects);
}
return;
}
// Carry over all effects on arguments of the entry block as those on the
// operands, this is the same value just remapped.
remapArgumentEffects(*operation.getBodyBlock(), operation->getOperands(),
effects);
}
void transform::SequenceOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
getPotentialTopLevelEffects(*this, effects);
getPotentialTopLevelEffects(effects);
}
OperandRange transform::SequenceOp::getSuccessorEntryOperands(
@ -1908,77 +1719,6 @@ void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
}
//===----------------------------------------------------------------------===//
// WithPDLPatternsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
TransformOpInterface transformOp = nullptr;
for (Operation &nested : getBody().front()) {
if (!isa<pdl::PatternOp>(nested)) {
transformOp = cast<TransformOpInterface>(nested);
break;
}
}
state.addExtension<PatternApplicatorExtension>(getOperation());
auto guard = llvm::make_scope_exit(
[&]() { state.removeExtension<PatternApplicatorExtension>(); });
auto scope = state.make_region_scope(getBody());
if (failed(mapBlockArguments(state)))
return DiagnosedSilenceableFailure::definiteFailure();
return state.applyTransform(transformOp);
}
void transform::WithPDLPatternsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
getPotentialTopLevelEffects(*this, effects);
}
LogicalResult transform::WithPDLPatternsOp::verify() {
Block *body = getBodyBlock();
Operation *topLevelOp = nullptr;
for (Operation &op : body->getOperations()) {
if (isa<pdl::PatternOp>(op))
continue;
if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
if (topLevelOp) {
InFlightDiagnostic diag =
emitOpError() << "expects only one non-pattern op in its body";
diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
diag.attachNote(op.getLoc()) << "second non-pattern op";
return diag;
}
topLevelOp = &op;
continue;
}
InFlightDiagnostic diag =
emitOpError()
<< "expects only pattern and top-level transform ops in its body";
diag.attachNote(op.getLoc()) << "offending op";
return diag;
}
if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
InFlightDiagnostic diag = emitOpError() << "cannot be nested";
diag.attachNote(parent.getLoc()) << "parent operation";
return diag;
}
if (!topLevelOp) {
InFlightDiagnostic diag = emitOpError()
<< "expects at least one non-pattern op";
return diag;
}
return success();
}
//===----------------------------------------------------------------------===//
// PrintOp
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,13 @@
add_mlir_dialect_library(MLIRTransformPDLExtension
PDLExtension.cpp
PDLExtensionOps.cpp
DEPENDS
MLIRTransformDialectPDLExtensionOpsIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRTransformDialect
MLIRPDLDialect
MLIRPDLInterpDialect
)

View File

@ -0,0 +1,69 @@
//===- PDLExtension.cpp - PDL extension for the Transform dialect ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
#include "mlir/IR/DialectRegistry.h"
using namespace mlir;
namespace {
/// Implementation of the TransformHandleTypeInterface for the PDL
/// OperationType. Accepts any payload operation.
struct PDLOperationTypeTransformHandleTypeInterfaceImpl
: public transform::TransformHandleTypeInterface::ExternalModel<
PDLOperationTypeTransformHandleTypeInterfaceImpl,
pdl::OperationType> {
/// Accept any operation.
DiagnosedSilenceableFailure
checkPayload(Type type, Location loc, ArrayRef<Operation *> payload) const {
return DiagnosedSilenceableFailure::success();
}
};
} // namespace
namespace {
/// PDL extension of the Transform dialect. This provides transform operations
/// that connect to PDL matching as well as interfaces for PDL types to be used
/// with Transform dialect operations.
class PDLExtension : public transform::TransformDialectExtension<PDLExtension> {
public:
void init() {
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
>();
addDialectDataInitializer<transform::PDLMatchHooks>(
[](transform::PDLMatchHooks &) {});
// Declare PDL as dependent so we can attach an interface to its type in the
// later step.
declareDependentDialect<pdl::PDLDialect>();
// PDLInterp is only relevant if we actually apply the transform IR so
// declare it as generated.
declareGeneratedDialect<pdl_interp::PDLInterpDialect>();
// Make PDL OperationType usable as a transform dialect type.
addCustomInitializationStep([](MLIRContext *context) {
pdl::OperationType::attachInterface<
PDLOperationTypeTransformHandleTypeInterfaceImpl>(*context);
});
}
};
} // namespace
void mlir::transform::registerPDLExtension(DialectRegistry &dialectRegistry) {
dialectRegistry.addExtensions<PDLExtension>();
}

View File

@ -0,0 +1,234 @@
//===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/ADT/ScopeExit.h"
using namespace mlir;
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks)
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
//===----------------------------------------------------------------------===//
// PatternApplicatorExtension
//===----------------------------------------------------------------------===//
namespace {
/// A TransformState extension that keeps track of compiled PDL pattern sets.
/// This is intended to be used along the WithPDLPatterns op. The extension
/// can be constructed given an operation that has a SymbolTable trait and
/// contains pdl::PatternOp instances. The patterns are compiled lazily and one
/// by one when requested; this behavior is subject to change.
class PatternApplicatorExtension : public transform::TransformState::Extension {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
/// Creates the extension for patterns contained in `patternContainer`.
explicit PatternApplicatorExtension(transform::TransformState &state,
Operation *patternContainer)
: Extension(state), patterns(patternContainer) {}
/// Appends to `results` the operations contained in `root` that matched the
/// PDL pattern with the given name. Note that `root` may or may not be the
/// operation that contains PDL patterns. Reports an error if the pattern
/// cannot be found. Note that when no operations are matched, this still
/// succeeds as long as the pattern exists.
LogicalResult findAllMatches(StringRef patternName, Operation *root,
SmallVectorImpl<Operation *> &results);
private:
/// Map from the pattern name to a singleton set of rewrite patterns that only
/// contains the pattern with this name. Populated when the pattern is first
/// requested.
// TODO: reconsider the efficiency of this storage when more usage data is
// available. Storing individual patterns in a set and triggering compilation
// for each of them has overhead. So does compiling a large set of patterns
// only to apply a handful of them.
llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
/// A symbol table operation containing the relevant PDL patterns.
SymbolTable patterns;
};
LogicalResult PatternApplicatorExtension::findAllMatches(
StringRef patternName, Operation *root,
SmallVectorImpl<Operation *> &results) {
auto it = compiledPatterns.find(patternName);
if (it == compiledPatterns.end()) {
auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
if (!patternOp)
return failure();
// Copy the pattern operation into a new module that is compiled and
// consumed by the PDL interpreter.
OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
builder.clone(*patternOp);
PDLPatternModule patternModule(std::move(pdlModuleOp));
// Merge in the hooks owned by the dialect. Make a copy as they may be
// also used by the following operations.
auto *dialect =
root->getContext()->getLoadedDialect<transform::TransformDialect>();
for (const auto &[name, constraintFn] :
dialect->getExtraData<transform::PDLMatchHooks>()
.getPDLConstraintHooks()) {
patternModule.registerConstraintFunction(name, constraintFn);
}
// Register a noop rewriter because PDL requires patterns to end with some
// rewrite call.
patternModule.registerRewriteFunction(
"transform.dialect", [](PatternRewriter &, Operation *) {});
it = compiledPatterns
.try_emplace(patternOp.getName(), std::move(patternModule))
.first;
}
PatternApplicator applicator(it->second);
// We want to discourage direct use of PatternRewriter in APIs but In this
// very specific case, an IRRewriter is not enough.
struct TrivialPatternRewriter : public PatternRewriter {
public:
explicit TrivialPatternRewriter(MLIRContext *context)
: PatternRewriter(context) {}
};
TrivialPatternRewriter rewriter(root->getContext());
applicator.applyDefaultCostModel();
root->walk([&](Operation *op) {
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
results.push_back(op);
});
return success();
}
} // namespace
//===----------------------------------------------------------------------===//
// PDLMatchHooks
//===----------------------------------------------------------------------===//
void transform::PDLMatchHooks::mergeInPDLMatchHooks(
llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
// Steal the constraint functions from the given map.
for (auto &it : constraintFns)
pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
}
const llvm::StringMap<PDLConstraintFunction> &
transform::PDLMatchHooks::getPDLConstraintHooks() const {
return pdlMatchHooks.getConstraintFunctions();
}
//===----------------------------------------------------------------------===//
// PDLMatchOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::PDLMatchOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
auto *extension = state.getExtension<PatternApplicatorExtension>();
assert(extension &&
"expected PatternApplicatorExtension to be attached by the parent op");
SmallVector<Operation *> targets;
for (Operation *root : state.getPayloadOps(getRoot())) {
if (failed(extension->findAllMatches(
getPatternName().getLeafReference().getValue(), root, targets))) {
emitDefiniteFailure()
<< "could not find pattern '" << getPatternName() << "'";
}
}
results.set(llvm::cast<OpResult>(getResult()), targets);
return DiagnosedSilenceableFailure::success();
}
void transform::PDLMatchOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getRoot(), effects);
producesHandle(getMatched(), effects);
onlyReadsPayload(effects);
}
//===----------------------------------------------------------------------===//
// WithPDLPatternsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
TransformOpInterface transformOp = nullptr;
for (Operation &nested : getBody().front()) {
if (!isa<pdl::PatternOp>(nested)) {
transformOp = cast<TransformOpInterface>(nested);
break;
}
}
state.addExtension<PatternApplicatorExtension>(getOperation());
auto guard = llvm::make_scope_exit(
[&]() { state.removeExtension<PatternApplicatorExtension>(); });
auto scope = state.make_region_scope(getBody());
if (failed(mapBlockArguments(state)))
return DiagnosedSilenceableFailure::definiteFailure();
return state.applyTransform(transformOp);
}
void transform::WithPDLPatternsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
getPotentialTopLevelEffects(effects);
}
LogicalResult transform::WithPDLPatternsOp::verify() {
Block *body = getBodyBlock();
Operation *topLevelOp = nullptr;
for (Operation &op : body->getOperations()) {
if (isa<pdl::PatternOp>(op))
continue;
if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
if (topLevelOp) {
InFlightDiagnostic diag =
emitOpError() << "expects only one non-pattern op in its body";
diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
diag.attachNote(op.getLoc()) << "second non-pattern op";
return diag;
}
topLevelOp = &op;
continue;
}
InFlightDiagnostic diag =
emitOpError()
<< "expects only pattern and top-level transform ops in its body";
diag.attachNote(op.getLoc()) << "offending op";
return diag;
}
if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
InFlightDiagnostic diag = emitOpError() << "cannot be nested";
diag.attachNote(parent.getLoc()) << "parent operation";
return diag;
}
if (!topLevelOp) {
InFlightDiagnostic diag = emitOpError()
<< "expects at least one non-pattern op";
return diag;
}
return success();
}

View File

@ -114,6 +114,16 @@ declare_mlir_dialect_python_bindings(
DIALECT_NAME linalg
DEPENDS LinalgOdsGen)
declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformPDLExtensionOps.td
SOURCES
dialects/_transform_pdl_extension_ops_ext.py
dialects/transform/pdl.py
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"

View File

@ -0,0 +1,20 @@
//===-- TransformPDLExtensionOps.td - Binding entry point --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Entry point of the generated Python bindings for the PDL extension of the
// Transform dialect.
//
//===----------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS
#define PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS
include "mlir/Bindings/Python/Attributes.td"
include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td"
#endif // PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS

View File

@ -60,26 +60,6 @@ class MergeHandlesOp:
)
class PDLMatchOp:
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
pattern_name: Union[Attribute, str],
*,
loc=None,
ip=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
pattern_name,
loc=loc,
ip=ip,
)
class ReplicateOp:
def __init__(
@ -152,28 +132,6 @@ class SequenceOp:
return self.body.arguments[1:]
class WithPDLPatternsOp:
def __init__(self,
target: Union[Operation, Value, Type],
*,
loc=None,
ip=None):
root = _get_op_result_or_value(target) if not isinstance(target,
Type) else None
root_type = target if isinstance(target, Type) else root.type
super().__init__(root=root, loc=loc, ip=ip)
self.regions[0].blocks.append(root_type)
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTarget(self) -> Value:
return self.body.arguments[0]
class YieldOp:
def __init__(

View File

@ -0,0 +1,55 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from ..ir import *
from ._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
)
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Union
class PDLMatchOp:
def __init__(
self,
result_type: Type,
target: Union[Operation, Value],
pattern_name: Union[Attribute, str],
*,
loc=None,
ip=None,
):
super().__init__(
result_type,
_get_op_result_or_value(target),
pattern_name,
loc=loc,
ip=ip,
)
class WithPDLPatternsOp:
def __init__(self,
target: Union[Operation, Value, Type],
*,
loc=None,
ip=None):
root = _get_op_result_or_value(target) if not isinstance(target,
Type) else None
root_type = target if isinstance(target, Type) else root.type
super().__init__(root=root, loc=loc, ip=ip)
self.regions[0].blocks.append(root_type)
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
@property
def bodyTarget(self) -> Value:
return self.body.arguments[0]

View File

@ -0,0 +1,5 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._transform_pdl_extension_ops_gen import *

View File

@ -83,33 +83,6 @@ transform.sequence failures(propagate) {
// -----
transform.with_pdl_patterns {
^bb0(%arg0: !transform.any_op):
sequence %arg0 : !transform.any_op failures(propagate) {
^bb0(%arg1: !transform.any_op):
%0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op
test_print_remark_at_operand %0, "matched" : !transform.any_op
}
pdl.pattern @some : benefit(1) {
%0 = pdl.operation "test.some_op"
pdl.rewrite %0 with "transform.dialect"
}
pdl.pattern @other : benefit(1) {
%0 = pdl.operation "test.other_op"
pdl.rewrite %0 with "transform.dialect"
}
}
// expected-remark @below {{matched}}
"test.some_op"() : () -> ()
"test.other_op"() : () -> ()
// expected-remark @below {{matched}}
"test.some_op"() : () -> ()
// -----
// expected-remark @below {{parent function}}
func.func @foo() {
%0 = arith.constant 0 : i32

View File

@ -0,0 +1,47 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
transform.with_pdl_patterns {
^bb0(%arg0: !transform.any_op):
sequence %arg0 : !transform.any_op failures(propagate) {
^bb0(%arg1: !transform.any_op):
%0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op
test_print_remark_at_operand %0, "matched" : !transform.any_op
}
pdl.pattern @some : benefit(1) {
%0 = pdl.operation "test.some_op"
pdl.rewrite %0 with "transform.dialect"
}
pdl.pattern @other : benefit(1) {
%0 = pdl.operation "test.other_op"
pdl.rewrite %0 with "transform.dialect"
}
}
// expected-remark @below {{matched}}
"test.some_op"() : () -> ()
"test.other_op"() : () -> ()
// expected-remark @below {{matched}}
"test.some_op"() : () -> ()
// -----
transform.with_pdl_patterns {
^bb0(%arg0: !transform.any_op):
sequence %arg0 : !transform.any_op failures(propagate) {
^bb1(%arg1: !transform.any_op):
%0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op
}
pdl.pattern @some : benefit(1) {
%0 = pdl.operation "test.some_op"
pdl.apply_native_constraint "verbose_constraint"(%0 : !pdl.operation)
pdl.rewrite %0 with "transform.dialect"
}
}
// expected-warning @below {{from PDL constraint}}
"test.some_op"() : () -> ()
"test.other_op"() : () -> ()

View File

@ -21,4 +21,5 @@ add_mlir_library(MLIRTestTransformDialect
MLIRPDLDialect
MLIRTransformDialect
MLIRTransformDialectTransforms
MLIRTransformPDLExtension
)

View File

@ -17,7 +17,9 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Compiler.h"
@ -754,6 +756,23 @@ public:
#define GET_TYPEDEF_LIST
#include "TestTransformDialectExtensionTypes.cpp.inc"
>();
auto verboseConstraint = [](PatternRewriter &rewriter,
ArrayRef<PDLValue> pdlValues) {
for (const PDLValue &pdlValue : pdlValues) {
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
op->emitWarning() << "from PDL constraint";
}
}
return success();
};
addDialectDataInitializer<transform::PDLMatchHooks>(
[&](transform::PDLMatchHooks &hooks) {
llvm::StringMap<PDLConstraintFunction> constraints;
constraints.try_emplace("verbose_constraint", verboseConstraint);
hooks.mergeInPDLMatchHooks(std::move(constraints));
});
}
};
} // namespace

View File

@ -2,7 +2,7 @@
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects import pdl
from mlir.dialects.transform import pdl as transform_pdl
def run(f):
@ -103,13 +103,13 @@ def testNestedSequenceOpWithExtras():
@run
def testTransformPDLOps():
withPdl = transform.WithPDLPatternsOp(transform.AnyOpType.get())
withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(withPdl.body):
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
[transform.AnyOpType.get()],
withPdl.bodyTarget)
with InsertionPoint(sequence.body):
match = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher")
match = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher")
transform.YieldOp(match)
# CHECK-LABEL: TEST: testTransformPDLOps
# CHECK: transform.with_pdl_patterns {
@ -148,13 +148,13 @@ def testMergeHandlesOp():
@run
def testReplicateOp():
with_pdl = transform.WithPDLPatternsOp(transform.AnyOpType.get())
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
with InsertionPoint(with_pdl.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget)
with InsertionPoint(sequence.body):
m1 = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first")
m2 = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second")
m1 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first")
m2 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second")
transform.ReplicateOp(m1, [m2])
transform.YieldOp()
# CHECK-LABEL: TEST: testReplicateOp

View File

@ -4,6 +4,7 @@ from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects import pdl
from mlir.dialects.transform import structured
from mlir.dialects.transform import pdl as transform_pdl
def run(f):
@ -151,13 +152,13 @@ def testTileZero():
@run
def testTileDynamic():
with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get())
with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
with InsertionPoint(with_pdl.body):
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [],
with_pdl.bodyTarget)
with InsertionPoint(sequence.body):
m1 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
m2 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
m1 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
m2 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
structured.TileOp(sequence.bodyTarget,
sizes=[m1, 3, m2, 0])
transform.YieldOp()

View File

@ -7495,6 +7495,7 @@ cc_library(
":TosaToLinalg",
":TransformDialect",
":TransformDialectTransforms",
":TransformPDLExtension",
":Transforms",
":TransformsPassIncGen",
":VectorDialect",
@ -9732,7 +9733,6 @@ td_library(
":ControlFlowInterfacesTdFiles",
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":PDLDialectTdFiles",
":SideEffectInterfacesTdFiles",
],
)
@ -9889,8 +9889,6 @@ cc_library(
":CallOpInterfaces",
":ControlFlowInterfaces",
":IR",
":PDLDialect",
":PDLInterpDialect",
":Rewrite",
":SideEffectInterfaces",
":Support",
@ -9906,6 +9904,54 @@ cc_library(
],
)
td_library(
name = "TransformPDLExtensionTdFiles",
srcs = glob(["include/mlir/Dialect/Transform/PDLExtension/*.td"]),
deps = [
":PDLDialectTdFiles",
":TransformDialectTdFiles",
],
)
gentbl_cc_library(
name = "TransformPDLExtensionOpsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
[
"-gen-op-decls",
],
"include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h.inc",
),
(
[
"-gen-op-defs",
],
"include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td",
deps = [":TransformPDLExtensionTdFiles"],
)
cc_library(
name = "TransformPDLExtension",
srcs = glob(["lib/Dialect/Transform/PDLExtension/*.cpp"]),
hdrs = glob(["include/mlir/Dialect/Transform/PDLExtension/*.h"]),
deps = [
":IR",
":PDLDialect",
":PDLInterpDialect",
":SideEffectInterfaces",
":Support",
":TransformDialect",
":TransformPDLExtensionOpsIncGen",
":Rewrite",
"//llvm:Support",
],
)
td_library(
name = "TransformDialectTransformsTdFiles",
srcs = glob(["include/mlir/Dialect/Transform/Transforms/*.td"]),

View File

@ -927,6 +927,26 @@ gentbl_filegroup(
],
)
gentbl_filegroup(
name = "PDLTransformOpsPyGen",
tbl_outs = [
(
[
"-gen-python-op-bindings",
"-bind-dialect=transform",
"-dialect-extension=transform_pdl_extension",
],
"mlir/dialects/_transform_pdl_extension_ops_gen.py",
),
],
tblgen = "//mlir:mlir-tblgen",
td_file = "mlir/dialects/TransformPDLExtensionOps.td",
deps = [
":TransformOpsPyTdFiles",
"//mlir:TransformPDLExtensionTdFiles",
],
)
filegroup(
name = "TransformOpsPyFiles",
srcs = [
@ -934,6 +954,7 @@ filegroup(
"mlir/dialects/_structured_transform_ops_ext.py",
"mlir/dialects/_transform_ops_ext.py",
":LoopTransformOpsPyGen",
":PDLTransformOpsPyGen",
":StructuredTransformOpsPyGen",
":TransformOpsPyGen",
],

View File

@ -317,6 +317,7 @@ gentbl_cc_library(
":TransformDialectTdFiles",
"//mlir:PDLDialectTdFiles",
"//mlir:TransformDialectTdFiles",
"//mlir:TransformPDLExtension",
],
)
@ -333,6 +334,7 @@ cc_library(
"//mlir:Pass",
"//mlir:TransformDialect",
"//mlir:TransformDialectTransforms",
"//mlir:TransformPDLExtension",
],
)