mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-29 11:17:28 +00:00
[mlir][linalg][bufferize][NFC] Move analysis-related code to Comprehensive Bufferize
The code in `BufferizableOpInterface`'s header/source no longer contains any analysis code. This makes it easier to run the bufferization with a different analysis or without any analysis. Differential Revision: https://reviews.llvm.org/D117478
This commit is contained in:
parent
5ea98988c6
commit
cd0a923b4c
@ -18,7 +18,6 @@
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/EquivalenceClasses.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -36,23 +35,6 @@ class BufferizationAliasInfo;
|
||||
class BufferizableOpInterface;
|
||||
struct BufferizationOptions;
|
||||
class BufferizationState;
|
||||
struct PostAnalysisStep;
|
||||
|
||||
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
|
||||
/// executed after the analysis, but before bufferization. They can be used to
|
||||
/// implement custom dialect-specific optimizations.
|
||||
struct PostAnalysisStep {
|
||||
virtual ~PostAnalysisStep() = default;
|
||||
|
||||
/// Run the post analysis step. This function may modify the IR, but must keep
|
||||
/// `aliasInfo` consistent. Newly created operations and operations that
|
||||
/// should be re-analyzed must be added to `newOps`.
|
||||
virtual LogicalResult run(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) = 0;
|
||||
};
|
||||
|
||||
using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
|
||||
|
||||
/// Options for ComprehensiveBufferize.
|
||||
struct BufferizationOptions {
|
||||
@ -146,25 +128,6 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
/// Options for analysis-enabled bufferization.
|
||||
struct AnalysisBufferizationOptions : public BufferizationOptions {
|
||||
AnalysisBufferizationOptions() = default;
|
||||
|
||||
// AnalysisBufferizationOptions cannot be copied.
|
||||
AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
|
||||
|
||||
/// Register a "post analysis" step. Such steps are executed after the
|
||||
/// analysis, but before bufferization.
|
||||
template <typename Step, typename... Args>
|
||||
void addPostAnalysisStep(Args... args) {
|
||||
postAnalysisSteps.emplace_back(
|
||||
std::make_unique<Step>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
/// Registered post analysis steps.
|
||||
PostAnalysisStepList postAnalysisSteps;
|
||||
};
|
||||
|
||||
/// Specify fine-grain relationship between buffers to enable more analysis.
|
||||
enum class BufferRelation {
|
||||
None,
|
||||
@ -173,93 +136,6 @@ enum class BufferRelation {
|
||||
Equivalent
|
||||
};
|
||||
|
||||
/// The BufferizationAliasInfo class maintains a list of buffer aliases and
|
||||
/// equivalence classes to support bufferization.
|
||||
class BufferizationAliasInfo {
|
||||
public:
|
||||
explicit BufferizationAliasInfo(Operation *rootOp);
|
||||
|
||||
// BufferizationAliasInfo should be passed as a reference.
|
||||
BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
|
||||
|
||||
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
|
||||
/// beginning the alias and equivalence sets only contain `v` itself.
|
||||
void createAliasInfoEntry(Value v);
|
||||
|
||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||
/// `alias`.
|
||||
void insertNewBufferAlias(Value newValue, Value alias);
|
||||
|
||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||
/// `alias`. Additionally, merge their equivalence classes.
|
||||
void insertNewBufferEquivalence(Value newValue, Value alias);
|
||||
|
||||
/// Set the inPlace bufferization spec to true.
|
||||
/// Merge result's and operand's aliasing sets and iterate to a fixed point.
|
||||
void bufferizeInPlace(OpOperand &operand, BufferizationState &state);
|
||||
|
||||
/// Set the inPlace bufferization spec to false.
|
||||
void bufferizeOutOfPlace(OpOperand &operand);
|
||||
|
||||
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
||||
bool areEquivalentBufferizedValues(Value v1, Value v2) const {
|
||||
return equivalentInfo.isEquivalent(v1, v2);
|
||||
}
|
||||
|
||||
/// Union the alias sets of `v1` and `v2`.
|
||||
void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
|
||||
|
||||
/// Union the equivalence classes of `v1` and `v2`.
|
||||
void unionEquivalenceClasses(Value v1, Value v2) {
|
||||
equivalentInfo.unionSets(v1, v2);
|
||||
}
|
||||
|
||||
/// Apply `fun` to all the members of the equivalence class of `v`.
|
||||
void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
|
||||
|
||||
/// Apply `fun` to all aliases of `v`.
|
||||
void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
|
||||
|
||||
/// Mark a value as in-place bufferized.
|
||||
void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); }
|
||||
|
||||
/// Return `true` if a value was marked as in-place bufferized.
|
||||
bool isInPlace(OpOperand &opOperand) const;
|
||||
|
||||
private:
|
||||
/// llvm::EquivalenceClasses wants comparable elements. This comparator uses
|
||||
/// uses pointer comparison on the defining op. This is a poor man's
|
||||
/// comparison but it's not like UnionFind needs ordering anyway.
|
||||
struct ValueComparator {
|
||||
bool operator()(const Value &lhs, const Value &rhs) const {
|
||||
return lhs.getImpl() < rhs.getImpl();
|
||||
}
|
||||
};
|
||||
|
||||
using EquivalenceClassRangeType = llvm::iterator_range<
|
||||
llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
|
||||
/// Check that aliasInfo for `v` exists and return a reference to it.
|
||||
EquivalenceClassRangeType getAliases(Value v) const;
|
||||
|
||||
/// Set of all OpResults that were decided to bufferize in-place.
|
||||
llvm::DenseSet<OpOperand *> inplaceBufferized;
|
||||
|
||||
/// Auxiliary structure to store all the values a given value may alias with.
|
||||
/// Alias information is "may be" conservative: In the presence of branches, a
|
||||
/// value may alias with one of multiple other values. The concrete aliasing
|
||||
/// value may not even be known at compile time. All such values are
|
||||
/// considered to be aliases.
|
||||
llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
|
||||
|
||||
/// Auxiliary structure to store all the equivalent buffer classes. Equivalent
|
||||
/// buffer information is "must be" conservative: Only if two values are
|
||||
/// guaranteed to be equivalent at runtime, they said to be equivalent. It is
|
||||
/// possible that, in the presence of branches, it cannot be determined
|
||||
/// statically if two values are equivalent. In that case, the values are
|
||||
/// considered to be not equivalent.
|
||||
llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
|
||||
};
|
||||
|
||||
/// Return `true` if the given value is a BlockArgument of a FuncOp.
|
||||
bool isFunctionArgument(Value value);
|
||||
|
||||
@ -391,33 +267,6 @@ private:
|
||||
const BufferizationOptions &options;
|
||||
};
|
||||
|
||||
/// State for analysis-enabled bufferization. This class keeps track of alias
|
||||
/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
|
||||
/// in-place.
|
||||
class AnalysisBufferizationState : public BufferizationState {
|
||||
public:
|
||||
AnalysisBufferizationState(Operation *op,
|
||||
const AnalysisBufferizationOptions &options);
|
||||
|
||||
AnalysisBufferizationState(const AnalysisBufferizationState &) = delete;
|
||||
|
||||
virtual ~AnalysisBufferizationState() = default;
|
||||
|
||||
/// Return a reference to the BufferizationAliasInfo.
|
||||
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
|
||||
|
||||
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||
bool isInPlace(OpOperand &opOperand) const override;
|
||||
|
||||
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
||||
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
|
||||
|
||||
private:
|
||||
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
|
||||
/// functions and `runComprehensiveBufferize` may access this object.
|
||||
BufferizationAliasInfo aliasInfo;
|
||||
};
|
||||
|
||||
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
|
||||
/// must be replaced with memref values.
|
||||
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
|
||||
|
@ -9,7 +9,9 @@
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H
|
||||
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "llvm/ADT/EquivalenceClasses.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@ -21,6 +23,155 @@ class BufferizationAliasInfo;
|
||||
struct AnalysisBufferizationOptions;
|
||||
class BufferizationState;
|
||||
|
||||
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
|
||||
/// executed after the analysis, but before bufferization. They can be used to
|
||||
/// implement custom dialect-specific optimizations.
|
||||
struct PostAnalysisStep {
|
||||
virtual ~PostAnalysisStep() = default;
|
||||
|
||||
/// Run the post analysis step. This function may modify the IR, but must keep
|
||||
/// `aliasInfo` consistent. Newly created operations and operations that
|
||||
/// should be re-analyzed must be added to `newOps`.
|
||||
virtual LogicalResult run(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) = 0;
|
||||
};
|
||||
|
||||
using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
|
||||
|
||||
/// Options for analysis-enabled bufferization.
|
||||
struct AnalysisBufferizationOptions : public BufferizationOptions {
|
||||
AnalysisBufferizationOptions() = default;
|
||||
|
||||
// AnalysisBufferizationOptions cannot be copied.
|
||||
AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
|
||||
|
||||
/// Register a "post analysis" step. Such steps are executed after the
|
||||
/// analysis, but before bufferization.
|
||||
template <typename Step, typename... Args>
|
||||
void addPostAnalysisStep(Args... args) {
|
||||
postAnalysisSteps.emplace_back(
|
||||
std::make_unique<Step>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
/// Registered post analysis steps.
|
||||
PostAnalysisStepList postAnalysisSteps;
|
||||
};
|
||||
|
||||
/// The BufferizationAliasInfo class maintains a list of buffer aliases and
|
||||
/// equivalence classes to support bufferization.
|
||||
class BufferizationAliasInfo {
|
||||
public:
|
||||
explicit BufferizationAliasInfo(Operation *rootOp);
|
||||
|
||||
// BufferizationAliasInfo should be passed as a reference.
|
||||
BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
|
||||
|
||||
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
|
||||
/// beginning the alias and equivalence sets only contain `v` itself.
|
||||
void createAliasInfoEntry(Value v);
|
||||
|
||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||
/// `alias`.
|
||||
void insertNewBufferAlias(Value newValue, Value alias);
|
||||
|
||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||
/// `alias`. Additionally, merge their equivalence classes.
|
||||
void insertNewBufferEquivalence(Value newValue, Value alias);
|
||||
|
||||
/// Set the inPlace bufferization spec to true.
|
||||
/// Merge result's and operand's aliasing sets and iterate to a fixed point.
|
||||
void bufferizeInPlace(OpOperand &operand, BufferizationState &state);
|
||||
|
||||
/// Set the inPlace bufferization spec to false.
|
||||
void bufferizeOutOfPlace(OpOperand &operand);
|
||||
|
||||
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
||||
bool areEquivalentBufferizedValues(Value v1, Value v2) const {
|
||||
return equivalentInfo.isEquivalent(v1, v2);
|
||||
}
|
||||
|
||||
/// Union the alias sets of `v1` and `v2`.
|
||||
void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
|
||||
|
||||
/// Union the equivalence classes of `v1` and `v2`.
|
||||
void unionEquivalenceClasses(Value v1, Value v2) {
|
||||
equivalentInfo.unionSets(v1, v2);
|
||||
}
|
||||
|
||||
/// Apply `fun` to all the members of the equivalence class of `v`.
|
||||
void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
|
||||
|
||||
/// Apply `fun` to all aliases of `v`.
|
||||
void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
|
||||
|
||||
/// Mark a value as in-place bufferized.
|
||||
void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); }
|
||||
|
||||
/// Return `true` if a value was marked as in-place bufferized.
|
||||
bool isInPlace(OpOperand &opOperand) const;
|
||||
|
||||
private:
|
||||
/// llvm::EquivalenceClasses wants comparable elements. This comparator uses
|
||||
/// uses pointer comparison on the defining op. This is a poor man's
|
||||
/// comparison but it's not like UnionFind needs ordering anyway.
|
||||
struct ValueComparator {
|
||||
bool operator()(const Value &lhs, const Value &rhs) const {
|
||||
return lhs.getImpl() < rhs.getImpl();
|
||||
}
|
||||
};
|
||||
|
||||
using EquivalenceClassRangeType = llvm::iterator_range<
|
||||
llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
|
||||
/// Check that aliasInfo for `v` exists and return a reference to it.
|
||||
EquivalenceClassRangeType getAliases(Value v) const;
|
||||
|
||||
/// Set of all OpResults that were decided to bufferize in-place.
|
||||
llvm::DenseSet<OpOperand *> inplaceBufferized;
|
||||
|
||||
/// Auxiliary structure to store all the values a given value may alias with.
|
||||
/// Alias information is "may be" conservative: In the presence of branches, a
|
||||
/// value may alias with one of multiple other values. The concrete aliasing
|
||||
/// value may not even be known at compile time. All such values are
|
||||
/// considered to be aliases.
|
||||
llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
|
||||
|
||||
/// Auxiliary structure to store all the equivalent buffer classes. Equivalent
|
||||
/// buffer information is "must be" conservative: Only if two values are
|
||||
/// guaranteed to be equivalent at runtime, they said to be equivalent. It is
|
||||
/// possible that, in the presence of branches, it cannot be determined
|
||||
/// statically if two values are equivalent. In that case, the values are
|
||||
/// considered to be not equivalent.
|
||||
llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
|
||||
};
|
||||
|
||||
/// State for analysis-enabled bufferization. This class keeps track of alias
|
||||
/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
|
||||
/// in-place.
|
||||
class AnalysisBufferizationState : public BufferizationState {
|
||||
public:
|
||||
AnalysisBufferizationState(Operation *op,
|
||||
const AnalysisBufferizationOptions &options);
|
||||
|
||||
AnalysisBufferizationState(const AnalysisBufferizationState &) = delete;
|
||||
|
||||
virtual ~AnalysisBufferizationState() = default;
|
||||
|
||||
/// Return a reference to the BufferizationAliasInfo.
|
||||
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
|
||||
|
||||
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||
bool isInPlace(OpOperand &opOperand) const override;
|
||||
|
||||
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
||||
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
|
||||
|
||||
private:
|
||||
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
|
||||
/// functions and `runComprehensiveBufferize` may access this object.
|
||||
BufferizationAliasInfo aliasInfo;
|
||||
};
|
||||
|
||||
/// Analyze `op` and its nested ops. Bufferization decisions are stored in
|
||||
/// `state`.
|
||||
LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state);
|
||||
|
@ -9,7 +9,7 @@
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H
|
||||
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
@ -9,7 +9,7 @@
|
||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
|
||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
|
||||
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
@ -57,95 +57,6 @@ BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferizationAliasInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
|
||||
rootOp->walk([&](Operation *op) {
|
||||
for (Value v : op->getResults())
|
||||
if (v.getType().isa<TensorType>())
|
||||
createAliasInfoEntry(v);
|
||||
for (Region &r : op->getRegions())
|
||||
for (Block &b : r.getBlocks())
|
||||
for (auto bbArg : b.getArguments())
|
||||
if (bbArg.getType().isa<TensorType>())
|
||||
createAliasInfoEntry(bbArg);
|
||||
});
|
||||
}
|
||||
|
||||
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
|
||||
/// beginning the alias and equivalence sets only contain `v` itself.
|
||||
void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
|
||||
aliasInfo.insert(v);
|
||||
equivalentInfo.insert(v);
|
||||
}
|
||||
|
||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||
/// `alias`.
|
||||
void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
|
||||
createAliasInfoEntry(newValue);
|
||||
aliasInfo.unionSets(newValue, alias);
|
||||
}
|
||||
|
||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||
/// `alias`. Additionally, merge their equivalence classes.
|
||||
void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
|
||||
Value alias) {
|
||||
insertNewBufferAlias(newValue, alias);
|
||||
equivalentInfo.unionSets(newValue, alias);
|
||||
}
|
||||
|
||||
/// Return `true` if a value was marked as in-place bufferized.
|
||||
bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
|
||||
return inplaceBufferized.contains(&operand);
|
||||
}
|
||||
|
||||
/// Set the inPlace bufferization spec to true.
|
||||
void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
|
||||
BufferizationState &state) {
|
||||
markInPlace(operand);
|
||||
if (OpResult result = state.getAliasingOpResult(operand))
|
||||
aliasInfo.unionSets(result, operand.get());
|
||||
}
|
||||
|
||||
/// Set the inPlace bufferization spec to false.
|
||||
void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
|
||||
assert(!inplaceBufferized.contains(&operand) &&
|
||||
"OpOperand was already decided to bufferize inplace");
|
||||
}
|
||||
|
||||
/// Apply `fun` to all the members of the equivalence class of `v`.
|
||||
void BufferizationAliasInfo::applyOnEquivalenceClass(
|
||||
Value v, function_ref<void(Value)> fun) const {
|
||||
auto leaderIt = equivalentInfo.findLeader(v);
|
||||
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
|
||||
++mit) {
|
||||
fun(*mit);
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply `fun` to all aliases of `v`.
|
||||
void BufferizationAliasInfo::applyOnAliases(
|
||||
Value v, function_ref<void(Value)> fun) const {
|
||||
auto leaderIt = aliasInfo.findLeader(v);
|
||||
for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
|
||||
fun(*mit);
|
||||
}
|
||||
}
|
||||
|
||||
BufferizationAliasInfo::EquivalenceClassRangeType
|
||||
BufferizationAliasInfo::getAliases(Value v) const {
|
||||
DenseSet<Value> res;
|
||||
auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
|
||||
for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
|
||||
mit != meit; ++mit) {
|
||||
res.insert(static_cast<Value>(*mit));
|
||||
}
|
||||
return BufferizationAliasInfo::EquivalenceClassRangeType(
|
||||
aliasInfo.member_begin(it), aliasInfo.member_end());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper functions for BufferizableOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -291,28 +202,6 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
|
||||
const BufferizationOptions &options)
|
||||
: options(options) {}
|
||||
|
||||
mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
|
||||
AnalysisBufferizationState(Operation *op,
|
||||
const AnalysisBufferizationOptions &options)
|
||||
: BufferizationState(options), aliasInfo(op) {
|
||||
// Set up alias sets for OpResults that must bufferize in-place. This should
|
||||
// be done before making any other bufferization decisions.
|
||||
op->walk([&](BufferizableOpInterface bufferizableOp) {
|
||||
if (!options.isOpAllowed(bufferizableOp))
|
||||
return WalkResult::skip();
|
||||
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
|
||||
if (opOperand.get().getType().isa<TensorType>())
|
||||
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
|
||||
if (OpResult opResult =
|
||||
bufferizableOp.getAliasingOpResult(opOperand, *this))
|
||||
aliasInfo.unionAliasSets(opOperand.get(), opResult);
|
||||
aliasInfo.markInPlace(opOperand);
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
|
||||
// bufferization.to_memref is not allowed to change the rank.
|
||||
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
|
||||
#ifndef NDEBUG
|
||||
@ -602,16 +491,6 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
|
||||
return isa<FuncOp>(bbArg.getOwner()->getParentOp());
|
||||
}
|
||||
|
||||
bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
|
||||
isInPlace(OpOperand &opOperand) const {
|
||||
return aliasInfo.isInPlace(opOperand);
|
||||
}
|
||||
|
||||
bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
|
||||
areEquivalentBufferizedValues(Value v1, Value v2) const {
|
||||
return aliasInfo.areEquivalentBufferizedValues(v1, v2);
|
||||
}
|
||||
|
||||
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
|
||||
ShapedType shapedType, MemRefLayoutAttrInterface layout,
|
||||
Attribute memorySpace) {
|
||||
|
@ -48,6 +48,7 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRBufferizableOpInterface
|
||||
MLIRComprehensiveBufferize
|
||||
MLIRIR
|
||||
MLIRLinalg
|
||||
MLIRTensor
|
||||
@ -58,6 +59,7 @@ add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRBufferizableOpInterface
|
||||
MLIRComprehensiveBufferize
|
||||
MLIRIR
|
||||
MLIRSCF
|
||||
)
|
||||
|
@ -98,6 +98,129 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
|
||||
OpBuilder(op).getStrArrayAttr(inPlaceVector));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferizationAliasInfo
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
|
||||
rootOp->walk([&](Operation *op) {
|
||||
for (Value v : op->getResults())
|
||||
if (v.getType().isa<TensorType>())
|
||||
createAliasInfoEntry(v);
|
||||
for (Region &r : op->getRegions())
|
||||
for (Block &b : r.getBlocks())
|
||||
for (auto bbArg : b.getArguments())
|
||||
if (bbArg.getType().isa<TensorType>())
|
||||
createAliasInfoEntry(bbArg);
|
||||
});
|
||||
}
|
||||
|
||||
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
|
||||
/// beginning the alias and equivalence sets only contain `v` itself.
|
||||
void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
|
||||
aliasInfo.insert(v);
|
||||
equivalentInfo.insert(v);
|
||||
}
|
||||
|
||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||
/// `alias`.
|
||||
void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
|
||||
createAliasInfoEntry(newValue);
|
||||
aliasInfo.unionSets(newValue, alias);
|
||||
}
|
||||
|
||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||
/// `alias`. Additionally, merge their equivalence classes.
|
||||
void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
|
||||
Value alias) {
|
||||
insertNewBufferAlias(newValue, alias);
|
||||
equivalentInfo.unionSets(newValue, alias);
|
||||
}
|
||||
|
||||
/// Return `true` if a value was marked as in-place bufferized.
|
||||
bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
|
||||
return inplaceBufferized.contains(&operand);
|
||||
}
|
||||
|
||||
/// Set the inPlace bufferization spec to true.
|
||||
void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
|
||||
BufferizationState &state) {
|
||||
markInPlace(operand);
|
||||
if (OpResult result = state.getAliasingOpResult(operand))
|
||||
aliasInfo.unionSets(result, operand.get());
|
||||
}
|
||||
|
||||
/// Set the inPlace bufferization spec to false.
|
||||
void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
|
||||
assert(!inplaceBufferized.contains(&operand) &&
|
||||
"OpOperand was already decided to bufferize inplace");
|
||||
}
|
||||
|
||||
/// Apply `fun` to all the members of the equivalence class of `v`.
|
||||
void BufferizationAliasInfo::applyOnEquivalenceClass(
|
||||
Value v, function_ref<void(Value)> fun) const {
|
||||
auto leaderIt = equivalentInfo.findLeader(v);
|
||||
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
|
||||
++mit) {
|
||||
fun(*mit);
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply `fun` to all aliases of `v`.
|
||||
void BufferizationAliasInfo::applyOnAliases(
|
||||
Value v, function_ref<void(Value)> fun) const {
|
||||
auto leaderIt = aliasInfo.findLeader(v);
|
||||
for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
|
||||
fun(*mit);
|
||||
}
|
||||
}
|
||||
|
||||
BufferizationAliasInfo::EquivalenceClassRangeType
|
||||
BufferizationAliasInfo::getAliases(Value v) const {
|
||||
DenseSet<Value> res;
|
||||
auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
|
||||
for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
|
||||
mit != meit; ++mit) {
|
||||
res.insert(static_cast<Value>(*mit));
|
||||
}
|
||||
return BufferizationAliasInfo::EquivalenceClassRangeType(
|
||||
aliasInfo.member_begin(it), aliasInfo.member_end());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AnalysisBufferizationState
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AnalysisBufferizationState::AnalysisBufferizationState(
|
||||
Operation *op, const AnalysisBufferizationOptions &options)
|
||||
: BufferizationState(options), aliasInfo(op) {
|
||||
// Set up alias sets for OpResults that must bufferize in-place. This should
|
||||
// be done before making any other bufferization decisions.
|
||||
op->walk([&](BufferizableOpInterface bufferizableOp) {
|
||||
if (!options.isOpAllowed(bufferizableOp))
|
||||
return WalkResult::skip();
|
||||
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
|
||||
if (opOperand.get().getType().isa<TensorType>())
|
||||
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
|
||||
if (OpResult opResult =
|
||||
bufferizableOp.getAliasingOpResult(opOperand, *this))
|
||||
aliasInfo.unionAliasSets(opOperand.get(), opResult);
|
||||
aliasInfo.markInPlace(opOperand);
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
|
||||
bool AnalysisBufferizationState::isInPlace(OpOperand &opOperand) const {
|
||||
return aliasInfo.isInPlace(opOperand);
|
||||
}
|
||||
|
||||
bool AnalysisBufferizationState::areEquivalentBufferizedValues(Value v1,
|
||||
Value v2) const {
|
||||
return aliasInfo.areEquivalentBufferizedValues(v1, v2);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bufferization-specific alias analysis.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -6601,6 +6601,7 @@ cc_library(
|
||||
deps = [
|
||||
":BufferizableOpInterface",
|
||||
":BufferizationDialect",
|
||||
":ComprehensiveBufferize",
|
||||
":IR",
|
||||
":LinalgOps",
|
||||
":LinalgStructuredOpsIncGen",
|
||||
@ -6620,6 +6621,7 @@ cc_library(
|
||||
deps = [
|
||||
":BufferizableOpInterface",
|
||||
":BufferizationDialect",
|
||||
":ComprehensiveBufferize",
|
||||
":IR",
|
||||
":SCFDialect",
|
||||
":Support",
|
||||
|
Loading…
Reference in New Issue
Block a user