mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-28 06:17:32 +00:00
[mlir] avoid exposing mutable DialectRegistry from MLIRContext
MLIRContext allows its users to access directly to the DialectRegistry it contains. While sometimes useful for registering additional dialects on an already existing context, this breaks the encapsulation by essentially giving raw accesses to a part of the context's internal state. Remove this mutable access and instead provide a method to append a given DialectRegistry to the one already contained in the context. Also provide a shortcut mechanism to construct a context from an already existing registry, which seems to be a common use case in the wild. Keep read-only access to the registry contained in the context in case it needs to be copied or used for constructing another context. With this change, DialectRegistry is no longer concerned with loading the dialects and deciding whether to invoke delayed interface registration. Loading is concentrated in the MLIRContext, and the functionality of the registry better reflects its name. Depends On D96137 Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D96331
This commit is contained in:
parent
3da51522fb
commit
2996a8d675
@ -61,8 +61,9 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
|
||||
// load the file into a module
|
||||
SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), SMLoc());
|
||||
mlir::MLIRContext context;
|
||||
fir::registerFIRDialects(context.getDialectRegistry());
|
||||
mlir::DialectRegistry registry;
|
||||
fir::registerFIRDialects(registry);
|
||||
mlir::MLIRContext context(registry);
|
||||
auto owningRef = mlir::parseSourceFile(sourceMgr, &context);
|
||||
|
||||
if (!owningRef) {
|
||||
|
@ -35,7 +35,9 @@ typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks;
|
||||
|
||||
#define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName) \
|
||||
static void mlirContextRegister##Name##Dialect(MlirContext context) { \
|
||||
unwrap(context)->getDialectRegistry().insert<ClassName>(); \
|
||||
mlir::DialectRegistry registry; \
|
||||
registry.insert<ClassName>(); \
|
||||
unwrap(context)->appendDialectRegistry(registry); \
|
||||
} \
|
||||
static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) { \
|
||||
return wrap(unwrap(context)->getOrLoadDialect<ClassName>()); \
|
||||
|
@ -26,6 +26,7 @@ class OpBuilder;
|
||||
class Type;
|
||||
|
||||
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
|
||||
using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
|
||||
using InterfaceAllocatorFunction =
|
||||
std::function<std::unique_ptr<DialectInterface>(Dialect *)>;
|
||||
|
||||
@ -241,8 +242,7 @@ class DialectRegistry {
|
||||
DenseMap<TypeID, SmallVector<InterfaceAllocatorFunction, 2>>;
|
||||
|
||||
public:
|
||||
explicit DialectRegistry(MLIRContext *context = nullptr)
|
||||
: owningContext(context) {}
|
||||
explicit DialectRegistry() {}
|
||||
|
||||
template <typename ConcreteDialect>
|
||||
void insert() {
|
||||
@ -267,42 +267,37 @@ public:
|
||||
/// ownership of the dialect and for delayed interface registration to happen.
|
||||
void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor);
|
||||
|
||||
/// Load a dialect for this namespace in the provided context.
|
||||
Dialect *loadByName(StringRef name, MLIRContext *context);
|
||||
/// Return an allocation function for constructing the dialect identified by
|
||||
/// its namespace, or nullptr if the namespace is not in this registry.
|
||||
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
|
||||
|
||||
// Register all dialects available in the current registry with the registry
|
||||
// in the provided context.
|
||||
void appendTo(DialectRegistry &destination) {
|
||||
void appendTo(DialectRegistry &destination) const {
|
||||
for (const auto &nameAndRegistrationIt : registry)
|
||||
destination.insert(nameAndRegistrationIt.second.first,
|
||||
nameAndRegistrationIt.first,
|
||||
nameAndRegistrationIt.second.second);
|
||||
destination.interfaces.insert(interfaces.begin(), interfaces.end());
|
||||
}
|
||||
// Load all dialects available in the registry in the provided context.
|
||||
void loadAll(MLIRContext *context) {
|
||||
for (const auto &nameAndRegistrationIt : registry)
|
||||
nameAndRegistrationIt.second.second(context);
|
||||
}
|
||||
|
||||
/// Return the names of dialects known to this registry.
|
||||
auto getDialectNames() {
|
||||
auto getDialectNames() const {
|
||||
return llvm::map_range(
|
||||
registry, [](const MapTy::value_type &item) { return item.first; });
|
||||
registry,
|
||||
[](const MapTy::value_type &item) -> StringRef { return item.first; });
|
||||
}
|
||||
|
||||
/// Add an interface constructed with the given allocation function to the
|
||||
/// dialect provided as template parameter. The dialect must be present in
|
||||
/// the registry, but may or may not be loaded. If it is not loaded, the
|
||||
/// interface registration is delayed until the loading.
|
||||
/// the registry.
|
||||
template <typename DialectTy>
|
||||
void addDialectInterface(InterfaceAllocatorFunction allocator) {
|
||||
addDialectInterface(DialectTy::getDialectNamespace(), allocator);
|
||||
}
|
||||
|
||||
/// Add an interface to the dialect, both provided as template parameter. The
|
||||
/// dialect must be present in the registry, but may or may not be loaded. If
|
||||
/// it is not loaded, the interface registration is delayed until the loading.
|
||||
/// dialect must be present in the registry.
|
||||
template <typename DialectTy, typename InterfaceTy>
|
||||
void addDialectInterface() {
|
||||
addDialectInterface<DialectTy>([](Dialect *dialect) {
|
||||
@ -312,7 +307,7 @@ public:
|
||||
|
||||
/// Register any interfaces required for the given dialect (based on its
|
||||
/// TypeID). Users are not expected to call this directly.
|
||||
void registerDelayedInterfaces(Dialect *dialect);
|
||||
void registerDelayedInterfaces(Dialect *dialect) const;
|
||||
|
||||
private:
|
||||
/// Add an interface constructed with the given allocation function to the
|
||||
@ -322,10 +317,6 @@ private:
|
||||
|
||||
MapTy registry;
|
||||
InterfaceMapTy interfaces;
|
||||
|
||||
/// If this registry belongs to a context, this points back to the context.
|
||||
/// Useful for checking if a dialect is loaded in the context.
|
||||
MLIRContext *owningContext;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -36,17 +36,19 @@ class StorageUniquer;
|
||||
class MLIRContext {
|
||||
public:
|
||||
/// Create a new Context.
|
||||
/// The loadAllDialects parameters allows to load all dialects from the global
|
||||
/// registry on Context construction. It is deprecated and will be removed
|
||||
/// soon.
|
||||
explicit MLIRContext();
|
||||
explicit MLIRContext(const DialectRegistry ®istry);
|
||||
~MLIRContext();
|
||||
|
||||
/// Return information about all IR dialects loaded in the context.
|
||||
std::vector<Dialect *> getLoadedDialects();
|
||||
|
||||
/// Return the dialect registry associated with this context.
|
||||
DialectRegistry &getDialectRegistry();
|
||||
const DialectRegistry &getDialectRegistry();
|
||||
|
||||
/// Append the contents of the given dialect registry to the registry
|
||||
/// associated with this context.
|
||||
void appendDialectRegistry(const DialectRegistry ®istry);
|
||||
|
||||
/// Return information about all available dialects in the registry in this
|
||||
/// context.
|
||||
@ -87,6 +89,9 @@ public:
|
||||
loadDialect<OtherDialect, MoreDialects...>();
|
||||
}
|
||||
|
||||
/// Load all dialects available in the registry in this context.
|
||||
void loadAllAvailableDialects();
|
||||
|
||||
/// Get (or create) a dialect for the given derived dialect name.
|
||||
/// The dialect will be loaded from the registry if no dialect is found.
|
||||
/// If no dialect is loaded for this name and none is available in the
|
||||
|
@ -45,7 +45,7 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// Add all the MLIR dialects to the provided registry.
|
||||
/// Add all the MLIR dialects to the provided registry.
|
||||
inline void registerAllDialects(DialectRegistry ®istry) {
|
||||
// clang-format off
|
||||
registry.insert<acc::OpenACCDialect,
|
||||
@ -78,6 +78,13 @@ inline void registerAllDialects(DialectRegistry ®istry) {
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
/// Append all the MLIR dialects to the registry contained in the given context.
|
||||
inline void registerAllDialects(MLIRContext &context) {
|
||||
DialectRegistry registry;
|
||||
registerAllDialects(registry);
|
||||
context.appendDialectRegistry(registry);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_INITALLDIALECTS_H_
|
||||
|
@ -12,7 +12,7 @@
|
||||
#include "mlir/InitAllDialects.h"
|
||||
|
||||
void mlirRegisterAllDialects(MlirContext context) {
|
||||
registerAllDialects(unwrap(context)->getDialectRegistry());
|
||||
mlir::registerAllDialects(*unwrap(context));
|
||||
// TODO: we may not want to eagerly load here.
|
||||
unwrap(context)->getDialectRegistry().loadAll(unwrap(context));
|
||||
unwrap(context)->loadAllAvailableDialects();
|
||||
}
|
||||
|
@ -331,7 +331,7 @@ int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) {
|
||||
}
|
||||
|
||||
MLIRContext context;
|
||||
registerAllDialects(context.getDialectRegistry());
|
||||
registerAllDialects(context);
|
||||
|
||||
auto m = parseMLIRInput(options.inputFilename, &context);
|
||||
if (!m) {
|
||||
|
@ -29,27 +29,18 @@ DialectAsmParser::~DialectAsmParser() {}
|
||||
void DialectRegistry::addDialectInterface(
|
||||
StringRef dialectName, InterfaceAllocatorFunction allocator) {
|
||||
assert(allocator && "unexpected null interface allocation function");
|
||||
|
||||
// If the dialect is already loaded, directly add the interface.
|
||||
if (Dialect *dialect = owningContext
|
||||
? owningContext->getLoadedDialect(dialectName)
|
||||
: nullptr) {
|
||||
dialect->addInterface(allocator(dialect));
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, store it in the interface map for delayed registration.
|
||||
auto it = registry.find(dialectName.str());
|
||||
assert(it != registry.end() &&
|
||||
"adding an interface for an unregistered dialect");
|
||||
interfaces[it->second.first].push_back(allocator);
|
||||
}
|
||||
|
||||
Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) {
|
||||
DialectAllocatorFunctionRef
|
||||
DialectRegistry::getDialectAllocator(StringRef name) const {
|
||||
auto it = registry.find(name.str());
|
||||
if (it == registry.end())
|
||||
return nullptr;
|
||||
return it->second.second(context);
|
||||
return it->second.second;
|
||||
}
|
||||
|
||||
void DialectRegistry::insert(TypeID typeID, StringRef name,
|
||||
@ -63,7 +54,7 @@ void DialectRegistry::insert(TypeID typeID, StringRef name,
|
||||
}
|
||||
}
|
||||
|
||||
void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) {
|
||||
void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const {
|
||||
auto it = interfaces.find(dialect->getTypeID());
|
||||
if (it == interfaces.end())
|
||||
return;
|
||||
|
@ -326,8 +326,7 @@ public:
|
||||
DictionaryAttr emptyDictionaryAttr;
|
||||
|
||||
public:
|
||||
MLIRContextImpl(MLIRContext *ctx)
|
||||
: dialectsRegistry(ctx), identifiers(identifierAllocator) {}
|
||||
MLIRContextImpl() : identifiers(identifierAllocator) {}
|
||||
~MLIRContextImpl() {
|
||||
for (auto typeMapping : registeredTypes)
|
||||
typeMapping.second->~AbstractType();
|
||||
@ -337,7 +336,10 @@ public:
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
MLIRContext::MLIRContext() : impl(new MLIRContextImpl(this)) {
|
||||
MLIRContext::MLIRContext() : MLIRContext(DialectRegistry()) {}
|
||||
|
||||
MLIRContext::MLIRContext(const DialectRegistry ®istry)
|
||||
: impl(new MLIRContextImpl) {
|
||||
// Initialize values based on the command line flags if they were provided.
|
||||
if (clOptions.isConstructed()) {
|
||||
disableMultithreading(clOptions->disableThreading);
|
||||
@ -348,6 +350,9 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl(this)) {
|
||||
// Ensure the builtin dialect is always pre-loaded.
|
||||
getOrLoadDialect<BuiltinDialect>();
|
||||
|
||||
// Pre-populate the registry.
|
||||
registry.appendTo(impl->dialectsRegistry);
|
||||
|
||||
// Initialize several common attributes and types to avoid the need to lock
|
||||
// the context when accessing them.
|
||||
|
||||
@ -424,7 +429,15 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
|
||||
// Dialect and Operation Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DialectRegistry &MLIRContext::getDialectRegistry() {
|
||||
void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) {
|
||||
registry.appendTo(impl->dialectsRegistry);
|
||||
|
||||
// For the already loaded dialects, register the interfaces immediately.
|
||||
for (const auto &kvp : impl->loadedDialects)
|
||||
registry.registerDelayedInterfaces(kvp.second.get());
|
||||
}
|
||||
|
||||
const DialectRegistry &MLIRContext::getDialectRegistry() {
|
||||
return impl->dialectsRegistry;
|
||||
}
|
||||
|
||||
@ -459,7 +472,9 @@ Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
|
||||
Dialect *dialect = getLoadedDialect(name);
|
||||
if (dialect)
|
||||
return dialect;
|
||||
return impl->dialectsRegistry.loadByName(name, this);
|
||||
DialectAllocatorFunctionRef allocator =
|
||||
impl->dialectsRegistry.getDialectAllocator(name);
|
||||
return allocator ? allocator(this) : nullptr;
|
||||
}
|
||||
|
||||
/// Get a dialect for the provided namespace and TypeID: abort the program if a
|
||||
@ -507,6 +522,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
|
||||
return dialect.get();
|
||||
}
|
||||
|
||||
void MLIRContext::loadAllAvailableDialects() {
|
||||
for (StringRef name : getAvailableDialects())
|
||||
getOrLoadDialect(name);
|
||||
}
|
||||
|
||||
llvm::hash_code MLIRContext::getRegistryHash() {
|
||||
llvm::hash_code hash(0);
|
||||
// Factor in number of loaded dialects, attributes, operations, types.
|
||||
|
@ -865,7 +865,9 @@ LogicalResult PassManager::run(Operation *op) {
|
||||
// Register all dialects for the current pipeline.
|
||||
DialectRegistry dependentDialects;
|
||||
getDependentDialects(dependentDialects);
|
||||
dependentDialects.loadAll(context);
|
||||
context->appendDialectRegistry(dependentDialects);
|
||||
for (StringRef name : dependentDialects.getDialectNames())
|
||||
context->getOrLoadDialect(name);
|
||||
|
||||
// Initialize all of the passes within the pass manager with a new generation.
|
||||
llvm::hash_code newInitKey = context->getRegistryHash();
|
||||
|
@ -95,10 +95,9 @@ static LogicalResult processBuffer(raw_ostream &os,
|
||||
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
|
||||
|
||||
// Parse the input file.
|
||||
MLIRContext context;
|
||||
registry.appendTo(context.getDialectRegistry());
|
||||
MLIRContext context(registry);
|
||||
if (preloadDialectsInContext)
|
||||
registry.loadAll(&context);
|
||||
context.loadAllAvailableDialects();
|
||||
context.allowUnregisteredDialects(allowUnregisteredDialects);
|
||||
context.printOpOnDiagnostic(!verifyDiagnostics);
|
||||
|
||||
|
@ -136,8 +136,9 @@ static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
|
||||
if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo)))
|
||||
return failure();
|
||||
|
||||
MLIRContext deserializationContext;
|
||||
context->getDialectRegistry().loadAll(&deserializationContext);
|
||||
MLIRContext deserializationContext(context->getDialectRegistry());
|
||||
// TODO: we should only load the required dialects instead of all dialects.
|
||||
deserializationContext.loadAllAvailableDialects();
|
||||
// Then deserialize to get back a SPIR-V module.
|
||||
spirv::OwningSPIRVModuleRef spirvModule =
|
||||
spirv::deserialize(binary, &deserializationContext);
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include "mlir/Translation.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Support/FileUtilities.h"
|
||||
@ -97,7 +98,9 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
|
||||
registerTranslation(name, [function, dialectRegistration](
|
||||
llvm::SourceMgr &sourceMgr, raw_ostream &output,
|
||||
MLIRContext *context) {
|
||||
dialectRegistration(context->getDialectRegistry());
|
||||
DialectRegistry registry;
|
||||
dialectRegistration(registry);
|
||||
context->appendDialectRegistry(registry);
|
||||
auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
|
||||
if (!module)
|
||||
return failure();
|
||||
|
@ -89,11 +89,12 @@ int main(int argc, char **argv) {
|
||||
if (!output)
|
||||
llvm::report_fatal_error(errorMessage);
|
||||
|
||||
mlir::MLIRContext context;
|
||||
registerAllDialects(context.getDialectRegistry());
|
||||
mlir::DialectRegistry registry;
|
||||
registerAllDialects(registry);
|
||||
#ifdef MLIR_INCLUDE_TESTS
|
||||
mlir::test::registerTestDialect(context.getDialectRegistry());
|
||||
mlir::test::registerTestDialect(registry);
|
||||
#endif
|
||||
mlir::MLIRContext context(registry);
|
||||
|
||||
mlir::OwningModuleRef moduleRef;
|
||||
if (failed(loadModule(context, moduleRef, inputFilename)))
|
||||
|
@ -65,8 +65,7 @@ TEST(Dialect, DelayedInterfaceRegistration) {
|
||||
// Delayed registration of an interface for TestDialect.
|
||||
registry.addDialectInterface<TestDialect, TestDialectInterface>();
|
||||
|
||||
MLIRContext context;
|
||||
registry.appendTo(context.getDialectRegistry());
|
||||
MLIRContext context(registry);
|
||||
|
||||
// Load the TestDialect and check that the interface got registered for it.
|
||||
auto *testDialect = context.getOrLoadDialect<TestDialect>();
|
||||
@ -85,8 +84,11 @@ TEST(Dialect, DelayedInterfaceRegistration) {
|
||||
|
||||
// Use the same mechanism as for delayed registration but for an already
|
||||
// loaded dialect and check that the interface is now registered.
|
||||
context.getDialectRegistry()
|
||||
DialectRegistry secondRegistry;
|
||||
secondRegistry.insert<SecondTestDialect>();
|
||||
secondRegistry
|
||||
.addDialectInterface<SecondTestDialect, SecondTestDialectInterface>();
|
||||
context.appendDialectRegistry(secondRegistry);
|
||||
secondTestDialectInterface =
|
||||
secondTestDialect->getRegisteredInterface<SecondTestDialectInterface>();
|
||||
EXPECT_TRUE(secondTestDialectInterface != nullptr);
|
||||
|
Loading…
x
Reference in New Issue
Block a user