From cd0a923b4c0c96d687303848dc1aaf39b5fe985f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 19 Jan 2022 22:19:31 +0900 Subject: [PATCH] [mlir][linalg][bufferize][NFC] Move analysis-related code to Comprehensive Bufferize The code in `BufferizableOpInterface`'s header/source no longer contains any analysis code. This makes it easier to run the bufferization with a different analysis or without any analysis. Differential Revision: https://reviews.llvm.org/D117478 --- .../BufferizableOpInterface.h | 151 ------------------ .../ComprehensiveBufferize.h | 151 ++++++++++++++++++ .../LinalgInterfaceImpl.h | 2 +- .../ComprehensiveBufferize/SCFInterfaceImpl.h | 2 +- .../BufferizableOpInterface.cpp | 121 -------------- .../ComprehensiveBufferize/CMakeLists.txt | 2 + .../ComprehensiveBufferize.cpp | 123 ++++++++++++++ .../llvm-project-overlay/mlir/BUILD.bazel | 2 + 8 files changed, 280 insertions(+), 274 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h index 86a58001ac73..5abec82efe71 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -18,7 +18,6 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" -#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/SetVector.h" namespace mlir { @@ -36,23 +35,6 @@ class BufferizationAliasInfo; class BufferizableOpInterface; struct BufferizationOptions; class BufferizationState; -struct PostAnalysisStep; - -/// PostAnalysisSteps can be registered with `BufferizationOptions` and are -/// executed after the analysis, but before bufferization. They can be used to -/// implement custom dialect-specific optimizations. -struct PostAnalysisStep { - virtual ~PostAnalysisStep() = default; - - /// Run the post analysis step. This function may modify the IR, but must keep - /// `aliasInfo` consistent. Newly created operations and operations that - /// should be re-analyzed must be added to `newOps`. - virtual LogicalResult run(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) = 0; -}; - -using PostAnalysisStepList = std::vector>; /// Options for ComprehensiveBufferize. struct BufferizationOptions { @@ -146,25 +128,6 @@ private: } }; -/// Options for analysis-enabled bufferization. -struct AnalysisBufferizationOptions : public BufferizationOptions { - AnalysisBufferizationOptions() = default; - - // AnalysisBufferizationOptions cannot be copied. - AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete; - - /// Register a "post analysis" step. Such steps are executed after the - /// analysis, but before bufferization. - template - void addPostAnalysisStep(Args... args) { - postAnalysisSteps.emplace_back( - std::make_unique(std::forward(args)...)); - } - - /// Registered post analysis steps. - PostAnalysisStepList postAnalysisSteps; -}; - /// Specify fine-grain relationship between buffers to enable more analysis. enum class BufferRelation { None, @@ -173,93 +136,6 @@ enum class BufferRelation { Equivalent }; -/// The BufferizationAliasInfo class maintains a list of buffer aliases and -/// equivalence classes to support bufferization. -class BufferizationAliasInfo { -public: - explicit BufferizationAliasInfo(Operation *rootOp); - - // BufferizationAliasInfo should be passed as a reference. - BufferizationAliasInfo(const BufferizationAliasInfo &) = delete; - - /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the - /// beginning the alias and equivalence sets only contain `v` itself. - void createAliasInfoEntry(Value v); - - /// Insert an info entry for `newValue` and merge its alias set with that of - /// `alias`. - void insertNewBufferAlias(Value newValue, Value alias); - - /// Insert an info entry for `newValue` and merge its alias set with that of - /// `alias`. Additionally, merge their equivalence classes. - void insertNewBufferEquivalence(Value newValue, Value alias); - - /// Set the inPlace bufferization spec to true. - /// Merge result's and operand's aliasing sets and iterate to a fixed point. - void bufferizeInPlace(OpOperand &operand, BufferizationState &state); - - /// Set the inPlace bufferization spec to false. - void bufferizeOutOfPlace(OpOperand &operand); - - /// Return true if `v1` and `v2` bufferize to equivalent buffers. - bool areEquivalentBufferizedValues(Value v1, Value v2) const { - return equivalentInfo.isEquivalent(v1, v2); - } - - /// Union the alias sets of `v1` and `v2`. - void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); } - - /// Union the equivalence classes of `v1` and `v2`. - void unionEquivalenceClasses(Value v1, Value v2) { - equivalentInfo.unionSets(v1, v2); - } - - /// Apply `fun` to all the members of the equivalence class of `v`. - void applyOnEquivalenceClass(Value v, function_ref fun) const; - - /// Apply `fun` to all aliases of `v`. - void applyOnAliases(Value v, function_ref fun) const; - - /// Mark a value as in-place bufferized. - void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); } - - /// Return `true` if a value was marked as in-place bufferized. - bool isInPlace(OpOperand &opOperand) const; - -private: - /// llvm::EquivalenceClasses wants comparable elements. This comparator uses - /// uses pointer comparison on the defining op. This is a poor man's - /// comparison but it's not like UnionFind needs ordering anyway. - struct ValueComparator { - bool operator()(const Value &lhs, const Value &rhs) const { - return lhs.getImpl() < rhs.getImpl(); - } - }; - - using EquivalenceClassRangeType = llvm::iterator_range< - llvm::EquivalenceClasses::member_iterator>; - /// Check that aliasInfo for `v` exists and return a reference to it. - EquivalenceClassRangeType getAliases(Value v) const; - - /// Set of all OpResults that were decided to bufferize in-place. - llvm::DenseSet inplaceBufferized; - - /// Auxiliary structure to store all the values a given value may alias with. - /// Alias information is "may be" conservative: In the presence of branches, a - /// value may alias with one of multiple other values. The concrete aliasing - /// value may not even be known at compile time. All such values are - /// considered to be aliases. - llvm::EquivalenceClasses aliasInfo; - - /// Auxiliary structure to store all the equivalent buffer classes. Equivalent - /// buffer information is "must be" conservative: Only if two values are - /// guaranteed to be equivalent at runtime, they said to be equivalent. It is - /// possible that, in the presence of branches, it cannot be determined - /// statically if two values are equivalent. In that case, the values are - /// considered to be not equivalent. - llvm::EquivalenceClasses equivalentInfo; -}; - /// Return `true` if the given value is a BlockArgument of a FuncOp. bool isFunctionArgument(Value value); @@ -391,33 +267,6 @@ private: const BufferizationOptions &options; }; -/// State for analysis-enabled bufferization. This class keeps track of alias -/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize -/// in-place. -class AnalysisBufferizationState : public BufferizationState { -public: - AnalysisBufferizationState(Operation *op, - const AnalysisBufferizationOptions &options); - - AnalysisBufferizationState(const AnalysisBufferizationState &) = delete; - - virtual ~AnalysisBufferizationState() = default; - - /// Return a reference to the BufferizationAliasInfo. - BufferizationAliasInfo &getAliasInfo() { return aliasInfo; } - - /// Return `true` if the given OpResult has been decided to bufferize inplace. - bool isInPlace(OpOperand &opOperand) const override; - - /// Return true if `v1` and `v2` bufferize to equivalent buffers. - bool areEquivalentBufferizedValues(Value v1, Value v2) const override; - -private: - /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal - /// functions and `runComprehensiveBufferize` may access this object. - BufferizationAliasInfo aliasInfo; -}; - /// Replace an op with replacement values. The op is deleted. Tensor OpResults /// must be replaced with memref values. void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h index 6a53295babcd..468f1d638220 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -9,7 +9,9 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/EquivalenceClasses.h" namespace mlir { @@ -21,6 +23,155 @@ class BufferizationAliasInfo; struct AnalysisBufferizationOptions; class BufferizationState; +/// PostAnalysisSteps can be registered with `BufferizationOptions` and are +/// executed after the analysis, but before bufferization. They can be used to +/// implement custom dialect-specific optimizations. +struct PostAnalysisStep { + virtual ~PostAnalysisStep() = default; + + /// Run the post analysis step. This function may modify the IR, but must keep + /// `aliasInfo` consistent. Newly created operations and operations that + /// should be re-analyzed must be added to `newOps`. + virtual LogicalResult run(Operation *op, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, + SmallVector &newOps) = 0; +}; + +using PostAnalysisStepList = std::vector>; + +/// Options for analysis-enabled bufferization. +struct AnalysisBufferizationOptions : public BufferizationOptions { + AnalysisBufferizationOptions() = default; + + // AnalysisBufferizationOptions cannot be copied. + AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete; + + /// Register a "post analysis" step. Such steps are executed after the + /// analysis, but before bufferization. + template + void addPostAnalysisStep(Args... args) { + postAnalysisSteps.emplace_back( + std::make_unique(std::forward(args)...)); + } + + /// Registered post analysis steps. + PostAnalysisStepList postAnalysisSteps; +}; + +/// The BufferizationAliasInfo class maintains a list of buffer aliases and +/// equivalence classes to support bufferization. +class BufferizationAliasInfo { +public: + explicit BufferizationAliasInfo(Operation *rootOp); + + // BufferizationAliasInfo should be passed as a reference. + BufferizationAliasInfo(const BufferizationAliasInfo &) = delete; + + /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the + /// beginning the alias and equivalence sets only contain `v` itself. + void createAliasInfoEntry(Value v); + + /// Insert an info entry for `newValue` and merge its alias set with that of + /// `alias`. + void insertNewBufferAlias(Value newValue, Value alias); + + /// Insert an info entry for `newValue` and merge its alias set with that of + /// `alias`. Additionally, merge their equivalence classes. + void insertNewBufferEquivalence(Value newValue, Value alias); + + /// Set the inPlace bufferization spec to true. + /// Merge result's and operand's aliasing sets and iterate to a fixed point. + void bufferizeInPlace(OpOperand &operand, BufferizationState &state); + + /// Set the inPlace bufferization spec to false. + void bufferizeOutOfPlace(OpOperand &operand); + + /// Return true if `v1` and `v2` bufferize to equivalent buffers. + bool areEquivalentBufferizedValues(Value v1, Value v2) const { + return equivalentInfo.isEquivalent(v1, v2); + } + + /// Union the alias sets of `v1` and `v2`. + void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); } + + /// Union the equivalence classes of `v1` and `v2`. + void unionEquivalenceClasses(Value v1, Value v2) { + equivalentInfo.unionSets(v1, v2); + } + + /// Apply `fun` to all the members of the equivalence class of `v`. + void applyOnEquivalenceClass(Value v, function_ref fun) const; + + /// Apply `fun` to all aliases of `v`. + void applyOnAliases(Value v, function_ref fun) const; + + /// Mark a value as in-place bufferized. + void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); } + + /// Return `true` if a value was marked as in-place bufferized. + bool isInPlace(OpOperand &opOperand) const; + +private: + /// llvm::EquivalenceClasses wants comparable elements. This comparator uses + /// uses pointer comparison on the defining op. This is a poor man's + /// comparison but it's not like UnionFind needs ordering anyway. + struct ValueComparator { + bool operator()(const Value &lhs, const Value &rhs) const { + return lhs.getImpl() < rhs.getImpl(); + } + }; + + using EquivalenceClassRangeType = llvm::iterator_range< + llvm::EquivalenceClasses::member_iterator>; + /// Check that aliasInfo for `v` exists and return a reference to it. + EquivalenceClassRangeType getAliases(Value v) const; + + /// Set of all OpResults that were decided to bufferize in-place. + llvm::DenseSet inplaceBufferized; + + /// Auxiliary structure to store all the values a given value may alias with. + /// Alias information is "may be" conservative: In the presence of branches, a + /// value may alias with one of multiple other values. The concrete aliasing + /// value may not even be known at compile time. All such values are + /// considered to be aliases. + llvm::EquivalenceClasses aliasInfo; + + /// Auxiliary structure to store all the equivalent buffer classes. Equivalent + /// buffer information is "must be" conservative: Only if two values are + /// guaranteed to be equivalent at runtime, they said to be equivalent. It is + /// possible that, in the presence of branches, it cannot be determined + /// statically if two values are equivalent. In that case, the values are + /// considered to be not equivalent. + llvm::EquivalenceClasses equivalentInfo; +}; + +/// State for analysis-enabled bufferization. This class keeps track of alias +/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize +/// in-place. +class AnalysisBufferizationState : public BufferizationState { +public: + AnalysisBufferizationState(Operation *op, + const AnalysisBufferizationOptions &options); + + AnalysisBufferizationState(const AnalysisBufferizationState &) = delete; + + virtual ~AnalysisBufferizationState() = default; + + /// Return a reference to the BufferizationAliasInfo. + BufferizationAliasInfo &getAliasInfo() { return aliasInfo; } + + /// Return `true` if the given OpResult has been decided to bufferize inplace. + bool isInPlace(OpOperand &opOperand) const override; + + /// Return true if `v1` and `v2` bufferize to equivalent buffers. + bool areEquivalentBufferizedValues(Value v1, Value v2) const override; + +private: + /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal + /// functions and `runComprehensiveBufferize` may access this object. + BufferizationAliasInfo aliasInfo; +}; + /// Analyze `op` and its nested ops. Bufferization decisions are stored in /// `state`. LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state); diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h index 8c0128b70c96..8614c9d50acf 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -9,7 +9,7 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h index f86550e35901..a2ba910aeac9 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h @@ -9,7 +9,7 @@ #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" namespace mlir { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp index be9e919fbb62..57bae783f9d7 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -57,95 +57,6 @@ BufferizableOpInterface mlir::linalg::comprehensive_bufferize:: return nullptr; } -//===----------------------------------------------------------------------===// -// BufferizationAliasInfo -//===----------------------------------------------------------------------===// - -BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { - rootOp->walk([&](Operation *op) { - for (Value v : op->getResults()) - if (v.getType().isa()) - createAliasInfoEntry(v); - for (Region &r : op->getRegions()) - for (Block &b : r.getBlocks()) - for (auto bbArg : b.getArguments()) - if (bbArg.getType().isa()) - createAliasInfoEntry(bbArg); - }); -} - -/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the -/// beginning the alias and equivalence sets only contain `v` itself. -void BufferizationAliasInfo::createAliasInfoEntry(Value v) { - aliasInfo.insert(v); - equivalentInfo.insert(v); -} - -/// Insert an info entry for `newValue` and merge its alias set with that of -/// `alias`. -void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) { - createAliasInfoEntry(newValue); - aliasInfo.unionSets(newValue, alias); -} - -/// Insert an info entry for `newValue` and merge its alias set with that of -/// `alias`. Additionally, merge their equivalence classes. -void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue, - Value alias) { - insertNewBufferAlias(newValue, alias); - equivalentInfo.unionSets(newValue, alias); -} - -/// Return `true` if a value was marked as in-place bufferized. -bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const { - return inplaceBufferized.contains(&operand); -} - -/// Set the inPlace bufferization spec to true. -void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand, - BufferizationState &state) { - markInPlace(operand); - if (OpResult result = state.getAliasingOpResult(operand)) - aliasInfo.unionSets(result, operand.get()); -} - -/// Set the inPlace bufferization spec to false. -void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) { - assert(!inplaceBufferized.contains(&operand) && - "OpOperand was already decided to bufferize inplace"); -} - -/// Apply `fun` to all the members of the equivalence class of `v`. -void BufferizationAliasInfo::applyOnEquivalenceClass( - Value v, function_ref fun) const { - auto leaderIt = equivalentInfo.findLeader(v); - for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; - ++mit) { - fun(*mit); - } -} - -/// Apply `fun` to all aliases of `v`. -void BufferizationAliasInfo::applyOnAliases( - Value v, function_ref fun) const { - auto leaderIt = aliasInfo.findLeader(v); - for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) { - fun(*mit); - } -} - -BufferizationAliasInfo::EquivalenceClassRangeType -BufferizationAliasInfo::getAliases(Value v) const { - DenseSet res; - auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v)); - for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); - mit != meit; ++mit) { - res.insert(static_cast(*mit)); - } - return BufferizationAliasInfo::EquivalenceClassRangeType( - aliasInfo.member_begin(it), aliasInfo.member_end()); -} - //===----------------------------------------------------------------------===// // Helper functions for BufferizableOpInterface //===----------------------------------------------------------------------===// @@ -291,28 +202,6 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState( const BufferizationOptions &options) : options(options) {} -mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState:: - AnalysisBufferizationState(Operation *op, - const AnalysisBufferizationOptions &options) - : BufferizationState(options), aliasInfo(op) { - // Set up alias sets for OpResults that must bufferize in-place. This should - // be done before making any other bufferization decisions. - op->walk([&](BufferizableOpInterface bufferizableOp) { - if (!options.isOpAllowed(bufferizableOp)) - return WalkResult::skip(); - for (OpOperand &opOperand : bufferizableOp->getOpOperands()) { - if (opOperand.get().getType().isa()) - if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) { - if (OpResult opResult = - bufferizableOp.getAliasingOpResult(opOperand, *this)) - aliasInfo.unionAliasSets(opOperand.get(), opResult); - aliasInfo.markInPlace(opOperand); - } - } - return WalkResult::advance(); - }); -} - // bufferization.to_memref is not allowed to change the rank. static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { #ifndef NDEBUG @@ -602,16 +491,6 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) { return isa(bbArg.getOwner()->getParentOp()); } -bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState:: - isInPlace(OpOperand &opOperand) const { - return aliasInfo.isInPlace(opOperand); -} - -bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState:: - areEquivalentBufferizedValues(Value v1, Value v2) const { - return aliasInfo.areEquivalentBufferizedValues(v1, v2); -} - MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType( ShapedType shapedType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt index 1b1467ec36a6..a912d2378dc2 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -48,6 +48,7 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl LINK_LIBS PUBLIC MLIRBufferizableOpInterface + MLIRComprehensiveBufferize MLIRIR MLIRLinalg MLIRTensor @@ -58,6 +59,7 @@ add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl LINK_LIBS PUBLIC MLIRBufferizableOpInterface + MLIRComprehensiveBufferize MLIRIR MLIRSCF ) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp index 67fd48364166..ae9532e25dc3 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -98,6 +98,129 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) { OpBuilder(op).getStrArrayAttr(inPlaceVector)); } +//===----------------------------------------------------------------------===// +// BufferizationAliasInfo +//===----------------------------------------------------------------------===// + +BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) { + rootOp->walk([&](Operation *op) { + for (Value v : op->getResults()) + if (v.getType().isa()) + createAliasInfoEntry(v); + for (Region &r : op->getRegions()) + for (Block &b : r.getBlocks()) + for (auto bbArg : b.getArguments()) + if (bbArg.getType().isa()) + createAliasInfoEntry(bbArg); + }); +} + +/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the +/// beginning the alias and equivalence sets only contain `v` itself. +void BufferizationAliasInfo::createAliasInfoEntry(Value v) { + aliasInfo.insert(v); + equivalentInfo.insert(v); +} + +/// Insert an info entry for `newValue` and merge its alias set with that of +/// `alias`. +void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) { + createAliasInfoEntry(newValue); + aliasInfo.unionSets(newValue, alias); +} + +/// Insert an info entry for `newValue` and merge its alias set with that of +/// `alias`. Additionally, merge their equivalence classes. +void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue, + Value alias) { + insertNewBufferAlias(newValue, alias); + equivalentInfo.unionSets(newValue, alias); +} + +/// Return `true` if a value was marked as in-place bufferized. +bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const { + return inplaceBufferized.contains(&operand); +} + +/// Set the inPlace bufferization spec to true. +void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand, + BufferizationState &state) { + markInPlace(operand); + if (OpResult result = state.getAliasingOpResult(operand)) + aliasInfo.unionSets(result, operand.get()); +} + +/// Set the inPlace bufferization spec to false. +void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) { + assert(!inplaceBufferized.contains(&operand) && + "OpOperand was already decided to bufferize inplace"); +} + +/// Apply `fun` to all the members of the equivalence class of `v`. +void BufferizationAliasInfo::applyOnEquivalenceClass( + Value v, function_ref fun) const { + auto leaderIt = equivalentInfo.findLeader(v); + for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit; + ++mit) { + fun(*mit); + } +} + +/// Apply `fun` to all aliases of `v`. +void BufferizationAliasInfo::applyOnAliases( + Value v, function_ref fun) const { + auto leaderIt = aliasInfo.findLeader(v); + for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) { + fun(*mit); + } +} + +BufferizationAliasInfo::EquivalenceClassRangeType +BufferizationAliasInfo::getAliases(Value v) const { + DenseSet res; + auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v)); + for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end(); + mit != meit; ++mit) { + res.insert(static_cast(*mit)); + } + return BufferizationAliasInfo::EquivalenceClassRangeType( + aliasInfo.member_begin(it), aliasInfo.member_end()); +} + +//===----------------------------------------------------------------------===// +// AnalysisBufferizationState +//===----------------------------------------------------------------------===// + +AnalysisBufferizationState::AnalysisBufferizationState( + Operation *op, const AnalysisBufferizationOptions &options) + : BufferizationState(options), aliasInfo(op) { + // Set up alias sets for OpResults that must bufferize in-place. This should + // be done before making any other bufferization decisions. + op->walk([&](BufferizableOpInterface bufferizableOp) { + if (!options.isOpAllowed(bufferizableOp)) + return WalkResult::skip(); + for (OpOperand &opOperand : bufferizableOp->getOpOperands()) { + if (opOperand.get().getType().isa()) + if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) { + if (OpResult opResult = + bufferizableOp.getAliasingOpResult(opOperand, *this)) + aliasInfo.unionAliasSets(opOperand.get(), opResult); + aliasInfo.markInPlace(opOperand); + } + } + return WalkResult::advance(); + }); +} + +bool AnalysisBufferizationState::isInPlace(OpOperand &opOperand) const { + return aliasInfo.isInPlace(opOperand); +} + +bool AnalysisBufferizationState::areEquivalentBufferizedValues(Value v1, + Value v2) const { + return aliasInfo.areEquivalentBufferizedValues(v1, v2); +} + //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index eab36aacc0c6..9b67e9a819d1 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6601,6 +6601,7 @@ cc_library( deps = [ ":BufferizableOpInterface", ":BufferizationDialect", + ":ComprehensiveBufferize", ":IR", ":LinalgOps", ":LinalgStructuredOpsIncGen", @@ -6620,6 +6621,7 @@ cc_library( deps = [ ":BufferizableOpInterface", ":BufferizationDialect", + ":ComprehensiveBufferize", ":IR", ":SCFDialect", ":Support",