mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-01-16 21:21:06 +00:00
[mlir][linalg][bufferize][NFC] Move analysis-related code to Comprehensive Bufferize
The code in `BufferizableOpInterface`'s header/source no longer contains any analysis code. This makes it easier to run the bufferization with a different analysis or without any analysis. Differential Revision: https://reviews.llvm.org/D117478
This commit is contained in:
parent
5ea98988c6
commit
cd0a923b4c
@ -18,7 +18,6 @@
|
|||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/EquivalenceClasses.h"
|
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
@ -36,23 +35,6 @@ class BufferizationAliasInfo;
|
|||||||
class BufferizableOpInterface;
|
class BufferizableOpInterface;
|
||||||
struct BufferizationOptions;
|
struct BufferizationOptions;
|
||||||
class BufferizationState;
|
class BufferizationState;
|
||||||
struct PostAnalysisStep;
|
|
||||||
|
|
||||||
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
|
|
||||||
/// executed after the analysis, but before bufferization. They can be used to
|
|
||||||
/// implement custom dialect-specific optimizations.
|
|
||||||
struct PostAnalysisStep {
|
|
||||||
virtual ~PostAnalysisStep() = default;
|
|
||||||
|
|
||||||
/// Run the post analysis step. This function may modify the IR, but must keep
|
|
||||||
/// `aliasInfo` consistent. Newly created operations and operations that
|
|
||||||
/// should be re-analyzed must be added to `newOps`.
|
|
||||||
virtual LogicalResult run(Operation *op, BufferizationState &state,
|
|
||||||
BufferizationAliasInfo &aliasInfo,
|
|
||||||
SmallVector<Operation *> &newOps) = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
|
|
||||||
|
|
||||||
/// Options for ComprehensiveBufferize.
|
/// Options for ComprehensiveBufferize.
|
||||||
struct BufferizationOptions {
|
struct BufferizationOptions {
|
||||||
@ -146,25 +128,6 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Options for analysis-enabled bufferization.
|
|
||||||
struct AnalysisBufferizationOptions : public BufferizationOptions {
|
|
||||||
AnalysisBufferizationOptions() = default;
|
|
||||||
|
|
||||||
// AnalysisBufferizationOptions cannot be copied.
|
|
||||||
AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
|
|
||||||
|
|
||||||
/// Register a "post analysis" step. Such steps are executed after the
|
|
||||||
/// analysis, but before bufferization.
|
|
||||||
template <typename Step, typename... Args>
|
|
||||||
void addPostAnalysisStep(Args... args) {
|
|
||||||
postAnalysisSteps.emplace_back(
|
|
||||||
std::make_unique<Step>(std::forward<Args>(args)...));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Registered post analysis steps.
|
|
||||||
PostAnalysisStepList postAnalysisSteps;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Specify fine-grain relationship between buffers to enable more analysis.
|
/// Specify fine-grain relationship between buffers to enable more analysis.
|
||||||
enum class BufferRelation {
|
enum class BufferRelation {
|
||||||
None,
|
None,
|
||||||
@ -173,93 +136,6 @@ enum class BufferRelation {
|
|||||||
Equivalent
|
Equivalent
|
||||||
};
|
};
|
||||||
|
|
||||||
/// The BufferizationAliasInfo class maintains a list of buffer aliases and
|
|
||||||
/// equivalence classes to support bufferization.
|
|
||||||
class BufferizationAliasInfo {
|
|
||||||
public:
|
|
||||||
explicit BufferizationAliasInfo(Operation *rootOp);
|
|
||||||
|
|
||||||
// BufferizationAliasInfo should be passed as a reference.
|
|
||||||
BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
|
|
||||||
|
|
||||||
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
|
|
||||||
/// beginning the alias and equivalence sets only contain `v` itself.
|
|
||||||
void createAliasInfoEntry(Value v);
|
|
||||||
|
|
||||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
|
||||||
/// `alias`.
|
|
||||||
void insertNewBufferAlias(Value newValue, Value alias);
|
|
||||||
|
|
||||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
|
||||||
/// `alias`. Additionally, merge their equivalence classes.
|
|
||||||
void insertNewBufferEquivalence(Value newValue, Value alias);
|
|
||||||
|
|
||||||
/// Set the inPlace bufferization spec to true.
|
|
||||||
/// Merge result's and operand's aliasing sets and iterate to a fixed point.
|
|
||||||
void bufferizeInPlace(OpOperand &operand, BufferizationState &state);
|
|
||||||
|
|
||||||
/// Set the inPlace bufferization spec to false.
|
|
||||||
void bufferizeOutOfPlace(OpOperand &operand);
|
|
||||||
|
|
||||||
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
|
||||||
bool areEquivalentBufferizedValues(Value v1, Value v2) const {
|
|
||||||
return equivalentInfo.isEquivalent(v1, v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Union the alias sets of `v1` and `v2`.
|
|
||||||
void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
|
|
||||||
|
|
||||||
/// Union the equivalence classes of `v1` and `v2`.
|
|
||||||
void unionEquivalenceClasses(Value v1, Value v2) {
|
|
||||||
equivalentInfo.unionSets(v1, v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Apply `fun` to all the members of the equivalence class of `v`.
|
|
||||||
void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
|
|
||||||
|
|
||||||
/// Apply `fun` to all aliases of `v`.
|
|
||||||
void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
|
|
||||||
|
|
||||||
/// Mark a value as in-place bufferized.
|
|
||||||
void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); }
|
|
||||||
|
|
||||||
/// Return `true` if a value was marked as in-place bufferized.
|
|
||||||
bool isInPlace(OpOperand &opOperand) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
/// llvm::EquivalenceClasses wants comparable elements. This comparator uses
|
|
||||||
/// uses pointer comparison on the defining op. This is a poor man's
|
|
||||||
/// comparison but it's not like UnionFind needs ordering anyway.
|
|
||||||
struct ValueComparator {
|
|
||||||
bool operator()(const Value &lhs, const Value &rhs) const {
|
|
||||||
return lhs.getImpl() < rhs.getImpl();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using EquivalenceClassRangeType = llvm::iterator_range<
|
|
||||||
llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
|
|
||||||
/// Check that aliasInfo for `v` exists and return a reference to it.
|
|
||||||
EquivalenceClassRangeType getAliases(Value v) const;
|
|
||||||
|
|
||||||
/// Set of all OpResults that were decided to bufferize in-place.
|
|
||||||
llvm::DenseSet<OpOperand *> inplaceBufferized;
|
|
||||||
|
|
||||||
/// Auxiliary structure to store all the values a given value may alias with.
|
|
||||||
/// Alias information is "may be" conservative: In the presence of branches, a
|
|
||||||
/// value may alias with one of multiple other values. The concrete aliasing
|
|
||||||
/// value may not even be known at compile time. All such values are
|
|
||||||
/// considered to be aliases.
|
|
||||||
llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
|
|
||||||
|
|
||||||
/// Auxiliary structure to store all the equivalent buffer classes. Equivalent
|
|
||||||
/// buffer information is "must be" conservative: Only if two values are
|
|
||||||
/// guaranteed to be equivalent at runtime, they said to be equivalent. It is
|
|
||||||
/// possible that, in the presence of branches, it cannot be determined
|
|
||||||
/// statically if two values are equivalent. In that case, the values are
|
|
||||||
/// considered to be not equivalent.
|
|
||||||
llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Return `true` if the given value is a BlockArgument of a FuncOp.
|
/// Return `true` if the given value is a BlockArgument of a FuncOp.
|
||||||
bool isFunctionArgument(Value value);
|
bool isFunctionArgument(Value value);
|
||||||
|
|
||||||
@ -391,33 +267,6 @@ private:
|
|||||||
const BufferizationOptions &options;
|
const BufferizationOptions &options;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// State for analysis-enabled bufferization. This class keeps track of alias
|
|
||||||
/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
|
|
||||||
/// in-place.
|
|
||||||
class AnalysisBufferizationState : public BufferizationState {
|
|
||||||
public:
|
|
||||||
AnalysisBufferizationState(Operation *op,
|
|
||||||
const AnalysisBufferizationOptions &options);
|
|
||||||
|
|
||||||
AnalysisBufferizationState(const AnalysisBufferizationState &) = delete;
|
|
||||||
|
|
||||||
virtual ~AnalysisBufferizationState() = default;
|
|
||||||
|
|
||||||
/// Return a reference to the BufferizationAliasInfo.
|
|
||||||
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
|
|
||||||
|
|
||||||
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
|
||||||
bool isInPlace(OpOperand &opOperand) const override;
|
|
||||||
|
|
||||||
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
|
||||||
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
|
|
||||||
/// functions and `runComprehensiveBufferize` may access this object.
|
|
||||||
BufferizationAliasInfo aliasInfo;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
|
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
|
||||||
/// must be replaced with memref values.
|
/// must be replaced with memref values.
|
||||||
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
|
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
|
||||||
|
@ -9,7 +9,9 @@
|
|||||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H
|
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H
|
||||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H
|
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVEBUFFERIZE_H
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "llvm/ADT/EquivalenceClasses.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
@ -21,6 +23,155 @@ class BufferizationAliasInfo;
|
|||||||
struct AnalysisBufferizationOptions;
|
struct AnalysisBufferizationOptions;
|
||||||
class BufferizationState;
|
class BufferizationState;
|
||||||
|
|
||||||
|
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
|
||||||
|
/// executed after the analysis, but before bufferization. They can be used to
|
||||||
|
/// implement custom dialect-specific optimizations.
|
||||||
|
struct PostAnalysisStep {
|
||||||
|
virtual ~PostAnalysisStep() = default;
|
||||||
|
|
||||||
|
/// Run the post analysis step. This function may modify the IR, but must keep
|
||||||
|
/// `aliasInfo` consistent. Newly created operations and operations that
|
||||||
|
/// should be re-analyzed must be added to `newOps`.
|
||||||
|
virtual LogicalResult run(Operation *op, BufferizationState &state,
|
||||||
|
BufferizationAliasInfo &aliasInfo,
|
||||||
|
SmallVector<Operation *> &newOps) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
|
||||||
|
|
||||||
|
/// Options for analysis-enabled bufferization.
|
||||||
|
struct AnalysisBufferizationOptions : public BufferizationOptions {
|
||||||
|
AnalysisBufferizationOptions() = default;
|
||||||
|
|
||||||
|
// AnalysisBufferizationOptions cannot be copied.
|
||||||
|
AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
|
||||||
|
|
||||||
|
/// Register a "post analysis" step. Such steps are executed after the
|
||||||
|
/// analysis, but before bufferization.
|
||||||
|
template <typename Step, typename... Args>
|
||||||
|
void addPostAnalysisStep(Args... args) {
|
||||||
|
postAnalysisSteps.emplace_back(
|
||||||
|
std::make_unique<Step>(std::forward<Args>(args)...));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Registered post analysis steps.
|
||||||
|
PostAnalysisStepList postAnalysisSteps;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// The BufferizationAliasInfo class maintains a list of buffer aliases and
|
||||||
|
/// equivalence classes to support bufferization.
|
||||||
|
class BufferizationAliasInfo {
|
||||||
|
public:
|
||||||
|
explicit BufferizationAliasInfo(Operation *rootOp);
|
||||||
|
|
||||||
|
// BufferizationAliasInfo should be passed as a reference.
|
||||||
|
BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
|
||||||
|
|
||||||
|
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
|
||||||
|
/// beginning the alias and equivalence sets only contain `v` itself.
|
||||||
|
void createAliasInfoEntry(Value v);
|
||||||
|
|
||||||
|
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||||
|
/// `alias`.
|
||||||
|
void insertNewBufferAlias(Value newValue, Value alias);
|
||||||
|
|
||||||
|
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||||
|
/// `alias`. Additionally, merge their equivalence classes.
|
||||||
|
void insertNewBufferEquivalence(Value newValue, Value alias);
|
||||||
|
|
||||||
|
/// Set the inPlace bufferization spec to true.
|
||||||
|
/// Merge result's and operand's aliasing sets and iterate to a fixed point.
|
||||||
|
void bufferizeInPlace(OpOperand &operand, BufferizationState &state);
|
||||||
|
|
||||||
|
/// Set the inPlace bufferization spec to false.
|
||||||
|
void bufferizeOutOfPlace(OpOperand &operand);
|
||||||
|
|
||||||
|
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
||||||
|
bool areEquivalentBufferizedValues(Value v1, Value v2) const {
|
||||||
|
return equivalentInfo.isEquivalent(v1, v2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Union the alias sets of `v1` and `v2`.
|
||||||
|
void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
|
||||||
|
|
||||||
|
/// Union the equivalence classes of `v1` and `v2`.
|
||||||
|
void unionEquivalenceClasses(Value v1, Value v2) {
|
||||||
|
equivalentInfo.unionSets(v1, v2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply `fun` to all the members of the equivalence class of `v`.
|
||||||
|
void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
|
||||||
|
|
||||||
|
/// Apply `fun` to all aliases of `v`.
|
||||||
|
void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
|
||||||
|
|
||||||
|
/// Mark a value as in-place bufferized.
|
||||||
|
void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); }
|
||||||
|
|
||||||
|
/// Return `true` if a value was marked as in-place bufferized.
|
||||||
|
bool isInPlace(OpOperand &opOperand) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// llvm::EquivalenceClasses wants comparable elements. This comparator uses
|
||||||
|
/// uses pointer comparison on the defining op. This is a poor man's
|
||||||
|
/// comparison but it's not like UnionFind needs ordering anyway.
|
||||||
|
struct ValueComparator {
|
||||||
|
bool operator()(const Value &lhs, const Value &rhs) const {
|
||||||
|
return lhs.getImpl() < rhs.getImpl();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using EquivalenceClassRangeType = llvm::iterator_range<
|
||||||
|
llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
|
||||||
|
/// Check that aliasInfo for `v` exists and return a reference to it.
|
||||||
|
EquivalenceClassRangeType getAliases(Value v) const;
|
||||||
|
|
||||||
|
/// Set of all OpResults that were decided to bufferize in-place.
|
||||||
|
llvm::DenseSet<OpOperand *> inplaceBufferized;
|
||||||
|
|
||||||
|
/// Auxiliary structure to store all the values a given value may alias with.
|
||||||
|
/// Alias information is "may be" conservative: In the presence of branches, a
|
||||||
|
/// value may alias with one of multiple other values. The concrete aliasing
|
||||||
|
/// value may not even be known at compile time. All such values are
|
||||||
|
/// considered to be aliases.
|
||||||
|
llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
|
||||||
|
|
||||||
|
/// Auxiliary structure to store all the equivalent buffer classes. Equivalent
|
||||||
|
/// buffer information is "must be" conservative: Only if two values are
|
||||||
|
/// guaranteed to be equivalent at runtime, they said to be equivalent. It is
|
||||||
|
/// possible that, in the presence of branches, it cannot be determined
|
||||||
|
/// statically if two values are equivalent. In that case, the values are
|
||||||
|
/// considered to be not equivalent.
|
||||||
|
llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// State for analysis-enabled bufferization. This class keeps track of alias
|
||||||
|
/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
|
||||||
|
/// in-place.
|
||||||
|
class AnalysisBufferizationState : public BufferizationState {
|
||||||
|
public:
|
||||||
|
AnalysisBufferizationState(Operation *op,
|
||||||
|
const AnalysisBufferizationOptions &options);
|
||||||
|
|
||||||
|
AnalysisBufferizationState(const AnalysisBufferizationState &) = delete;
|
||||||
|
|
||||||
|
virtual ~AnalysisBufferizationState() = default;
|
||||||
|
|
||||||
|
/// Return a reference to the BufferizationAliasInfo.
|
||||||
|
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
|
||||||
|
|
||||||
|
/// Return `true` if the given OpResult has been decided to bufferize inplace.
|
||||||
|
bool isInPlace(OpOperand &opOperand) const override;
|
||||||
|
|
||||||
|
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
|
||||||
|
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
|
||||||
|
/// functions and `runComprehensiveBufferize` may access this object.
|
||||||
|
BufferizationAliasInfo aliasInfo;
|
||||||
|
};
|
||||||
|
|
||||||
/// Analyze `op` and its nested ops. Bufferization decisions are stored in
|
/// Analyze `op` and its nested ops. Bufferization decisions are stored in
|
||||||
/// `state`.
|
/// `state`.
|
||||||
LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state);
|
LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state);
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H
|
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H
|
||||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H
|
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_LINALGINTERFACEIMPL_H
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
|
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
|
||||||
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
|
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H
|
||||||
|
|
||||||
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
@ -57,95 +57,6 @@ BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// BufferizationAliasInfo
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
|
|
||||||
rootOp->walk([&](Operation *op) {
|
|
||||||
for (Value v : op->getResults())
|
|
||||||
if (v.getType().isa<TensorType>())
|
|
||||||
createAliasInfoEntry(v);
|
|
||||||
for (Region &r : op->getRegions())
|
|
||||||
for (Block &b : r.getBlocks())
|
|
||||||
for (auto bbArg : b.getArguments())
|
|
||||||
if (bbArg.getType().isa<TensorType>())
|
|
||||||
createAliasInfoEntry(bbArg);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
|
|
||||||
/// beginning the alias and equivalence sets only contain `v` itself.
|
|
||||||
void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
|
|
||||||
aliasInfo.insert(v);
|
|
||||||
equivalentInfo.insert(v);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
|
||||||
/// `alias`.
|
|
||||||
void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
|
|
||||||
createAliasInfoEntry(newValue);
|
|
||||||
aliasInfo.unionSets(newValue, alias);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Insert an info entry for `newValue` and merge its alias set with that of
|
|
||||||
/// `alias`. Additionally, merge their equivalence classes.
|
|
||||||
void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
|
|
||||||
Value alias) {
|
|
||||||
insertNewBufferAlias(newValue, alias);
|
|
||||||
equivalentInfo.unionSets(newValue, alias);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return `true` if a value was marked as in-place bufferized.
|
|
||||||
bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
|
|
||||||
return inplaceBufferized.contains(&operand);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the inPlace bufferization spec to true.
|
|
||||||
void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
|
|
||||||
BufferizationState &state) {
|
|
||||||
markInPlace(operand);
|
|
||||||
if (OpResult result = state.getAliasingOpResult(operand))
|
|
||||||
aliasInfo.unionSets(result, operand.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the inPlace bufferization spec to false.
|
|
||||||
void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
|
|
||||||
assert(!inplaceBufferized.contains(&operand) &&
|
|
||||||
"OpOperand was already decided to bufferize inplace");
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Apply `fun` to all the members of the equivalence class of `v`.
|
|
||||||
void BufferizationAliasInfo::applyOnEquivalenceClass(
|
|
||||||
Value v, function_ref<void(Value)> fun) const {
|
|
||||||
auto leaderIt = equivalentInfo.findLeader(v);
|
|
||||||
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
|
|
||||||
++mit) {
|
|
||||||
fun(*mit);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Apply `fun` to all aliases of `v`.
|
|
||||||
void BufferizationAliasInfo::applyOnAliases(
|
|
||||||
Value v, function_ref<void(Value)> fun) const {
|
|
||||||
auto leaderIt = aliasInfo.findLeader(v);
|
|
||||||
for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
|
|
||||||
fun(*mit);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BufferizationAliasInfo::EquivalenceClassRangeType
|
|
||||||
BufferizationAliasInfo::getAliases(Value v) const {
|
|
||||||
DenseSet<Value> res;
|
|
||||||
auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
|
|
||||||
for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
|
|
||||||
mit != meit; ++mit) {
|
|
||||||
res.insert(static_cast<Value>(*mit));
|
|
||||||
}
|
|
||||||
return BufferizationAliasInfo::EquivalenceClassRangeType(
|
|
||||||
aliasInfo.member_begin(it), aliasInfo.member_end());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Helper functions for BufferizableOpInterface
|
// Helper functions for BufferizableOpInterface
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -291,28 +202,6 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
|
|||||||
const BufferizationOptions &options)
|
const BufferizationOptions &options)
|
||||||
: options(options) {}
|
: options(options) {}
|
||||||
|
|
||||||
mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
|
|
||||||
AnalysisBufferizationState(Operation *op,
|
|
||||||
const AnalysisBufferizationOptions &options)
|
|
||||||
: BufferizationState(options), aliasInfo(op) {
|
|
||||||
// Set up alias sets for OpResults that must bufferize in-place. This should
|
|
||||||
// be done before making any other bufferization decisions.
|
|
||||||
op->walk([&](BufferizableOpInterface bufferizableOp) {
|
|
||||||
if (!options.isOpAllowed(bufferizableOp))
|
|
||||||
return WalkResult::skip();
|
|
||||||
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
|
|
||||||
if (opOperand.get().getType().isa<TensorType>())
|
|
||||||
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
|
|
||||||
if (OpResult opResult =
|
|
||||||
bufferizableOp.getAliasingOpResult(opOperand, *this))
|
|
||||||
aliasInfo.unionAliasSets(opOperand.get(), opResult);
|
|
||||||
aliasInfo.markInPlace(opOperand);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return WalkResult::advance();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// bufferization.to_memref is not allowed to change the rank.
|
// bufferization.to_memref is not allowed to change the rank.
|
||||||
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
|
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
@ -602,16 +491,6 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
|
|||||||
return isa<FuncOp>(bbArg.getOwner()->getParentOp());
|
return isa<FuncOp>(bbArg.getOwner()->getParentOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
|
|
||||||
isInPlace(OpOperand &opOperand) const {
|
|
||||||
return aliasInfo.isInPlace(opOperand);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
|
|
||||||
areEquivalentBufferizedValues(Value v1, Value v2) const {
|
|
||||||
return aliasInfo.areEquivalentBufferizedValues(v1, v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
|
MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
|
||||||
ShapedType shapedType, MemRefLayoutAttrInterface layout,
|
ShapedType shapedType, MemRefLayoutAttrInterface layout,
|
||||||
Attribute memorySpace) {
|
Attribute memorySpace) {
|
||||||
|
@ -48,6 +48,7 @@ add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
|
|||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRBufferizableOpInterface
|
MLIRBufferizableOpInterface
|
||||||
|
MLIRComprehensiveBufferize
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRLinalg
|
MLIRLinalg
|
||||||
MLIRTensor
|
MLIRTensor
|
||||||
@ -58,6 +59,7 @@ add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
|
|||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRBufferizableOpInterface
|
MLIRBufferizableOpInterface
|
||||||
|
MLIRComprehensiveBufferize
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRSCF
|
MLIRSCF
|
||||||
)
|
)
|
||||||
|
@ -98,6 +98,129 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
|
|||||||
OpBuilder(op).getStrArrayAttr(inPlaceVector));
|
OpBuilder(op).getStrArrayAttr(inPlaceVector));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// BufferizationAliasInfo
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
|
||||||
|
rootOp->walk([&](Operation *op) {
|
||||||
|
for (Value v : op->getResults())
|
||||||
|
if (v.getType().isa<TensorType>())
|
||||||
|
createAliasInfoEntry(v);
|
||||||
|
for (Region &r : op->getRegions())
|
||||||
|
for (Block &b : r.getBlocks())
|
||||||
|
for (auto bbArg : b.getArguments())
|
||||||
|
if (bbArg.getType().isa<TensorType>())
|
||||||
|
createAliasInfoEntry(bbArg);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
|
||||||
|
/// beginning the alias and equivalence sets only contain `v` itself.
|
||||||
|
void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
|
||||||
|
aliasInfo.insert(v);
|
||||||
|
equivalentInfo.insert(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||||
|
/// `alias`.
|
||||||
|
void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
|
||||||
|
createAliasInfoEntry(newValue);
|
||||||
|
aliasInfo.unionSets(newValue, alias);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert an info entry for `newValue` and merge its alias set with that of
|
||||||
|
/// `alias`. Additionally, merge their equivalence classes.
|
||||||
|
void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
|
||||||
|
Value alias) {
|
||||||
|
insertNewBufferAlias(newValue, alias);
|
||||||
|
equivalentInfo.unionSets(newValue, alias);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return `true` if a value was marked as in-place bufferized.
|
||||||
|
bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
|
||||||
|
return inplaceBufferized.contains(&operand);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the inPlace bufferization spec to true.
|
||||||
|
void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
|
||||||
|
BufferizationState &state) {
|
||||||
|
markInPlace(operand);
|
||||||
|
if (OpResult result = state.getAliasingOpResult(operand))
|
||||||
|
aliasInfo.unionSets(result, operand.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the inPlace bufferization spec to false.
|
||||||
|
void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
|
||||||
|
assert(!inplaceBufferized.contains(&operand) &&
|
||||||
|
"OpOperand was already decided to bufferize inplace");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply `fun` to all the members of the equivalence class of `v`.
|
||||||
|
void BufferizationAliasInfo::applyOnEquivalenceClass(
|
||||||
|
Value v, function_ref<void(Value)> fun) const {
|
||||||
|
auto leaderIt = equivalentInfo.findLeader(v);
|
||||||
|
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
|
||||||
|
++mit) {
|
||||||
|
fun(*mit);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply `fun` to all aliases of `v`.
|
||||||
|
void BufferizationAliasInfo::applyOnAliases(
|
||||||
|
Value v, function_ref<void(Value)> fun) const {
|
||||||
|
auto leaderIt = aliasInfo.findLeader(v);
|
||||||
|
for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
|
||||||
|
fun(*mit);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BufferizationAliasInfo::EquivalenceClassRangeType
|
||||||
|
BufferizationAliasInfo::getAliases(Value v) const {
|
||||||
|
DenseSet<Value> res;
|
||||||
|
auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
|
||||||
|
for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
|
||||||
|
mit != meit; ++mit) {
|
||||||
|
res.insert(static_cast<Value>(*mit));
|
||||||
|
}
|
||||||
|
return BufferizationAliasInfo::EquivalenceClassRangeType(
|
||||||
|
aliasInfo.member_begin(it), aliasInfo.member_end());
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AnalysisBufferizationState
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
AnalysisBufferizationState::AnalysisBufferizationState(
|
||||||
|
Operation *op, const AnalysisBufferizationOptions &options)
|
||||||
|
: BufferizationState(options), aliasInfo(op) {
|
||||||
|
// Set up alias sets for OpResults that must bufferize in-place. This should
|
||||||
|
// be done before making any other bufferization decisions.
|
||||||
|
op->walk([&](BufferizableOpInterface bufferizableOp) {
|
||||||
|
if (!options.isOpAllowed(bufferizableOp))
|
||||||
|
return WalkResult::skip();
|
||||||
|
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
|
||||||
|
if (opOperand.get().getType().isa<TensorType>())
|
||||||
|
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
|
||||||
|
if (OpResult opResult =
|
||||||
|
bufferizableOp.getAliasingOpResult(opOperand, *this))
|
||||||
|
aliasInfo.unionAliasSets(opOperand.get(), opResult);
|
||||||
|
aliasInfo.markInPlace(opOperand);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnalysisBufferizationState::isInPlace(OpOperand &opOperand) const {
|
||||||
|
return aliasInfo.isInPlace(opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnalysisBufferizationState::areEquivalentBufferizedValues(Value v1,
|
||||||
|
Value v2) const {
|
||||||
|
return aliasInfo.areEquivalentBufferizedValues(v1, v2);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Bufferization-specific alias analysis.
|
// Bufferization-specific alias analysis.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -6601,6 +6601,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":BufferizableOpInterface",
|
":BufferizableOpInterface",
|
||||||
":BufferizationDialect",
|
":BufferizationDialect",
|
||||||
|
":ComprehensiveBufferize",
|
||||||
":IR",
|
":IR",
|
||||||
":LinalgOps",
|
":LinalgOps",
|
||||||
":LinalgStructuredOpsIncGen",
|
":LinalgStructuredOpsIncGen",
|
||||||
@ -6620,6 +6621,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":BufferizableOpInterface",
|
":BufferizableOpInterface",
|
||||||
":BufferizationDialect",
|
":BufferizationDialect",
|
||||||
|
":ComprehensiveBufferize",
|
||||||
":IR",
|
":IR",
|
||||||
":SCFDialect",
|
":SCFDialect",
|
||||||
":Support",
|
":Support",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user