[mlir] Move bufferization-related passes to bufferization dialect.

[RFC](https://llvm.discourse.group/t/rfc-dialect-for-bufferization-related-ops/4712)

Differential Revision: https://reviews.llvm.org/D114698
This commit is contained in:
Alexander Belyaev 2021-11-29 13:45:01 +01:00
parent 0d0371f58f
commit f89bb3c012
39 changed files with 344 additions and 174 deletions

View File

@ -10,14 +10,18 @@
#define MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Bufferize.h"
namespace mlir {
namespace bufferization {
class BufferizeTypeConverter;
} // end namespace bufferization
namespace arith {
/// Add patterns to bufferize Arithmetic ops.
void populateArithmeticBufferizePatterns(BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
void populateArithmeticBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Create a pass to bufferize Arithmetic ops.
std::unique_ptr<Pass> createArithmeticBufferizePass();

View File

@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -20,19 +20,13 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TRANSFORMS_BUFFERIZE_H
#define MLIR_TRANSFORMS_BUFFERIZE_H
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERIZE_H
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERIZE_H
#include "mlir/Analysis/BufferViewFlowAnalysis.h"
#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace bufferization {
/// A helper type converter class that automatically populates the relevant
/// materializations and type conversions for bufferization.
@ -58,6 +52,7 @@ void populateBufferizeMaterializationLegality(ConversionTarget &target);
void populateEliminateBufferizeMaterializationsPatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns);
} // end namespace bufferization
} // end namespace mlir
#endif // MLIR_TRANSFORMS_BUFFERIZE_H
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_BUFFERIZE_H

View File

@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Bufferization)
add_public_tablegen_target(MLIRBufferizationPassIncGen)
add_dependencies(mlir-headers MLIRBufferizationPassIncGen)
add_mlir_doc(Passes BufferizationPasses ./ -gen-pass-doc)

View File

@ -0,0 +1,32 @@
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace bufferization {
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
/// Creates an instance of the BufferDeallocation pass to free all allocated
/// buffers.
std::unique_ptr<Pass> createBufferDeallocationPass();
/// Creates a pass that finalizes a partial bufferization by removing remaining
/// bufferization.to_tensor and bufferization.to_memref operations.
std::unique_ptr<FunctionPass> createFinalizingBufferizePass();
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
} // end namespace bufferization
} // end namespace mlir
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES_H

View File

@ -0,0 +1,107 @@
//===-- Passes.td - Bufferization passes definition file ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES
include "mlir/Pass/PassBase.td"
def BufferDeallocation : FunctionPass<"buffer-deallocation"> {
let summary = "Adds all required dealloc operations for all allocations in "
"the input program";
let description = [{
This pass implements an algorithm to automatically introduce all required
deallocation operations for all buffers in the input program. This ensures
that the resulting program does not have any memory leaks.
Input
```mlir
#map0 = affine_map<(d0) -> (d0)>
module {
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = alloc() : memref<2xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]} %arg1, %0 {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: memref<2xf32>, memref<2xf32>
br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
}
```
Output
```mlir
#map0 = affine_map<(d0) -> (d0)>
module {
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
^bb1: // pred: ^bb0
%0 = alloc() : memref<2xf32>
linalg.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
br ^bb3(%0 : memref<2xf32>)
^bb2: // pred: ^bb0
%1 = alloc() : memref<2xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]} %arg1, %1 {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%4 = exp %arg3 : f32
linalg.yield %4 : f32
}: memref<2xf32>, memref<2xf32>
%2 = alloc() : memref<2xf32>
linalg.copy(%1, %2) : memref<2xf32>, memref<2xf32>
dealloc %1 : memref<2xf32>
br ^bb3(%2 : memref<2xf32>)
^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2
linalg.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
dealloc %3 : memref<2xf32>
return
}
}
```
}];
let constructor = "mlir::bufferization::createBufferDeallocationPass()";
}
def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> {
let summary = "Finalize a partial bufferization";
let description = [{
A bufferize pass that finalizes a partial bufferization by removing
remaining `bufferization.to_tensor` and `bufferization.to_buffer` operations.
The removal of those operations is only possible if the operations only
exist in pairs, i.e., all uses of `bufferization.to_tensor` operations are
`bufferization.to_buffer` operations.
This pass will fail if not all operations can be removed or if any operation
with tensor typed operands remains.
}];
let constructor = "mlir::bufferization::createFinalizingBufferizePass()";
}
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES

