[mlir][linalg][bufferize][NFC] Move analysis-related code to Comprehensive Bufferize

The code in `BufferizableOpInterface`'s header/source no longer contains any analysis code. This makes it easier to run the bufferization with a different analysis or without any analysis.

Differential Revision: https://reviews.llvm.org/D117478
This commit is contained in:
Matthias Springer 2022-01-19 22:19:31 +09:00
parent 5ea98988c6
commit cd0a923b4c
8 changed files with 280 additions and 274 deletions

View File

@ -18,7 +18,6 @@
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
namespace mlir { namespace mlir {
@ -36,23 +35,6 @@ class BufferizationAliasInfo;
class BufferizableOpInterface; class BufferizableOpInterface;
struct BufferizationOptions; struct BufferizationOptions;
class BufferizationState; class BufferizationState;
struct PostAnalysisStep;
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
/// executed after the analysis, but before bufferization. They can be used to
/// implement custom dialect-specific optimizations.
struct PostAnalysisStep {
virtual ~PostAnalysisStep() = default;
/// Run the post analysis step. This function may modify the IR, but must keep
/// `aliasInfo` consistent. Newly created operations and operations that
/// should be re-analyzed must be added to `newOps`.
virtual LogicalResult run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) = 0;
};
using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
/// Options for ComprehensiveBufferize. /// Options for ComprehensiveBufferize.
struct BufferizationOptions { struct BufferizationOptions {
@ -146,25 +128,6 @@ private:
} }
}; };
/// Options for analysis-enabled bufferization.
struct AnalysisBufferizationOptions : public BufferizationOptions {
AnalysisBufferizationOptions() = default;
// AnalysisBufferizationOptions cannot be copied.
AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
/// Register a "post analysis" step. Such steps are executed after the
/// analysis, but before bufferization.
template <typename Step, typename... Args>
void addPostAnalysisStep(Args... args) {
postAnalysisSteps.emplace_back(
std::make_unique<Step>(std::forward<Args>(args)...));
}
/// Registered post analysis steps.
PostAnalysisStepList postAnalysisSteps;
};
/// Specify fine-grain relationship between buffers to enable more analysis. /// Specify fine-grain relationship between buffers to enable more analysis.
enum class BufferRelation { enum class BufferRelation {
None, None,
@ -173,93 +136,6 @@ enum class BufferRelation {
Equivalent Equivalent
}; };
/// The BufferizationAliasInfo class maintains a list of buffer aliases and
/// equivalence classes to support bufferization.
class BufferizationAliasInfo {
public:
explicit BufferizationAliasInfo(Operation *rootOp);
// BufferizationAliasInfo should be passed as a reference.
BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
/// beginning the alias and equivalence sets only contain `v` itself.
void createAliasInfoEntry(Value v);
/// Insert an info entry for `newValue` and merge its alias set with that of
/// `alias`.
void insertNewBufferAlias(Value newValue, Value alias);
/// Insert an info entry for `newValue` and merge its alias set with that of
/// `alias`. Additionally, merge their equivalence classes.
void insertNewBufferEquivalence(Value newValue, Value alias);
/// Set the inPlace bufferization spec to true.
/// Merge result's and operand's aliasing sets and iterate to a fixed point.
void bufferizeInPlace(OpOperand &operand, BufferizationState &state);
/// Set the inPlace bufferization spec to false.
void bufferizeOutOfPlace(OpOperand &operand);
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const {
return equivalentInfo.isEquivalent(v1, v2);
}
/// Union the alias sets of `v1` and `v2`.
void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
/// Union the equivalence classes of `v1` and `v2`.
void unionEquivalenceClasses(Value v1, Value v2) {
equivalentInfo.unionSets(v1, v2);
}
/// Apply `fun` to all the members of the equivalence class of `v`.
void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
/// Apply `fun` to all aliases of `v`.
void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
/// Mark a value as in-place bufferized.
void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); }
/// Return `true` if a value was marked as in-place bufferized.
bool isInPlace(OpOperand &opOperand) const;
private:
/// llvm::EquivalenceClasses wants comparable elements. This comparator uses
/// uses pointer comparison on the defining op. This is a poor man's
/// comparison but it's not like UnionFind needs ordering anyway.
struct ValueComparator {
bool operator()(const Value &lhs, const Value &rhs) const {
return lhs.getImpl() < rhs.getImpl();
}
};
using EquivalenceClassRangeType = llvm::iterator_range<
llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
/// Check that aliasInfo for `v` exists and return a reference to it.
EquivalenceClassRangeType getAliases(Value v) const;
/// Set of all OpResults that were decided to bufferize in-place.
llvm::DenseSet<OpOperand *> inplaceBufferized;
/// Auxiliary structure to store all the values a given value may alias with.
/// Alias information is "may be" conservative: In the presence of branches, a
/// value may alias with one of multiple other values. The concrete aliasing
/// value may not even be known at compile time. All such values are
/// considered to be aliases.
llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
/// Auxiliary structure to store all the equivalent buffer classes. Equivalent
/// buffer information is "must be" conservative: Only if two values are
/// guaranteed to be equivalent at runtime, they said to be equivalent. It is
/// possible that, in the presence of branches, it cannot be determined
/// statically if two values are equivalent. In that case, the values are
/// considered to be not equivalent.
llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
};
/// Return `true` if the given value is a BlockArgument of a FuncOp. /// Return `true` if the given value is a BlockArgument of a FuncOp.
bool isFunctionArgument(Value value); bool isFunctionArgument(Value value);
@ -391,33 +267,6 @@ private:
const BufferizationOptions &options; const BufferizationOptions &options;
}; };
/// State for analysis-enabled bufferization. This class keeps track of alias
/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
/// in-place.
class AnalysisBufferizationState : public BufferizationState {
public:
AnalysisBufferizationState(Operation *op,
const AnalysisBufferizationOptions &options);
AnalysisBufferizationState(const AnalysisBufferizationState &) = delete;
virtual ~AnalysisBufferizationState() = default;
/// Return a reference to the BufferizationAliasInfo.
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const override;
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
private:
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
/// functions and `runComprehensiveBufferize` may access this object.
BufferizationAliasInfo aliasInfo;
};
/// Replace an op with replacement values. The op is deleted. Tensor OpResults /// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values. /// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,

