mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-18 16:58:23 +00:00
[mlir][linalg][bufferize] Separate pass from ComprehensiveBufferize
This commit separates the bufferization from the bufferization pass in Linalg. This allows other dialects to use ComprehensiveBufferize more easily. This commit mainly moves files to a new directory and adds a new build target. Differential Revision: https://reviews.llvm.org/D112989
This commit is contained in:
parent
005456e5fc
commit
95e62eb430
@ -1,5 +1,5 @@
|
||||
add_subdirectory(ComprehensiveBufferize)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Linalg)
|
||||
|
@ -6,8 +6,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE_H_
|
||||
#define MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE_H_
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
@ -31,6 +31,6 @@ enum class BufferRelation {
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h.inc"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h.inc"
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE_H_
|
||||
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZABLEOPINTERFACE_H_
|
@ -6,8 +6,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE
|
||||
#define MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE
|
||||
#ifndef BUFFERIZABLE_OP_INTERFACE
|
||||
#define BUFFERIZABLE_OP_INTERFACE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
@ -176,4 +176,4 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||
];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE
|
||||
#endif // BUFFERIZABLE_OP_INTERFACE
|
@ -6,8 +6,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_COMPREHENSIVE_BUFFERIZE_H
|
||||
#define MLIR_DIALECT_LINALG_TRANSFORMS_COMPREHENSIVE_BUFFERIZE_H
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
@ -19,9 +19,13 @@ namespace mlir {
|
||||
class DominanceInfo;
|
||||
class FuncOp;
|
||||
class GlobalCreator;
|
||||
class ModuleOp;
|
||||
|
||||
namespace linalg {
|
||||
|
||||
// TODO: from some HW description.
|
||||
static constexpr int64_t kBufferAlignments = 128;
|
||||
|
||||
/// The BufferizationAliasInfo class maintains a list of buffer aliases and
|
||||
/// equivalence classes to support bufferization.
|
||||
/// ExtractSliceOps have special behavior, they act as a level of indirection
|
||||
@ -120,6 +124,7 @@ LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
|
||||
const DominanceInfo &domInfo,
|
||||
unsigned analysisFuzzerSeed = 0);
|
||||
|
||||
// TODO: Do not expose those functions in the header file.
|
||||
/// Default allocation function that is used by the comprehensive bufferization
|
||||
/// pass. The default currently creates a ranked memref using `memref.alloc`.
|
||||
Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc, MemRefType type,
|
||||
@ -149,6 +154,10 @@ struct AllocationCallbacks {
|
||||
MemCpyFn copyFn)
|
||||
: allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
|
||||
|
||||
AllocationCallbacks()
|
||||
: allocationFn(defaultAllocationFn),
|
||||
deallocationFn(defaultDeallocationFn), memCpyFn(defaultMemCpyFn) {}
|
||||
|
||||
AllocationFn allocationFn;
|
||||
DeallocationFn deallocationFn;
|
||||
MemCpyFn memCpyFn;
|
||||
@ -188,7 +197,20 @@ LogicalResult initTensorElimination(
|
||||
LogicalResult eliminateInsertSliceAnchoredInitTensorOps(
|
||||
FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo);
|
||||
|
||||
struct BufferizationOptions {
|
||||
BufferizationOptions()
|
||||
: allocationFns(std::make_unique<AllocationCallbacks>()) {}
|
||||
|
||||
std::unique_ptr<AllocationCallbacks> allocationFns;
|
||||
bool allowReturnMemref = false;
|
||||
unsigned analysisFuzzerSeed = 0;
|
||||
bool testAnalysisOnly = false;
|
||||
};
|
||||
|
||||
LogicalResult runComprehensiveBufferize(ModuleOp moduleOp,
|
||||
const BufferizationOptions &options);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
||||
#endif // define MLIR_DIALECT_LINALG_TRANSFORMS_COMPREHENSIVE_BUFFERIZE_H
|
||||
#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H
|
@ -1,4 +1,5 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(ComprehensiveBufferize)
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(Utils)
|
||||
|
@ -6,12 +6,12 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.cpp.inc"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp.inc"
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
@ -0,0 +1,30 @@
|
||||
set(LLVM_OPTIONAL_SOURCES
|
||||
BufferizableOpInterface.cpp
|
||||
ComprehensiveBufferize.cpp
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRBufferizableOpInterface
|
||||
BufferizableOpInterface.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRBufferizableOpInterfaceIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRComprehensiveBufferize
|
||||
ComprehensiveBufferize.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRBufferizableOpInterface
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
MLIRLinalg
|
||||
MLIRSCF
|
||||
MLIRStandard
|
||||
MLIRStandardOpsTransforms
|
||||
MLIRTensor
|
||||
MLIRVector
|
||||
)
|
@ -105,16 +105,12 @@
|
||||
// expected layouts after transformations. Combinations of memref.cast +
|
||||
// canonicalization are responsible for clean ups.
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
@ -125,8 +121,6 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/BufferUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
@ -142,9 +136,6 @@ using namespace tensor;
|
||||
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
|
||||
#define LDBG(X) LLVM_DEBUG(DBGS() << X)
|
||||
|
||||
// TODO: from some HW description.
|
||||
static constexpr int64_t kBufferAlignments = 128;
|
||||
|
||||
// Forward declarations.
|
||||
static std::string printOperationInfo(Operation *, bool prefix = true);
|
||||
static std::string printValueInfo(Value, bool prefix = true);
|
||||
@ -1208,6 +1199,17 @@ Operation *getFirstParentOfType(Value v) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
|
||||
/// the type of `source`.
|
||||
static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
|
||||
int64_t dim) {
|
||||
if (source.getType().isa<UnrankedMemRefType, MemRefType>())
|
||||
return b.createOrFold<memref::DimOp>(loc, source, dim);
|
||||
if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
|
||||
return b.createOrFold<tensor::DimOp>(loc, source, dim);
|
||||
llvm_unreachable("Expected MemRefType or TensorType");
|
||||
}
|
||||
|
||||
/// Compute the type of the `memref` to use for allocating the buffer for
|
||||
/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
|
||||
/// dynamic dimensions in the returned `memref` type. The function also sets the
|
||||
@ -1664,14 +1666,6 @@ mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc, MemRefType type,
|
||||
return allocated;
|
||||
}
|
||||
|
||||
static Optional<Value>
|
||||
allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type,
|
||||
const SmallVector<Value> &dynShape) {
|
||||
Value allocated = b.create<memref::AllocaOp>(
|
||||
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
|
||||
return allocated;
|
||||
}
|
||||
|
||||
void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc,
|
||||
Value allocatedBuffer) {
|
||||
b.create<memref::DeallocOp>(loc, allocatedBuffer);
|
||||
@ -2018,36 +2012,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct LinalgComprehensiveModuleBufferize
|
||||
: public LinalgComprehensiveModuleBufferizeBase<
|
||||
LinalgComprehensiveModuleBufferize> {
|
||||
LinalgComprehensiveModuleBufferize() {}
|
||||
|
||||
LinalgComprehensiveModuleBufferize(
|
||||
const LinalgComprehensiveModuleBufferize &p) {}
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry
|
||||
.insert<linalg::LinalgDialect, memref::MemRefDialect,
|
||||
tensor::TensorDialect, vector::VectorDialect, scf::SCFDialect,
|
||||
arith::ArithmeticDialect, StandardOpsDialect>();
|
||||
registerBufferizableOpInterfaceExternalModels(registry);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<AllocationCallbacks> allocationFns;
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
static void applyEnablingTransformations(ModuleOp moduleOp) {
|
||||
RewritePatternSet patterns(moduleOp.getContext());
|
||||
patterns.add<GeneralizePadTensorOpPattern>(moduleOp.getContext());
|
||||
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void
|
||||
foreachCaller(const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
|
||||
FuncOp callee, llvm::function_ref<void(Operation *)> doit) {
|
||||
@ -2261,33 +2225,14 @@ LogicalResult mlir::linalg::eliminateInsertSliceAnchoredInitTensorOps(
|
||||
});
|
||||
}
|
||||
|
||||
void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
||||
if (!allocationFns) {
|
||||
// The allocation functions to use needs to be set here. The flag for the
|
||||
// pass and flag for the use of alloca map to LLVM command line
|
||||
// options. These being static global objects have no set order in which
|
||||
// they are defined. So ideally this should be in the constructor, but the
|
||||
// constructor might be called before the flag is initialized using the
|
||||
// command line option. So this is set up at the start of the pass.
|
||||
if (useAlloca) {
|
||||
AllocationCallbacks allocaAllocationFns = {
|
||||
allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {},
|
||||
defaultMemCpyFn};
|
||||
allocationFns =
|
||||
std::make_unique<AllocationCallbacks>(std::move(allocaAllocationFns));
|
||||
} else {
|
||||
allocationFns = std::make_unique<AllocationCallbacks>(
|
||||
defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
|
||||
}
|
||||
}
|
||||
ModuleOp moduleOp = getOperation();
|
||||
applyEnablingTransformations(moduleOp);
|
||||
|
||||
LogicalResult
|
||||
mlir::linalg::runComprehensiveBufferize(ModuleOp moduleOp,
|
||||
const BufferizationOptions &options) {
|
||||
SmallVector<FuncOp> orderedFuncOps;
|
||||
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
|
||||
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
|
||||
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
|
||||
return signalPassFailure();
|
||||
return failure();
|
||||
|
||||
DominanceInfo domInfo(moduleOp);
|
||||
BufferizationAliasInfo aliasInfo(moduleOp);
|
||||
@ -2313,49 +2258,41 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
||||
|
||||
// If the analysis fails, just return.
|
||||
if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,
|
||||
analysisFuzzerSeed))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
options.analysisFuzzerSeed)))
|
||||
return failure();
|
||||
|
||||
// Try to eliminate InitTensorOps to avoid new allocations during the
|
||||
// bufferization phase.
|
||||
if (failed(eliminateInsertSliceAnchoredInitTensorOps(funcOp, aliasInfo,
|
||||
domInfo))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
domInfo)))
|
||||
return failure();
|
||||
|
||||
// Bufferization phase.
|
||||
if (!testAnalysisOnly) {
|
||||
if (!options.testAnalysisOnly) {
|
||||
BlockAndValueMapping tensorToBufferMap;
|
||||
if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
|
||||
*allocationFns,
|
||||
bufferizedFunctionTypes))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
*options.allocationFns,
|
||||
bufferizedFunctionTypes)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
// Don't drop the attributes if we only want to report the analysis.
|
||||
if (testAnalysisOnly)
|
||||
return;
|
||||
if (options.testAnalysisOnly)
|
||||
return success();
|
||||
|
||||
for (FuncOp funcOp : orderedFuncOps) {
|
||||
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
|
||||
// would be invalidated.
|
||||
if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo,
|
||||
bufferizedFunctionTypes))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
if (!allowReturnMemref &&
|
||||
bufferizedFunctionTypes)))
|
||||
return failure();
|
||||
|
||||
if (!options.allowReturnMemref &&
|
||||
llvm::any_of(funcOp.getType().getResults(), [](Type t) {
|
||||
return t.isa<MemRefType, UnrankedMemRefType>();
|
||||
})) {
|
||||
funcOp->emitError("memref return type is unsupported");
|
||||
signalPassFailure();
|
||||
return;
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
@ -2371,15 +2308,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
||||
removeBufferizationFuncArguments(bbArg);
|
||||
});
|
||||
|
||||
OpPassManager cleanupPipeline("builtin.module");
|
||||
cleanupPipeline.addPass(createCanonicalizerPass());
|
||||
cleanupPipeline.addPass(createCSEPass());
|
||||
cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
|
||||
(void)runPipeline(cleanupPipeline, moduleOp);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
|
||||
return std::make_unique<LinalgComprehensiveModuleBufferize>();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
@ -1,42 +1,7 @@
|
||||
set(LLVM_OPTIONAL_SOURCES
|
||||
Bufferize.cpp
|
||||
BufferizableOpInterface.cpp
|
||||
CodegenStrategy.cpp
|
||||
ComprehensiveBufferize.cpp
|
||||
Detensorize.cpp
|
||||
Distribution.cpp
|
||||
DropUnitDims.cpp
|
||||
ElementwiseOpFusion.cpp
|
||||
ElementwiseToLinalg.cpp
|
||||
Fusion.cpp
|
||||
FusionOnTensors.cpp
|
||||
Generalization.cpp
|
||||
Hoisting.cpp
|
||||
HoistPadding.cpp
|
||||
InlineScalarOperands.cpp
|
||||
Interchange.cpp
|
||||
Loops.cpp
|
||||
LinalgStrategyPasses.cpp
|
||||
Promotion.cpp
|
||||
Tiling.cpp
|
||||
Transforms.cpp
|
||||
Vectorization.cpp
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRBufferizableOpInterface
|
||||
BufferizableOpInterface.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRBufferizableOpInterfaceIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
)
|
||||
|
||||
add_mlir_dialect_library(MLIRLinalgTransforms
|
||||
Bufferize.cpp
|
||||
CodegenStrategy.cpp
|
||||
ComprehensiveBufferize.cpp
|
||||
ComprehensiveBufferizePass.cpp
|
||||
Detensorize.cpp
|
||||
Distribution.cpp
|
||||
DropUnitDims.cpp
|
||||
@ -69,6 +34,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
||||
MLIRArithmetic
|
||||
MLIRBufferizableOpInterface
|
||||
MLIRComplex
|
||||
MLIRComprehensiveBufferize
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRIR
|
||||
MLIRMemRef
|
||||
|
@ -0,0 +1,86 @@
|
||||
//===- ComprehensiveBufferize.cpp - Single pass bufferization -------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
namespace {
|
||||
struct LinalgComprehensiveModuleBufferize
|
||||
: public LinalgComprehensiveModuleBufferizeBase<
|
||||
LinalgComprehensiveModuleBufferize> {
|
||||
LinalgComprehensiveModuleBufferize() {}
|
||||
|
||||
LinalgComprehensiveModuleBufferize(
|
||||
const LinalgComprehensiveModuleBufferize &p) {}
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry
|
||||
.insert<linalg::LinalgDialect, memref::MemRefDialect,
|
||||
tensor::TensorDialect, vector::VectorDialect, scf::SCFDialect,
|
||||
arith::ArithmeticDialect, StandardOpsDialect>();
|
||||
registerBufferizableOpInterfaceExternalModels(registry);
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
static void applyEnablingTransformations(ModuleOp moduleOp) {
|
||||
RewritePatternSet patterns(moduleOp.getContext());
|
||||
patterns.add<GeneralizePadTensorOpPattern>(moduleOp.getContext());
|
||||
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static Optional<Value>
|
||||
allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type,
|
||||
const SmallVector<Value> &dynShape) {
|
||||
Value allocated = b.create<memref::AllocaOp>(
|
||||
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
|
||||
return allocated;
|
||||
}
|
||||
|
||||
void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
||||
BufferizationOptions options;
|
||||
if (useAlloca) {
|
||||
options.allocationFns->allocationFn = allocationFnUsingAlloca;
|
||||
options.allocationFns->deallocationFn = [](OpBuilder &b, Location loc,
|
||||
Value v) {};
|
||||
}
|
||||
options.allowReturnMemref = allowReturnMemref;
|
||||
options.analysisFuzzerSeed = analysisFuzzerSeed;
|
||||
options.testAnalysisOnly = testAnalysisOnly;
|
||||
|
||||
ModuleOp moduleOp = getOperation();
|
||||
applyEnablingTransformations(moduleOp);
|
||||
|
||||
if (failed(runComprehensiveBufferize(moduleOp, options))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (options.testAnalysisOnly)
|
||||
return;
|
||||
|
||||
OpPassManager cleanupPipeline("builtin.module");
|
||||
cleanupPipeline.addPass(createCanonicalizerPass());
|
||||
cleanupPipeline.addPass(createCSEPass());
|
||||
cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
|
||||
(void)runPipeline(cleanupPipeline, moduleOp);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
|
||||
return std::make_unique<LinalgComprehensiveModuleBufferize>();
|
||||
}
|
@ -6129,7 +6129,7 @@ gentbl_cc_library(
|
||||
td_library(
|
||||
name = "BufferizableOpInterfaceTdFiles",
|
||||
srcs = [
|
||||
"include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td",
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
@ -6143,15 +6143,15 @@ gentbl_cc_library(
|
||||
tbl_outs = [
|
||||
(
|
||||
["-gen-op-interface-decls"],
|
||||
"include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h.inc",
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h.inc",
|
||||
),
|
||||
(
|
||||
["-gen-op-interface-defs"],
|
||||
"include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.cpp.inc",
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":mlir-tblgen",
|
||||
td_file = "include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td",
|
||||
td_file = "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td",
|
||||
deps = [
|
||||
":BufferizableOpInterfaceTdFiles",
|
||||
],
|
||||
@ -6160,10 +6160,10 @@ gentbl_cc_library(
|
||||
cc_library(
|
||||
name = "BufferizableOpInterface",
|
||||
srcs = [
|
||||
"lib/Dialect/Linalg/Transforms/BufferizableOpInterface.cpp",
|
||||
"lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h",
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
@ -6352,13 +6352,10 @@ gentbl_cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "LinalgTransforms",
|
||||
srcs = glob(
|
||||
[
|
||||
"lib/Dialect/Linalg/Transforms/*.cpp",
|
||||
"lib/Dialect/Linalg/Transforms/*.h",
|
||||
],
|
||||
exclude = ["lib/Dialect/Linalg/Transforms/BufferizableOpInterface.cpp"],
|
||||
) + [
|
||||
srcs = glob([
|
||||
"lib/Dialect/Linalg/Transforms/*.cpp",
|
||||
"lib/Dialect/Linalg/Transforms/*.h",
|
||||
]) + [
|
||||
"lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp",
|
||||
"lib/Dialect/Linalg/Utils/Utils.cpp",
|
||||
],
|
||||
@ -6366,7 +6363,6 @@ cc_library(
|
||||
"include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h",
|
||||
"include/mlir/Dialect/Linalg/Passes.h",
|
||||
"include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h",
|
||||
"include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h",
|
||||
"include/mlir/Dialect/Linalg/Transforms/HoistPadding.h",
|
||||
"include/mlir/Dialect/Linalg/Transforms/Hoisting.h",
|
||||
"include/mlir/Dialect/Linalg/Transforms/Transforms.h",
|
||||
@ -6378,8 +6374,8 @@ cc_library(
|
||||
":AffineUtils",
|
||||
":Analysis",
|
||||
":ArithmeticDialect",
|
||||
":BufferizableOpInterface",
|
||||
":ComplexDialect",
|
||||
":ComprehensiveBufferize",
|
||||
":DialectUtils",
|
||||
":IR",
|
||||
":InferTypeOpInterface",
|
||||
@ -6402,6 +6398,35 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ComprehensiveBufferize",
|
||||
srcs = [
|
||||
"lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":ArithmeticDialect",
|
||||
":BufferizableOpInterface",
|
||||
":DialectUtils",
|
||||
":IR",
|
||||
":InferTypeOpInterface",
|
||||
":LinalgOps",
|
||||
":LinalgStructuredOpsIncGen",
|
||||
":MemRefDialect",
|
||||
":Pass",
|
||||
":SCFDialect",
|
||||
":StandardOps",
|
||||
":Support",
|
||||
":TensorDialect",
|
||||
":TransformUtils",
|
||||
":VectorOps",
|
||||
"//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "TilingInterface",
|
||||
srcs = ["lib/Interfaces/TilingInterface.cpp"],
|
||||
|
Loading…
x
Reference in New Issue
Block a user