[mlir] Add config for PDL (#69927)

Make it so that PDL in pattern rewrites can be optionally disabled.

PDL is still enabled by default and not optional bazel. So this should
be a NOP for most folks, while enabling other to disable.

This only works with tests disabled. With tests enabled this still
compiles but tests fail as there is no lit config to disable tests that
depend on PDL rewrites yet.
This commit is contained in:
Jacques Pienaar 2024-01-03 20:37:19 -08:00
parent cda388c440
commit 6ae7f66ff5
27 changed files with 1327 additions and 1091 deletions

View File

@ -133,6 +133,8 @@ set(MLIR_ENABLE_NVPTXCOMPILER 0 CACHE BOOL
"Statically link the nvptxlibrary instead of calling ptxas as a subprocess \
for compiling PTX to cubin")
set(MLIR_ENABLE_PDL_IN_PATTERNMATCH 1 CACHE BOOL "Enable PDL in PatternMatch")
option(MLIR_INCLUDE_TESTS
"Generate build targets for the MLIR unit tests."
${LLVM_INCLUDE_TESTS})
@ -178,10 +180,9 @@ include_directories( ${MLIR_INCLUDE_DIR})
# Adding tools/mlir-tblgen here as calling add_tablegen sets some variables like
# MLIR_TABLEGEN_EXE in PARENT_SCOPE which gets lost if that folder is included
# from another directory like tools
add_subdirectory(tools/mlir-tblgen)
add_subdirectory(tools/mlir-linalg-ods-gen)
add_subdirectory(tools/mlir-pdll)
add_subdirectory(tools/mlir-tblgen)
set(MLIR_TABLEGEN_EXE "${MLIR_TABLEGEN_EXE}" CACHE INTERNAL "")
set(MLIR_TABLEGEN_TARGET "${MLIR_TABLEGEN_TARGET}" CACHE INTERNAL "")
set(MLIR_PDLL_TABLEGEN_EXE "${MLIR_PDLL_TABLEGEN_EXE}" CACHE INTERNAL "")

View File

@ -14,10 +14,10 @@ Below are some example measurements taken at the time of the LLVM 17 release,
using clang-14 on a X86 Ubuntu and [bloaty](https://github.com/google/bloaty).
| | Base | Os | Oz | Os LTO | Oz LTO |
| :-----------------------------: | ------ | ------ | ------ | ------ | ------ |
| `mlir-cat` | 1018kB | 836KB | 879KB | 697KB | 649KB |
| `mlir-minimal-opt` | 1.54MB | 1.25MB | 1.29MB | 1.10MB | 1.00MB |
| `mlir-minimal-opt-canonicalize` | 2.24MB | 1.81MB | 1.86MB | 1.62MB | 1.48MB |
| :------------------------------: | ------ | ------ | ------ | ------ | ------ |
| `mlir-cat` | 1024KB | 840KB | 885KB | 706KB | 657KB |
| `mlir-minimal-opt` | 1.62MB | 1.32MB | 1.36MB | 1.17MB | 1.07MB |
| `mlir-minimal-opt-canonicalize` | 1.83MB | 1.40MB | 1.45MB | 1.25MB | 1.14MB |
Base configuration:
@ -32,6 +32,7 @@ cmake ../llvm/ -G Ninja \
-DCMAKE_CXX_COMPILER=clang++ \
-DLLVM_ENABLE_LLD=ON \
-DLLVM_ENABLE_BACKTRACES=OFF \
-DMLIR_ENABLE_PDL_IN_PATTERNMATCH=OFF \
-DCMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=-Wl,-icf=all
```

View File

@ -26,4 +26,7 @@
numeric seed that is passed to the random number generator. */
#cmakedefine MLIR_GREEDY_REWRITE_RANDOMIZER_SEED ${MLIR_GREEDY_REWRITE_RANDOMIZER_SEED}
/* If set, enables PDL usage. */
#cmakedefine01 MLIR_ENABLE_PDL_IN_PATTERNMATCH
#endif

View File

@ -15,6 +15,7 @@
#define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {

View File

@ -29,6 +29,7 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringExtras.h"
// Pull in all enum type definitions and utility function declarations.

View File

@ -0,0 +1,995 @@
//===- PDLPatternMatch.h - PDLPatternMatcher classes -------==---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_PDLPATTERNMATCH_H
#define MLIR_IR_PDLPATTERNMATCH_H
#include "mlir/Config/mlir-config.h"
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// PDL Patterns
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// PDLValue
/// Storage type of byte-code interpreter values. These are passed to constraint
/// functions as arguments.
class PDLValue {
public:
/// The underlying kind of a PDL value.
enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
/// Construct a new PDL value.
PDLValue(const PDLValue &other) = default;
PDLValue(std::nullptr_t = nullptr) {}
PDLValue(Attribute value)
: value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
PDLValue(Value value)
: value(value.getAsOpaquePointer()), kind(Kind::Value) {}
PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
/// Returns true if the type of the held value is `T`.
template <typename T>
bool isa() const {
assert(value && "isa<> used on a null value");
return kind == getKindOf<T>();
}
/// Attempt to dynamically cast this value to type `T`, returns null if this
/// value is not an instance of `T`.
template <typename T,
typename ResultT = std::conditional_t<
std::is_convertible<T, bool>::value, T, std::optional<T>>>
ResultT dyn_cast() const {
return isa<T>() ? castImpl<T>() : ResultT();
}
/// Cast this value to type `T`, asserts if this value is not an instance of
/// `T`.
template <typename T>
T cast() const {
assert(isa<T>() && "expected value to be of type `T`");
return castImpl<T>();
}
/// Get an opaque pointer to the value.
const void *getAsOpaquePointer() const { return value; }
/// Return if this value is null or not.
explicit operator bool() const { return value; }
/// Return the kind of this value.
Kind getKind() const { return kind; }
/// Print this value to the provided output stream.
void print(raw_ostream &os) const;
/// Print the specified value kind to an output stream.
static void print(raw_ostream &os, Kind kind);
private:
/// Find the index of a given type in a range of other types.
template <typename...>
struct index_of_t;
template <typename T, typename... R>
struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
template <typename T, typename F, typename... R>
struct index_of_t<T, F, R...>
: std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
/// Return the kind used for the given T.
template <typename T>
static Kind getKindOf() {
return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
TypeRange, Value, ValueRange>::value);
}
/// The internal implementation of `cast`, that returns the underlying value
/// as the given type `T`.
template <typename T>
std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
castImpl() const {
return T::getFromOpaquePointer(value);
}
template <typename T>
std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
castImpl() const {
return *reinterpret_cast<T *>(const_cast<void *>(value));
}
template <typename T>
std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
return reinterpret_cast<T>(const_cast<void *>(value));
}
/// The internal opaque representation of a PDLValue.
const void *value{nullptr};
/// The kind of the opaque value.
Kind kind{Kind::Attribute};
};
inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
value.print(os);
return os;
}
inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
PDLValue::print(os, kind);
return os;
}
//===----------------------------------------------------------------------===//
// PDLResultList
/// The class represents a list of PDL results, returned by a native rewrite
/// method. It provides the mechanism with which to pass PDLValues back to the
/// PDL bytecode.
class PDLResultList {
public:
/// Push a new Attribute value onto the result list.
void push_back(Attribute value) { results.push_back(value); }
/// Push a new Operation onto the result list.
void push_back(Operation *value) { results.push_back(value); }
/// Push a new Type onto the result list.
void push_back(Type value) { results.push_back(value); }
/// Push a new TypeRange onto the result list.
void push_back(TypeRange value) {
// The lifetime of a TypeRange can't be guaranteed, so we'll need to
// allocate a storage for it.
llvm::OwningArrayRef<Type> storage(value.size());
llvm::copy(value, storage.begin());
allocatedTypeRanges.emplace_back(std::move(storage));
typeRanges.push_back(allocatedTypeRanges.back());
results.push_back(&typeRanges.back());
}
void push_back(ValueTypeRange<OperandRange> value) {
typeRanges.push_back(value);
results.push_back(&typeRanges.back());
}
void push_back(ValueTypeRange<ResultRange> value) {
typeRanges.push_back(value);
results.push_back(&typeRanges.back());
}
/// Push a new Value onto the result list.
void push_back(Value value) { results.push_back(value); }
/// Push a new ValueRange onto the result list.
void push_back(ValueRange value) {
// The lifetime of a ValueRange can't be guaranteed, so we'll need to
// allocate a storage for it.
llvm::OwningArrayRef<Value> storage(value.size());
llvm::copy(value, storage.begin());
allocatedValueRanges.emplace_back(std::move(storage));
valueRanges.push_back(allocatedValueRanges.back());
results.push_back(&valueRanges.back());
}
void push_back(OperandRange value) {
valueRanges.push_back(value);
results.push_back(&valueRanges.back());
}
void push_back(ResultRange value) {
valueRanges.push_back(value);
results.push_back(&valueRanges.back());
}
protected:
/// Create a new result list with the expected number of results.
PDLResultList(unsigned maxNumResults) {
// For now just reserve enough space for all of the results. We could do
// separate counts per range type, but it isn't really worth it unless there
// are a "large" number of results.
typeRanges.reserve(maxNumResults);
valueRanges.reserve(maxNumResults);
}
/// The PDL results held by this list.
SmallVector<PDLValue> results;
/// Memory used to store ranges held by the list.
SmallVector<TypeRange> typeRanges;
SmallVector<ValueRange> valueRanges;
/// Memory allocated to store ranges in the result list whose lifetime was
/// generated in the native function.
SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
};
//===----------------------------------------------------------------------===//
// PDLPatternConfig
/// An individual configuration for a pattern, which can be accessed by native
/// functions via the PDLPatternConfigSet. This allows for injecting additional
/// configuration into PDL patterns that is specific to certain compilation
/// flows.
class PDLPatternConfig {
public:
virtual ~PDLPatternConfig() = default;
/// Hooks that are invoked at the beginning and end of a rewrite of a matched
/// pattern. These can be used to setup any specific state necessary for the
/// rewrite.
virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
/// Return the TypeID that represents this configuration.
TypeID getTypeID() const { return id; }
protected:
PDLPatternConfig(TypeID id) : id(id) {}
private:
TypeID id;
};
/// This class provides a base class for users implementing a type of pattern
/// configuration.
template <typename T>
class PDLPatternConfigBase : public PDLPatternConfig {
public:
/// Support LLVM style casting.
static bool classof(const PDLPatternConfig *config) {
return config->getTypeID() == getConfigID();
}
/// Return the type id used for this configuration.
static TypeID getConfigID() { return TypeID::get<T>(); }
protected:
PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
};
/// This class contains a set of configurations for a specific pattern.
/// Configurations are uniqued by TypeID, meaning that only one configuration of
/// each type is allowed.
class PDLPatternConfigSet {
public:
PDLPatternConfigSet() = default;
/// Construct a set with the given configurations.
template <typename... ConfigsT>
PDLPatternConfigSet(ConfigsT &&...configs) {
(addConfig(std::forward<ConfigsT>(configs)), ...);
}
/// Get the configuration defined by the given type. Asserts that the
/// configuration of the provided type exists.
template <typename T>
const T &get() const {
const T *config = tryGet<T>();
assert(config && "configuration not found");
return *config;
}
/// Get the configuration defined by the given type, returns nullptr if the
/// configuration does not exist.
template <typename T>
const T *tryGet() const {
for (const auto &configIt : configs)
if (const T *config = dyn_cast<T>(configIt.get()))
return config;
return nullptr;
}
/// Notify the configurations within this set at the beginning or end of a
/// rewrite of a matched pattern.
void notifyRewriteBegin(PatternRewriter &rewriter) {
for (const auto &config : configs)
config->notifyRewriteBegin(rewriter);
}
void notifyRewriteEnd(PatternRewriter &rewriter) {
for (const auto &config : configs)
config->notifyRewriteEnd(rewriter);
}
protected:
/// Add a configuration to the set.
template <typename T>
void addConfig(T &&config) {
assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
configs.emplace_back(
std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
}
/// The set of configurations for this pattern. This uses a vector instead of
/// a map with the expectation that the number of configurations per set is
/// small (<= 1).
SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
};
//===----------------------------------------------------------------------===//
// PDLPatternModule
/// A generic PDL pattern constraint function. This function applies a
/// constraint to a given set of opaque PDLValue entities. Returns success if
/// the constraint successfully held, failure otherwise.
using PDLConstraintFunction =
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only
/// invoked when the corresponding match was successful. Returns failure if an
/// invariant of the rewrite was broken (certain rewriters may recover from
/// partial pattern application).
using PDLRewriteFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
namespace detail {
namespace pdl_function_builder {
/// A utility variable that always resolves to false. This is useful for static
/// asserts that are always false, but only should fire in certain templated
/// constructs. For example, if a templated function should never be called, the
/// function could be defined as:
///
/// template <typename T>
/// void foo() {
/// static_assert(always_false<T>, "This function should never be called");
/// }
///
template <class... T>
constexpr bool always_false = false;
//===----------------------------------------------------------------------===//
// PDL Function Builder: Type Processing
//===----------------------------------------------------------------------===//
/// This struct provides a convenient way to determine how to process a given
/// type as either a PDL parameter, or a result value. This allows for
/// supporting complex types in constraint and rewrite functions, without
/// requiring the user to hand-write the necessary glue code themselves.
/// Specializations of this class should implement the following methods to
/// enable support as a PDL argument or result type:
///
/// static LogicalResult verifyAsArg(
/// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
/// size_t argIdx);
///
/// * This method verifies that the given PDLValue is valid for use as a
/// value of `T`.
///
/// static T processAsArg(PDLValue pdlValue);
///
/// * This method processes the given PDLValue as a value of `T`.
///
/// static void processAsResult(PatternRewriter &, PDLResultList &results,
/// const T &value);
///
/// * This method processes the given value of `T` as the result of a
/// function invocation. The method should package the value into an
/// appropriate form and append it to the given result list.
///
/// If the type `T` is based on a higher order value, consider using
/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
/// the implementation.
///
template <typename T, typename Enable = void>
struct ProcessPDLValue;
/// This struct provides a simplified model for processing types that are based
/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
/// allows for building the necessary processing functions on top of the base
/// value instead of a PDLValue. Derived users should implement the following
/// (which subsume the ProcessPDLValue variants):
///
/// static LogicalResult verifyAsArg(
/// function_ref<LogicalResult(const Twine &)> errorFn,
/// const BaseT &baseValue, size_t argIdx);
///
/// * This method verifies that the given PDLValue is valid for use as a
/// value of `T`.
///
/// static T processAsArg(BaseT baseValue);
///
/// * This method processes the given base value as a value of `T`.
///
template <typename T, typename BaseT>
struct ProcessPDLValueBasedOn {
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
PDLValue pdlValue, size_t argIdx) {
// Verify the base class before continuing.
if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
return failure();
return ProcessPDLValue<T>::verifyAsArg(
errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
}
static T processAsArg(PDLValue pdlValue) {
return ProcessPDLValue<T>::processAsArg(
ProcessPDLValue<BaseT>::processAsArg(pdlValue));
}
/// Explicitly add the expected parent API to ensure the parent class
/// implements the necessary API (and doesn't implicitly inherit it from
/// somewhere else).
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
size_t argIdx) {
return success();
}
static T processAsArg(BaseT baseValue);
};
/// This struct provides a simplified model for processing types that have
/// "builtin" PDLValue support:
/// * Attribute, Operation *, Type, TypeRange, ValueRange
template <typename T>
struct ProcessBuiltinPDLValue {
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
PDLValue pdlValue, size_t argIdx) {
if (pdlValue)
return success();
return errorFn("expected a non-null value for argument " + Twine(argIdx) +
" of type: " + llvm::getTypeName<T>());
}
static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
static void processAsResult(PatternRewriter &, PDLResultList &results,
T value) {
results.push_back(value);
}
};
/// This struct provides a simplified model for processing types that inherit
/// from builtin PDLValue types. For example, derived attributes like
/// IntegerAttr, derived types like IntegerType, derived operations like
/// ModuleOp, Interfaces, etc.
template <typename T, typename BaseT>
struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
BaseT baseValue, size_t argIdx) {
return TypeSwitch<BaseT, LogicalResult>(baseValue)
.Case([&](T) { return success(); })
.Default([&](BaseT) {
return errorFn("expected argument " + Twine(argIdx) +
" to be of type: " + llvm::getTypeName<T>());
});
}
using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
static T processAsArg(BaseT baseValue) {
return baseValue.template cast<T>();
}
using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
static void processAsResult(PatternRewriter &, PDLResultList &results,
T value) {
results.push_back(value);
}
};
//===----------------------------------------------------------------------===//
// Attribute
template <>
struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
template <typename T>
struct ProcessPDLValue<T,
std::enable_if_t<std::is_base_of<Attribute, T>::value>>
: public ProcessDerivedPDLValue<T, Attribute> {};
/// Handling for various Attribute value types.
template <>
struct ProcessPDLValue<StringRef>
: public ProcessPDLValueBasedOn<StringRef, StringAttr> {
static StringRef processAsArg(StringAttr value) { return value.getValue(); }
using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
StringRef value) {
results.push_back(rewriter.getStringAttr(value));
}
};
template <>
struct ProcessPDLValue<std::string>
: public ProcessPDLValueBasedOn<std::string, StringAttr> {
template <typename T>
static std::string processAsArg(T value) {
static_assert(always_false<T>,
"`std::string` arguments require a string copy, use "
"`StringRef` for string-like arguments instead");
return {};
}
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
StringRef value) {
results.push_back(rewriter.getStringAttr(value));
}
};
//===----------------------------------------------------------------------===//
// Operation
template <>
struct ProcessPDLValue<Operation *>
: public ProcessBuiltinPDLValue<Operation *> {};
template <typename T>
struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
: public ProcessDerivedPDLValue<T, Operation *> {
static T processAsArg(Operation *value) { return cast<T>(value); }
};
//===----------------------------------------------------------------------===//
// Type
template <>
struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
template <typename T>
struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
: public ProcessDerivedPDLValue<T, Type> {};
//===----------------------------------------------------------------------===//
// TypeRange
template <>
struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
template <>
struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
ValueTypeRange<OperandRange> types) {
results.push_back(types);
}
};
template <>
struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
ValueTypeRange<ResultRange> types) {
results.push_back(types);
}
};
template <unsigned N>
struct ProcessPDLValue<SmallVector<Type, N>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
SmallVector<Type, N> values) {
results.push_back(TypeRange(values));
}
};
//===----------------------------------------------------------------------===//
// Value
template <>
struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
//===----------------------------------------------------------------------===//
// ValueRange
template <>
struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
};
template <>
struct ProcessPDLValue<OperandRange> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
OperandRange values) {
results.push_back(values);
}
};
template <>
struct ProcessPDLValue<ResultRange> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
ResultRange values) {
results.push_back(values);
}
};
template <unsigned N>
struct ProcessPDLValue<SmallVector<Value, N>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
SmallVector<Value, N> values) {
results.push_back(ValueRange(values));
}
};
//===----------------------------------------------------------------------===//
// PDL Function Builder: Argument Handling
//===----------------------------------------------------------------------===//
/// Validate the given PDLValues match the constraints defined by the argument
/// types of the given function. In the case of failure, a match failure
/// diagnostic is emitted.
/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
/// does not currently preserve Constraint application ordering.
template <typename PDLFnT, std::size_t... I>
LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
using FnTraitsT = llvm::function_traits<PDLFnT>;
auto errorFn = [&](const Twine &msg) {
return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
};
return success(
(succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
verifyAsArg(errorFn, values[I], I)) &&
...));
}
/// Assert that the given PDLValues match the constraints defined by the
/// arguments of the given function. In the case of failure, a fatal error
/// is emitted.
template <typename PDLFnT, std::size_t... I>
void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
// We only want to do verification in debug builds, same as with `assert`.
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
using FnTraitsT = llvm::function_traits<PDLFnT>;
auto errorFn = [&](const Twine &msg) -> LogicalResult {
llvm::report_fatal_error(msg);
};
(void)errorFn;
assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
verifyAsArg(errorFn, values[I], I)) &&
...));
#endif
(void)values;
}
//===----------------------------------------------------------------------===//
// PDL Function Builder: Results Handling
//===----------------------------------------------------------------------===//
/// Store a single result within the result list.
template <typename T>
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results, T &&value) {
ProcessPDLValue<T>::processAsResult(rewriter, results,
std::forward<T>(value));
return success();
}
/// Store a std::pair<> as individual results within the result list.
template <typename T1, typename T2>
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
std::pair<T1, T2> &&pair) {
if (failed(processResults(rewriter, results, std::move(pair.first))) ||
failed(processResults(rewriter, results, std::move(pair.second))))
return failure();
return success();
}
/// Store a std::tuple<> as individual results within the result list.
template <typename... Ts>
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
std::tuple<Ts...> &&tuple) {
auto applyFn = [&](auto &&...args) {
return (succeeded(processResults(rewriter, results, std::move(args))) &&
...);
};
return success(std::apply(applyFn, std::move(tuple)));
}
/// Handle LogicalResult propagation.
inline LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
LogicalResult &&result) {
return result;
}
template <typename T>
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
FailureOr<T> &&result) {
if (failed(result))
return failure();
return processResults(rewriter, results, std::move(*result));
}
//===----------------------------------------------------------------------===//
// PDL Constraint Builder
//===----------------------------------------------------------------------===//
/// Process the arguments of a native constraint and invoke it.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
typename FnTraitsT::result_t
processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
return fn(
rewriter,
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
values[I]))...);
}
/// Build a constraint function from the given function `ConstraintFnT`. This
/// allows for enabling the user to define simpler, more direct constraint
/// functions without needing to handle the low-level PDL goop.
///
/// If the constraint function is already in the correct form, we just forward
/// it directly.
template <typename ConstraintFnT>
std::enable_if_t<
std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {
return std::forward<ConstraintFnT>(constraintFn);
}
/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
/// we desire.
template <typename ConstraintFnT>
std::enable_if_t<
!std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
PatternRewriter &rewriter,
ArrayRef<PDLValue> values) -> LogicalResult {
auto argIndices = std::make_index_sequence<
llvm::function_traits<ConstraintFnT>::num_args - 1>();
if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
return failure();
return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
argIndices);
};
}
//===----------------------------------------------------------------------===//
// PDL Rewrite Builder
//===----------------------------------------------------------------------===//
/// Process the arguments of a native rewrite and invoke it.
/// This overload handles the case of no return values.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
fn(rewriter,
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
values[I]))...);
return success();
}
/// This overload handles the case of return values, which need to be packaged
/// into the result list.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &results, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
return processResults(
rewriter, results,
fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
processAsArg(values[I]))...));
(void)values;
}
/// Build a rewrite function from the given function `RewriteFnT`. This
/// allows for enabling the user to define simpler, more direct rewrite
/// functions without needing to handle the low-level PDL goop.
///
/// If the rewrite function is already in the correct form, we just forward
/// it directly.
template <typename RewriteFnT>
std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
PDLRewriteFunction>
buildRewriteFn(RewriteFnT &&rewriteFn) {
return std::forward<RewriteFnT>(rewriteFn);
}
/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
/// we desire.
template <typename RewriteFnT>
std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
PDLRewriteFunction>
buildRewriteFn(RewriteFnT &&rewriteFn) {
return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> values) {
auto argIndices =
std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1>();
assertArgs<RewriteFnT>(rewriter, values, argIndices);
return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
argIndices);
};
}
} // namespace pdl_function_builder
} // namespace detail
//===----------------------------------------------------------------------===//
// PDLPatternModule
/// This class contains all of the necessary data for a set of PDL patterns, or
/// pattern rewrites specified in the form of the PDL dialect. This PDL module
/// contained by this pattern may contain any number of `pdl.pattern`
/// operations.
class PDLPatternModule {
public:
PDLPatternModule() = default;
/// Construct a PDL pattern with the given module and configurations.
PDLPatternModule(OwningOpRef<ModuleOp> module)
: pdlModule(std::move(module)) {}
template <typename... ConfigsT>
PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
: PDLPatternModule(std::move(module)) {
auto configSet = std::make_unique<PDLPatternConfigSet>(
std::forward<ConfigsT>(patternConfigs)...);
attachConfigToPatterns(*pdlModule, *configSet);
configs.emplace_back(std::move(configSet));
}
/// Merge the state in `other` into this pattern module.
void mergeIn(PDLPatternModule &&other);
/// Return the internal PDL module of this pattern.
ModuleOp getModule() { return pdlModule.get(); }
/// Return the MLIR context of this pattern.
MLIRContext *getContext() { return getModule()->getContext(); }
//===--------------------------------------------------------------------===//
// Function Registry
/// Register a constraint function with PDL. A constraint function may be
/// specified in one of two ways:
///
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
///
/// In this overload the arguments of the constraint function are passed via
/// the low-level PDLValue form.
///
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
///
/// In this form the arguments of the constraint function are passed via the
/// expected high level C++ type. In this form, the framework will
/// automatically unwrap PDLValues and convert them to the expected ValueTs.
/// For example, if the constraint function accepts a `Operation *`, the
/// framework will automatically cast the input PDLValue. In the case of a
/// `StringRef`, the framework will automatically unwrap the argument as a
/// StringAttr and pass the underlying string value. To see the full list of
/// supported types, or to see how to add handling for custom types, view
/// the definition of `ProcessPDLValue` above.
void registerConstraintFunction(StringRef name,
PDLConstraintFunction constraintFn);
template <typename ConstraintFnT>
void registerConstraintFunction(StringRef name,
ConstraintFnT &&constraintFn) {
registerConstraintFunction(name,
detail::pdl_function_builder::buildConstraintFn(
std::forward<ConstraintFnT>(constraintFn)));
}
/// Register a rewrite function with PDL. A rewrite function may be specified
/// in one of two ways:
///
/// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
///
/// In this overload the arguments of the constraint function are passed via
/// the low-level PDLValue form, and the results are manually appended to
/// the given result list.
///
/// * `ResultT (PatternRewriter &, ValueTs... values)`
///
/// In this form the arguments and result of the rewrite function are passed
/// via the expected high level C++ type. In this form, the framework will
/// automatically unwrap the PDLValues arguments and convert them to the
/// expected ValueTs. It will also automatically handle the processing and
/// packaging of the result value to the result list. For example, if the
/// rewrite function takes a `Operation *`, the framework will automatically
/// cast the input PDLValue. In the case of a `StringRef`, the framework
/// will automatically unwrap the argument as a StringAttr and pass the
/// underlying string value. In the reverse case, if the rewrite returns a
/// StringRef or std::string, it will automatically package this as a
/// StringAttr and append it to the result list. To see the full list of
/// supported types, or to see how to add handling for custom types, view
/// the definition of `ProcessPDLValue` above.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
template <typename RewriteFnT>
void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
std::forward<RewriteFnT>(rewriteFn)));
}
/// Return the set of the registered constraint functions.
const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
return constraintFunctions;
}
llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
return constraintFunctions;
}
/// Return the set of the registered rewrite functions.
const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
return rewriteFunctions;
}
llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
return rewriteFunctions;
}
/// Return the set of the registered pattern configs.
SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
return std::move(configs);
}
DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
return std::move(configMap);
}
/// Clear out the patterns and functions within this module.
void clear() {
pdlModule = nullptr;
constraintFunctions.clear();
rewriteFunctions.clear();
}
private:
/// Attach the given pattern config set to the patterns defined within the
/// given module.
void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
/// The module containing the `pdl.pattern` operations.
OwningOpRef<ModuleOp> pdlModule;
/// The set of configuration sets referenced by patterns within `pdlModule`.
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
DenseMap<Operation *, PDLPatternConfigSet *> configMap;
/// The external functions referenced from within the PDL module.
llvm::StringMap<PDLConstraintFunction> constraintFunctions;
llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
};
} // namespace mlir
#else
namespace mlir {
// Stubs for when PDL in pattern rewrites is not enabled.
class PDLValue {
public:
template <typename T>
T dyn_cast() const {
return nullptr;
}
};
class PDLResultList {};
using PDLConstraintFunction =
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
using PDLRewriteFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
class PDLPatternModule {
public:
PDLPatternModule() = default;
PDLPatternModule(OwningOpRef<ModuleOp> /*module*/) {}
MLIRContext *getContext() {
llvm_unreachable("Error: PDL for rewrites when PDL is not enabled");
}
void mergeIn(PDLPatternModule &&other) {}
void clear() {}
template <typename ConstraintFnT>
void registerConstraintFunction(StringRef name,
ConstraintFnT &&constraintFn) {}
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) {}
template <typename RewriteFnT>
void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {}
const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
return constraintFunctions;
}
private:
llvm::StringMap<PDLConstraintFunction> constraintFunctions;
};
} // namespace mlir
#endif
#endif // MLIR_IR_PDLPATTERNMATCH_H

View File

@ -735,932 +735,12 @@ public:
virtual bool canRecoverFromRewriteFailure() const { return false; }
};
//===----------------------------------------------------------------------===//
// PDL Patterns
//===----------------------------------------------------------------------===//
} // namespace mlir
//===----------------------------------------------------------------------===//
// PDLValue
// Optionally expose PDL pattern matching methods.
#include "PDLPatternMatch.h.inc"
/// Storage type of byte-code interpreter values. These are passed to constraint
/// functions as arguments.
class PDLValue {
public:
/// The underlying kind of a PDL value.
enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
/// Construct a new PDL value.
PDLValue(const PDLValue &other) = default;
PDLValue(std::nullptr_t = nullptr) {}
PDLValue(Attribute value)
: value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
PDLValue(Value value)
: value(value.getAsOpaquePointer()), kind(Kind::Value) {}
PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
/// Returns true if the type of the held value is `T`.
template <typename T>
bool isa() const {
assert(value && "isa<> used on a null value");
return kind == getKindOf<T>();
}
/// Attempt to dynamically cast this value to type `T`, returns null if this
/// value is not an instance of `T`.
template <typename T,
typename ResultT = std::conditional_t<
std::is_convertible<T, bool>::value, T, std::optional<T>>>
ResultT dyn_cast() const {
return isa<T>() ? castImpl<T>() : ResultT();
}
/// Cast this value to type `T`, asserts if this value is not an instance of
/// `T`.
template <typename T>
T cast() const {
assert(isa<T>() && "expected value to be of type `T`");
return castImpl<T>();
}
/// Get an opaque pointer to the value.
const void *getAsOpaquePointer() const { return value; }
/// Return if this value is null or not.
explicit operator bool() const { return value; }
/// Return the kind of this value.
Kind getKind() const { return kind; }
/// Print this value to the provided output stream.
void print(raw_ostream &os) const;
/// Print the specified value kind to an output stream.
static void print(raw_ostream &os, Kind kind);
private:
/// Find the index of a given type in a range of other types.
template <typename...>
struct index_of_t;
template <typename T, typename... R>
struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
template <typename T, typename F, typename... R>
struct index_of_t<T, F, R...>
: std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
/// Return the kind used for the given T.
template <typename T>
static Kind getKindOf() {
return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
TypeRange, Value, ValueRange>::value);
}
/// The internal implementation of `cast`, that returns the underlying value
/// as the given type `T`.
template <typename T>
std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
castImpl() const {
return T::getFromOpaquePointer(value);
}
template <typename T>
std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
castImpl() const {
return *reinterpret_cast<T *>(const_cast<void *>(value));
}
template <typename T>
std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
return reinterpret_cast<T>(const_cast<void *>(value));
}
/// The internal opaque representation of a PDLValue.
const void *value{nullptr};
/// The kind of the opaque value.
Kind kind{Kind::Attribute};
};
inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
value.print(os);
return os;
}
inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
PDLValue::print(os, kind);
return os;
}
//===----------------------------------------------------------------------===//
// PDLResultList
/// The class represents a list of PDL results, returned by a native rewrite
/// method. It provides the mechanism with which to pass PDLValues back to the
/// PDL bytecode.
class PDLResultList {
public:
/// Push a new Attribute value onto the result list.
void push_back(Attribute value) { results.push_back(value); }
/// Push a new Operation onto the result list.
void push_back(Operation *value) { results.push_back(value); }
/// Push a new Type onto the result list.
void push_back(Type value) { results.push_back(value); }
/// Push a new TypeRange onto the result list.
void push_back(TypeRange value) {
// The lifetime of a TypeRange can't be guaranteed, so we'll need to
// allocate a storage for it.
llvm::OwningArrayRef<Type> storage(value.size());
llvm::copy(value, storage.begin());
allocatedTypeRanges.emplace_back(std::move(storage));
typeRanges.push_back(allocatedTypeRanges.back());
results.push_back(&typeRanges.back());
}
void push_back(ValueTypeRange<OperandRange> value) {
typeRanges.push_back(value);
results.push_back(&typeRanges.back());
}
void push_back(ValueTypeRange<ResultRange> value) {
typeRanges.push_back(value);
results.push_back(&typeRanges.back());
}
/// Push a new Value onto the result list.
void push_back(Value value) { results.push_back(value); }
/// Push a new ValueRange onto the result list.
void push_back(ValueRange value) {
// The lifetime of a ValueRange can't be guaranteed, so we'll need to
// allocate a storage for it.
llvm::OwningArrayRef<Value> storage(value.size());
llvm::copy(value, storage.begin());
allocatedValueRanges.emplace_back(std::move(storage));
valueRanges.push_back(allocatedValueRanges.back());
results.push_back(&valueRanges.back());
}
void push_back(OperandRange value) {
valueRanges.push_back(value);
results.push_back(&valueRanges.back());
}
void push_back(ResultRange value) {
valueRanges.push_back(value);
results.push_back(&valueRanges.back());
}
protected:
/// Create a new result list with the expected number of results.
PDLResultList(unsigned maxNumResults) {
// For now just reserve enough space for all of the results. We could do
// separate counts per range type, but it isn't really worth it unless there
// are a "large" number of results.
typeRanges.reserve(maxNumResults);
valueRanges.reserve(maxNumResults);
}
/// The PDL results held by this list.
SmallVector<PDLValue> results;
/// Memory used to store ranges held by the list.
SmallVector<TypeRange> typeRanges;
SmallVector<ValueRange> valueRanges;
/// Memory allocated to store ranges in the result list whose lifetime was
/// generated in the native function.
SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
};
//===----------------------------------------------------------------------===//
// PDLPatternConfig
/// An individual configuration for a pattern, which can be accessed by native
/// functions via the PDLPatternConfigSet. This allows for injecting additional
/// configuration into PDL patterns that is specific to certain compilation
/// flows.
class PDLPatternConfig {
public:
virtual ~PDLPatternConfig() = default;
/// Hooks that are invoked at the beginning and end of a rewrite of a matched
/// pattern. These can be used to setup any specific state necessary for the
/// rewrite.
virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
/// Return the TypeID that represents this configuration.
TypeID getTypeID() const { return id; }
protected:
PDLPatternConfig(TypeID id) : id(id) {}
private:
TypeID id;
};
/// This class provides a base class for users implementing a type of pattern
/// configuration.
template <typename T>
class PDLPatternConfigBase : public PDLPatternConfig {
public:
/// Support LLVM style casting.
static bool classof(const PDLPatternConfig *config) {
return config->getTypeID() == getConfigID();
}
/// Return the type id used for this configuration.
static TypeID getConfigID() { return TypeID::get<T>(); }
protected:
PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
};
/// This class contains a set of configurations for a specific pattern.
/// Configurations are uniqued by TypeID, meaning that only one configuration of
/// each type is allowed.
class PDLPatternConfigSet {
public:
PDLPatternConfigSet() = default;
/// Construct a set with the given configurations.
template <typename... ConfigsT>
PDLPatternConfigSet(ConfigsT &&...configs) {
(addConfig(std::forward<ConfigsT>(configs)), ...);
}
/// Get the configuration defined by the given type. Asserts that the
/// configuration of the provided type exists.
template <typename T>
const T &get() const {
const T *config = tryGet<T>();
assert(config && "configuration not found");
return *config;
}
/// Get the configuration defined by the given type, returns nullptr if the
/// configuration does not exist.
template <typename T>
const T *tryGet() const {
for (const auto &configIt : configs)
if (const T *config = dyn_cast<T>(configIt.get()))
return config;
return nullptr;
}
/// Notify the configurations within this set at the beginning or end of a
/// rewrite of a matched pattern.
void notifyRewriteBegin(PatternRewriter &rewriter) {
for (const auto &config : configs)
config->notifyRewriteBegin(rewriter);
}
void notifyRewriteEnd(PatternRewriter &rewriter) {
for (const auto &config : configs)
config->notifyRewriteEnd(rewriter);
}
protected:
/// Add a configuration to the set.
template <typename T>
void addConfig(T &&config) {
assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
configs.emplace_back(
std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
}
/// The set of configurations for this pattern. This uses a vector instead of
/// a map with the expectation that the number of configurations per set is
/// small (<= 1).
SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
};
//===----------------------------------------------------------------------===//
// PDLPatternModule
/// A generic PDL pattern constraint function. This function applies a
/// constraint to a given set of opaque PDLValue entities. Returns success if
/// the constraint successfully held, failure otherwise.
using PDLConstraintFunction =
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only
/// invoked when the corresponding match was successful. Returns failure if an
/// invariant of the rewrite was broken (certain rewriters may recover from
/// partial pattern application).
using PDLRewriteFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
namespace detail {
namespace pdl_function_builder {
/// A utility variable that always resolves to false. This is useful for static
/// asserts that are always false, but only should fire in certain templated
/// constructs. For example, if a templated function should never be called, the
/// function could be defined as:
///
/// template <typename T>
/// void foo() {
/// static_assert(always_false<T>, "This function should never be called");
/// }
///
template <class... T>
constexpr bool always_false = false;
//===----------------------------------------------------------------------===//
// PDL Function Builder: Type Processing
//===----------------------------------------------------------------------===//
/// This struct provides a convenient way to determine how to process a given
/// type as either a PDL parameter, or a result value. This allows for
/// supporting complex types in constraint and rewrite functions, without
/// requiring the user to hand-write the necessary glue code themselves.
/// Specializations of this class should implement the following methods to
/// enable support as a PDL argument or result type:
///
/// static LogicalResult verifyAsArg(
/// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
/// size_t argIdx);
///
/// * This method verifies that the given PDLValue is valid for use as a
/// value of `T`.
///
/// static T processAsArg(PDLValue pdlValue);
///
/// * This method processes the given PDLValue as a value of `T`.
///
/// static void processAsResult(PatternRewriter &, PDLResultList &results,
/// const T &value);
///
/// * This method processes the given value of `T` as the result of a
/// function invocation. The method should package the value into an
/// appropriate form and append it to the given result list.
///
/// If the type `T` is based on a higher order value, consider using
/// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
/// the implementation.
///
template <typename T, typename Enable = void>
struct ProcessPDLValue;
/// This struct provides a simplified model for processing types that are based
/// on another type, e.g. APInt is based on the handling for IntegerAttr. This
/// allows for building the necessary processing functions on top of the base
/// value instead of a PDLValue. Derived users should implement the following
/// (which subsume the ProcessPDLValue variants):
///
/// static LogicalResult verifyAsArg(
/// function_ref<LogicalResult(const Twine &)> errorFn,
/// const BaseT &baseValue, size_t argIdx);
///
/// * This method verifies that the given PDLValue is valid for use as a
/// value of `T`.
///
/// static T processAsArg(BaseT baseValue);
///
/// * This method processes the given base value as a value of `T`.
///
template <typename T, typename BaseT>
struct ProcessPDLValueBasedOn {
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
PDLValue pdlValue, size_t argIdx) {
// Verify the base class before continuing.
if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
return failure();
return ProcessPDLValue<T>::verifyAsArg(
errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
}
static T processAsArg(PDLValue pdlValue) {
return ProcessPDLValue<T>::processAsArg(
ProcessPDLValue<BaseT>::processAsArg(pdlValue));
}
/// Explicitly add the expected parent API to ensure the parent class
/// implements the necessary API (and doesn't implicitly inherit it from
/// somewhere else).
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
size_t argIdx) {
return success();
}
static T processAsArg(BaseT baseValue);
};
/// This struct provides a simplified model for processing types that have
/// "builtin" PDLValue support:
/// * Attribute, Operation *, Type, TypeRange, ValueRange
template <typename T>
struct ProcessBuiltinPDLValue {
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
PDLValue pdlValue, size_t argIdx) {
if (pdlValue)
return success();
return errorFn("expected a non-null value for argument " + Twine(argIdx) +
" of type: " + llvm::getTypeName<T>());
}
static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
static void processAsResult(PatternRewriter &, PDLResultList &results,
T value) {
results.push_back(value);
}
};
/// This struct provides a simplified model for processing types that inherit
/// from builtin PDLValue types. For example, derived attributes like
/// IntegerAttr, derived types like IntegerType, derived operations like
/// ModuleOp, Interfaces, etc.
template <typename T, typename BaseT>
struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
static LogicalResult
verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
BaseT baseValue, size_t argIdx) {
return TypeSwitch<BaseT, LogicalResult>(baseValue)
.Case([&](T) { return success(); })
.Default([&](BaseT) {
return errorFn("expected argument " + Twine(argIdx) +
" to be of type: " + llvm::getTypeName<T>());
});
}
using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
static T processAsArg(BaseT baseValue) {
return baseValue.template cast<T>();
}
using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
static void processAsResult(PatternRewriter &, PDLResultList &results,
T value) {
results.push_back(value);
}
};
//===----------------------------------------------------------------------===//
// Attribute
template <>
struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
template <typename T>
struct ProcessPDLValue<T,
std::enable_if_t<std::is_base_of<Attribute, T>::value>>
: public ProcessDerivedPDLValue<T, Attribute> {};
/// Handling for various Attribute value types.
template <>
struct ProcessPDLValue<StringRef>
: public ProcessPDLValueBasedOn<StringRef, StringAttr> {
static StringRef processAsArg(StringAttr value) { return value.getValue(); }
using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
StringRef value) {
results.push_back(rewriter.getStringAttr(value));
}
};
template <>
struct ProcessPDLValue<std::string>
: public ProcessPDLValueBasedOn<std::string, StringAttr> {
template <typename T>
static std::string processAsArg(T value) {
static_assert(always_false<T>,
"`std::string` arguments require a string copy, use "
"`StringRef` for string-like arguments instead");
return {};
}
static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
StringRef value) {
results.push_back(rewriter.getStringAttr(value));
}
};
//===----------------------------------------------------------------------===//
// Operation
template <>
struct ProcessPDLValue<Operation *>
: public ProcessBuiltinPDLValue<Operation *> {};
template <typename T>
struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
: public ProcessDerivedPDLValue<T, Operation *> {
static T processAsArg(Operation *value) { return cast<T>(value); }
};
//===----------------------------------------------------------------------===//
// Type
template <>
struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
template <typename T>
struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
: public ProcessDerivedPDLValue<T, Type> {};
//===----------------------------------------------------------------------===//
// TypeRange
template <>
struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
template <>
struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
ValueTypeRange<OperandRange> types) {
results.push_back(types);
}
};
template <>
struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
ValueTypeRange<ResultRange> types) {
results.push_back(types);
}
};
template <unsigned N>
struct ProcessPDLValue<SmallVector<Type, N>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
SmallVector<Type, N> values) {
results.push_back(TypeRange(values));
}
};
//===----------------------------------------------------------------------===//
// Value
template <>
struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
//===----------------------------------------------------------------------===//
// ValueRange
template <>
struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
};
template <>
struct ProcessPDLValue<OperandRange> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
OperandRange values) {
results.push_back(values);
}
};
template <>
struct ProcessPDLValue<ResultRange> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
ResultRange values) {
results.push_back(values);
}
};
template <unsigned N>
struct ProcessPDLValue<SmallVector<Value, N>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
SmallVector<Value, N> values) {
results.push_back(ValueRange(values));
}
};
//===----------------------------------------------------------------------===//
// PDL Function Builder: Argument Handling
//===----------------------------------------------------------------------===//
/// Validate the given PDLValues match the constraints defined by the argument
/// types of the given function. In the case of failure, a match failure
/// diagnostic is emitted.
/// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
/// does not currently preserve Constraint application ordering.
template <typename PDLFnT, std::size_t... I>
LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
using FnTraitsT = llvm::function_traits<PDLFnT>;
auto errorFn = [&](const Twine &msg) {
return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
};
return success(
(succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
verifyAsArg(errorFn, values[I], I)) &&
...));
}
/// Assert that the given PDLValues match the constraints defined by the
/// arguments of the given function. In the case of failure, a fatal error
/// is emitted.
template <typename PDLFnT, std::size_t... I>
void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
// We only want to do verification in debug builds, same as with `assert`.
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
using FnTraitsT = llvm::function_traits<PDLFnT>;
auto errorFn = [&](const Twine &msg) -> LogicalResult {
llvm::report_fatal_error(msg);
};
(void)errorFn;
assert((succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
verifyAsArg(errorFn, values[I], I)) &&
...));
#endif
(void)values;
}
//===----------------------------------------------------------------------===//
// PDL Function Builder: Results Handling
//===----------------------------------------------------------------------===//
/// Store a single result within the result list.
template <typename T>
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results, T &&value) {
ProcessPDLValue<T>::processAsResult(rewriter, results,
std::forward<T>(value));
return success();
}
/// Store a std::pair<> as individual results within the result list.
template <typename T1, typename T2>
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
std::pair<T1, T2> &&pair) {
if (failed(processResults(rewriter, results, std::move(pair.first))) ||
failed(processResults(rewriter, results, std::move(pair.second))))
return failure();
return success();
}
/// Store a std::tuple<> as individual results within the result list.
template <typename... Ts>
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
std::tuple<Ts...> &&tuple) {
auto applyFn = [&](auto &&...args) {
return (succeeded(processResults(rewriter, results, std::move(args))) &&
...);
};
return success(std::apply(applyFn, std::move(tuple)));
}
/// Handle LogicalResult propagation.
inline LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
LogicalResult &&result) {
return result;
}
template <typename T>
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
FailureOr<T> &&result) {
if (failed(result))
return failure();
return processResults(rewriter, results, std::move(*result));
}
//===----------------------------------------------------------------------===//
// PDL Constraint Builder
//===----------------------------------------------------------------------===//
/// Process the arguments of a native constraint and invoke it.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
typename FnTraitsT::result_t
processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
return fn(
rewriter,
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
values[I]))...);
}
/// Build a constraint function from the given function `ConstraintFnT`. This
/// allows for enabling the user to define simpler, more direct constraint
/// functions without needing to handle the low-level PDL goop.
///
/// If the constraint function is already in the correct form, we just forward
/// it directly.
template <typename ConstraintFnT>
std::enable_if_t<
std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {
return std::forward<ConstraintFnT>(constraintFn);
}
/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
/// we desire.
template <typename ConstraintFnT>
std::enable_if_t<
!std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
PatternRewriter &rewriter,
ArrayRef<PDLValue> values) -> LogicalResult {
auto argIndices = std::make_index_sequence<
llvm::function_traits<ConstraintFnT>::num_args - 1>();
if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
return failure();
return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
argIndices);
};
}
//===----------------------------------------------------------------------===//
// PDL Rewrite Builder
//===----------------------------------------------------------------------===//
/// Process the arguments of a native rewrite and invoke it.
/// This overload handles the case of no return values.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
fn(rewriter,
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
values[I]))...);
return success();
}
/// This overload handles the case of return values, which need to be packaged
/// into the result list.
template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &results, ArrayRef<PDLValue> values,
std::index_sequence<I...>) {
return processResults(
rewriter, results,
fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
processAsArg(values[I]))...));
(void)values;
}
/// Build a rewrite function from the given function `RewriteFnT`. This
/// allows for enabling the user to define simpler, more direct rewrite
/// functions without needing to handle the low-level PDL goop.
///
/// If the rewrite function is already in the correct form, we just forward
/// it directly.
template <typename RewriteFnT>
std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
PDLRewriteFunction>
buildRewriteFn(RewriteFnT &&rewriteFn) {
return std::forward<RewriteFnT>(rewriteFn);
}
/// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
/// we desire.
template <typename RewriteFnT>
std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
PDLRewriteFunction>
buildRewriteFn(RewriteFnT &&rewriteFn) {
return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> values) {
auto argIndices =
std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1>();
assertArgs<RewriteFnT>(rewriter, values, argIndices);
return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
argIndices);
};
}
} // namespace pdl_function_builder
} // namespace detail
//===----------------------------------------------------------------------===//
// PDLPatternModule
/// This class contains all of the necessary data for a set of PDL patterns, or
/// pattern rewrites specified in the form of the PDL dialect. This PDL module
/// contained by this pattern may contain any number of `pdl.pattern`
/// operations.
class PDLPatternModule {
public:
PDLPatternModule() = default;
/// Construct a PDL pattern with the given module and configurations.
PDLPatternModule(OwningOpRef<ModuleOp> module)
: pdlModule(std::move(module)) {}
template <typename... ConfigsT>
PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
: PDLPatternModule(std::move(module)) {
auto configSet = std::make_unique<PDLPatternConfigSet>(
std::forward<ConfigsT>(patternConfigs)...);
attachConfigToPatterns(*pdlModule, *configSet);
configs.emplace_back(std::move(configSet));
}
/// Merge the state in `other` into this pattern module.
void mergeIn(PDLPatternModule &&other);
/// Return the internal PDL module of this pattern.
ModuleOp getModule() { return pdlModule.get(); }
//===--------------------------------------------------------------------===//
// Function Registry
/// Register a constraint function with PDL. A constraint function may be
/// specified in one of two ways:
///
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
///
/// In this overload the arguments of the constraint function are passed via
/// the low-level PDLValue form.
///
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
///
/// In this form the arguments of the constraint function are passed via the
/// expected high level C++ type. In this form, the framework will
/// automatically unwrap PDLValues and convert them to the expected ValueTs.
/// For example, if the constraint function accepts a `Operation *`, the
/// framework will automatically cast the input PDLValue. In the case of a
/// `StringRef`, the framework will automatically unwrap the argument as a
/// StringAttr and pass the underlying string value. To see the full list of
/// supported types, or to see how to add handling for custom types, view
/// the definition of `ProcessPDLValue` above.
void registerConstraintFunction(StringRef name,
PDLConstraintFunction constraintFn);
template <typename ConstraintFnT>
void registerConstraintFunction(StringRef name,
ConstraintFnT &&constraintFn) {
registerConstraintFunction(name,
detail::pdl_function_builder::buildConstraintFn(
std::forward<ConstraintFnT>(constraintFn)));
}
/// Register a rewrite function with PDL. A rewrite function may be specified
/// in one of two ways:
///
/// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
///
/// In this overload the arguments of the constraint function are passed via
/// the low-level PDLValue form, and the results are manually appended to
/// the given result list.
///
/// * `ResultT (PatternRewriter &, ValueTs... values)`
///
/// In this form the arguments and result of the rewrite function are passed
/// via the expected high level C++ type. In this form, the framework will
/// automatically unwrap the PDLValues arguments and convert them to the
/// expected ValueTs. It will also automatically handle the processing and
/// packaging of the result value to the result list. For example, if the
/// rewrite function takes a `Operation *`, the framework will automatically
/// cast the input PDLValue. In the case of a `StringRef`, the framework
/// will automatically unwrap the argument as a StringAttr and pass the
/// underlying string value. In the reverse case, if the rewrite returns a
/// StringRef or std::string, it will automatically package this as a
/// StringAttr and append it to the result list. To see the full list of
/// supported types, or to see how to add handling for custom types, view
/// the definition of `ProcessPDLValue` above.
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
template <typename RewriteFnT>
void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
std::forward<RewriteFnT>(rewriteFn)));
}
/// Return the set of the registered constraint functions.
const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
return constraintFunctions;
}
llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
return constraintFunctions;
}
/// Return the set of the registered rewrite functions.
const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
return rewriteFunctions;
}
llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
return rewriteFunctions;
}
/// Return the set of the registered pattern configs.
SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
return std::move(configs);
}
DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
return std::move(configMap);
}
/// Clear out the patterns and functions within this module.
void clear() {
pdlModule = nullptr;
constraintFunctions.clear();
rewriteFunctions.clear();
}
private:
/// Attach the given pattern config set to the patterns defined within the
/// given module.
void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
/// The module containing the `pdl.pattern` operations.
OwningOpRef<ModuleOp> pdlModule;
/// The set of configuration sets referenced by patterns within `pdlModule`.
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
DenseMap<Operation *, PDLPatternConfigSet *> configMap;
/// The external functions referenced from within the PDL module.
llvm::StringMap<PDLConstraintFunction> constraintFunctions;
llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
};
namespace mlir {
//===----------------------------------------------------------------------===//
// RewritePatternSet
@ -1679,8 +759,7 @@ public:
nativePatterns.emplace_back(std::move(pattern));
}
RewritePatternSet(PDLPatternModule &&pattern)
: context(pattern.getModule()->getContext()),
pdlPatterns(std::move(pattern)) {}
: context(pattern.getContext()), pdlPatterns(std::move(pattern)) {}
MLIRContext *getContext() const { return context; }
@ -1853,6 +932,7 @@ private:
pattern->addDebugLabels(debugLabels);
nativePatterns.emplace_back(std::move(pattern));
}
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
@ -1863,6 +943,9 @@ private:
MLIRContext *const context;
NativePatternListT nativePatterns;
// Patterns expressed with PDL. This will compile to a stub class when PDL is
// not enabled.
PDLPatternModule pdlPatterns;
};

View File

@ -13,6 +13,7 @@
#ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
#include "mlir/Config/mlir-config.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/StringMap.h"
@ -1015,6 +1016,7 @@ private:
MLIRContext &ctx;
};
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
//===----------------------------------------------------------------------===//
// PDL Configuration
//===----------------------------------------------------------------------===//
@ -1044,6 +1046,19 @@ private:
/// Register the dialect conversion PDL functions with the given pattern set.
void registerConversionPDLFunctions(RewritePatternSet &patterns);
#else
// Stubs for when PDL in rewriting is not enabled.
inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
class PDLConversionConfig final {
public:
PDLConversionConfig(const TypeConverter * /*converter*/) {}
};
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//

View File

@ -11,8 +11,9 @@ add_mlir_conversion_library(MLIRComplexToLibm
Core
LINK_LIBS PUBLIC
MLIRComplexDialect
MLIRDialectUtils
MLIRFuncDialect
MLIRComplexDialect
MLIRPass
MLIRTransformUtils
)

View File

@ -15,6 +15,7 @@ add_mlir_conversion_library(MLIRMathToLibm
MLIRDialectUtils
MLIRFuncDialect
MLIRMathDialect
MLIRPass
MLIRTransformUtils
MLIRVectorDialect
MLIRVectorUtils

View File

@ -14,7 +14,6 @@ add_mlir_dialect_library(MLIRBufferizationTransformOps
MLIRFunctionInterfaces
MLIRLinalgDialect
MLIRParser
MLIRPDLDialect
MLIRSideEffectInterfaces
MLIRTransformDialect
)

View File

@ -1,3 +1,9 @@
if(MLIR_ENABLE_PDL_IN_PATTERNMATCH)
set(pdl_src
PDL/PDLPatternMatch.cpp
)
endif()
add_mlir_library(MLIRIR
AffineExpr.cpp
AffineMap.cpp
@ -36,6 +42,7 @@ add_mlir_library(MLIRIR
ValueRange.cpp
Verifier.cpp
Visitors.cpp
${pdl_src}
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
@ -61,3 +68,4 @@ add_mlir_library(MLIRIR
LINK_LIBS PUBLIC
MLIRSupport
)

View File

@ -0,0 +1,133 @@
//===- PDLPatternMatch.cpp - Base classes for PDL pattern match
//------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/RegionKindInterface.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// PDLValue
//===----------------------------------------------------------------------===//
void PDLValue::print(raw_ostream &os) const {
if (!value) {
os << "<NULL-PDLValue>";
return;
}
switch (kind) {
case Kind::Attribute:
os << cast<Attribute>();
break;
case Kind::Operation:
os << *cast<Operation *>();
break;
case Kind::Type:
os << cast<Type>();
break;
case Kind::TypeRange:
llvm::interleaveComma(cast<TypeRange>(), os);
break;
case Kind::Value:
os << cast<Value>();
break;
case Kind::ValueRange:
llvm::interleaveComma(cast<ValueRange>(), os);
break;
}
}
void PDLValue::print(raw_ostream &os, Kind kind) {
switch (kind) {
case Kind::Attribute:
os << "Attribute";
break;
case Kind::Operation:
os << "Operation";
break;
case Kind::Type:
os << "Type";
break;
case Kind::TypeRange:
os << "TypeRange";
break;
case Kind::Value:
os << "Value";
break;
case Kind::ValueRange:
os << "ValueRange";
break;
}
}
//===----------------------------------------------------------------------===//
// PDLPatternModule
//===----------------------------------------------------------------------===//
void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
// Ignore the other module if it has no patterns.
if (!other.pdlModule)
return;
// Steal the functions and config of the other module.
for (auto &it : other.constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second));
for (auto &it : other.rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second));
for (auto &it : other.configs)
configs.emplace_back(std::move(it));
for (auto &it : other.configMap)
configMap.insert(it);
// Steal the other state if we have no patterns.
if (!pdlModule) {
pdlModule = std::move(other.pdlModule);
return;
}
// Merge the pattern operations from the other module into this one.
Block *block = pdlModule->getBody();
block->getOperations().splice(block->end(),
other.pdlModule->getBody()->getOperations());
}
void PDLPatternModule::attachConfigToPatterns(ModuleOp module,
PDLPatternConfigSet &configSet) {
// Attach the configuration to the symbols within the module. We only add
// to symbols to avoid hardcoding any specific operation names here (given
// that we don't depend on any PDL dialect). We can't use
// cast<SymbolOpInterface> here because patterns may be optional symbols.
module->walk([&](Operation *op) {
if (op->hasTrait<SymbolOpInterface::Trait>())
configMap[op] = &configSet;
});
}
//===----------------------------------------------------------------------===//
// Function Registry
void PDLPatternModule::registerConstraintFunction(
StringRef name, PDLConstraintFunction constraintFn) {
// TODO: Is it possible to diagnose when `name` is already registered to
// a function that is not equivalent to `constraintFn`?
// Allow existing mappings in the case multiple patterns depend on the same
// constraint.
constraintFunctions.try_emplace(name, std::move(constraintFn));
}
void PDLPatternModule::registerRewriteFunction(StringRef name,
PDLRewriteFunction rewriteFn) {
// TODO: Is it possible to diagnose when `name` is already registered to
// a function that is not equivalent to `rewriteFn`?
// Allow existing mappings in the case multiple patterns depend on the same
// rewrite.
rewriteFunctions.try_emplace(name, std::move(rewriteFn));
}

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/PatternMatch.h"
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"
@ -97,124 +98,6 @@ LogicalResult RewritePattern::match(Operation *op) const {
/// Out-of-line vtable anchor.
void RewritePattern::anchor() {}
//===----------------------------------------------------------------------===//
// PDLValue
//===----------------------------------------------------------------------===//
void PDLValue::print(raw_ostream &os) const {
if (!value) {
os << "<NULL-PDLValue>";
return;
}
switch (kind) {
case Kind::Attribute:
os << cast<Attribute>();
break;
case Kind::Operation:
os << *cast<Operation *>();
break;
case Kind::Type:
os << cast<Type>();
break;
case Kind::TypeRange:
llvm::interleaveComma(cast<TypeRange>(), os);
break;
case Kind::Value:
os << cast<Value>();
break;
case Kind::ValueRange:
llvm::interleaveComma(cast<ValueRange>(), os);
break;
}
}
void PDLValue::print(raw_ostream &os, Kind kind) {
switch (kind) {
case Kind::Attribute:
os << "Attribute";
break;
case Kind::Operation:
os << "Operation";
break;
case Kind::Type:
os << "Type";
break;
case Kind::TypeRange:
os << "TypeRange";
break;
case Kind::Value:
os << "Value";
break;
case Kind::ValueRange:
os << "ValueRange";
break;
}
}
//===----------------------------------------------------------------------===//
// PDLPatternModule
//===----------------------------------------------------------------------===//
void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
// Ignore the other module if it has no patterns.
if (!other.pdlModule)
return;
// Steal the functions and config of the other module.
for (auto &it : other.constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second));
for (auto &it : other.rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second));
for (auto &it : other.configs)
configs.emplace_back(std::move(it));
for (auto &it : other.configMap)
configMap.insert(it);
// Steal the other state if we have no patterns.
if (!pdlModule) {
pdlModule = std::move(other.pdlModule);
return;
}
// Merge the pattern operations from the other module into this one.
Block *block = pdlModule->getBody();
block->getOperations().splice(block->end(),
other.pdlModule->getBody()->getOperations());
}
void PDLPatternModule::attachConfigToPatterns(ModuleOp module,
PDLPatternConfigSet &configSet) {
// Attach the configuration to the symbols within the module. We only add
// to symbols to avoid hardcoding any specific operation names here (given
// that we don't depend on any PDL dialect). We can't use
// cast<SymbolOpInterface> here because patterns may be optional symbols.
module->walk([&](Operation *op) {
if (op->hasTrait<SymbolOpInterface::Trait>())
configMap[op] = &configSet;
});
}
//===----------------------------------------------------------------------===//
// Function Registry
void PDLPatternModule::registerConstraintFunction(
StringRef name, PDLConstraintFunction constraintFn) {
// TODO: Is it possible to diagnose when `name` is already registered to
// a function that is not equivalent to `constraintFn`?
// Allow existing mappings in the case multiple patterns depend on the same
// constraint.
constraintFunctions.try_emplace(name, std::move(constraintFn));
}
void PDLPatternModule::registerRewriteFunction(StringRef name,
PDLRewriteFunction rewriteFn) {
// TODO: Is it possible to diagnose when `name` is already registered to
// a function that is not equivalent to `rewriteFn`?
// Allow existing mappings in the case multiple patterns depend on the same
// rewrite.
rewriteFunctions.try_emplace(name, std::move(rewriteFn));
}
//===----------------------------------------------------------------------===//
// RewriterBase
//===----------------------------------------------------------------------===//

View File

@ -16,6 +16,8 @@
#include "mlir/IR/PatternMatch.h"
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
namespace mlir {
namespace pdl_interp {
class RecordMatchOp;
@ -224,4 +226,38 @@ private:
} // namespace detail
} // namespace mlir
#else
namespace mlir::detail {
class PDLByteCodeMutableState {
public:
void cleanupAfterMatchAndRewrite() {}
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit) {}
};
class PDLByteCodePattern : public Pattern {};
class PDLByteCode {
public:
struct MatchResult {
const PDLByteCodePattern *pattern = nullptr;
PatternBenefit benefit;
};
void initializeMutableState(PDLByteCodeMutableState &state) const {}
void match(Operation *op, PatternRewriter &rewriter,
SmallVectorImpl<MatchResult> &matches,
PDLByteCodeMutableState &state) const {}
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
PDLByteCodeMutableState &state) const {
return failure();
}
ArrayRef<PDLByteCodePattern> getPatterns() const { return {}; }
};
} // namespace mlir::detail
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
#endif // MLIR_REWRITE_BYTECODE_H_

View File

@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES ByteCode.cpp)
add_mlir_library(MLIRRewrite
ByteCode.cpp
FrozenRewritePatternSet.cpp
PatternApplicator.cpp
@ -11,8 +12,31 @@ add_mlir_library(MLIRRewrite
LINK_LIBS PUBLIC
MLIRIR
MLIRPDLDialect
MLIRPDLInterpDialect
MLIRPDLToPDLInterp
MLIRSideEffectInterfaces
)
if(MLIR_ENABLE_PDL_IN_PATTERNMATCH)
add_mlir_library(MLIRRewritePDL
ByteCode.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite
DEPENDS
mlir-generic-headers
LINK_LIBS PUBLIC
MLIRIR
MLIRPDLDialect
MLIRPDLInterpDialect
MLIRPDLToPDLInterp
MLIRSideEffectInterfaces
)
target_link_libraries(MLIRRewrite PUBLIC
MLIRPDLDialect
MLIRPDLInterpDialect
MLIRPDLToPDLInterp
MLIRRewritePDL)
endif()

View File

@ -8,8 +8,6 @@
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "ByteCode.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Dialect/PDL/IR/PDLOps.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@ -17,6 +15,11 @@
using namespace mlir;
// Include the PDL rewrite support.
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Dialect/PDL/IR/PDLOps.h"
static LogicalResult
convertPDLToPDLInterp(ModuleOp pdlModule,
DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
@ -48,6 +51,7 @@ convertPDLToPDLInterp(ModuleOp pdlModule,
pdlModule.getBody()->walk(simplifyFn);
return success();
}
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
//===----------------------------------------------------------------------===//
// FrozenRewritePatternSet
@ -121,6 +125,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
impl->nativeAnyOpPatterns.push_back(std::move(pat));
}
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
// Generate the bytecode for the PDL patterns if any were provided.
PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
ModuleOp pdlModule = pdlPatterns.getModule();
@ -137,6 +142,7 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
pdlModule, pdlPatterns.takeConfigs(), configMap,
pdlPatterns.takeConstraintFunctions(),
pdlPatterns.takeRewriteFunctions());
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
}
FrozenRewritePatternSet::~FrozenRewritePatternSet() = default;

View File

@ -152,7 +152,6 @@ LogicalResult PatternApplicator::matchAndRewrite(
// Find the next pattern with the highest benefit.
const Pattern *bestPattern = nullptr;
unsigned *bestPatternIt = &opIt;
const PDLByteCode::MatchResult *pdlMatch = nullptr;
/// Operation specific patterns.
if (opIt < opE)
@ -164,6 +163,8 @@ LogicalResult PatternApplicator::matchAndRewrite(
bestPatternIt = &anyIt;
bestPattern = anyOpPatterns[anyIt];
}
const PDLByteCode::MatchResult *pdlMatch = nullptr;
/// PDL patterns.
if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
pdlMatches[pdlIt].benefit)) {
@ -171,6 +172,7 @@ LogicalResult PatternApplicator::matchAndRewrite(
pdlMatch = &pdlMatches[pdlIt];
bestPattern = pdlMatch->pattern;
}
if (!bestPattern)
break;

View File

@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@ -3312,6 +3313,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const
return std::nullopt;
}
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
//===----------------------------------------------------------------------===//
// PDL Configuration
//===----------------------------------------------------------------------===//
@ -3382,6 +3384,7 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
return std::move(remappedTypes);
});
}
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points

View File

@ -97,16 +97,13 @@ set(MLIR_TEST_DEPENDS
mlir-capi-ir-test
mlir-capi-llvm-test
mlir-capi-pass-test
mlir-capi-pdl-test
mlir-capi-quant-test
mlir-capi-sparse-tensor-test
mlir-capi-transform-test
mlir-capi-translation-test
mlir-linalg-ods-yaml-gen
mlir-lsp-server
mlir-pdll-lsp-server
mlir-opt
mlir-pdll
mlir-query
mlir-reduce
mlir-tblgen
@ -115,6 +112,12 @@ set(MLIR_TEST_DEPENDS
tblgen-to-irdl
)
set(MLIR_TEST_DEPENDS ${MLIR_TEST_DEPENDS}
mlir-capi-pdl-test
mlir-pdll-lsp-server
mlir-pdll
)
# The native target may not be enabled, in this case we won't
# run tests that involves executing on the host: do not build
# useless binaries.
@ -159,9 +162,10 @@ if(LLVM_BUILD_EXAMPLES)
toyc-ch3
toyc-ch4
toyc-ch5
)
list(APPEND MLIR_TEST_DEPENDS
transform-opt-ch2
transform-opt-ch3
mlir-minimal-opt
)
if(MLIR_ENABLE_EXECUTION_ENGINE)
list(APPEND MLIR_TEST_DEPENDS

View File

@ -1,16 +1,17 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestRewrite
TestPDLByteCode.cpp
if (MLIR_ENABLE_PDL_IN_PATTERNMATCH)
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestRewrite
TestPDLByteCode.cpp
EXCLUDE_FROM_LIBMLIR
EXCLUDE_FROM_LIBMLIR
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRSupport
MLIRTransformUtils
)
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRSupport
MLIRTransformUtils
)
endif()

View File

@ -23,6 +23,8 @@ add_mlir_library(MLIRTestPDLL
MLIRCastInterfaces
MLIRIR
MLIRPass
MLIRPDLInterpDialect
MLIRPDLDialect
MLIRSupport
MLIRTestDialect
MLIRTransformUtils

View File

@ -1,3 +1,8 @@
set(LLVM_OPTIONAL_SOURCES
TestDialectConversion.cpp)
set(MLIRTestTransformsPDLDep)
set(MLIRTestTransformsPDLSrc)
if(MLIR_ENABLE_PDL_IN_PATTERNMATCH)
add_mlir_pdll_library(MLIRTestDialectConversionPDLLPatternsIncGen
TestDialectConversion.pdll
TestDialectConversionPDLLPatterns.h.inc
@ -6,17 +11,22 @@ add_mlir_pdll_library(MLIRTestDialectConversionPDLLPatternsIncGen
${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test
${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test
)
set(MLIRTestTransformsPDLSrc
TestDialectConversion.cpp)
set(MLIRTestTransformsPDLDep
MLIRTestDialectConversionPDLLPatternsIncGen)
endif()
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
TestCommutativityUtils.cpp
TestConstantFold.cpp
TestControlFlowSink.cpp
TestDialectConversion.cpp
TestInlining.cpp
TestIntRangeInference.cpp
TestMakeIsolatedFromAbove.cpp
TestTopologicalSort.cpp
${MLIRTestTransformsPDLSrc}
EXCLUDE_FROM_LIBMLIR
@ -24,7 +34,7 @@ add_mlir_library(MLIRTestTransforms
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
DEPENDS
MLIRTestDialectConversionPDLLPatternsIncGen
${MLIRTestTransformsPDLDep}
LINK_LIBS PUBLIC
MLIRAnalysis

View File

@ -21,10 +21,17 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestIR
MLIRTestPass
MLIRTestReducer
MLIRTestRewrite
MLIRTestTransformDialect
MLIRTestTransforms
)
set(test_libs
${test_libs}
MLIRTestTransformDialect
MLIRTestTransforms)
if (MLIR_ENABLE_PDL_IN_PATTERNMATCH)
set(test_libs
${test_libs}
MLIRTestRewrite)
endif()
endif()
set(LIBS

View File

@ -38,16 +38,24 @@ if(MLIR_INCLUDE_TESTS)
MLIRTestIR
MLIRTestOneToNTypeConversionPass
MLIRTestPass
MLIRTestPDLL
MLIRTestReducer
MLIRTestRewrite
MLIRTestTransformDialect
MLIRTestTransforms
MLIRTilingInterfaceTestPasses
MLIRVectorTestPasses
MLIRTestVectorToSPIRV
MLIRLLVMTestPasses
)
set(test_libs ${test_libs}
MLIRTestPDLL
MLIRTestTransformDialect
)
if (MLIR_ENABLE_PDL_IN_PATTERNMATCH)
set(test_libs ${test_libs}
MLIRTestPDLL
MLIRTestRewrite
)
endif()
endif()
set(LIBS

View File

@ -85,7 +85,6 @@ void registerTestDataLayoutQuery();
void registerTestDeadCodeAnalysisPass();
void registerTestDecomposeCallGraphTypes();
void registerTestDiagnosticsPass();
void registerTestDialectConversionPasses();
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
void registerTestEmulateNarrowTypePass();
@ -124,8 +123,6 @@ void registerTestNextAccessPass();
void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
void registerTestPadFusion();
void registerTestPDLByteCodePass();
void registerTestPDLLPasses();
void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestRecursiveTypesPass();
void registerTestSCFUtilsPass();
@ -142,13 +139,18 @@ void registerTestWrittenToPass();
void registerTestVectorLowerings();
void registerTestVectorReductionToSPIRVDotProd();
void registerTestNvgpuLowerings();
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
void registerTestDialectConversionPasses();
void registerTestPDLByteCodePass();
void registerTestPDLLPasses();
#endif
} // namespace test
} // namespace mlir
namespace test {
void registerTestDialect(DialectRegistry &);
void registerTestTransformDialectExtension(DialectRegistry &);
void registerTestDynDialect(DialectRegistry &);
void registerTestTransformDialectExtension(DialectRegistry &);
} // namespace test
#ifdef MLIR_INCLUDE_TESTS
@ -202,7 +204,6 @@ void registerTestPasses() {
mlir::test::registerTestConstantFold();
mlir::test::registerTestControlFlowSink();
mlir::test::registerTestDiagnosticsPass();
mlir::test::registerTestDialectConversionPasses();
mlir::test::registerTestDecomposeCallGraphTypes();
mlir::test::registerTestDataLayoutPropagation();
mlir::test::registerTestDataLayoutQuery();
@ -243,8 +244,6 @@ void registerTestPasses() {
mlir::test::registerTestOneToNTypeConversionPass();
mlir::test::registerTestOpaqueLoc();
mlir::test::registerTestPadFusion();
mlir::test::registerTestPDLByteCodePass();
mlir::test::registerTestPDLLPasses();
mlir::test::registerTestRecursiveTypesPass();
mlir::test::registerTestSCFUtilsPass();
mlir::test::registerTestSCFWhileOpBuilderPass();
@ -260,6 +259,11 @@ void registerTestPasses() {
mlir::test::registerTestVectorReductionToSPIRVDotProd();
mlir::test::registerTestNvgpuLowerings();
mlir::test::registerTestWrittenToPass();
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
mlir::test::registerTestDialectConversionPasses();
mlir::test::registerTestPDLByteCodePass();
mlir::test::registerTestPDLLPasses();
#endif
}
#endif

View File

@ -35,6 +35,7 @@ expand_template(
substitutions = {
"#cmakedefine01 MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS": "#define MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 0",
"#cmakedefine MLIR_GREEDY_REWRITE_RANDOMIZER_SEED ${MLIR_GREEDY_REWRITE_RANDOMIZER_SEED}": "/* #undef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED */",
"#cmakedefine01 MLIR_ENABLE_PDL_IN_PATTERNMATCH": "#define MLIR_ENABLE_PDL_IN_PATTERNMATCH 1",
},
template = "include/mlir/Config/mlir-config.h.cmake",
)
@ -318,11 +319,13 @@ cc_library(
srcs = glob([
"lib/IR/*.cpp",
"lib/IR/*.h",
"lib/IR/PDL/*.cpp",
"lib/Bytecode/Reader/*.h",
"lib/Bytecode/Writer/*.h",
"lib/Bytecode/*.h",
]) + [
"lib/Bytecode/BytecodeOpInterface.cpp",
"include/mlir/IR/PDLPatternMatch.h.inc",
],
hdrs = glob([
"include/mlir/IR/*.h",
@ -345,6 +348,7 @@ cc_library(
":BuiltinTypesIncGen",
":BytecodeOpInterfaceIncGen",
":CallOpInterfacesIncGen",
":config",
":DataLayoutInterfacesIncGen",
":InferTypeOpInterfaceIncGen",
":OpAsmInterfaceIncGen",