View File

@ -9,7 +9,9 @@
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/EquivalenceClasses.h"
namespace mlir { namespace mlir {
@ -21,6 +23,155 @@ class BufferizationAliasInfo;
struct AnalysisBufferizationOptions; struct AnalysisBufferizationOptions;
class BufferizationState; class BufferizationState;
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
/// executed after the analysis, but before bufferization. They can be used to
/// implement custom dialect-specific optimizations.
struct PostAnalysisStep {
virtual ~PostAnalysisStep() = default;
/// Run the post analysis step. This function may modify the IR, but must keep
/// `aliasInfo` consistent. Newly created operations and operations that
/// should be re-analyzed must be added to `newOps`.
virtual LogicalResult run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) = 0;
};
using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
/// Options for analysis-enabled bufferization.
struct AnalysisBufferizationOptions : public BufferizationOptions {
AnalysisBufferizationOptions() = default;
// AnalysisBufferizationOptions cannot be copied.
AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
/// Register a "post analysis" step. Such steps are executed after the
/// analysis, but before bufferization.
template <typename Step, typename... Args>
void addPostAnalysisStep(Args... args) {
postAnalysisSteps.emplace_back(
std::make_unique<Step>(std::forward<Args>(args)...));
}
/// Registered post analysis steps.
PostAnalysisStepList postAnalysisSteps;
};
/// The BufferizationAliasInfo class maintains a list of buffer aliases and
/// equivalence classes to support bufferization.
class BufferizationAliasInfo {
public:
explicit BufferizationAliasInfo(Operation *rootOp);
// BufferizationAliasInfo should be passed as a reference.
BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
/// beginning the alias and equivalence sets only contain `v` itself.
void createAliasInfoEntry(Value v);
/// Insert an info entry for `newValue` and merge its alias set with that of
/// `alias`.
void insertNewBufferAlias(Value newValue, Value alias);
/// Insert an info entry for `newValue` and merge its alias set with that of
/// `alias`. Additionally, merge their equivalence classes.
void insertNewBufferEquivalence(Value newValue, Value alias);
/// Set the inPlace bufferization spec to true.
/// Merge result's and operand's aliasing sets and iterate to a fixed point.
void bufferizeInPlace(OpOperand &operand, BufferizationState &state);
/// Set the inPlace bufferization spec to false.
void bufferizeOutOfPlace(OpOperand &operand);
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const {
return equivalentInfo.isEquivalent(v1, v2);
}
/// Union the alias sets of `v1` and `v2`.
void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
/// Union the equivalence classes of `v1` and `v2`.
void unionEquivalenceClasses(Value v1, Value v2) {
equivalentInfo.unionSets(v1, v2);
}
/// Apply `fun` to all the members of the equivalence class of `v`.
void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
/// Apply `fun` to all aliases of `v`.
void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
/// Mark a value as in-place bufferized.
void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); }
/// Return `true` if a value was marked as in-place bufferized.
bool isInPlace(OpOperand &opOperand) const;
private:
/// llvm::EquivalenceClasses wants comparable elements. This comparator uses
/// uses pointer comparison on the defining op. This is a poor man's
/// comparison but it's not like UnionFind needs ordering anyway.
struct ValueComparator {
bool operator()(const Value &lhs, const Value &rhs) const {
return lhs.getImpl() < rhs.getImpl();
}
};
using EquivalenceClassRangeType = llvm::iterator_range<
llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
/// Check that aliasInfo for `v` exists and return a reference to it.
EquivalenceClassRangeType getAliases(Value v) const;
/// Set of all OpResults that were decided to bufferize in-place.
llvm::DenseSet<OpOperand *> inplaceBufferized;
/// Auxiliary structure to store all the values a given value may alias with.
/// Alias information is "may be" conservative: In the presence of branches, a
/// value may alias with one of multiple other values. The concrete aliasing
/// value may not even be known at compile time. All such values are
/// considered to be aliases.
llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
/// Auxiliary structure to store all the equivalent buffer classes. Equivalent
/// buffer information is "must be" conservative: Only if two values are
/// guaranteed to be equivalent at runtime, they said to be equivalent. It is
/// possible that, in the presence of branches, it cannot be determined
/// statically if two values are equivalent. In that case, the values are
/// considered to be not equivalent.
llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
};
/// State for analysis-enabled bufferization. This class keeps track of alias
/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
/// in-place.
class AnalysisBufferizationState : public BufferizationState {
public:
AnalysisBufferizationState(Operation *op,
const AnalysisBufferizationOptions &options);
AnalysisBufferizationState(const AnalysisBufferizationState &) = delete;
virtual ~AnalysisBufferizationState() = default;
/// Return a reference to the BufferizationAliasInfo.
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const override;
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
private:
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
/// functions and `runComprehensiveBufferize` may access this object.
BufferizationAliasInfo aliasInfo;
};
/// Analyze `op` and its nested ops. Bufferization decisions are stored in /// Analyze `op` and its nested ops. Bufferization decisions are stored in
/// `state`. /// `state`.
LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state); LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state);

