[mlir][bufferization][NFC] Move memref specific implementation of AllocationOpInterface to memref dialect directory (#66637)

Follow-up on #65578
This commit is contained in:
Martin Erhart 2023-09-20 14:49:52 +02:00 committed by GitHub
parent e88a64f7ab
commit 65341b09b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 93 additions and 62 deletions

View File

@ -211,9 +211,6 @@ std::unique_ptr<Pass> createBufferizationBufferizePass();
// Registration
//===----------------------------------------------------------------------===//
/// Register external models for AllocationOpInterface.
void registerAllocationOpInterfaceExternalModels(DialectRegistry &registry);
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"

View File

@ -0,0 +1,20 @@
//===- AllocationOpInterfaceImpl.h - Impl. of AllocationOpInterface -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H
#define MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
namespace memref {
void registerAllocationOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace memref
} // namespace mlir
#endif // MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H

View File

@ -51,6 +51,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@ -149,6 +150,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
linalg::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerTilingInterfaceExternalModels(registry);
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
memref::registerBufferizableOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);

View File

@ -174,5 +174,4 @@ public:
void mlir::bufferization::registerTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<BufferizationTransformDialectExtension>();
bufferization::registerAllocationOpInterfaceExternalModels(registry);
}

View File

@ -634,7 +634,6 @@ struct BufferDeallocationPass
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<bufferization::BufferizationDialect>();
registry.insert<memref::MemRefDialect>();
registerAllocationOpInterfaceExternalModels(registry);
}
void runOnOperation() override {

View File

@ -195,7 +195,6 @@ struct OneShotBufferizePass
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
registerAllocationOpInterfaceExternalModels(registry);
}
void runOnOperation() override {
@ -672,59 +671,3 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
options.opFilter.allowDialect<BufferizationDialect>();
return options;
}
//===----------------------------------------------------------------------===//
// Default AllocationOpInterface implementation and registration
//===----------------------------------------------------------------------===//
namespace {
struct DefaultAllocationInterface
: public bufferization::AllocationOpInterface::ExternalModel<
DefaultAllocationInterface, memref::AllocOp> {
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
Value alloc) {
return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
.getOperation();
}
static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
.getResult();
}
static ::mlir::HoistingKind getHoistingKind() {
return HoistingKind::Loop | HoistingKind::Block;
}
static ::std::optional<::mlir::Operation *>
buildPromotedAlloc(OpBuilder &builder, Value alloc) {
Operation *definingOp = alloc.getDefiningOp();
return builder.create<memref::AllocaOp>(
definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
definingOp->getOperands(), definingOp->getAttrs());
}
};
struct DefaultAutomaticAllocationHoistingInterface
: public bufferization::AllocationOpInterface::ExternalModel<
DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> {
static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; }
};
struct DefaultReallocationInterface
: public bufferization::AllocationOpInterface::ExternalModel<
DefaultAllocationInterface, memref::ReallocOp> {
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
Value realloc) {
return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
.getOperation();
}
};
} // namespace
void bufferization::registerAllocationOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
memref::AllocaOp::attachInterface<
DefaultAutomaticAllocationHoistingInterface>(*ctx);
memref::ReallocOp::attachInterface<DefaultReallocationInterface>(*ctx);
});
}

View File

@ -0,0 +1,69 @@
//===- AllocationOpInterfaceImpl.cpp - Impl. of AllocationOpInterface -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
using namespace mlir;
namespace {
struct DefaultAllocationInterface
: public bufferization::AllocationOpInterface::ExternalModel<
DefaultAllocationInterface, memref::AllocOp> {
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
Value alloc) {
return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
.getOperation();
}
static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
.getResult();
}
static ::mlir::HoistingKind getHoistingKind() {
return HoistingKind::Loop | HoistingKind::Block;
}
static ::std::optional<::mlir::Operation *>
buildPromotedAlloc(OpBuilder &builder, Value alloc) {
Operation *definingOp = alloc.getDefiningOp();
return builder.create<memref::AllocaOp>(
definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
definingOp->getOperands(), definingOp->getAttrs());
}
};
struct DefaultAutomaticAllocationHoistingInterface
: public bufferization::AllocationOpInterface::ExternalModel<
DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> {
static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; }
};
struct DefaultReallocationInterface
: public bufferization::AllocationOpInterface::ExternalModel<
DefaultAllocationInterface, memref::ReallocOp> {
static std::optional<Operation *> buildDealloc(OpBuilder &builder,
Value realloc) {
return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
.getOperation();
}
};
} // namespace
void mlir::memref::registerAllocationOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
memref::AllocaOp::attachInterface<
DefaultAutomaticAllocationHoistingInterface>(*ctx);
memref::ReallocOp::attachInterface<DefaultReallocationInterface>(*ctx);
});
}

View File

@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRMemRefTransforms
AllocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
ComposeSubView.cpp
ExpandOps.cpp

View File

@ -11738,6 +11738,7 @@ cc_library(
":AffineDialect",
":AffineTransforms",
":AffineUtils",
":AllocationOpInterface",
":ArithDialect",
":ArithTransforms",
":ArithUtils",