View File

@ -18,12 +18,15 @@
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallSet.h"
namespace mlir {
namespace bufferization {
class BufferizeTypeConverter;
} // namespace bufferization
class FrozenRewritePatternSet;
namespace linalg {
@ -90,8 +93,9 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);
/// Populates the given list with patterns to bufferize linalg ops.
void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
RewritePatternSet &patterns);
void populateLinalgBufferizePatterns(
bufferization::BufferizeTypeConverter &converter,
RewritePatternSet &patterns);
/// Create linalg op on buffers given the original tensor-based operation and
/// the buffers for the outputs.

View File

@ -15,16 +15,19 @@
#define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Bufferize.h"
namespace mlir {
namespace bufferization {
class BufferizeTypeConverter;
} // end namespace bufferization
class GlobalCreator;
class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
void populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
void populateStdBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Creates an instance of std bufferization pass.
std::unique_ptr<Pass> createStdBufferizePass();
@ -35,7 +38,8 @@ std::unique_ptr<Pass> createFuncBufferizePass();
/// Add patterns to bufferize tensor constants into global memrefs to the given
/// pattern list.
void populateTensorConstantBufferizePatterns(
GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter,
GlobalCreator &globalCreator,
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Creates an instance of tensor constant bufferization pass.

View File

@ -10,15 +10,18 @@
#define MLIR_DIALECT_TENSOR_TRANSFORMS_PASSES_H_
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Bufferize.h"
namespace mlir {
namespace bufferization {
class BufferizeTypeConverter;
} // end namespace bufferization
class RewritePatternSet;
using OwningRewritePatternList = RewritePatternSet;
void populateTensorBufferizePatterns(BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
void populateTensorBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns);
/// Creates an instance of `tensor` dialect bufferization pass.
std::unique_ptr<Pass> createTensorBufferizePass();

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
@ -54,6 +55,7 @@ inline void registerAllPasses() {
registerAffinePasses();
registerAsyncPasses();
arith::registerArithmeticPasses();
bufferization::registerBufferizationPasses();
registerGPUPasses();
registerGpuSerializeToCubinPass();
registerGpuSerializeToHsacoPass();

View File

@ -33,10 +33,6 @@ enum FusionMode { Greedy, ProducerConsumer, Sibling };
// Passes
//===----------------------------------------------------------------------===//
/// Creates an instance of the BufferDeallocation pass to free all allocated
/// buffers.
std::unique_ptr<Pass> createBufferDeallocationPass();
/// Creates a pass that moves allocations upwards to reduce the number of
/// required copies that are inserted during the BufferDeallocation pass.
std::unique_ptr<Pass> createBufferHoistingPass();
@ -58,10 +54,6 @@ createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024,
std::unique_ptr<Pass>
createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
/// Creates a pass that finalizes a partial bufferization by removing remaining
/// tensor_load and buffer_cast operations.
std::unique_ptr<FunctionPass> createFinalizingBufferizePass();
/// Creates a pass that converts memref function results to out-params.
std::unique_ptr<Pass> createBufferResultsToOutParamsPass();

View File

@ -217,83 +217,6 @@ def AffinePipelineDataTransfer
let constructor = "mlir::createPipelineDataTransferPass()";
}
def BufferDeallocation : FunctionPass<"buffer-deallocation"> {
let summary = "Adds all required dealloc operations for all allocations in the "
"input program";
let description = [{
This pass implements an algorithm to automatically introduce all required
deallocation operations for all buffers in the input program. This ensures that
the resulting program does not have any memory leaks.
Input
```mlir
#map0 = affine_map<(d0) -> (d0)>
module {
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
^bb1:
br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = alloc() : memref<2xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]} %arg1, %0 {
^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: memref<2xf32>, memref<2xf32>
br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
}
```
Output
```mlir
#map0 = affine_map<(d0) -> (d0)>
module {
func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2
^bb1: // pred: ^bb0
%0 = alloc() : memref<2xf32>
linalg.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
br ^bb3(%0 : memref<2xf32>)
^bb2: // pred: ^bb0
%1 = alloc() : memref<2xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]} %arg1, %1 {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%4 = exp %arg3 : f32
linalg.yield %4 : f32
}: memref<2xf32>, memref<2xf32>
%2 = alloc() : memref<2xf32>
linalg.copy(%1, %2) : memref<2xf32>, memref<2xf32>
dealloc %1 : memref<2xf32>
br ^bb3(%2 : memref<2xf32>)
^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2
linalg.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
dealloc %3 : memref<2xf32>
return
}
}
```
}];
let constructor = "mlir::createBufferDeallocationPass()";
}
def BufferHoisting : FunctionPass<"buffer-hoisting"> {
let summary = "Optimizes placement of allocation operations by moving them "
"into common dominators and out of nested regions";
@ -416,22 +339,6 @@ def Inliner : Pass<"inline"> {
];
}
def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> {
let summary = "Finalize a partial bufferization";
let description = [{
A bufferize pass that finalizes a partial bufferization by removing
remaining `memref.tensor_load` and `memref.buffer_cast` operations.
The removal of those operations is only possible if the operations only
exist in pairs, i.e., all uses of `memref.tensor_load` operations are
`memref.buffer_cast` operations.
This pass will fail if not all operations can be removed or if any operation
with tensor typed operands remains.
}];
let constructor = "mlir::createFinalizingBufferizePass()";
}
def LocationSnapshot : Pass<"snapshot-op-locations"> {
let summary = "Generate new locations from the current IR";
let description = [{

View File

@ -6,10 +6,11 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
using namespace mlir;
@ -35,7 +36,7 @@ struct BufferizeIndexCastOp : public OpConversionPattern<arith::IndexCastOp> {
struct ArithmeticBufferizePass
: public ArithmeticBufferizeBase<ArithmeticBufferizePass> {
void runOnFunction() override {
BufferizeTypeConverter typeConverter;
bufferization::BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(&getContext());
ConversionTarget target(getContext());
@ -57,7 +58,8 @@ struct ArithmeticBufferizePass
} // end anonymous namespace
void mlir::arith::populateArithmeticBufferizePatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeIndexCastOp>(typeConverter, patterns.getContext());
}

View File

@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRBufferizationTransforms
MLIRIR
MLIRMemRef
MLIRPass

View File

@ -7,8 +7,11 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;

View File

@ -1 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@ -54,14 +54,9 @@
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/BufferUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SetOperations.h"
using namespace mlir;
@ -676,6 +671,6 @@ struct BufferDeallocationPass : BufferDeallocationBase<BufferDeallocationPass> {
// BufferDeallocationPass construction
//===----------------------------------------------------------------------===//
std::unique_ptr<Pass> mlir::createBufferDeallocationPass() {
std::unique_ptr<Pass> mlir::bufferization::createBufferDeallocationPass() {
return std::make_unique<BufferDeallocationPass>();
}

View File

@ -6,20 +6,22 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/IR/Operation.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::bufferization;
//===----------------------------------------------------------------------===//
// BufferizeTypeConverter
//===----------------------------------------------------------------------===//
static Value materializeTensorLoad(OpBuilder &builder, TensorType type,
ValueRange inputs, Location loc) {
static Value materializeToTensor(OpBuilder &builder, TensorType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<BaseMemRefType>());
return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]);
@ -37,8 +39,8 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
addConversion([](UnrankedTensorType type) -> Type {
return UnrankedMemRefType::get(type.getElementType(), 0);
});
addArgumentMaterialization(materializeTensorLoad);
addSourceMaterialization(materializeTensorLoad);
addArgumentMaterialization(materializeToTensor);
addSourceMaterialization(materializeToTensor);
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
@ -47,14 +49,15 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
});
}
void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
void mlir::bufferization::populateBufferizeMaterializationLegality(
ConversionTarget &target) {
target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>();
}
namespace {
// In a finalizing bufferize conversion, we know that all tensors have been
// converted to memrefs, thus, this op becomes an identity.
class BufferizeTensorLoadOp
class BufferizeToTensorOp
: public OpConversionPattern<bufferization::ToTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -70,7 +73,8 @@ public:
namespace {
// In a finalizing bufferize conversion, we know that all tensors have been
// converted to memrefs, thus, this op becomes an identity.
class BufferizeCastOp : public OpConversionPattern<bufferization::ToMemrefOp> {
class BufferizeToMemrefOp
: public OpConversionPattern<bufferization::ToMemrefOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
@ -82,10 +86,10 @@ public:
};
} // namespace
void mlir::populateEliminateBufferizeMaterializationsPatterns(
void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<BufferizeTensorLoadOp, BufferizeCastOp>(typeConverter,
patterns.getContext());
patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
patterns.getContext());
}
namespace {
@ -121,6 +125,7 @@ struct FinalizingBufferizePass
};
} // namespace
std::unique_ptr<FunctionPass> mlir::createFinalizingBufferizePass() {
std::unique_ptr<FunctionPass>
mlir::bufferization::createFinalizingBufferizePass() {
return std::make_unique<FinalizingBufferizePass>();
}

View File

@ -0,0 +1,15 @@
add_mlir_dialect_library(MLIRBufferizationTransforms
Bufferize.cpp
BufferDeallocation.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization
DEPENDS
MLIRBufferizationPassIncGen
LINK_LIBS PUBLIC
MLIRBufferization
MLIRPass
MLIRTransforms
)

View File

@ -0,0 +1,31 @@
//===- PassDetail.h - Bufferization Pass details ----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef DIALECT_BUFFERIZATION_TRANSFORMS_PASSDETAIL_H_
#define DIALECT_BUFFERIZATION_TRANSFORMS_PASSDETAIL_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
class StandardOpsDialect;
namespace bufferization {
class BufferizationDialect;
} // end namespace bufferization
namespace memref {
class MemRefDialect;
} // end namespace memref
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
} // end namespace mlir
#endif // DIALECT_BUFFERIZATION_TRANSFORMS_PASSDETAIL_H_

View File

@ -6,10 +6,11 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@ -313,7 +314,7 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
void runOnOperation() override {
MLIRContext &context = getContext();
ConversionTarget target(context);
BufferizeTypeConverter typeConverter;
bufferization::BufferizeTypeConverter typeConverter;
// Mark all Standard operations legal.
target.addLegalDialect<arith::ArithmeticDialect, AffineDialect,
@ -345,7 +346,8 @@ std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
}
void mlir::linalg::populateLinalgBufferizePatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
// TODO: Drop this once tensor constants work in standard.
// clang-format off
patterns.add<

View File

@ -28,6 +28,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::linalg;

View File

@ -24,7 +24,6 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@ -25,11 +25,11 @@ struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
auto func = getOperation();
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
bufferization::BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateBufferizeMaterializationLegality(target);
bufferization::populateBufferizeMaterializationLegality(target);
populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns,
target);
if (failed(applyPartialConversion(func, target, std::move(patterns))))

View File

@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
LINK_LIBS PUBLIC
MLIRAffine
MLIRArithmetic
MLIRBufferizationTransforms
MLIRIR
MLIRMemRef
MLIRPass

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@ -21,10 +21,10 @@ struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
BufferizeTypeConverter typeConverter;
bufferization::BufferizeTypeConverter typeConverter;
ConversionTarget target(ctx);
populateBufferizeMaterializationLegality(target);
bufferization::populateBufferizeMaterializationLegality(target);
populateShapeStructuralTypeConversionsAndLegality(typeConverter, patterns,
target);

View File

@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms
target_link_libraries(MLIRShapeOpsTransforms
PUBLIC
MLIRArithmetic
MLIRBufferizationTransforms
MLIRIR
MLIRMemRef
MLIRPass

View File

@ -10,7 +10,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@ -41,8 +41,9 @@ public:
};
} // namespace
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
void mlir::populateStdBufferizePatterns(
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeSelectOp>(typeConverter, patterns.getContext());
}
@ -50,7 +51,7 @@ namespace {
struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
bufferization::BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);

View File

@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms
LINK_LIBS PUBLIC
MLIRArithmeticTransforms
MLIRBufferizationTransforms
MLIRIR
MLIRMemRef
MLIRPass

View File

@ -13,13 +13,14 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;

View File

@ -12,10 +12,10 @@
#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@ -27,7 +27,7 @@ struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
auto module = getOperation();
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
bufferization::BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);

View File

@ -12,12 +12,12 @@
#include "PassDetail.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/BufferUtils.h"
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@ -25,7 +25,7 @@ using namespace mlir;
memref::GlobalOp GlobalCreator::getGlobalFor(arith::ConstantOp constantOp) {
auto type = constantOp.getType().cast<RankedTensorType>();
BufferizeTypeConverter typeConverter;
bufferization::BufferizeTypeConverter typeConverter;
// If we already have a global for this constant value, no need to do
// anything else.
@ -91,7 +91,8 @@ public:
} // namespace
void mlir::populateTensorConstantBufferizePatterns(
GlobalCreator &globalCreator, BufferizeTypeConverter &typeConverter,
GlobalCreator &globalCreator,
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeTensorConstantOp>(globalCreator, typeConverter,
patterns.getContext());
@ -111,7 +112,7 @@ public:
GlobalCreator globals(module, alignment);
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
bufferization::BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);

View File

@ -10,7 +10,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@ -153,7 +153,8 @@ public:
} // namespace
void mlir::populateTensorBufferizePatterns(
BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
bufferization::BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<BufferizeCastOp, BufferizeDimOp, BufferizeExtractOp,
BufferizeFromElementsOp, BufferizeGenerateOp>(
typeConverter, patterns.getContext());
@ -163,11 +164,11 @@ namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
bufferization::BufferizeTypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
populateBufferizeMaterializationLegality(target);
bufferization::populateBufferizeMaterializationLegality(target);
populateTensorBufferizePatterns(typeConverter, patterns);
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,

View File

@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRBufferizationTransforms
MLIRIR
MLIRMemRef
MLIRPass

View File

@ -1,11 +1,9 @@
add_subdirectory(Utils)
add_mlir_library(MLIRTransforms
BufferDeallocation.cpp
BufferOptimizations.cpp
BufferResultsToOutParams.cpp
BufferUtils.cpp
Bufferize.cpp
Canonicalizer.cpp
CSE.cpp
Inliner.cpp

View File

@ -27,6 +27,10 @@ namespace memref {
class MemRefDialect;
} // end namespace memref
namespace bufferization {
class BufferizationDialect;
} // namespace bufferization
#define GEN_PASS_CLASSES
#include "mlir/Transforms/Passes.h.inc"

View File

@ -1664,6 +1664,7 @@ cc_library(
":Analysis",
":ArithmeticDialect",
":BufferizationDialect",
":BufferizationTransforms",
":DialectUtils",
":IR",
":MemRefDialect",
@ -2499,6 +2500,7 @@ cc_library(
deps = [
":ArithmeticDialect",
":BufferizationDialect",
":BufferizationTransforms",
":IR",
":MemRefDialect",
":Pass",
@ -2567,6 +2569,7 @@ cc_library(
":ArithmeticDialect",
":ArithmeticTransforms",
":BufferizationDialect",
":BufferizationTransforms",
":IR",
":MemRefDialect", # TODO: Remove dependency on MemRef dialect
":Pass",
@ -4192,6 +4195,7 @@ cc_library(
":ArithmeticDialect",
":Async",
":BufferizationDialect",
":BufferizationTransforms",
":IR",
":MemRefDialect",
":ParallelLoopMapperAttrGen",
@ -5368,6 +5372,7 @@ cc_library(
":AsyncToLLVM",
":AsyncTransforms",
":BufferizationDialect",
":BufferizationTransforms",
":ComplexDialect",
":ComplexToLLVM",
":ConversionPasses",
@ -6634,6 +6639,7 @@ cc_library(
":ArithmeticDialect",
":BufferizableOpInterface",
":BufferizationDialect",
":BufferizationTransforms",
":ComplexDialect",
":ComprehensiveBufferize",
":DialectUtils",
@ -7327,9 +7333,11 @@ cc_library(
":ArithmeticDialect",
":ArithmeticPassIncGen",
":BufferizationDialect",
":BufferizationTransforms",
":IR",
":MemRefDialect",
":Pass",
":StandardOps",
":Transforms",
],
)
@ -7716,6 +7724,46 @@ cc_library(
],
)
gentbl_cc_library(
name = "BufferizationPassIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=Bufferization",
],
"include/mlir/Dialect/Bufferization/Transforms/Passes.h.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Bufferization/Transforms/Passes.td",
deps = [":PassBaseTdFiles"],
)
cc_library(
name = "BufferizationTransforms",
srcs = glob(
[
"lib/Dialect/Bufferization/Transforms/*.cpp",
"lib/Dialect/Bufferization/Transforms/*.h",
],
),
hdrs = glob(["include/mlir/Dialect/Bufferization/Transforms/*.h"]),
includes = ["include"],
deps = [
":AllocationOpInterface",
":Analysis",
":BufferizationDialect",
":BufferizationPassIncGen",
":IR",
":MemRefDialect",
":Pass",
":Transforms",
"//llvm:Support",
],
)
td_library(
name = "DLTIDialectTdFiles",
srcs = [