View File

@ -9,7 +9,7 @@
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
namespace mlir { namespace mlir {

View File

@ -9,7 +9,7 @@
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
namespace mlir { namespace mlir {

View File

@ -57,95 +57,6 @@ BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// BufferizationAliasInfo
//===----------------------------------------------------------------------===//
BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
rootOp->walk([&](Operation *op) {
for (Value v : op->getResults())
if (v.getType().isa<TensorType>())
createAliasInfoEntry(v);
for (Region &r : op->getRegions())
for (Block &b : r.getBlocks())
for (auto bbArg : b.getArguments())
if (bbArg.getType().isa<TensorType>())
createAliasInfoEntry(bbArg);
});
}
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
/// beginning the alias and equivalence sets only contain `v` itself.
void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
aliasInfo.insert(v);
equivalentInfo.insert(v);
}
/// Insert an info entry for `newValue` and merge its alias set with that of
/// `alias`.
void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
createAliasInfoEntry(newValue);
aliasInfo.unionSets(newValue, alias);
}
/// Insert an info entry for `newValue` and merge its alias set with that of
/// `alias`. Additionally, merge their equivalence classes.
void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
Value alias) {
insertNewBufferAlias(newValue, alias);
equivalentInfo.unionSets(newValue, alias);
}
/// Return `true` if a value was marked as in-place bufferized.
bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
return inplaceBufferized.contains(&operand);
}
/// Set the inPlace bufferization spec to true.
void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
BufferizationState &state) {
markInPlace(operand);
if (OpResult result = state.getAliasingOpResult(operand))
aliasInfo.unionSets(result, operand.get());
}
/// Set the inPlace bufferization spec to false.
void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
assert(!inplaceBufferized.contains(&operand) &&
"OpOperand was already decided to bufferize inplace");
}
/// Apply `fun` to all the members of the equivalence class of `v`.
void BufferizationAliasInfo::applyOnEquivalenceClass(
Value v, function_ref<void(Value)> fun) const {
auto leaderIt = equivalentInfo.findLeader(v);
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
++mit) {
fun(*mit);
}
}
/// Apply `fun` to all aliases of `v`.
void BufferizationAliasInfo::applyOnAliases(
Value v, function_ref<void(Value)> fun) const {
auto leaderIt = aliasInfo.findLeader(v);
for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
fun(*mit);
}
}
BufferizationAliasInfo::EquivalenceClassRangeType
BufferizationAliasInfo::getAliases(Value v) const {
DenseSet<Value> res;
auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
mit != meit; ++mit) {
res.insert(static_cast<Value>(*mit));
}
return BufferizationAliasInfo::EquivalenceClassRangeType(
aliasInfo.member_begin(it), aliasInfo.member_end());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Helper functions for BufferizableOpInterface // Helper functions for BufferizableOpInterface
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -291,28 +202,6 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
const BufferizationOptions &options) const BufferizationOptions &options)
: options(options) {} : options(options) {}
mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
AnalysisBufferizationState(Operation *op,
const AnalysisBufferizationOptions &options)
: BufferizationState(options), aliasInfo(op) {
// Set up alias sets for OpResults that must bufferize in-place. This should
// be done before making any other bufferization decisions.
op->walk([&](BufferizableOpInterface bufferizableOp) {
if (!options.isOpAllowed(bufferizableOp))
return WalkResult::skip();
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
if (opOperand.get().getType().isa<TensorType>())
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
if (OpResult opResult =
bufferizableOp.getAliasingOpResult(opOperand, *this))
aliasInfo.unionAliasSets(opOperand.get(), opResult);
aliasInfo.markInPlace(opOperand);
}
}
return WalkResult::advance();
});
}
// bufferization.to_memref is not allowed to change the rank. // bufferization.to_memref is not allowed to change the rank.
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
#ifndef NDEBUG #ifndef NDEBUG
@ -602,16 +491,6 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
return isa<FuncOp>(bbArg.getOwner()->getParentOp()); return isa<FuncOp>(bbArg.getOwner()->getParentOp());
} }
bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
isInPlace(OpOperand &opOperand) const {
return aliasInfo.isInPlace(opOperand);
}
bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
areEquivalentBufferizedValues(Value v1, Value v2) const {
return aliasInfo.areEquivalentBufferizedValues(v1, v2);
}
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType( MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
ShapedType shapedType, MemRefLayoutAttrInterface layout, ShapedType shapedType, MemRefLayoutAttrInterface layout,
Attribute memorySpace) { Attribute memorySpace) {

View File

@ -48,6 +48,7 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRBufferizableOpInterface MLIRBufferizableOpInterface
MLIRComprehensiveBufferize
MLIRIR MLIRIR
MLIRLinalg MLIRLinalg
MLIRTensor MLIRTensor
@ -58,6 +59,7 @@ add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRBufferizableOpInterface MLIRBufferizableOpInterface
MLIRComprehensiveBufferize
MLIRIR MLIRIR
MLIRSCF MLIRSCF
) )

