mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-04-01 12:43:47 +00:00
[mlir] Refactor DialectRegistry delayed interface support into a general DialectExtension mechanism
The current dialect registry allows for attaching delayed interfaces, that are added to attrs/dialects/ops/etc. when the owning dialect gets loaded. This is clunky for quite a few reasons, e.g. each interface type has a separate tracking structure, and is also quite limiting. This commit refactors this delayed mutation of dialect constructs into a more general DialectExtension mechanism. This mechanism is essentially a registration callback that is invoked when a set of dialects have been loaded. This allows for attaching interfaces directly on the loaded constructs, and also allows for loading new dependent dialects. The latter of which is extremely useful as it will now enable dependent dialects to only apply in the contexts in which they are necessary. For example, a dialect dependency can now be conditional on if a user actually needs the interface that relies on it. Differential Revision: https://reviews.llvm.org/D120367
This commit is contained in:
parent
8212b41b7b
commit
77eee5795e
@ -13,6 +13,7 @@
|
||||
#ifndef MLIR_IR_DIALECT_H
|
||||
#define MLIR_IR_DIALECT_H
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Support/TypeID.h"
|
||||
|
||||
@ -26,11 +27,9 @@ class DialectInterface;
|
||||
class OpBuilder;
|
||||
class Type;
|
||||
|
||||
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
|
||||
using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
|
||||
using DialectInterfaceAllocatorFunction =
|
||||
std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
|
||||
using ObjectInterfaceAllocatorFunction = std::function<void(MLIRContext *)>;
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Dialects are groups of MLIR operations, types and attributes, as well as
|
||||
/// behavior associated with the entire group. For example, hooks into other
|
||||
@ -180,6 +179,16 @@ public:
|
||||
getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName));
|
||||
}
|
||||
|
||||
/// Register a dialect interface with this dialect instance.
|
||||
void addInterface(std::unique_ptr<DialectInterface> interface);
|
||||
|
||||
/// Register a set of dialect interfaces with this dialect instance.
|
||||
template <typename... Args>
|
||||
void addInterfaces() {
|
||||
(void)std::initializer_list<int>{
|
||||
0, (addInterface(std::make_unique<Args>(this)), 0)...};
|
||||
}
|
||||
|
||||
protected:
|
||||
/// The constructor takes a unique namespace for this dialect as well as the
|
||||
/// context to bind to.
|
||||
@ -218,15 +227,6 @@ protected:
|
||||
/// Enable support for unregistered types.
|
||||
void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
|
||||
|
||||
/// Register a dialect interface with this dialect instance.
|
||||
void addInterface(std::unique_ptr<DialectInterface> interface);
|
||||
|
||||
/// Register a set of dialect interfaces with this dialect instance.
|
||||
template <typename... Args> void addInterfaces() {
|
||||
(void)std::initializer_list<int>{
|
||||
0, (addInterface(std::make_unique<Args>(this)), 0)...};
|
||||
}
|
||||
|
||||
private:
|
||||
Dialect(const Dialect &) = delete;
|
||||
void operator=(Dialect &) = delete;
|
||||
@ -274,168 +274,6 @@ private:
|
||||
friend class MLIRContext;
|
||||
};
|
||||
|
||||
/// The DialectRegistry maps a dialect namespace to a constructor for the
|
||||
/// matching dialect.
|
||||
/// This allows for decoupling the list of dialects "available" from the
|
||||
/// dialects loaded in the Context. The parser in particular will lazily load
|
||||
/// dialects in the Context as operations are encountered.
|
||||
class DialectRegistry {
|
||||
/// Lists of interfaces that need to be registered when the dialect is loaded.
|
||||
struct DelayedInterfaces {
|
||||
/// Dialect interfaces.
|
||||
SmallVector<std::pair<TypeID, DialectInterfaceAllocatorFunction>, 2>
|
||||
dialectInterfaces;
|
||||
/// Attribute/Operation/Type interfaces.
|
||||
SmallVector<std::tuple<TypeID, TypeID, ObjectInterfaceAllocatorFunction>, 2>
|
||||
objectInterfaces;
|
||||
};
|
||||
|
||||
using MapTy =
|
||||
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
|
||||
using InterfaceMapTy = DenseMap<TypeID, DelayedInterfaces>;
|
||||
|
||||
public:
|
||||
explicit DialectRegistry();
|
||||
|
||||
template <typename ConcreteDialect> void insert() {
|
||||
insert(TypeID::get<ConcreteDialect>(),
|
||||
ConcreteDialect::getDialectNamespace(),
|
||||
static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
|
||||
// Just allocate the dialect, the context
|
||||
// takes ownership of it.
|
||||
return ctx->getOrLoadDialect<ConcreteDialect>();
|
||||
})));
|
||||
}
|
||||
|
||||
template <typename ConcreteDialect, typename OtherDialect,
|
||||
typename... MoreDialects>
|
||||
void insert() {
|
||||
insert<ConcreteDialect>();
|
||||
insert<OtherDialect, MoreDialects...>();
|
||||
}
|
||||
|
||||
/// Add a new dialect constructor to the registry. The constructor must be
|
||||
/// calling MLIRContext::getOrLoadDialect in order for the context to take
|
||||
/// ownership of the dialect and for delayed interface registration to happen.
|
||||
void insert(TypeID typeID, StringRef name,
|
||||
const DialectAllocatorFunction &ctor);
|
||||
|
||||
/// Return an allocation function for constructing the dialect identified by
|
||||
/// its namespace, or nullptr if the namespace is not in this registry.
|
||||
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
|
||||
|
||||
// Register all dialects available in the current registry with the registry
|
||||
// in the provided context.
|
||||
void appendTo(DialectRegistry &destination) const {
|
||||
for (const auto &nameAndRegistrationIt : registry)
|
||||
destination.insert(nameAndRegistrationIt.second.first,
|
||||
nameAndRegistrationIt.first,
|
||||
nameAndRegistrationIt.second.second);
|
||||
// Merge interfaces.
|
||||
for (auto it : interfaces) {
|
||||
TypeID dialect = it.first;
|
||||
auto destInterfaces = destination.interfaces.find(dialect);
|
||||
if (destInterfaces == destination.interfaces.end()) {
|
||||
destination.interfaces[dialect] = it.second;
|
||||
continue;
|
||||
}
|
||||
// The destination already has delayed interface registrations for this
|
||||
// dialect. Merge registrations into the destination registry.
|
||||
destInterfaces->second.dialectInterfaces.append(
|
||||
it.second.dialectInterfaces.begin(),
|
||||
it.second.dialectInterfaces.end());
|
||||
destInterfaces->second.objectInterfaces.append(
|
||||
it.second.objectInterfaces.begin(), it.second.objectInterfaces.end());
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the names of dialects known to this registry.
|
||||
auto getDialectNames() const {
|
||||
return llvm::map_range(
|
||||
registry,
|
||||
[](const MapTy::value_type &item) -> StringRef { return item.first; });
|
||||
}
|
||||
|
||||
/// Add an interface constructed with the given allocation function to the
|
||||
/// dialect provided as template parameter. The dialect must be present in
|
||||
/// the registry.
|
||||
template <typename DialectTy>
|
||||
void addDialectInterface(TypeID interfaceTypeID,
|
||||
DialectInterfaceAllocatorFunction allocator) {
|
||||
addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID,
|
||||
allocator);
|
||||
}
|
||||
|
||||
/// Add an interface to the dialect, both provided as template parameter. The
|
||||
/// dialect must be present in the registry.
|
||||
template <typename DialectTy, typename InterfaceTy>
|
||||
void addDialectInterface() {
|
||||
addDialectInterface<DialectTy>(
|
||||
InterfaceTy::getInterfaceID(), [](Dialect *dialect) {
|
||||
return std::make_unique<InterfaceTy>(dialect);
|
||||
});
|
||||
}
|
||||
|
||||
/// Add an external op interface model for an op that belongs to a dialect,
|
||||
/// both provided as template parameters. The dialect must be present in the
|
||||
/// registry.
|
||||
template <typename OpTy, typename ModelTy> void addOpInterface() {
|
||||
StringRef opName = OpTy::getOperationName();
|
||||
StringRef dialectName = opName.split('.').first;
|
||||
addObjectInterface(dialectName, TypeID::get<OpTy>(),
|
||||
ModelTy::Interface::getInterfaceID(),
|
||||
[](MLIRContext *context) {
|
||||
OpTy::template attachInterface<ModelTy>(*context);
|
||||
});
|
||||
}
|
||||
|
||||
/// Add an external attribute interface model for an attribute type `AttrTy`
|
||||
/// that is going to belong to `DialectTy`. The dialect must be present in the
|
||||
/// registry.
|
||||
template <typename DialectTy, typename AttrTy, typename ModelTy>
|
||||
void addAttrInterface() {
|
||||
addStorageUserInterface<AttrTy, ModelTy>(DialectTy::getDialectNamespace());
|
||||
}
|
||||
|
||||
/// Add an external type interface model for an type class `TypeTy` that is
|
||||
/// going to belong to `DialectTy`. The dialect must be present in the
|
||||
/// registry.
|
||||
template <typename DialectTy, typename TypeTy, typename ModelTy>
|
||||
void addTypeInterface() {
|
||||
addStorageUserInterface<TypeTy, ModelTy>(DialectTy::getDialectNamespace());
|
||||
}
|
||||
|
||||
/// Register any interfaces required for the given dialect (based on its
|
||||
/// TypeID). Users are not expected to call this directly.
|
||||
void registerDelayedInterfaces(Dialect *dialect) const;
|
||||
|
||||
private:
|
||||
/// Add an interface constructed with the given allocation function to the
|
||||
/// dialect identified by its namespace.
|
||||
void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID,
|
||||
const DialectInterfaceAllocatorFunction &allocator);
|
||||
|
||||
/// Add an attribute/operation/type interface constructible with the given
|
||||
/// allocation function to the dialect identified by its namespace.
|
||||
void addObjectInterface(StringRef dialectName, TypeID objectID,
|
||||
TypeID interfaceTypeID,
|
||||
const ObjectInterfaceAllocatorFunction &allocator);
|
||||
|
||||
/// Add an external model for an attribute/type interface to the dialect
|
||||
/// identified by its namespace.
|
||||
template <typename ObjectTy, typename ModelTy>
|
||||
void addStorageUserInterface(StringRef dialectName) {
|
||||
addObjectInterface(dialectName, TypeID::get<ObjectTy>(),
|
||||
ModelTy::Interface::getInterfaceID(),
|
||||
[](MLIRContext *context) {
|
||||
ObjectTy::template attachInterface<ModelTy>(*context);
|
||||
});
|
||||
}
|
||||
|
||||
MapTy registry;
|
||||
InterfaceMapTy interfaces;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
|
222
mlir/include/mlir/IR/DialectRegistry.h
Normal file
222
mlir/include/mlir/IR/DialectRegistry.h
Normal file
@ -0,0 +1,222 @@
|
||||
//===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines functionality for registring and extending dialects.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_DIALECTREGISTRY_H
|
||||
#define MLIR_IR_DIALECTREGISTRY_H
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
|
||||
namespace mlir {
|
||||
class Dialect;
|
||||
|
||||
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
|
||||
using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectExtension
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class represents an opaque dialect extension. It contains a set of
|
||||
/// required dialects and an application function. The required dialects control
|
||||
/// when the extension is applied, i.e. the extension is applied when all
|
||||
/// required dialects are loaded. The application function can be used to attach
|
||||
/// additional functionality to attributes, dialects, operations, types, etc.,
|
||||
/// and may also load additional necessary dialects.
|
||||
class DialectExtensionBase {
|
||||
public:
|
||||
virtual ~DialectExtensionBase();
|
||||
|
||||
/// Return the dialects that our required by this extension to be loaded
|
||||
/// before applying.
|
||||
ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
|
||||
|
||||
/// Apply this extension to the given context and the required dialects.
|
||||
virtual void apply(MLIRContext *context,
|
||||
MutableArrayRef<Dialect *> dialects) const = 0;
|
||||
|
||||
/// Return a copy of this extension.
|
||||
virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
|
||||
|
||||
protected:
|
||||
/// Initialize the extension with a set of required dialects. Note that there
|
||||
/// should always be at least one affected dialect.
|
||||
DialectExtensionBase(ArrayRef<StringRef> dialectNames)
|
||||
: dialectNames(dialectNames.begin(), dialectNames.end()) {
|
||||
assert(!dialectNames.empty() && "expected at least one affected dialect");
|
||||
}
|
||||
|
||||
private:
|
||||
/// The names of the dialects affected by this extension.
|
||||
SmallVector<StringRef> dialectNames;
|
||||
};
|
||||
|
||||
/// This class represents a dialect extension anchored on the given set of
|
||||
/// dialects. When all of the specified dialects have been loaded, the
|
||||
/// application function of this extension will be executed.
|
||||
template <typename DerivedT, typename... DialectsT>
|
||||
class DialectExtension : public DialectExtensionBase {
|
||||
public:
|
||||
/// Applies this extension to the given context and set of required dialects.
|
||||
virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
|
||||
|
||||
/// Return a copy of this extension.
|
||||
std::unique_ptr<DialectExtensionBase> clone() const final {
|
||||
return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
|
||||
}
|
||||
|
||||
protected:
|
||||
DialectExtension()
|
||||
: DialectExtensionBase(
|
||||
ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
|
||||
|
||||
/// Override the base apply method to allow providing the exact dialect types.
|
||||
void apply(MLIRContext *context,
|
||||
MutableArrayRef<Dialect *> dialects) const final {
|
||||
unsigned dialectIdx = 0;
|
||||
auto derivedDialects = std::tuple<DialectsT *...>{
|
||||
static_cast<DialectsT *>(dialects[dialectIdx++])...};
|
||||
llvm::apply_tuple(
|
||||
[&](DialectsT *...dialect) { apply(context, dialect...); },
|
||||
derivedDialects);
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectRegistry
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// The DialectRegistry maps a dialect namespace to a constructor for the
|
||||
/// matching dialect. This allows for decoupling the list of dialects
|
||||
/// "available" from the dialects loaded in the Context. The parser in
|
||||
/// particular will lazily load dialects in the Context as operations are
|
||||
/// encountered.
|
||||
class DialectRegistry {
|
||||
using MapTy =
|
||||
std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
|
||||
|
||||
public:
|
||||
explicit DialectRegistry();
|
||||
|
||||
template <typename ConcreteDialect>
|
||||
void insert() {
|
||||
insert(TypeID::get<ConcreteDialect>(),
|
||||
ConcreteDialect::getDialectNamespace(),
|
||||
static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
|
||||
// Just allocate the dialect, the context
|
||||
// takes ownership of it.
|
||||
return ctx->getOrLoadDialect<ConcreteDialect>();
|
||||
})));
|
||||
}
|
||||
|
||||
template <typename ConcreteDialect, typename OtherDialect,
|
||||
typename... MoreDialects>
|
||||
void insert() {
|
||||
insert<ConcreteDialect>();
|
||||
insert<OtherDialect, MoreDialects...>();
|
||||
}
|
||||
|
||||
/// Add a new dialect constructor to the registry. The constructor must be
|
||||
/// calling MLIRContext::getOrLoadDialect in order for the context to take
|
||||
/// ownership of the dialect and for delayed interface registration to happen.
|
||||
void insert(TypeID typeID, StringRef name,
|
||||
const DialectAllocatorFunction &ctor);
|
||||
|
||||
/// Return an allocation function for constructing the dialect identified by
|
||||
/// its namespace, or nullptr if the namespace is not in this registry.
|
||||
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
|
||||
|
||||
// Register all dialects available in the current registry with the registry
|
||||
// in the provided context.
|
||||
void appendTo(DialectRegistry &destination) const {
|
||||
for (const auto &nameAndRegistrationIt : registry)
|
||||
destination.insert(nameAndRegistrationIt.second.first,
|
||||
nameAndRegistrationIt.first,
|
||||
nameAndRegistrationIt.second.second);
|
||||
// Merge the extensions.
|
||||
for (const auto &extension : extensions)
|
||||
destination.extensions.push_back(extension->clone());
|
||||
}
|
||||
|
||||
/// Return the names of dialects known to this registry.
|
||||
auto getDialectNames() const {
|
||||
return llvm::map_range(
|
||||
registry,
|
||||
[](const MapTy::value_type &item) -> StringRef { return item.first; });
|
||||
}
|
||||
|
||||
/// Apply any held extensions that require the given dialect. Users are not
|
||||
/// expected to call this directly.
|
||||
void applyExtensions(Dialect *dialect) const;
|
||||
|
||||
/// Apply any applicable extensions to the given context. Users are not
|
||||
/// expected to call this directly.
|
||||
void applyExtensions(MLIRContext *ctx) const;
|
||||
|
||||
/// Add the given extension to the registry.
|
||||
void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
|
||||
extensions.push_back(std::move(extension));
|
||||
}
|
||||
|
||||
/// Add the given extensions to the registry.
|
||||
template <typename... ExtensionsT>
|
||||
void addExtensions() {
|
||||
(void)std::initializer_list<int>{
|
||||
addExtension(std::make_unique<ExtensionsT>())...};
|
||||
}
|
||||
|
||||
/// Add an extension function that requires the given dialects.
|
||||
/// Note: This bare functor overload is provided in addition to the
|
||||
/// std::function variant to enable dialect type deduction, e.g.:
|
||||
/// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
|
||||
///
|
||||
/// is equivalent to:
|
||||
/// registry.addExtension<MyDialect>(
|
||||
/// [](MLIRContext *ctx, MyDialect *dialect){ ... }
|
||||
/// )
|
||||
template <typename... DialectsT>
|
||||
void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
|
||||
addExtension<DialectsT...>(
|
||||
std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
|
||||
}
|
||||
template <typename... DialectsT>
|
||||
void
|
||||
addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
|
||||
using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
|
||||
|
||||
struct Extension : public DialectExtension<Extension, DialectsT...> {
|
||||
Extension(const Extension &) = default;
|
||||
Extension(ExtensionFnT extensionFn)
|
||||
: extensionFn(std::move(extensionFn)) {}
|
||||
~Extension() override = default;
|
||||
|
||||
void apply(MLIRContext *context, DialectsT *...dialects) const final {
|
||||
extensionFn(context, dialects...);
|
||||
}
|
||||
ExtensionFnT extensionFn;
|
||||
};
|
||||
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
|
||||
}
|
||||
|
||||
private:
|
||||
MapTy registry;
|
||||
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_IR_DIALECTREGISTRY_H
|
@ -154,7 +154,9 @@ struct SelectOpInterface
|
||||
|
||||
void mlir::arith::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<ConstantOp, ConstantOpInterface>();
|
||||
registry.addOpInterface<IndexCastOp, IndexCastOpInterface>();
|
||||
registry.addOpInterface<SelectOp, SelectOpInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, ArithmeticDialect *dialect) {
|
||||
ConstantOp::attachInterface<ConstantOpInterface>(*ctx);
|
||||
IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx);
|
||||
SelectOp::attachInterface<SelectOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
@ -695,7 +695,9 @@ LogicalResult bufferization::deallocateBuffers(Operation *op) {
|
||||
|
||||
void bufferization::registerAllocationOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<memref::AllocOp, DefaultAllocationInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
|
||||
memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -962,9 +962,11 @@ struct FuncOpInterface
|
||||
|
||||
void mlir::linalg::comprehensive_bufferize::std_ext::
|
||||
registerModuleBufferizationExternalModels(DialectRegistry ®istry) {
|
||||
registry.addOpInterface<func::CallOp, std_ext::CallOpInterface>();
|
||||
registry.addOpInterface<func::ReturnOp, std_ext::ReturnOpInterface>();
|
||||
registry.addOpInterface<FuncOp, std_ext::FuncOpInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, func::FuncDialect *dialect) {
|
||||
func::CallOp::attachInterface<std_ext::CallOpInterface>(*ctx);
|
||||
func::ReturnOp::attachInterface<std_ext::ReturnOpInterface>(*ctx);
|
||||
func::FuncOp::attachInterface<std_ext::FuncOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
/// Set the attribute that triggers inplace bufferization on a FuncOp argument
|
||||
|
@ -246,22 +246,13 @@ struct InitTensorOpInterface
|
||||
|
||||
/// Helper structure that iterates over all LinalgOps in `OpTys` and registers
|
||||
/// the `BufferizableOpInterface` with each of them.
|
||||
template <typename... OpTys>
|
||||
struct LinalgOpInterfaceHelper;
|
||||
|
||||
template <typename First, typename... Others>
|
||||
struct LinalgOpInterfaceHelper<First, Others...> {
|
||||
static void registerOpInterface(DialectRegistry ®istry) {
|
||||
registry.addOpInterface<First, LinalgOpInterface<First>>();
|
||||
LinalgOpInterfaceHelper<Others...>::registerOpInterface(registry);
|
||||
template <typename... Ops>
|
||||
struct LinalgOpInterfaceHelper {
|
||||
static void registerOpInterface(MLIRContext *ctx) {
|
||||
(void)std::initializer_list<int>{
|
||||
0, (Ops::template attachInterface<LinalgOpInterface<Ops>>(*ctx), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct LinalgOpInterfaceHelper<> {
|
||||
static void registerOpInterface(DialectRegistry ®istry) {}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Return true if all `neededValues` are in scope at the given
|
||||
@ -501,13 +492,15 @@ LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep(
|
||||
|
||||
void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<linalg::InitTensorOp, InitTensorOpInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
|
||||
linalg::InitTensorOp::attachInterface<InitTensorOpInterface>(*ctx);
|
||||
|
||||
// Register all Linalg structured ops. `LinalgOp` is an interface and it is
|
||||
// not possible to attach an external interface to an existing interface.
|
||||
// Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
|
||||
LinalgOpInterfaceHelper<
|
||||
// Register all Linalg structured ops. `LinalgOp` is an interface and it is
|
||||
// not possible to attach an external interface to an existing interface.
|
||||
// Therefore, attach the `BufferizableOpInterface` to all ops one-by-one.
|
||||
LinalgOpInterfaceHelper<
|
||||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
||||
>::registerOpInterface(registry);
|
||||
>::registerOpInterface(ctx);
|
||||
});
|
||||
}
|
||||
|
@ -503,8 +503,10 @@ struct YieldOpInterface
|
||||
|
||||
void mlir::scf::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<ExecuteRegionOp, ExecuteRegionOpInterface>();
|
||||
registry.addOpInterface<ForOp, ForOpInterface>();
|
||||
registry.addOpInterface<IfOp, IfOpInterface>();
|
||||
registry.addOpInterface<YieldOp, YieldOpInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
|
||||
ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
|
||||
ForOp::attachInterface<ForOpInterface>(*ctx);
|
||||
IfOp::attachInterface<IfOpInterface>(*ctx);
|
||||
YieldOp::attachInterface<YieldOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
@ -168,6 +168,8 @@ struct AssumingYieldOpInterface
|
||||
|
||||
void mlir::shape::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<shape::AssumingOp, AssumingOpInterface>();
|
||||
registry.addOpInterface<shape::AssumingYieldOp, AssumingYieldOpInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) {
|
||||
shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx);
|
||||
shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
@ -205,11 +205,11 @@ struct ReifyPadOp
|
||||
|
||||
void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry
|
||||
.addOpInterface<tensor::ExpandShapeOp,
|
||||
ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>();
|
||||
registry
|
||||
.addOpInterface<tensor::CollapseShapeOp,
|
||||
ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>();
|
||||
registry.addOpInterface<tensor::PadOp, ReifyPadOp>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
|
||||
ExpandShapeOp::attachInterface<
|
||||
ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
|
||||
CollapseShapeOp::attachInterface<
|
||||
ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
|
||||
PadOp::attachInterface<ReifyPadOp>(*ctx);
|
||||
});
|
||||
}
|
||||
|
@ -283,5 +283,7 @@ Operation *tensor::bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
|
||||
|
||||
void mlir::tensor::registerTilingOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<tensor::PadOp, PadOpTiling>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
|
||||
tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
|
||||
});
|
||||
}
|
||||
|
@ -700,15 +700,17 @@ struct RankOpInterface
|
||||
|
||||
void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<CastOp, CastOpInterface>();
|
||||
registry.addOpInterface<CollapseShapeOp, CollapseShapeOpInterface>();
|
||||
registry.addOpInterface<DimOp, DimOpInterface>();
|
||||
registry.addOpInterface<ExpandShapeOp, ExpandShapeOpInterface>();
|
||||
registry.addOpInterface<ExtractSliceOp, ExtractSliceOpInterface>();
|
||||
registry.addOpInterface<ExtractOp, ExtractOpInterface>();
|
||||
registry.addOpInterface<FromElementsOp, FromElementsOpInterface>();
|
||||
registry.addOpInterface<GenerateOp, GenerateOpInterface>();
|
||||
registry.addOpInterface<InsertOp, InsertOpInterface>();
|
||||
registry.addOpInterface<InsertSliceOp, InsertSliceOpInterface>();
|
||||
registry.addOpInterface<RankOp, RankOpInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) {
|
||||
CastOp::attachInterface<CastOpInterface>(*ctx);
|
||||
CollapseShapeOp::attachInterface<CollapseShapeOpInterface>(*ctx);
|
||||
DimOp::attachInterface<DimOpInterface>(*ctx);
|
||||
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
|
||||
ExtractSliceOp::attachInterface<ExtractSliceOpInterface>(*ctx);
|
||||
ExtractOp::attachInterface<ExtractOpInterface>(*ctx);
|
||||
FromElementsOp::attachInterface<FromElementsOpInterface>(*ctx);
|
||||
GenerateOp::attachInterface<GenerateOpInterface>(*ctx);
|
||||
InsertOp::attachInterface<InsertOpInterface>(*ctx);
|
||||
InsertSliceOp::attachInterface<InsertSliceOpInterface>(*ctx);
|
||||
RankOp::attachInterface<RankOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
@ -121,6 +121,8 @@ struct TransferWriteOpInterface
|
||||
|
||||
void mlir::vector::registerBufferizableOpInterfaceExternalModels(
|
||||
DialectRegistry ®istry) {
|
||||
registry.addOpInterface<TransferReadOp, TransferReadOpInterface>();
|
||||
registry.addOpInterface<TransferWriteOp, TransferWriteOpInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
|
||||
TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
|
||||
TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
|
||||
});
|
||||
}
|
||||
|
@ -24,97 +24,6 @@
|
||||
using namespace mlir;
|
||||
using namespace detail;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectRegistry
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
|
||||
|
||||
void DialectRegistry::addDialectInterface(
|
||||
StringRef dialectName, TypeID interfaceTypeID,
|
||||
const DialectInterfaceAllocatorFunction &allocator) {
|
||||
assert(allocator && "unexpected null interface allocation function");
|
||||
auto it = registry.find(dialectName.str());
|
||||
assert(it != registry.end() &&
|
||||
"adding an interface for an unregistered dialect");
|
||||
|
||||
// Bail out if the interface with the given ID is already in the registry for
|
||||
// the given dialect. We expect a small number (dozens) of interfaces so a
|
||||
// linear search is fine here.
|
||||
auto &ifaces = interfaces[it->second.first];
|
||||
for (const auto &kvp : ifaces.dialectInterfaces) {
|
||||
if (kvp.first == interfaceTypeID) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "[" DEBUG_TYPE
|
||||
"] repeated interface registration for dialect "
|
||||
<< dialectName);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ifaces.dialectInterfaces.emplace_back(interfaceTypeID, allocator);
|
||||
}
|
||||
|
||||
void DialectRegistry::addObjectInterface(
|
||||
StringRef dialectName, TypeID objectID, TypeID interfaceTypeID,
|
||||
const ObjectInterfaceAllocatorFunction &allocator) {
|
||||
assert(allocator && "unexpected null interface allocation function");
|
||||
|
||||
auto it = registry.find(dialectName.str());
|
||||
assert(it != registry.end() &&
|
||||
"adding an interface for an op from an unregistered dialect");
|
||||
|
||||
auto dialectID = it->second.first;
|
||||
auto &ifaces = interfaces[dialectID];
|
||||
|
||||
for (const auto &info : ifaces.objectInterfaces) {
|
||||
if (std::get<0>(info) == objectID && std::get<1>(info) == interfaceTypeID) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "[" DEBUG_TYPE
|
||||
"] repeated interface object interface registration");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ifaces.objectInterfaces.emplace_back(objectID, interfaceTypeID, allocator);
|
||||
}
|
||||
|
||||
DialectAllocatorFunctionRef
|
||||
DialectRegistry::getDialectAllocator(StringRef name) const {
|
||||
auto it = registry.find(name.str());
|
||||
if (it == registry.end())
|
||||
return nullptr;
|
||||
return it->second.second;
|
||||
}
|
||||
|
||||
void DialectRegistry::insert(TypeID typeID, StringRef name,
|
||||
const DialectAllocatorFunction &ctor) {
|
||||
auto inserted = registry.insert(
|
||||
std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
|
||||
if (!inserted.second && inserted.first->second.first != typeID) {
|
||||
llvm::report_fatal_error(
|
||||
"Trying to register different dialects for the same namespace: " +
|
||||
name);
|
||||
}
|
||||
}
|
||||
|
||||
void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
|
||||
auto it = interfaces.find(dialect->getTypeID());
|
||||
if (it == interfaces.end())
|
||||
return;
|
||||
|
||||
// Add an interface if it is not already present.
|
||||
for (const auto &kvp : it->getSecond().dialectInterfaces) {
|
||||
if (dialect->getRegisteredInterface(kvp.first))
|
||||
continue;
|
||||
dialect->addInterface(kvp.second(dialect));
|
||||
}
|
||||
|
||||
// Add attribute, operation and type interfaces.
|
||||
for (const auto &info : it->getSecond().objectInterfaces)
|
||||
std::get<2>(info)(dialect->getContext());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -189,7 +98,13 @@ void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
|
||||
auto it = registeredInterfaces.try_emplace(interface->getID(),
|
||||
std::move(interface));
|
||||
(void)it;
|
||||
assert(it.second && "interface kind has already been registered");
|
||||
LLVM_DEBUG({
|
||||
if (!it.second) {
|
||||
llvm::dbgs() << "[" DEBUG_TYPE
|
||||
"] repeated interface registration for dialect "
|
||||
<< getNamespace();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -216,3 +131,100 @@ const DialectInterface *
|
||||
DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
|
||||
return getInterfaceFor(op->getDialect());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectExtension
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DialectExtensionBase::~DialectExtensionBase() = default;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DialectRegistry
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
|
||||
|
||||
DialectAllocatorFunctionRef
|
||||
DialectRegistry::getDialectAllocator(StringRef name) const {
|
||||
auto it = registry.find(name.str());
|
||||
if (it == registry.end())
|
||||
return nullptr;
|
||||
return it->second.second;
|
||||
}
|
||||
|
||||
void DialectRegistry::insert(TypeID typeID, StringRef name,
|
||||
const DialectAllocatorFunction &ctor) {
|
||||
auto inserted = registry.insert(
|
||||
std::make_pair(std::string(name), std::make_pair(typeID, ctor)));
|
||||
if (!inserted.second && inserted.first->second.first != typeID) {
|
||||
llvm::report_fatal_error(
|
||||
"Trying to register different dialects for the same namespace: " +
|
||||
name);
|
||||
}
|
||||
}
|
||||
|
||||
void DialectRegistry::applyExtensions(Dialect *dialect) const {
|
||||
MLIRContext *ctx = dialect->getContext();
|
||||
StringRef dialectName = dialect->getNamespace();
|
||||
|
||||
// Functor used to try to apply the given extension.
|
||||
auto applyExtension = [&](const DialectExtensionBase &extension) {
|
||||
ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
|
||||
|
||||
// Handle the simple case of a single dialect name. In this case, the
|
||||
// required dialect should be the current dialect.
|
||||
if (dialectNames.size() == 1) {
|
||||
if (dialectNames.front() == dialectName)
|
||||
extension.apply(ctx, dialect);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, check to see if this extension requires this dialect.
|
||||
const StringRef *nameIt = llvm::find(dialectNames, dialectName);
|
||||
if (nameIt == dialectNames.end())
|
||||
return;
|
||||
|
||||
// If it does, ensure that all of the other required dialects have been
|
||||
// loaded.
|
||||
SmallVector<Dialect *> requiredDialects;
|
||||
requiredDialects.reserve(dialectNames.size());
|
||||
for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
|
||||
++it) {
|
||||
// The current dialect is known to be loaded.
|
||||
if (it == nameIt) {
|
||||
requiredDialects.push_back(dialect);
|
||||
continue;
|
||||
}
|
||||
// Otherwise, check if it is loaded.
|
||||
Dialect *loadedDialect = ctx->getLoadedDialect(*it);
|
||||
if (!loadedDialect)
|
||||
return;
|
||||
requiredDialects.push_back(loadedDialect);
|
||||
}
|
||||
extension.apply(ctx, requiredDialects);
|
||||
};
|
||||
|
||||
for (const auto &extension : extensions)
|
||||
applyExtension(*extension);
|
||||
}
|
||||
|
||||
void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
|
||||
// Functor used to try to apply the given extension.
|
||||
auto applyExtension = [&](const DialectExtensionBase &extension) {
|
||||
ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
|
||||
|
||||
// Check to see if all of the dialects for this extension are loaded.
|
||||
SmallVector<Dialect *> requiredDialects;
|
||||
requiredDialects.reserve(dialectNames.size());
|
||||
for (StringRef dialectName : dialectNames) {
|
||||
Dialect *loadedDialect = ctx->getLoadedDialect(dialectName);
|
||||
if (!loadedDialect)
|
||||
return;
|
||||
requiredDialects.push_back(loadedDialect);
|
||||
}
|
||||
extension.apply(ctx, requiredDialects);
|
||||
};
|
||||
|
||||
for (const auto &extension : extensions)
|
||||
applyExtension(*extension);
|
||||
}
|
||||
|
@ -357,9 +357,8 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
|
||||
void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) {
|
||||
registry.appendTo(impl->dialectsRegistry);
|
||||
|
||||
// For the already loaded dialects, register the interfaces immediately.
|
||||
for (const auto &kvp : impl->loadedDialects)
|
||||
registry.registerDelayedInterfaces(kvp.second.get());
|
||||
// For the already loaded dialects, apply any possible extensions immediately.
|
||||
registry.applyExtensions(this);
|
||||
}
|
||||
|
||||
const DialectRegistry &MLIRContext::getDialectRegistry() {
|
||||
@ -437,8 +436,8 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
|
||||
impl.dialectReferencingStrAttrs.erase(stringAttrsIt);
|
||||
}
|
||||
|
||||
// Actually register the interfaces with delayed registration.
|
||||
impl.dialectsRegistry.registerDelayedInterfaces(dialect.get());
|
||||
// Apply any extensions to this newly loaded dialect.
|
||||
impl.dialectsRegistry.applyExtensions(dialect.get());
|
||||
return dialect.get();
|
||||
}
|
||||
|
||||
|
@ -44,8 +44,9 @@ public:
|
||||
|
||||
void mlir::registerAMXDialectTranslation(DialectRegistry ®istry) {
|
||||
registry.insert<amx::AMXDialect>();
|
||||
registry.addDialectInterface<amx::AMXDialect,
|
||||
AMXDialectLLVMIRTranslationInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
|
||||
dialect->addInterfaces<AMXDialectLLVMIRTranslationInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::registerAMXDialectTranslation(MLIRContext &context) {
|
||||
|
@ -45,8 +45,10 @@ public:
|
||||
|
||||
void mlir::registerArmNeonDialectTranslation(DialectRegistry ®istry) {
|
||||
registry.insert<arm_neon::ArmNeonDialect>();
|
||||
registry.addDialectInterface<arm_neon::ArmNeonDialect,
|
||||
ArmNeonDialectLLVMIRTranslationInterface>();
|
||||
registry.addExtension(
|
||||
+[](MLIRContext *ctx, arm_neon::ArmNeonDialect *dialect) {
|
||||
dialect->addInterfaces<ArmNeonDialectLLVMIRTranslationInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::registerArmNeonDialectTranslation(MLIRContext &context) {
|
||||
|
@ -44,8 +44,9 @@ public:
|
||||
|
||||
void mlir::registerArmSVEDialectTranslation(DialectRegistry ®istry) {
|
||||
registry.insert<arm_sve::ArmSVEDialect>();
|
||||
registry.addDialectInterface<arm_sve::ArmSVEDialect,
|
||||
ArmSVEDialectLLVMIRTranslationInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, arm_sve::ArmSVEDialect *dialect) {
|
||||
dialect->addInterfaces<ArmSVEDialectLLVMIRTranslationInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::registerArmSVEDialectTranslation(MLIRContext &context) {
|
||||
|
@ -503,8 +503,9 @@ public:
|
||||
|
||||
void mlir::registerLLVMDialectTranslation(DialectRegistry ®istry) {
|
||||
registry.insert<LLVM::LLVMDialect>();
|
||||
registry.addDialectInterface<LLVM::LLVMDialect,
|
||||
LLVMDialectLLVMIRTranslationInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
|
||||
dialect->addInterfaces<LLVMDialectLLVMIRTranslationInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::registerLLVMDialectTranslation(MLIRContext &context) {
|
||||
|
@ -141,8 +141,9 @@ public:
|
||||
|
||||
void mlir::registerNVVMDialectTranslation(DialectRegistry ®istry) {
|
||||
registry.insert<NVVM::NVVMDialect>();
|
||||
registry.addDialectInterface<NVVM::NVVMDialect,
|
||||
NVVMDialectLLVMIRTranslationInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
|
||||
dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::registerNVVMDialectTranslation(MLIRContext &context) {
|
||||
|
@ -533,8 +533,9 @@ LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
|
||||
|
||||
void mlir::registerOpenACCDialectTranslation(DialectRegistry ®istry) {
|
||||
registry.insert<acc::OpenACCDialect>();
|
||||
registry.addDialectInterface<acc::OpenACCDialect,
|
||||
OpenACCDialectLLVMIRTranslationInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, acc::OpenACCDialect *dialect) {
|
||||
dialect->addInterfaces<OpenACCDialectLLVMIRTranslationInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::registerOpenACCDialectTranslation(MLIRContext &context) {
|
||||
|
@ -1270,8 +1270,9 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
|
||||
|
||||
void mlir::registerOpenMPDialectTranslation(DialectRegistry ®istry) {
|
||||
registry.insert<omp::OpenMPDialect>();
|
||||
registry.addDialectInterface<omp::OpenMPDialect,
|
||||
OpenMPDialectLLVMIRTranslationInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
|
||||
dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
|
||||
|
@ -107,8 +107,9 @@ public:
|
||||
|
||||
void mlir::registerROCDLDialectTranslation(DialectRegistry ®istry) {
|
||||
registry.insert<ROCDL::ROCDLDialect>();
|
||||
registry.addDialectInterface<ROCDL::ROCDLDialect,
|
||||
ROCDLDialectLLVMIRTranslationInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, ROCDL::ROCDLDialect *dialect) {
|
||||
dialect->addInterfaces<ROCDLDialectLLVMIRTranslationInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::registerROCDLDialectTranslation(MLIRContext &context) {
|
||||
|
@ -45,8 +45,10 @@ public:
|
||||
|
||||
void mlir::registerX86VectorDialectTranslation(DialectRegistry ®istry) {
|
||||
registry.insert<x86vector::X86VectorDialect>();
|
||||
registry.addDialectInterface<x86vector::X86VectorDialect,
|
||||
X86VectorDialectLLVMIRTranslationInterface>();
|
||||
registry.addExtension(
|
||||
+[](MLIRContext *ctx, x86vector::X86VectorDialect *dialect) {
|
||||
dialect->addInterfaces<X86VectorDialectLLVMIRTranslationInterface>();
|
||||
});
|
||||
}
|
||||
|
||||
void mlir::registerX86VectorDialectTranslation(MLIRContext &context) {
|
||||
|
@ -63,7 +63,9 @@ TEST(Dialect, DelayedInterfaceRegistration) {
|
||||
registry.insert<TestDialect, SecondTestDialect>();
|
||||
|
||||
// Delayed registration of an interface for TestDialect.
|
||||
registry.addDialectInterface<TestDialect, TestDialectInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
|
||||
dialect->addInterfaces<TestDialectInterface>();
|
||||
});
|
||||
|
||||
MLIRContext context(registry);
|
||||
|
||||
@ -85,8 +87,10 @@ TEST(Dialect, DelayedInterfaceRegistration) {
|
||||
// loaded dialect and check that the interface is now registered.
|
||||
DialectRegistry secondRegistry;
|
||||
secondRegistry.insert<SecondTestDialect>();
|
||||
secondRegistry
|
||||
.addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
|
||||
secondRegistry.addExtension(
|
||||
+[](MLIRContext *ctx, SecondTestDialect *dialect) {
|
||||
dialect->addInterfaces<SecondTestDialectInterface>();
|
||||
});
|
||||
context.appendDialectRegistry(secondRegistry);
|
||||
secondTestDialectInterface =
|
||||
dyn_cast<SecondTestDialectInterface>(secondTestDialect);
|
||||
@ -97,7 +101,9 @@ TEST(Dialect, RepeatedDelayedRegistration) {
|
||||
// Set up the delayed registration.
|
||||
DialectRegistry registry;
|
||||
registry.insert<TestDialect>();
|
||||
registry.addDialectInterface<TestDialect, TestDialectInterface>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
|
||||
dialect->addInterfaces<TestDialectInterface>();
|
||||
});
|
||||
MLIRContext context(registry);
|
||||
|
||||
// Load the TestDialect and check that the interface got registered for it.
|
||||
@ -110,33 +116,12 @@ TEST(Dialect, RepeatedDelayedRegistration) {
|
||||
// on repeated interface registration.
|
||||
DialectRegistry secondRegistry;
|
||||
secondRegistry.insert<TestDialect>();
|
||||
secondRegistry.addDialectInterface<TestDialect, TestDialectInterface>();
|
||||
secondRegistry.addExtension(+[](MLIRContext *ctx, TestDialect *dialect) {
|
||||
dialect->addInterfaces<TestDialectInterface>();
|
||||
});
|
||||
context.appendDialectRegistry(secondRegistry);
|
||||
testDialectInterface = dyn_cast<TestDialectInterfaceBase>(testDialect);
|
||||
EXPECT_TRUE(testDialectInterface != nullptr);
|
||||
}
|
||||
|
||||
// A dialect that registers two interfaces with the same InterfaceID, triggering
|
||||
// an assertion failure.
|
||||
struct RepeatedRegistrationDialect : public Dialect {
|
||||
static StringRef getDialectNamespace() { return "repeatedreg"; }
|
||||
RepeatedRegistrationDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context,
|
||||
TypeID::get<RepeatedRegistrationDialect>()) {
|
||||
addInterfaces<TestDialectInterface>();
|
||||
addInterfaces<SecondTestDialectInterface>();
|
||||
}
|
||||
};
|
||||
|
||||
TEST(Dialect, RepeatedInterfaceRegistrationDeath) {
|
||||
MLIRContext context;
|
||||
(void)context;
|
||||
|
||||
// This triggers an assertion in debug mode.
|
||||
#ifndef NDEBUG
|
||||
ASSERT_DEATH(context.loadDialect<RepeatedRegistrationDialect>(),
|
||||
"interface kind has already been registered");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -102,7 +102,9 @@ TEST(InterfaceAttachment, TypeDelayedContextConstruct) {
|
||||
// Put the interface in the registry.
|
||||
DialectRegistry registry;
|
||||
registry.insert<test::TestDialect>();
|
||||
registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
|
||||
test::TestType::attachInterface<TestTypeModel>(*ctx);
|
||||
});
|
||||
|
||||
// Check that when a context is constructed with the given registry, the type
|
||||
// interface gets registered.
|
||||
@ -119,7 +121,9 @@ TEST(InterfaceAttachment, TypeDelayedContextAppend) {
|
||||
// Put the interface in the registry.
|
||||
DialectRegistry registry;
|
||||
registry.insert<test::TestDialect>();
|
||||
registry.addTypeInterface<test::TestDialect, test::TestType, TestTypeModel>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
|
||||
test::TestType::attachInterface<TestTypeModel>(*ctx);
|
||||
});
|
||||
|
||||
// Check that when the registry gets appended to the context, the interface
|
||||
// becomes available for objects in loaded dialects.
|
||||
@ -133,7 +137,9 @@ TEST(InterfaceAttachment, TypeDelayedContextAppend) {
|
||||
|
||||
TEST(InterfaceAttachment, RepeatedRegistration) {
|
||||
DialectRegistry registry;
|
||||
registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
|
||||
IntegerType::attachInterface<Model>(*ctx);
|
||||
});
|
||||
MLIRContext context(registry);
|
||||
|
||||
// Should't fail on repeated registration through the dialect registry.
|
||||
@ -144,7 +150,9 @@ TEST(InterfaceAttachment, TypeBuiltinDelayed) {
|
||||
// Builtin dialect needs to registration or loading, but delayed interface
|
||||
// registration must still work.
|
||||
DialectRegistry registry;
|
||||
registry.addTypeInterface<BuiltinDialect, IntegerType, Model>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
|
||||
IntegerType::attachInterface<Model>(*ctx);
|
||||
});
|
||||
|
||||
MLIRContext context(registry);
|
||||
IntegerType i16 = IntegerType::get(&context, 16);
|
||||
@ -238,8 +246,9 @@ TEST(InterfaceAttachmentTest, AttributeDelayed) {
|
||||
// that the delayed registration work for attributes.
|
||||
DialectRegistry registry;
|
||||
registry.insert<test::TestDialect>();
|
||||
registry.addAttrInterface<test::TestDialect, test::SimpleAAttr,
|
||||
TestExternalSimpleAAttrModel>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
|
||||
test::SimpleAAttr::attachInterface<TestExternalSimpleAAttrModel>(*ctx);
|
||||
});
|
||||
|
||||
MLIRContext context(registry);
|
||||
context.loadDialect<test::TestDialect>();
|
||||
@ -343,12 +352,16 @@ struct TestExternalTestOpModel
|
||||
TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
|
||||
DialectRegistry registry;
|
||||
registry.insert<test::TestDialect>();
|
||||
registry.addOpInterface<ModuleOp, TestExternalOpModel>();
|
||||
registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
|
||||
registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
|
||||
ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
|
||||
});
|
||||
registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
|
||||
test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
|
||||
test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
|
||||
});
|
||||
|
||||
// Construct the context directly from a registry. The interfaces are expected
|
||||
// to be readily available on operations.
|
||||
// Construct the context directly from a registry. The interfaces are
|
||||
// expected to be readily available on operations.
|
||||
MLIRContext context(registry);
|
||||
context.loadDialect<test::TestDialect>();
|
||||
|
||||
@ -370,9 +383,13 @@ TEST(InterfaceAttachment, OperationDelayedContextConstruct) {
|
||||
TEST(InterfaceAttachment, OperationDelayedContextAppend) {
|
||||
DialectRegistry registry;
|
||||
registry.insert<test::TestDialect>();
|
||||
registry.addOpInterface<ModuleOp, TestExternalOpModel>();
|
||||
registry.addOpInterface<test::OpJ, TestExternalTestOpModel<test::OpJ>>();
|
||||
registry.addOpInterface<test::OpH, TestExternalTestOpModel<test::OpH>>();
|
||||
registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
|
||||
ModuleOp::attachInterface<TestExternalOpModel>(*ctx);
|
||||
});
|
||||
registry.addExtension(+[](MLIRContext *ctx, test::TestDialect *dialect) {
|
||||
test::OpJ::attachInterface<TestExternalTestOpModel<test::OpJ>>(*ctx);
|
||||
test::OpH::attachInterface<TestExternalTestOpModel<test::OpH>>(*ctx);
|
||||
});
|
||||
|
||||
// Construct the context, create ops, and only then append the registry. The
|
||||
// interfaces are expected to be available after appending the registry.
|
||||
|
Loading…
x
Reference in New Issue
Block a user