mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-03-05 08:58:13 +00:00
[mlir] move PDL-related transform ops into an extension
The initial bring-up of the Transform dialect relied on PDL to provide the default handle type (`!pdl.operation`) and the matching capability. Both are now provided natively by the Transform dialect removing the reason to have a hard dependency on the PDL dialect and its interpreter. Move PDL-related transform operations into a separate extension. This requires us to introduce a dialect state extension mechanism into the Transform dialect so it no longer needs to know about PDL constraint functions that may be injected by extensions similarly to operations and types. This mechanism will be reused to connect pattern application drivers and the Transform dialect. This completes the restructuring of the Transform dialect to remove overrilance on PDL. Note to downstreams: flow that are using `!pdl.operation` with Transform dialect operations will now require `transform::PDLExtension` to be applied to the transform dialect in order to provide the transform handle type interface for `!pdl.operation`. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D151104
This commit is contained in:
parent
3590945a11
commit
94d608d410
@ -1,2 +1,3 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(PDLExtension)
|
||||
add_subdirectory(Transforms)
|
||||
|
@ -12,12 +12,52 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/TypeID.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include <optional>
|
||||
|
||||
namespace mlir {
|
||||
namespace transform {
|
||||
|
||||
namespace detail {
|
||||
/// Concrete base class for CRTP TransformDialectDataBase. Must not be used
|
||||
/// directly.
|
||||
class TransformDialectDataBase {
|
||||
public:
|
||||
virtual ~TransformDialectDataBase() = default;
|
||||
|
||||
/// Returns the dynamic type ID of the subclass.
|
||||
TypeID getTypeID() const { return typeID; }
|
||||
|
||||
protected:
|
||||
/// Must be called by the subclass with the appropriate type ID.
|
||||
explicit TransformDialectDataBase(TypeID typeID) : typeID(typeID) {}
|
||||
|
||||
private:
|
||||
/// The type ID of the subclass.
|
||||
const TypeID typeID;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// Base class for additional data owned by the Transform dialect. Extensions
|
||||
/// may communicate with each other using this data. The data object is
|
||||
/// identified by the TypeID of the specific data subclass, querying the data of
|
||||
/// the same subclass returns a reference to the same object. When a Transform
|
||||
/// dialect extension is initialized, it can populate the data in the specific
|
||||
/// subclass. When a Transform op is applied, it can read (but not mutate) the
|
||||
/// data in the specific subclass, including the data provided by other
|
||||
/// extensions.
|
||||
///
|
||||
/// This follows CRTP: derived classes must list themselves as template
|
||||
/// argument.
|
||||
template <typename DerivedTy>
|
||||
class TransformDialectData : public detail::TransformDialectDataBase {
|
||||
protected:
|
||||
/// Forward the TypeID of the derived class to the base.
|
||||
TransformDialectData() : TransformDialectDataBase(TypeID::get<DerivedTy>()) {}
|
||||
};
|
||||
|
||||
#ifndef NDEBUG
|
||||
namespace detail {
|
||||
/// Asserts that the operations provided as template arguments implement the
|
||||
@ -85,9 +125,8 @@ public:
|
||||
for (const DialectLoader &loader : generatedDialectLoaders)
|
||||
loader(context);
|
||||
|
||||
for (const Initializer &init : opInitializers)
|
||||
for (const Initializer &init : initializers)
|
||||
init(transformDialect);
|
||||
transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns));
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -100,6 +139,41 @@ protected:
|
||||
static_cast<DerivedTy *>(this)->init();
|
||||
}
|
||||
|
||||
/// Registers a custom initialization step to be performed when the extension
|
||||
/// is applied to the dialect while loading. This is discouraged in favor of
|
||||
/// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer`
|
||||
/// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It
|
||||
/// will be called during the extension initialization and given the current
|
||||
/// MLIR context. This may be used to attach additional interfaces that cannot
|
||||
/// be attached elsewhere.
|
||||
template <typename Func>
|
||||
void addCustomInitializationStep(Func &&func) {
|
||||
std::function<void(MLIRContext *)> initializer = func;
|
||||
dialectLoaders.push_back(
|
||||
[init = std::move(initializer)](MLIRContext *ctx) { init(ctx); });
|
||||
}
|
||||
|
||||
/// Registers the given function as one of the initializers for the
|
||||
/// dialect-owned data of the kind specified as template argument. The
|
||||
/// function must be convertible to the `void (DataTy &)` form. It will be
|
||||
/// called during the extension initialization and will be given a mutable
|
||||
/// reference to `DataTy`. The callback is expected to append data to the
|
||||
/// given storage, and is not allowed to remove or destructively mutate the
|
||||
/// existing data. The order in which callbacks from different extensions are
|
||||
/// executed is unspecified so the callbacks may not rely on data being
|
||||
/// already present. `DataTy` must be a class deriving `TransformDialectData`.
|
||||
template <typename DataTy, typename Func>
|
||||
void addDialectDataInitializer(Func &&func) {
|
||||
static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>,
|
||||
"only classes deriving TransformDialectData are accepted");
|
||||
|
||||
std::function<void(DataTy &)> initializer = func;
|
||||
initializers.push_back(
|
||||
[init = std::move(initializer)](TransformDialect *transformDialect) {
|
||||
init(transformDialect->getOrCreateExtraData<DataTy>());
|
||||
});
|
||||
}
|
||||
|
||||
/// Hook for derived classes to inject constructor behavior.
|
||||
void init() {}
|
||||
|
||||
@ -108,7 +182,7 @@ protected:
|
||||
/// implementations must be already available when the operation is injected.
|
||||
template <typename... OpTys>
|
||||
void registerTransformOps() {
|
||||
opInitializers.push_back([](TransformDialect *transformDialect) {
|
||||
initializers.push_back([](TransformDialect *transformDialect) {
|
||||
transformDialect->addOperationsChecked<OpTys...>();
|
||||
});
|
||||
}
|
||||
@ -120,7 +194,7 @@ protected:
|
||||
/// `StringRef` that is unique across all injected types.
|
||||
template <typename... TypeTys>
|
||||
void registerTypes() {
|
||||
opInitializers.push_back([](TransformDialect *transformDialect) {
|
||||
initializers.push_back([](TransformDialect *transformDialect) {
|
||||
transformDialect->addTypesChecked<TypeTys...>();
|
||||
});
|
||||
}
|
||||
@ -151,22 +225,10 @@ protected:
|
||||
[](MLIRContext *context) { context->loadDialect<DialectTy>(); });
|
||||
}
|
||||
|
||||
/// Injects the named constraint to make it available for use with the
|
||||
/// PDLMatchOp in the transform dialect.
|
||||
void registerPDLMatchConstraintFn(StringRef name,
|
||||
PDLConstraintFunction &&fn) {
|
||||
pdlMatchConstraintFns.try_emplace(name,
|
||||
std::forward<PDLConstraintFunction>(fn));
|
||||
}
|
||||
template <typename ConstraintFnTy>
|
||||
void registerPDLMatchConstraintFn(StringRef name, ConstraintFnTy &&fn) {
|
||||
pdlMatchConstraintFns.try_emplace(
|
||||
name, ::mlir::detail::pdl_function_builder::buildConstraintFn(
|
||||
std::forward<ConstraintFnTy>(fn)));
|
||||
}
|
||||
|
||||
private:
|
||||
SmallVector<Initializer> opInitializers;
|
||||
/// Callbacks performing extension initialization, e.g., registering ops,
|
||||
/// types and defining the additional data.
|
||||
SmallVector<Initializer> initializers;
|
||||
|
||||
/// Callbacks loading the dependent dialects, i.e. the dialect needed for the
|
||||
/// extension ops.
|
||||
@ -176,13 +238,6 @@ private:
|
||||
/// applying the transformations.
|
||||
SmallVector<DialectLoader> generatedDialectLoaders;
|
||||
|
||||
/// A list of constraints that should be made available to PDL patterns
|
||||
/// processed by PDLMatchOp in the Transform dialect.
|
||||
///
|
||||
/// Declared as mutable so its contents can be moved in the `apply` const
|
||||
/// method, which is only called once.
|
||||
mutable llvm::StringMap<PDLConstraintFunction> pdlMatchConstraintFns;
|
||||
|
||||
/// Indicates that the extension is in build-only mode.
|
||||
bool buildOnly;
|
||||
};
|
||||
@ -232,6 +287,17 @@ void TransformDialect::addTypeIfNotRegistered() {
|
||||
#endif // NDEBUG
|
||||
}
|
||||
|
||||
template <typename DataTy>
|
||||
DataTy &TransformDialect::getOrCreateExtraData() {
|
||||
TypeID typeID = TypeID::get<DataTy>();
|
||||
auto it = extraData.find(typeID);
|
||||
if (it != extraData.end())
|
||||
return static_cast<DataTy &>(*it->getSecond());
|
||||
|
||||
auto emplaced = extraData.try_emplace(typeID, std::make_unique<DataTy>());
|
||||
return static_cast<DataTy &>(*emplaced.first->getSecond());
|
||||
}
|
||||
|
||||
/// A wrapper for transform dialect extensions that forces them to be
|
||||
/// constructed in the build-only mode.
|
||||
template <typename DerivedTy>
|
||||
|
@ -18,36 +18,31 @@ def Transform_Dialect : Dialect {
|
||||
let name = "transform";
|
||||
let cppNamespace = "::mlir::transform";
|
||||
|
||||
let dependentDialects = [
|
||||
"::mlir::pdl::PDLDialect",
|
||||
"::mlir::pdl_interp::PDLInterpDialect",
|
||||
];
|
||||
|
||||
let hasOperationAttrVerify = 1;
|
||||
let usePropertiesForAttributes = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Name of the attribute attachable to the symbol table operation
|
||||
/// containing named sequences. This is used to trigger verification.
|
||||
constexpr const static llvm::StringLiteral
|
||||
constexpr const static ::llvm::StringLiteral
|
||||
kWithNamedSequenceAttrName = "transform.with_named_sequence";
|
||||
|
||||
/// Names of the attribute attachable to an operation so it can be
|
||||
/// identified as root by the default interpreter pass.
|
||||
constexpr const static llvm::StringLiteral
|
||||
constexpr const static ::llvm::StringLiteral
|
||||
kTargetTagAttrName = "transform.target_tag";
|
||||
|
||||
/// Names of the attributes indicating whether an argument of an external
|
||||
/// transform dialect symbol is consumed or only read.
|
||||
constexpr const static llvm::StringLiteral
|
||||
constexpr const static ::llvm::StringLiteral
|
||||
kArgConsumedAttrName = "transform.consumed";
|
||||
constexpr const static llvm::StringLiteral
|
||||
constexpr const static ::llvm::StringLiteral
|
||||
kArgReadOnlyAttrName = "transform.readonly";
|
||||
|
||||
/// Returns the named PDL constraint functions available in the dialect
|
||||
/// as a map from their name to the function.
|
||||
const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
|
||||
getPDLConstraintHooks() const;
|
||||
template <typename DataTy>
|
||||
const DataTy &getExtraData() const {
|
||||
return *static_cast<const DataTy *>(extraData.at(::mlir::TypeID::get<DataTy>()).get());
|
||||
}
|
||||
|
||||
/// Parses a type registered by this dialect or one of its extensions.
|
||||
::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override;
|
||||
@ -92,23 +87,27 @@ def Transform_Dialect : Dialect {
|
||||
/// mnemonic.
|
||||
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);
|
||||
|
||||
/// Registers dialect types with the context.
|
||||
void initializeTypes();
|
||||
|
||||
// Give extensions access to injection functions.
|
||||
template <typename, typename...>
|
||||
friend class TransformDialectExtension;
|
||||
|
||||
/// Takes ownership of the named PDL constraint function from the given
|
||||
/// map and makes them available for use by the operations in the dialect.
|
||||
void mergeInPDLMatchHooks(
|
||||
::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns);
|
||||
/// Gets a mutable reference to extra data of the kind specified as
|
||||
/// template argument. Allocates the data on the first call.
|
||||
template <typename DataTy>
|
||||
DataTy &getOrCreateExtraData();
|
||||
|
||||
//===----------------------------------------------------------------===//
|
||||
// Data fields
|
||||
//===----------------------------------------------------------------===//
|
||||
|
||||
/// A container for PDL constraint function that can be used by
|
||||
/// operations in this dialect.
|
||||
::mlir::PDLPatternModule pdlMatchHooks;
|
||||
/// Additional data associated with and owned by the dialect. Accessible
|
||||
/// to extensions.
|
||||
::llvm::DenseMap<::mlir::TypeID, std::unique_ptr<
|
||||
::mlir::transform::detail::TransformDialectDataBase>>
|
||||
extraData;
|
||||
|
||||
/// A map from type mnemonic to its parsing function for the remainder of
|
||||
/// the syntax. The parser has access to the mnemonic, so it is used for
|
||||
|
@ -38,6 +38,14 @@ mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
|
||||
/// Verification hook for PossibleTopLevelTransformOpTrait.
|
||||
LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
|
||||
|
||||
/// Populates `effects` with side effects implied by
|
||||
/// PossibleTopLevelTransformOpTrait for the given operation. The operation may
|
||||
/// have an optional `root` operand, indicating it is not in fact top-level. It
|
||||
/// is also expected to have a single-block body.
|
||||
void getPotentialTopLevelEffects(
|
||||
Operation *operation, Value root, Block &body,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
|
||||
|
||||
/// Verification hook for TransformOpInterface.
|
||||
LogicalResult verifyTransformOpInterface(Operation *op);
|
||||
|
||||
@ -753,15 +761,16 @@ TransformState::make_isolated_region_scope(Region ®ion) {
|
||||
/// can be standalone top-level transforms. Such operations typically contain
|
||||
/// other Transform dialect operations that can be executed following some
|
||||
/// control flow logic specific to the current operation. The operations with
|
||||
/// this trait are expected to have at least one single-block region with one
|
||||
/// argument of PDL Operation type. The operations are also expected to be valid
|
||||
/// without operands, in which case they are considered top-level, and with one
|
||||
/// or more arguments, in which case they are considered nested. Top-level
|
||||
/// operations have the block argument of the entry block in the Transform IR
|
||||
/// correspond to the root operation of Payload IR. Nested operations have the
|
||||
/// block argument of the entry block in the Transform IR correspond to a list
|
||||
/// of Payload IR operations mapped to the first operand of the Transform IR
|
||||
/// operation. The operation must implement TransformOpInterface.
|
||||
/// this trait are expected to have at least one single-block region with at
|
||||
/// least one argument of type implementing TransformHandleTypeInterface. The
|
||||
/// operations are also expected to be valid without operands, in which case
|
||||
/// they are considered top-level, and with one or more arguments, in which case
|
||||
/// they are considered nested. Top-level operations have the block argument of
|
||||
/// the entry block in the Transform IR correspond to the root operation of
|
||||
/// Payload IR. Nested operations have the block argument of the entry block in
|
||||
/// the Transform IR correspond to a list of Payload IR operations mapped to the
|
||||
/// first operand of the Transform IR operation. The operation must implement
|
||||
/// TransformOpInterface.
|
||||
template <typename OpTy>
|
||||
class PossibleTopLevelTransformOpTrait
|
||||
: public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
|
||||
@ -777,6 +786,14 @@ public:
|
||||
return &this->getOperation()->getRegion(region).front();
|
||||
}
|
||||
|
||||
/// Populates `effects` with side effects implied by this trait.
|
||||
void getPotentialTopLevelEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
detail::getPotentialTopLevelEffects(
|
||||
this->getOperation(), cast<OpTy>(this->getOperation()).getRoot(),
|
||||
*getBodyBlock(), effects);
|
||||
}
|
||||
|
||||
/// Sets up the mapping between the entry block of the given region of this op
|
||||
/// and the relevant list of Payload IR operations in the given state. The
|
||||
/// state is expected to be already scoped at the region of this operation.
|
||||
|
@ -9,7 +9,6 @@
|
||||
#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
|
||||
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
|
||||
|
||||
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
|
||||
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
|
@ -575,37 +575,6 @@ def ParamConstantOp : Op<Transform_Dialect, "param.constant", [
|
||||
let assemblyFormat = "$value attr-dict `->` type($param)";
|
||||
}
|
||||
|
||||
def PDLMatchOp : TransformDialectOp<"pdl_match",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
let summary = "Finds ops that match the named PDL pattern";
|
||||
let description = [{
|
||||
Find Payload IR ops nested within the Payload IR op associated with the
|
||||
operand that match the PDL pattern identified by its name. The pattern is
|
||||
expected to be defined in the closest surrounding `WithPDLPatternsOp`.
|
||||
|
||||
Produces a Transform IR value associated with the list of Payload IR ops
|
||||
that matched the pattern. The order of results in the list is that of the
|
||||
Operation::walk, clients are advised not to rely on a specific order though.
|
||||
If the operand is associated with multiple Payload IR ops, finds matching
|
||||
ops nested within each of those and produces a single list containing all
|
||||
of the matched ops.
|
||||
|
||||
The transformation is considered successful regardless of whether some
|
||||
Payload IR ops actually matched the pattern and only fails if the pattern
|
||||
could not be looked up or compiled.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TransformHandleTypeInterface, "Payload IR scope to match within">:$root,
|
||||
SymbolRefAttr:$pattern_name);
|
||||
let results = (outs
|
||||
Res<TransformHandleTypeInterface, "Handle to the matched Payload IR ops">:$matched);
|
||||
|
||||
let assemblyFormat = "$pattern_name `in` $root attr-dict `:` "
|
||||
"functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def PrintOp : TransformDialectOp<"print",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
@ -753,61 +722,6 @@ def SequenceOp : TransformDialectOp<"sequence",
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>, NoTerminator,
|
||||
OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
SymbolTable]> {
|
||||
let summary = "Contains PDL patterns available for use in transforms";
|
||||
let description = [{
|
||||
This op contains a set of named PDL patterns that are available for the
|
||||
Transform dialect operations to be used for pattern matching. For example,
|
||||
PDLMatchOp can be used to produce a Transform IR value associated with all
|
||||
Payload IR operations that match the pattern as follows:
|
||||
|
||||
```mlir
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
pdl.pattern @my_pattern : benefit(1) {
|
||||
%0 = pdl.operation //...
|
||||
// Regular PDL goes here.
|
||||
pdl.rewrite %0 with "transform.dialect"
|
||||
}
|
||||
|
||||
sequence %arg0 failures(propagate) {
|
||||
^bb0(%arg1: !transform.any_op):
|
||||
%1 = pdl_match @my_pattern in %arg1
|
||||
// Use %1 as handle
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note that the pattern is expected to finish with a `pdl.rewrite` terminator
|
||||
that points to the custom rewriter named "transform.dialect". The rewriter
|
||||
actually does nothing, but the transform application will keep track of the
|
||||
operations that matched the pattern.
|
||||
|
||||
This op is expected to contain `pdl.pattern` operations and exactly one
|
||||
another Transform dialect operation that gets executed with all patterns
|
||||
available. This op is a possible top-level Transform IR op, the argument of
|
||||
its entry block corresponds to either the root op of the payload IR or the
|
||||
ops associated with its operand when provided.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<Optional<TransformHandleTypeInterface>, "Root operation of the Payload IR"
|
||||
>:$root);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions";
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Allow the dialect prefix to be omitted.
|
||||
static StringRef getDefaultDialect() { return "transform"; }
|
||||
}];
|
||||
}
|
||||
|
||||
def YieldOp : TransformDialectOp<"yield",
|
||||
[Terminator, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
let summary = "Yields operation handles from a transform IR region";
|
||||
|
@ -0,0 +1,6 @@
|
||||
set(LLVM_TARGET_DEFINITIONS PDLExtensionOps.td)
|
||||
mlir_tablegen(PDLExtensionOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(PDLExtensionOps.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(MLIRTransformDialectPDLExtensionOpsIncGen)
|
||||
|
||||
add_mlir_doc(PDLExtensionOps PDLExtensionOps Dialects/ -gen-op-doc)
|
@ -0,0 +1,16 @@
|
||||
//===- PDLExtension.h - PDL extension for Transform dialect -----*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace transform {
|
||||
/// Registers the PDL extension of the Transform dialect in the given registry.
|
||||
void registerPDLExtension(DialectRegistry &dialectRegistry);
|
||||
} // namespace transform
|
||||
} // namespace mlir
|
@ -0,0 +1,49 @@
|
||||
//===- PDLExtensionOps.h - PDL extension for Transform dialect --*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H
|
||||
#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H
|
||||
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace transform {
|
||||
/// PDL constraint callbacks that can be used by the PDL extension of the
|
||||
/// Transform dialect. These are owned by the Transform dialect and can be
|
||||
/// populated by extensions.
|
||||
class PDLMatchHooks : public TransformDialectData<PDLMatchHooks> {
|
||||
public:
|
||||
/// Takes ownership of the named PDL constraint function from the given
|
||||
/// map and makes them available for use by the operations in the dialect.
|
||||
void
|
||||
mergeInPDLMatchHooks(llvm::StringMap<PDLConstraintFunction> &&constraintFns);
|
||||
|
||||
/// Returns the named PDL constraint functions available in the dialect
|
||||
/// as a map from their name to the function.
|
||||
const llvm::StringMap<::mlir::PDLConstraintFunction> &
|
||||
getPDLConstraintHooks() const;
|
||||
|
||||
private:
|
||||
/// A container for PDL constraint function that can be used by
|
||||
/// operations in this dialect.
|
||||
PDLPatternModule pdlMatchHooks;
|
||||
};
|
||||
} // namespace transform
|
||||
} // namespace mlir
|
||||
|
||||
MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks)
|
||||
|
||||
#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H
|
@ -0,0 +1,104 @@
|
||||
//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS
|
||||
#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS
|
||||
|
||||
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
||||
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
def PDLMatchOp : TransformDialectOp<"pdl_match",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
let summary = "Finds ops that match the named PDL pattern";
|
||||
let description = [{
|
||||
Find Payload IR ops nested within the Payload IR op associated with the
|
||||
operand that match the PDL pattern identified by its name. The pattern is
|
||||
expected to be defined in the closest surrounding `WithPDLPatternsOp`.
|
||||
|
||||
Produces a Transform IR value associated with the list of Payload IR ops
|
||||
that matched the pattern. The order of results in the list is that of the
|
||||
Operation::walk, clients are advised not to rely on a specific order though.
|
||||
If the operand is associated with multiple Payload IR ops, finds matching
|
||||
ops nested within each of those and produces a single list containing all
|
||||
of the matched ops.
|
||||
|
||||
The transformation is considered successful regardless of whether some
|
||||
Payload IR ops actually matched the pattern and only fails if the pattern
|
||||
could not be looked up or compiled.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<TransformHandleTypeInterface, "Payload IR scope to match within">:$root,
|
||||
SymbolRefAttr:$pattern_name);
|
||||
let results = (outs
|
||||
Res<TransformHandleTypeInterface, "Handle to the matched Payload IR ops">:$matched);
|
||||
|
||||
let assemblyFormat = "$pattern_name `in` $root attr-dict `:` "
|
||||
"functional-type(operands, results)";
|
||||
}
|
||||
|
||||
def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>, NoTerminator,
|
||||
OpAsmOpInterface, PossibleTopLevelTransformOpTrait,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
SymbolTable]> {
|
||||
let summary = "Contains PDL patterns available for use in transforms";
|
||||
let description = [{
|
||||
This op contains a set of named PDL patterns that are available for the
|
||||
Transform dialect operations to be used for pattern matching. For example,
|
||||
PDLMatchOp can be used to produce a Transform IR value associated with all
|
||||
Payload IR operations that match the pattern as follows:
|
||||
|
||||
```mlir
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
pdl.pattern @my_pattern : benefit(1) {
|
||||
%0 = pdl.operation //...
|
||||
// Regular PDL goes here.
|
||||
pdl.rewrite %0 with "transform.dialect"
|
||||
}
|
||||
|
||||
sequence %arg0 failures(propagate) {
|
||||
^bb0(%arg1: !transform.any_op):
|
||||
%1 = pdl_match @my_pattern in %arg1
|
||||
// Use %1 as handle
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Note that the pattern is expected to finish with a `pdl.rewrite` terminator
|
||||
that points to the custom rewriter named "transform.dialect". The rewriter
|
||||
actually does nothing, but the transform application will keep track of the
|
||||
operations that matched the pattern.
|
||||
|
||||
This op is expected to contain `pdl.pattern` operations and exactly one
|
||||
another Transform dialect operation that gets executed with all patterns
|
||||
available. This op is a possible top-level Transform IR op, the argument of
|
||||
its entry block corresponds to either the root op of the payload IR or the
|
||||
ops associated with its operand when provided.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<Optional<TransformHandleTypeInterface>, "Root operation of the Payload IR"
|
||||
>:$root);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions";
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Allow the dialect prefix to be omitted.
|
||||
static StringRef getDefaultDialect() { return "transform"; }
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS
|
@ -76,6 +76,7 @@
|
||||
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
@ -135,6 +136,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
||||
memref::registerTransformDialectExtension(registry);
|
||||
scf::registerTransformDialectExtension(registry);
|
||||
tensor::registerTransformDialectExtension(registry);
|
||||
transform::registerPDLExtension(registry);
|
||||
vector::registerTransformDialectExtension(registry);
|
||||
|
||||
// Register all external models.
|
||||
|
@ -1,3 +1,4 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(PDLExtension)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(Utils)
|
||||
|
@ -14,8 +14,6 @@ add_mlir_dialect_library(MLIRTransformDialect
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRParser
|
||||
MLIRPDLDialect
|
||||
MLIRPDLInterpDialect
|
||||
MLIRRewrite
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRTransforms
|
||||
|
@ -8,8 +8,6 @@
|
||||
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Analysis/CallGraph.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDL.h"
|
||||
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
||||
@ -51,18 +49,6 @@ void transform::detail::checkImplementsTransformHandleTypeInterface(
|
||||
}
|
||||
#endif // NDEBUG
|
||||
|
||||
namespace {
|
||||
struct PDLOperationTypeTransformHandleTypeInterfaceImpl
|
||||
: public transform::TransformHandleTypeInterface::ExternalModel<
|
||||
PDLOperationTypeTransformHandleTypeInterfaceImpl,
|
||||
pdl::OperationType> {
|
||||
DiagnosedSilenceableFailure
|
||||
checkPayload(Type type, Location loc, ArrayRef<Operation *> payload) const {
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void transform::TransformDialect::initialize() {
|
||||
// Using the checked versions to enable the same assertions as for the ops
|
||||
// from extensions.
|
||||
@ -71,21 +57,6 @@ void transform::TransformDialect::initialize() {
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
|
||||
>();
|
||||
initializeTypes();
|
||||
|
||||
pdl::OperationType::attachInterface<
|
||||
PDLOperationTypeTransformHandleTypeInterfaceImpl>(*getContext());
|
||||
}
|
||||
|
||||
void transform::TransformDialect::mergeInPDLMatchHooks(
|
||||
llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
|
||||
// Steal the constraint functions from the given map.
|
||||
for (auto &it : constraintFns)
|
||||
pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
|
||||
}
|
||||
|
||||
const llvm::StringMap<PDLConstraintFunction> &
|
||||
transform::TransformDialect::getPDLConstraintHooks() const {
|
||||
return pdlMatchHooks.getConstraintFunctions();
|
||||
}
|
||||
|
||||
Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
|
||||
|
@ -1242,6 +1242,61 @@ void transform::detail::forwardTerminatorOperands(
|
||||
// Utilities for PossibleTopLevelTransformOpTrait.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Appends to `effects` the memory effect instances on `target` with the same
|
||||
/// resource and effect as the ones the operation `iface` having on `source`.
|
||||
static void
|
||||
remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
|
||||
iface.getEffectsOnValue(source, nestedEffects);
|
||||
for (const auto &effect : nestedEffects)
|
||||
effects.emplace_back(effect.getEffect(), target, effect.getResource());
|
||||
}
|
||||
|
||||
/// Appends to `effects` the same effects as the operations of `block` have on
|
||||
/// block arguments but associated with `operands.`
|
||||
static void
|
||||
remapArgumentEffects(Block &block, ValueRange operands,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
for (Operation &op : block) {
|
||||
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
|
||||
if (!iface)
|
||||
continue;
|
||||
|
||||
for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
|
||||
remapEffects(iface, source, target, effects);
|
||||
}
|
||||
|
||||
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
|
||||
iface.getEffectsOnResource(transform::PayloadIRResource::get(),
|
||||
nestedEffects);
|
||||
llvm::append_range(effects, nestedEffects);
|
||||
}
|
||||
}
|
||||
|
||||
void transform::detail::getPotentialTopLevelEffects(
|
||||
Operation *operation, Value root, Block &body,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
transform::onlyReadsHandle(operation->getOperands(), effects);
|
||||
transform::producesHandle(operation->getResults(), effects);
|
||||
|
||||
if (!root) {
|
||||
for (Operation &op : body) {
|
||||
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
|
||||
if (!iface)
|
||||
continue;
|
||||
|
||||
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
|
||||
iface.getEffects(effects);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Carry over all effects on arguments of the entry block as those on the
|
||||
// operands, this is the same value just remapped.
|
||||
remapArgumentEffects(body, operation->getOperands(), effects);
|
||||
}
|
||||
|
||||
LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
|
||||
TransformState &state, Operation *op, Region ®ion) {
|
||||
SmallVector<Operation *> targets;
|
||||
|
@ -7,7 +7,6 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDLOps.h"
|
||||
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
@ -17,8 +16,6 @@
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
||||
#include "mlir/Rewrite/PatternApplicator.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
@ -52,99 +49,6 @@ static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternApplicatorExtension
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// A TransformState extension that keeps track of compiled PDL pattern sets.
|
||||
/// This is intended to be used along the WithPDLPatterns op. The extension
|
||||
/// can be constructed given an operation that has a SymbolTable trait and
|
||||
/// contains pdl::PatternOp instances. The patterns are compiled lazily and one
|
||||
/// by one when requested; this behavior is subject to change.
|
||||
class PatternApplicatorExtension : public transform::TransformState::Extension {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
|
||||
|
||||
/// Creates the extension for patterns contained in `patternContainer`.
|
||||
explicit PatternApplicatorExtension(transform::TransformState &state,
|
||||
Operation *patternContainer)
|
||||
: Extension(state), patterns(patternContainer) {}
|
||||
|
||||
/// Appends to `results` the operations contained in `root` that matched the
|
||||
/// PDL pattern with the given name. Note that `root` may or may not be the
|
||||
/// operation that contains PDL patterns. Reports an error if the pattern
|
||||
/// cannot be found. Note that when no operations are matched, this still
|
||||
/// succeeds as long as the pattern exists.
|
||||
LogicalResult findAllMatches(StringRef patternName, Operation *root,
|
||||
SmallVectorImpl<Operation *> &results);
|
||||
|
||||
private:
|
||||
/// Map from the pattern name to a singleton set of rewrite patterns that only
|
||||
/// contains the pattern with this name. Populated when the pattern is first
|
||||
/// requested.
|
||||
// TODO: reconsider the efficiency of this storage when more usage data is
|
||||
// available. Storing individual patterns in a set and triggering compilation
|
||||
// for each of them has overhead. So does compiling a large set of patterns
|
||||
// only to apply a handlful of them.
|
||||
llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
|
||||
|
||||
/// A symbol table operation containing the relevant PDL patterns.
|
||||
SymbolTable patterns;
|
||||
};
|
||||
|
||||
LogicalResult PatternApplicatorExtension::findAllMatches(
|
||||
StringRef patternName, Operation *root,
|
||||
SmallVectorImpl<Operation *> &results) {
|
||||
auto it = compiledPatterns.find(patternName);
|
||||
if (it == compiledPatterns.end()) {
|
||||
auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
|
||||
if (!patternOp)
|
||||
return failure();
|
||||
|
||||
// Copy the pattern operation into a new module that is compiled and
|
||||
// consumed by the PDL interpreter.
|
||||
OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
|
||||
auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
|
||||
builder.clone(*patternOp);
|
||||
PDLPatternModule patternModule(std::move(pdlModuleOp));
|
||||
|
||||
// Merge in the hooks owned by the dialect. Make a copy as they may be
|
||||
// also used by the following operations.
|
||||
auto *dialect =
|
||||
root->getContext()->getLoadedDialect<transform::TransformDialect>();
|
||||
for (const auto &[name, constraintFn] : dialect->getPDLConstraintHooks())
|
||||
patternModule.registerConstraintFunction(name, constraintFn);
|
||||
|
||||
// Register a noop rewriter because PDL requires patterns to end with some
|
||||
// rewrite call.
|
||||
patternModule.registerRewriteFunction(
|
||||
"transform.dialect", [](PatternRewriter &, Operation *) {});
|
||||
|
||||
it = compiledPatterns
|
||||
.try_emplace(patternOp.getName(), std::move(patternModule))
|
||||
.first;
|
||||
}
|
||||
|
||||
PatternApplicator applicator(it->second);
|
||||
// We want to discourage direct use of PatternRewriter in APIs but In this
|
||||
// very specific case, an IRRewriter is not enough.
|
||||
struct TrivialPatternRewriter : public PatternRewriter {
|
||||
public:
|
||||
explicit TrivialPatternRewriter(MLIRContext *context)
|
||||
: PatternRewriter(context) {}
|
||||
};
|
||||
TrivialPatternRewriter rewriter(root->getContext());
|
||||
applicator.applyDefaultCostModel();
|
||||
root->walk([&](Operation *op) {
|
||||
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
|
||||
results.push_back(op);
|
||||
});
|
||||
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TrackingListener
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -420,10 +324,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
||||
assert(outputs.size() == 1 && "expected one output");
|
||||
return llvm::all_of(
|
||||
std::initializer_list<Type>{inputs.front(), outputs.front()},
|
||||
[](Type ty) {
|
||||
return llvm::isa<pdl::OperationType,
|
||||
transform::TransformHandleTypeInterface>(ty);
|
||||
});
|
||||
[](Type ty) { return isa<transform::TransformHandleTypeInterface>(ty); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1031,38 +932,6 @@ transform::IncludeOp::apply(transform::TransformResults &results,
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Appends to `effects` the memory effect instances on `target` with the same
|
||||
/// resource and effect as the ones the operation `iface` having on `source`.
|
||||
static void
|
||||
remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
|
||||
iface.getEffectsOnValue(source, nestedEffects);
|
||||
for (const auto &effect : nestedEffects)
|
||||
effects.emplace_back(effect.getEffect(), target, effect.getResource());
|
||||
}
|
||||
|
||||
/// Appends to `effects` the same effects as the operations of `block` have on
|
||||
/// block arguments but associated with `operands.`
|
||||
static void
|
||||
remapArgumentEffects(Block &block, ValueRange operands,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
for (Operation &op : block) {
|
||||
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
|
||||
if (!iface)
|
||||
continue;
|
||||
|
||||
for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
|
||||
remapEffects(iface, source, target, effects);
|
||||
}
|
||||
|
||||
SmallVector<MemoryEffects::EffectInstance> nestedEffects;
|
||||
iface.getEffectsOnResource(transform::PayloadIRResource::get(),
|
||||
nestedEffects);
|
||||
llvm::append_range(effects, nestedEffects);
|
||||
}
|
||||
}
|
||||
|
||||
static DiagnosedSilenceableFailure
|
||||
verifyNamedSequenceOp(transform::NamedSequenceOp op);
|
||||
|
||||
@ -1474,8 +1343,7 @@ LogicalResult transform::NamedSequenceOp::verify() {
|
||||
void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value target, int64_t numResultHandles) {
|
||||
result.addOperands(target);
|
||||
auto pdlOpType = pdl::OperationType::get(builder.getContext());
|
||||
result.addTypes(SmallVector<pdl::OperationType>(numResultHandles, pdlOpType));
|
||||
result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
@ -1535,35 +1403,6 @@ LogicalResult transform::SplitHandleOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PDLMatchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::PDLMatchOp::apply(transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
auto *extension = state.getExtension<PatternApplicatorExtension>();
|
||||
assert(extension &&
|
||||
"expected PatternApplicatorExtension to be attached by the parent op");
|
||||
SmallVector<Operation *> targets;
|
||||
for (Operation *root : state.getPayloadOps(getRoot())) {
|
||||
if (failed(extension->findAllMatches(
|
||||
getPatternName().getLeafReference().getValue(), root, targets))) {
|
||||
emitDefiniteFailure()
|
||||
<< "could not find pattern '" << getPatternName() << "'";
|
||||
}
|
||||
}
|
||||
results.set(llvm::cast<OpResult>(getResult()), targets);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void transform::PDLMatchOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
onlyReadsHandle(getRoot(), effects);
|
||||
producesHandle(getMatched(), effects);
|
||||
onlyReadsPayload(effects);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReplicateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1776,37 +1615,9 @@ LogicalResult transform::SequenceOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Populate `effects` with transform dialect memory effects for the potential
|
||||
/// top-level operation. Such operations have recursive effects from nested
|
||||
/// operations. When they have an operand, we can additionally remap effects on
|
||||
/// the block argument to be effects on the operand.
|
||||
template <typename OpTy>
|
||||
static void getPotentialTopLevelEffects(
|
||||
OpTy operation, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
transform::onlyReadsHandle(operation->getOperands(), effects);
|
||||
transform::producesHandle(operation->getResults(), effects);
|
||||
|
||||
if (!operation.getRoot()) {
|
||||
for (Operation &op : *operation.getBodyBlock()) {
|
||||
auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
|
||||
if (!iface)
|
||||
continue;
|
||||
|
||||
SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
|
||||
iface.getEffects(effects);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Carry over all effects on arguments of the entry block as those on the
|
||||
// operands, this is the same value just remapped.
|
||||
remapArgumentEffects(*operation.getBodyBlock(), operation->getOperands(),
|
||||
effects);
|
||||
}
|
||||
|
||||
void transform::SequenceOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
getPotentialTopLevelEffects(*this, effects);
|
||||
getPotentialTopLevelEffects(effects);
|
||||
}
|
||||
|
||||
OperandRange transform::SequenceOp::getSuccessorEntryOperands(
|
||||
@ -1908,77 +1719,6 @@ void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
|
||||
buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WithPDLPatternsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
TransformOpInterface transformOp = nullptr;
|
||||
for (Operation &nested : getBody().front()) {
|
||||
if (!isa<pdl::PatternOp>(nested)) {
|
||||
transformOp = cast<TransformOpInterface>(nested);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
state.addExtension<PatternApplicatorExtension>(getOperation());
|
||||
auto guard = llvm::make_scope_exit(
|
||||
[&]() { state.removeExtension<PatternApplicatorExtension>(); });
|
||||
|
||||
auto scope = state.make_region_scope(getBody());
|
||||
if (failed(mapBlockArguments(state)))
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
return state.applyTransform(transformOp);
|
||||
}
|
||||
|
||||
void transform::WithPDLPatternsOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
getPotentialTopLevelEffects(*this, effects);
|
||||
}
|
||||
|
||||
LogicalResult transform::WithPDLPatternsOp::verify() {
|
||||
Block *body = getBodyBlock();
|
||||
Operation *topLevelOp = nullptr;
|
||||
for (Operation &op : body->getOperations()) {
|
||||
if (isa<pdl::PatternOp>(op))
|
||||
continue;
|
||||
|
||||
if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
|
||||
if (topLevelOp) {
|
||||
InFlightDiagnostic diag =
|
||||
emitOpError() << "expects only one non-pattern op in its body";
|
||||
diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
|
||||
diag.attachNote(op.getLoc()) << "second non-pattern op";
|
||||
return diag;
|
||||
}
|
||||
topLevelOp = &op;
|
||||
continue;
|
||||
}
|
||||
|
||||
InFlightDiagnostic diag =
|
||||
emitOpError()
|
||||
<< "expects only pattern and top-level transform ops in its body";
|
||||
diag.attachNote(op.getLoc()) << "offending op";
|
||||
return diag;
|
||||
}
|
||||
|
||||
if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
|
||||
InFlightDiagnostic diag = emitOpError() << "cannot be nested";
|
||||
diag.attachNote(parent.getLoc()) << "parent operation";
|
||||
return diag;
|
||||
}
|
||||
|
||||
if (!topLevelOp) {
|
||||
InFlightDiagnostic diag = emitOpError()
|
||||
<< "expects at least one non-pattern op";
|
||||
return diag;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PrintOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
13
mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt
Normal file
13
mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt
Normal file
@ -0,0 +1,13 @@
|
||||
add_mlir_dialect_library(MLIRTransformPDLExtension
|
||||
PDLExtension.cpp
|
||||
PDLExtensionOps.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRTransformDialectPDLExtensionOpsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransformDialect
|
||||
MLIRPDLDialect
|
||||
MLIRPDLInterpDialect
|
||||
)
|
69
mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp
Normal file
69
mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp
Normal file
@ -0,0 +1,69 @@
|
||||
//===- PDLExtension.cpp - PDL extension for the Transform dialect ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDL.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
|
||||
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
/// Implementation of the TransformHandleTypeInterface for the PDL
|
||||
/// OperationType. Accepts any payload operation.
|
||||
struct PDLOperationTypeTransformHandleTypeInterfaceImpl
|
||||
: public transform::TransformHandleTypeInterface::ExternalModel<
|
||||
PDLOperationTypeTransformHandleTypeInterfaceImpl,
|
||||
pdl::OperationType> {
|
||||
|
||||
/// Accept any operation.
|
||||
DiagnosedSilenceableFailure
|
||||
checkPayload(Type type, Location loc, ArrayRef<Operation *> payload) const {
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
/// PDL extension of the Transform dialect. This provides transform operations
|
||||
/// that connect to PDL matching as well as interfaces for PDL types to be used
|
||||
/// with Transform dialect operations.
|
||||
class PDLExtension : public transform::TransformDialectExtension<PDLExtension> {
|
||||
public:
|
||||
void init() {
|
||||
registerTransformOps<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
|
||||
>();
|
||||
|
||||
addDialectDataInitializer<transform::PDLMatchHooks>(
|
||||
[](transform::PDLMatchHooks &) {});
|
||||
|
||||
// Declare PDL as dependent so we can attach an interface to its type in the
|
||||
// later step.
|
||||
declareDependentDialect<pdl::PDLDialect>();
|
||||
|
||||
// PDLInterp is only relevant if we actually apply the transform IR so
|
||||
// declare it as generated.
|
||||
declareGeneratedDialect<pdl_interp::PDLInterpDialect>();
|
||||
|
||||
// Make PDL OperationType usable as a transform dialect type.
|
||||
addCustomInitializationStep([](MLIRContext *context) {
|
||||
pdl::OperationType::attachInterface<
|
||||
PDLOperationTypeTransformHandleTypeInterfaceImpl>(*context);
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::transform::registerPDLExtension(DialectRegistry &dialectRegistry) {
|
||||
dialectRegistry.addExtensions<PDLExtension>();
|
||||
}
|
234
mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
Normal file
234
mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp
Normal file
@ -0,0 +1,234 @@
|
||||
//===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDLOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
||||
#include "mlir/Rewrite/PatternApplicator.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks)
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternApplicatorExtension
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// A TransformState extension that keeps track of compiled PDL pattern sets.
|
||||
/// This is intended to be used along the WithPDLPatterns op. The extension
|
||||
/// can be constructed given an operation that has a SymbolTable trait and
|
||||
/// contains pdl::PatternOp instances. The patterns are compiled lazily and one
|
||||
/// by one when requested; this behavior is subject to change.
|
||||
class PatternApplicatorExtension : public transform::TransformState::Extension {
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
|
||||
|
||||
/// Creates the extension for patterns contained in `patternContainer`.
|
||||
explicit PatternApplicatorExtension(transform::TransformState &state,
|
||||
Operation *patternContainer)
|
||||
: Extension(state), patterns(patternContainer) {}
|
||||
|
||||
/// Appends to `results` the operations contained in `root` that matched the
|
||||
/// PDL pattern with the given name. Note that `root` may or may not be the
|
||||
/// operation that contains PDL patterns. Reports an error if the pattern
|
||||
/// cannot be found. Note that when no operations are matched, this still
|
||||
/// succeeds as long as the pattern exists.
|
||||
LogicalResult findAllMatches(StringRef patternName, Operation *root,
|
||||
SmallVectorImpl<Operation *> &results);
|
||||
|
||||
private:
|
||||
/// Map from the pattern name to a singleton set of rewrite patterns that only
|
||||
/// contains the pattern with this name. Populated when the pattern is first
|
||||
/// requested.
|
||||
// TODO: reconsider the efficiency of this storage when more usage data is
|
||||
// available. Storing individual patterns in a set and triggering compilation
|
||||
// for each of them has overhead. So does compiling a large set of patterns
|
||||
// only to apply a handful of them.
|
||||
llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
|
||||
|
||||
/// A symbol table operation containing the relevant PDL patterns.
|
||||
SymbolTable patterns;
|
||||
};
|
||||
|
||||
LogicalResult PatternApplicatorExtension::findAllMatches(
|
||||
StringRef patternName, Operation *root,
|
||||
SmallVectorImpl<Operation *> &results) {
|
||||
auto it = compiledPatterns.find(patternName);
|
||||
if (it == compiledPatterns.end()) {
|
||||
auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
|
||||
if (!patternOp)
|
||||
return failure();
|
||||
|
||||
// Copy the pattern operation into a new module that is compiled and
|
||||
// consumed by the PDL interpreter.
|
||||
OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
|
||||
auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
|
||||
builder.clone(*patternOp);
|
||||
PDLPatternModule patternModule(std::move(pdlModuleOp));
|
||||
|
||||
// Merge in the hooks owned by the dialect. Make a copy as they may be
|
||||
// also used by the following operations.
|
||||
auto *dialect =
|
||||
root->getContext()->getLoadedDialect<transform::TransformDialect>();
|
||||
for (const auto &[name, constraintFn] :
|
||||
dialect->getExtraData<transform::PDLMatchHooks>()
|
||||
.getPDLConstraintHooks()) {
|
||||
patternModule.registerConstraintFunction(name, constraintFn);
|
||||
}
|
||||
|
||||
// Register a noop rewriter because PDL requires patterns to end with some
|
||||
// rewrite call.
|
||||
patternModule.registerRewriteFunction(
|
||||
"transform.dialect", [](PatternRewriter &, Operation *) {});
|
||||
|
||||
it = compiledPatterns
|
||||
.try_emplace(patternOp.getName(), std::move(patternModule))
|
||||
.first;
|
||||
}
|
||||
|
||||
PatternApplicator applicator(it->second);
|
||||
// We want to discourage direct use of PatternRewriter in APIs but In this
|
||||
// very specific case, an IRRewriter is not enough.
|
||||
struct TrivialPatternRewriter : public PatternRewriter {
|
||||
public:
|
||||
explicit TrivialPatternRewriter(MLIRContext *context)
|
||||
: PatternRewriter(context) {}
|
||||
};
|
||||
TrivialPatternRewriter rewriter(root->getContext());
|
||||
applicator.applyDefaultCostModel();
|
||||
root->walk([&](Operation *op) {
|
||||
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
|
||||
results.push_back(op);
|
||||
});
|
||||
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PDLMatchHooks
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void transform::PDLMatchHooks::mergeInPDLMatchHooks(
|
||||
llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
|
||||
// Steal the constraint functions from the given map.
|
||||
for (auto &it : constraintFns)
|
||||
pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
|
||||
}
|
||||
|
||||
const llvm::StringMap<PDLConstraintFunction> &
|
||||
transform::PDLMatchHooks::getPDLConstraintHooks() const {
|
||||
return pdlMatchHooks.getConstraintFunctions();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PDLMatchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::PDLMatchOp::apply(transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
auto *extension = state.getExtension<PatternApplicatorExtension>();
|
||||
assert(extension &&
|
||||
"expected PatternApplicatorExtension to be attached by the parent op");
|
||||
SmallVector<Operation *> targets;
|
||||
for (Operation *root : state.getPayloadOps(getRoot())) {
|
||||
if (failed(extension->findAllMatches(
|
||||
getPatternName().getLeafReference().getValue(), root, targets))) {
|
||||
emitDefiniteFailure()
|
||||
<< "could not find pattern '" << getPatternName() << "'";
|
||||
}
|
||||
}
|
||||
results.set(llvm::cast<OpResult>(getResult()), targets);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void transform::PDLMatchOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
onlyReadsHandle(getRoot(), effects);
|
||||
producesHandle(getMatched(), effects);
|
||||
onlyReadsPayload(effects);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WithPDLPatternsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
TransformOpInterface transformOp = nullptr;
|
||||
for (Operation &nested : getBody().front()) {
|
||||
if (!isa<pdl::PatternOp>(nested)) {
|
||||
transformOp = cast<TransformOpInterface>(nested);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
state.addExtension<PatternApplicatorExtension>(getOperation());
|
||||
auto guard = llvm::make_scope_exit(
|
||||
[&]() { state.removeExtension<PatternApplicatorExtension>(); });
|
||||
|
||||
auto scope = state.make_region_scope(getBody());
|
||||
if (failed(mapBlockArguments(state)))
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
return state.applyTransform(transformOp);
|
||||
}
|
||||
|
||||
void transform::WithPDLPatternsOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
getPotentialTopLevelEffects(effects);
|
||||
}
|
||||
|
||||
LogicalResult transform::WithPDLPatternsOp::verify() {
|
||||
Block *body = getBodyBlock();
|
||||
Operation *topLevelOp = nullptr;
|
||||
for (Operation &op : body->getOperations()) {
|
||||
if (isa<pdl::PatternOp>(op))
|
||||
continue;
|
||||
|
||||
if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
|
||||
if (topLevelOp) {
|
||||
InFlightDiagnostic diag =
|
||||
emitOpError() << "expects only one non-pattern op in its body";
|
||||
diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
|
||||
diag.attachNote(op.getLoc()) << "second non-pattern op";
|
||||
return diag;
|
||||
}
|
||||
topLevelOp = &op;
|
||||
continue;
|
||||
}
|
||||
|
||||
InFlightDiagnostic diag =
|
||||
emitOpError()
|
||||
<< "expects only pattern and top-level transform ops in its body";
|
||||
diag.attachNote(op.getLoc()) << "offending op";
|
||||
return diag;
|
||||
}
|
||||
|
||||
if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
|
||||
InFlightDiagnostic diag = emitOpError() << "cannot be nested";
|
||||
diag.attachNote(parent.getLoc()) << "parent operation";
|
||||
return diag;
|
||||
}
|
||||
|
||||
if (!topLevelOp) {
|
||||
InFlightDiagnostic diag = emitOpError()
|
||||
<< "expects at least one non-pattern op";
|
||||
return diag;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
@ -114,6 +114,16 @@ declare_mlir_dialect_python_bindings(
|
||||
DIALECT_NAME linalg
|
||||
DEPENDS LinalgOdsGen)
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
TD_FILE dialects/TransformPDLExtensionOps.td
|
||||
SOURCES
|
||||
dialects/_transform_pdl_extension_ops_ext.py
|
||||
dialects/transform/pdl.py
|
||||
DIALECT_NAME transform
|
||||
EXTENSION_NAME transform_pdl_extension)
|
||||
|
||||
declare_mlir_dialect_python_bindings(
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects
|
||||
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
|
||||
|
20
mlir/python/mlir/dialects/TransformPDLExtensionOps.td
Normal file
20
mlir/python/mlir/dialects/TransformPDLExtensionOps.td
Normal file
@ -0,0 +1,20 @@
|
||||
//===-- TransformPDLExtensionOps.td - Binding entry point --*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Entry point of the generated Python bindings for the PDL extension of the
|
||||
// Transform dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS
|
||||
#define PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS
|
||||
|
||||
include "mlir/Bindings/Python/Attributes.td"
|
||||
include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td"
|
||||
|
||||
#endif // PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS
|
@ -60,26 +60,6 @@ class MergeHandlesOp:
|
||||
)
|
||||
|
||||
|
||||
class PDLMatchOp:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
pattern_name: Union[Attribute, str],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
pattern_name,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class ReplicateOp:
|
||||
|
||||
def __init__(
|
||||
@ -152,28 +132,6 @@ class SequenceOp:
|
||||
return self.body.arguments[1:]
|
||||
|
||||
|
||||
class WithPDLPatternsOp:
|
||||
|
||||
def __init__(self,
|
||||
target: Union[Operation, Value, Type],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
root = _get_op_result_or_value(target) if not isinstance(target,
|
||||
Type) else None
|
||||
root_type = target if isinstance(target, Type) else root.type
|
||||
super().__init__(root=root, loc=loc, ip=ip)
|
||||
self.regions[0].blocks.append(root_type)
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def bodyTarget(self) -> Value:
|
||||
return self.body.arguments[0]
|
||||
|
||||
|
||||
class YieldOp:
|
||||
|
||||
def __init__(
|
||||
|
@ -0,0 +1,55 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
from typing import Union
|
||||
|
||||
class PDLMatchOp:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
result_type: Type,
|
||||
target: Union[Operation, Value],
|
||||
pattern_name: Union[Attribute, str],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
pattern_name,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
class WithPDLPatternsOp:
|
||||
|
||||
def __init__(self,
|
||||
target: Union[Operation, Value, Type],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
root = _get_op_result_or_value(target) if not isinstance(target,
|
||||
Type) else None
|
||||
root_type = target if isinstance(target, Type) else root.type
|
||||
super().__init__(root=root, loc=loc, ip=ip)
|
||||
self.regions[0].blocks.append(root_type)
|
||||
|
||||
@property
|
||||
def body(self) -> Block:
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
@property
|
||||
def bodyTarget(self) -> Value:
|
||||
return self.body.arguments[0]
|
5
mlir/python/mlir/dialects/transform/pdl.py
Normal file
5
mlir/python/mlir/dialects/transform/pdl.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from .._transform_pdl_extension_ops_gen import *
|
@ -83,33 +83,6 @@ transform.sequence failures(propagate) {
|
||||
|
||||
// -----
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
sequence %arg0 : !transform.any_op failures(propagate) {
|
||||
^bb0(%arg1: !transform.any_op):
|
||||
%0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
test_print_remark_at_operand %0, "matched" : !transform.any_op
|
||||
}
|
||||
|
||||
pdl.pattern @some : benefit(1) {
|
||||
%0 = pdl.operation "test.some_op"
|
||||
pdl.rewrite %0 with "transform.dialect"
|
||||
}
|
||||
|
||||
pdl.pattern @other : benefit(1) {
|
||||
%0 = pdl.operation "test.other_op"
|
||||
pdl.rewrite %0 with "transform.dialect"
|
||||
}
|
||||
}
|
||||
|
||||
// expected-remark @below {{matched}}
|
||||
"test.some_op"() : () -> ()
|
||||
"test.other_op"() : () -> ()
|
||||
// expected-remark @below {{matched}}
|
||||
"test.some_op"() : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-remark @below {{parent function}}
|
||||
func.func @foo() {
|
||||
%0 = arith.constant 0 : i32
|
||||
|
47
mlir/test/Dialect/Transform/test-pdl-extension.mlir
Normal file
47
mlir/test/Dialect/Transform/test-pdl-extension.mlir
Normal file
@ -0,0 +1,47 @@
|
||||
// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
sequence %arg0 : !transform.any_op failures(propagate) {
|
||||
^bb0(%arg1: !transform.any_op):
|
||||
%0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
test_print_remark_at_operand %0, "matched" : !transform.any_op
|
||||
}
|
||||
|
||||
pdl.pattern @some : benefit(1) {
|
||||
%0 = pdl.operation "test.some_op"
|
||||
pdl.rewrite %0 with "transform.dialect"
|
||||
}
|
||||
|
||||
pdl.pattern @other : benefit(1) {
|
||||
%0 = pdl.operation "test.other_op"
|
||||
pdl.rewrite %0 with "transform.dialect"
|
||||
}
|
||||
}
|
||||
|
||||
// expected-remark @below {{matched}}
|
||||
"test.some_op"() : () -> ()
|
||||
"test.other_op"() : () -> ()
|
||||
// expected-remark @below {{matched}}
|
||||
"test.some_op"() : () -> ()
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
sequence %arg0 : !transform.any_op failures(propagate) {
|
||||
^bb1(%arg1: !transform.any_op):
|
||||
%0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
}
|
||||
|
||||
pdl.pattern @some : benefit(1) {
|
||||
%0 = pdl.operation "test.some_op"
|
||||
pdl.apply_native_constraint "verbose_constraint"(%0 : !pdl.operation)
|
||||
pdl.rewrite %0 with "transform.dialect"
|
||||
}
|
||||
}
|
||||
|
||||
// expected-warning @below {{from PDL constraint}}
|
||||
"test.some_op"() : () -> ()
|
||||
"test.other_op"() : () -> ()
|
@ -21,4 +21,5 @@ add_mlir_library(MLIRTestTransformDialect
|
||||
MLIRPDLDialect
|
||||
MLIRTransformDialect
|
||||
MLIRTransformDialectTransforms
|
||||
MLIRTransformPDLExtension
|
||||
)
|
||||
|
@ -17,7 +17,9 @@
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
||||
#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/Compiler.h"
|
||||
@ -754,6 +756,23 @@ public:
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "TestTransformDialectExtensionTypes.cpp.inc"
|
||||
>();
|
||||
|
||||
auto verboseConstraint = [](PatternRewriter &rewriter,
|
||||
ArrayRef<PDLValue> pdlValues) {
|
||||
for (const PDLValue &pdlValue : pdlValues) {
|
||||
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
|
||||
op->emitWarning() << "from PDL constraint";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
addDialectDataInitializer<transform::PDLMatchHooks>(
|
||||
[&](transform::PDLMatchHooks &hooks) {
|
||||
llvm::StringMap<PDLConstraintFunction> constraints;
|
||||
constraints.try_emplace("verbose_constraint", verboseConstraint);
|
||||
hooks.mergeInPDLMatchHooks(std::move(constraints));
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir.dialects import transform
|
||||
from mlir.dialects import pdl
|
||||
from mlir.dialects.transform import pdl as transform_pdl
|
||||
|
||||
|
||||
def run(f):
|
||||
@ -103,13 +103,13 @@ def testNestedSequenceOpWithExtras():
|
||||
|
||||
@run
|
||||
def testTransformPDLOps():
|
||||
withPdl = transform.WithPDLPatternsOp(transform.AnyOpType.get())
|
||||
withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
|
||||
with InsertionPoint(withPdl.body):
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
|
||||
[transform.AnyOpType.get()],
|
||||
withPdl.bodyTarget)
|
||||
with InsertionPoint(sequence.body):
|
||||
match = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher")
|
||||
match = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher")
|
||||
transform.YieldOp(match)
|
||||
# CHECK-LABEL: TEST: testTransformPDLOps
|
||||
# CHECK: transform.with_pdl_patterns {
|
||||
@ -148,13 +148,13 @@ def testMergeHandlesOp():
|
||||
|
||||
@run
|
||||
def testReplicateOp():
|
||||
with_pdl = transform.WithPDLPatternsOp(transform.AnyOpType.get())
|
||||
with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get())
|
||||
with InsertionPoint(with_pdl.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget)
|
||||
with InsertionPoint(sequence.body):
|
||||
m1 = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first")
|
||||
m2 = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second")
|
||||
m1 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first")
|
||||
m2 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second")
|
||||
transform.ReplicateOp(m1, [m2])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testReplicateOp
|
||||
|
@ -4,6 +4,7 @@ from mlir.ir import *
|
||||
from mlir.dialects import transform
|
||||
from mlir.dialects import pdl
|
||||
from mlir.dialects.transform import structured
|
||||
from mlir.dialects.transform import pdl as transform_pdl
|
||||
|
||||
|
||||
def run(f):
|
||||
@ -151,13 +152,13 @@ def testTileZero():
|
||||
|
||||
@run
|
||||
def testTileDynamic():
|
||||
with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get())
|
||||
with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
|
||||
with InsertionPoint(with_pdl.body):
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [],
|
||||
with_pdl.bodyTarget)
|
||||
with InsertionPoint(sequence.body):
|
||||
m1 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
|
||||
m2 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
|
||||
m1 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first")
|
||||
m2 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second")
|
||||
structured.TileOp(sequence.bodyTarget,
|
||||
sizes=[m1, 3, m2, 0])
|
||||
transform.YieldOp()
|
||||
|
@ -7495,6 +7495,7 @@ cc_library(
|
||||
":TosaToLinalg",
|
||||
":TransformDialect",
|
||||
":TransformDialectTransforms",
|
||||
":TransformPDLExtension",
|
||||
":Transforms",
|
||||
":TransformsPassIncGen",
|
||||
":VectorDialect",
|
||||
@ -9732,7 +9733,6 @@ td_library(
|
||||
":ControlFlowInterfacesTdFiles",
|
||||
":InferTypeOpInterfaceTdFiles",
|
||||
":OpBaseTdFiles",
|
||||
":PDLDialectTdFiles",
|
||||
":SideEffectInterfacesTdFiles",
|
||||
],
|
||||
)
|
||||
@ -9889,8 +9889,6 @@ cc_library(
|
||||
":CallOpInterfaces",
|
||||
":ControlFlowInterfaces",
|
||||
":IR",
|
||||
":PDLDialect",
|
||||
":PDLInterpDialect",
|
||||
":Rewrite",
|
||||
":SideEffectInterfaces",
|
||||
":Support",
|
||||
@ -9906,6 +9904,54 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
td_library(
|
||||
name = "TransformPDLExtensionTdFiles",
|
||||
srcs = glob(["include/mlir/Dialect/Transform/PDLExtension/*.td"]),
|
||||
deps = [
|
||||
":PDLDialectTdFiles",
|
||||
":TransformDialectTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl_cc_library(
|
||||
name = "TransformPDLExtensionOpsIncGen",
|
||||
strip_include_prefix = "include",
|
||||
tbl_outs = [
|
||||
(
|
||||
[
|
||||
"-gen-op-decls",
|
||||
],
|
||||
"include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h.inc",
|
||||
),
|
||||
(
|
||||
[
|
||||
"-gen-op-defs",
|
||||
],
|
||||
"include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":mlir-tblgen",
|
||||
td_file = "include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td",
|
||||
deps = [":TransformPDLExtensionTdFiles"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "TransformPDLExtension",
|
||||
srcs = glob(["lib/Dialect/Transform/PDLExtension/*.cpp"]),
|
||||
hdrs = glob(["include/mlir/Dialect/Transform/PDLExtension/*.h"]),
|
||||
deps = [
|
||||
":IR",
|
||||
":PDLDialect",
|
||||
":PDLInterpDialect",
|
||||
":SideEffectInterfaces",
|
||||
":Support",
|
||||
":TransformDialect",
|
||||
":TransformPDLExtensionOpsIncGen",
|
||||
":Rewrite",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
td_library(
|
||||
name = "TransformDialectTransformsTdFiles",
|
||||
srcs = glob(["include/mlir/Dialect/Transform/Transforms/*.td"]),
|
||||
|
@ -927,6 +927,26 @@ gentbl_filegroup(
|
||||
],
|
||||
)
|
||||
|
||||
gentbl_filegroup(
|
||||
name = "PDLTransformOpsPyGen",
|
||||
tbl_outs = [
|
||||
(
|
||||
[
|
||||
"-gen-python-op-bindings",
|
||||
"-bind-dialect=transform",
|
||||
"-dialect-extension=transform_pdl_extension",
|
||||
],
|
||||
"mlir/dialects/_transform_pdl_extension_ops_gen.py",
|
||||
),
|
||||
],
|
||||
tblgen = "//mlir:mlir-tblgen",
|
||||
td_file = "mlir/dialects/TransformPDLExtensionOps.td",
|
||||
deps = [
|
||||
":TransformOpsPyTdFiles",
|
||||
"//mlir:TransformPDLExtensionTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "TransformOpsPyFiles",
|
||||
srcs = [
|
||||
@ -934,6 +954,7 @@ filegroup(
|
||||
"mlir/dialects/_structured_transform_ops_ext.py",
|
||||
"mlir/dialects/_transform_ops_ext.py",
|
||||
":LoopTransformOpsPyGen",
|
||||
":PDLTransformOpsPyGen",
|
||||
":StructuredTransformOpsPyGen",
|
||||
":TransformOpsPyGen",
|
||||
],
|
||||
|
@ -317,6 +317,7 @@ gentbl_cc_library(
|
||||
":TransformDialectTdFiles",
|
||||
"//mlir:PDLDialectTdFiles",
|
||||
"//mlir:TransformDialectTdFiles",
|
||||
"//mlir:TransformPDLExtension",
|
||||
],
|
||||
)
|
||||
|
||||
@ -333,6 +334,7 @@ cc_library(
|
||||
"//mlir:Pass",
|
||||
"//mlir:TransformDialect",
|
||||
"//mlir:TransformDialectTransforms",
|
||||
"//mlir:TransformPDLExtension",
|
||||
],
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user