View File

@ -98,6 +98,129 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
OpBuilder(op).getStrArrayAttr(inPlaceVector)); OpBuilder(op).getStrArrayAttr(inPlaceVector));
} }
//===----------------------------------------------------------------------===//
// BufferizationAliasInfo
//===----------------------------------------------------------------------===//
BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
rootOp->walk([&](Operation *op) {
for (Value v : op->getResults())
if (v.getType().isa<TensorType>())
createAliasInfoEntry(v);
for (Region &r : op->getRegions())
for (Block &b : r.getBlocks())
for (auto bbArg : b.getArguments())
if (bbArg.getType().isa<TensorType>())
createAliasInfoEntry(bbArg);
});
}
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
/// beginning the alias and equivalence sets only contain `v` itself.
void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
aliasInfo.insert(v);
equivalentInfo.insert(v);
}
/// Insert an info entry for `newValue` and merge its alias set with that of
/// `alias`.
void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
createAliasInfoEntry(newValue);
aliasInfo.unionSets(newValue, alias);
}
/// Insert an info entry for `newValue` and merge its alias set with that of
/// `alias`. Additionally, merge their equivalence classes.
void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
Value alias) {
insertNewBufferAlias(newValue, alias);
equivalentInfo.unionSets(newValue, alias);
}
/// Return `true` if a value was marked as in-place bufferized.
bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
return inplaceBufferized.contains(&operand);
}
/// Set the inPlace bufferization spec to true.
void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
BufferizationState &state) {
markInPlace(operand);
if (OpResult result = state.getAliasingOpResult(operand))
aliasInfo.unionSets(result, operand.get());
}
/// Set the inPlace bufferization spec to false.
void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
assert(!inplaceBufferized.contains(&operand) &&
"OpOperand was already decided to bufferize inplace");
}
/// Apply `fun` to all the members of the equivalence class of `v`.
void BufferizationAliasInfo::applyOnEquivalenceClass(
Value v, function_ref<void(Value)> fun) const {
auto leaderIt = equivalentInfo.findLeader(v);
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
++mit) {
fun(*mit);
}
}
/// Apply `fun` to all aliases of `v`.
void BufferizationAliasInfo::applyOnAliases(
Value v, function_ref<void(Value)> fun) const {
auto leaderIt = aliasInfo.findLeader(v);
for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
fun(*mit);
}
}
BufferizationAliasInfo::EquivalenceClassRangeType
BufferizationAliasInfo::getAliases(Value v) const {
DenseSet<Value> res;
auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
mit != meit; ++mit) {
res.insert(static_cast<Value>(*mit));
}
return BufferizationAliasInfo::EquivalenceClassRangeType(
aliasInfo.member_begin(it), aliasInfo.member_end());
}
//===----------------------------------------------------------------------===//
// AnalysisBufferizationState
//===----------------------------------------------------------------------===//
AnalysisBufferizationState::AnalysisBufferizationState(
Operation *op, const AnalysisBufferizationOptions &options)
: BufferizationState(options), aliasInfo(op) {
// Set up alias sets for OpResults that must bufferize in-place. This should
// be done before making any other bufferization decisions.
op->walk([&](BufferizableOpInterface bufferizableOp) {
if (!options.isOpAllowed(bufferizableOp))
return WalkResult::skip();
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
if (opOperand.get().getType().isa<TensorType>())
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
if (OpResult opResult =
bufferizableOp.getAliasingOpResult(opOperand, *this))
aliasInfo.unionAliasSets(opOperand.get(), opResult);
aliasInfo.markInPlace(opOperand);
}
}
return WalkResult::advance();
});
}
bool AnalysisBufferizationState::isInPlace(OpOperand &opOperand) const {
return aliasInfo.isInPlace(opOperand);
}
bool AnalysisBufferizationState::areEquivalentBufferizedValues(Value v1,
Value v2) const {
return aliasInfo.areEquivalentBufferizedValues(v1, v2);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Bufferization-specific alias analysis. // Bufferization-specific alias analysis.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -6601,6 +6601,7 @@ cc_library(
deps = [ deps = [
":BufferizableOpInterface", ":BufferizableOpInterface",
":BufferizationDialect", ":BufferizationDialect",
":ComprehensiveBufferize",
":IR", ":IR",
":LinalgOps", ":LinalgOps",
":LinalgStructuredOpsIncGen", ":LinalgStructuredOpsIncGen",
@ -6620,6 +6621,7 @@ cc_library(
deps = [ deps = [
":BufferizableOpInterface", ":BufferizableOpInterface",
":BufferizationDialect", ":BufferizationDialect",
":ComprehensiveBufferize",
":IR", ":IR",
":SCFDialect", ":SCFDialect",
":Support", ":Support",