[mlir][DialectUtils] Cleanup IndexingUtils and provide more affine variants while reusing implementations

Differential Revision: https://reviews.llvm.org/D145784
This commit is contained in:
Nicolas Vasilache 2023-03-13 12:24:58 -07:00
parent c113d0b766
commit 203fad476b
11 changed files with 348 additions and 145 deletions

View File

@ -23,35 +23,68 @@
namespace mlir {
class ArrayAttr;
/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
//===----------------------------------------------------------------------===//
// Utils that operate on static integer values.
//===----------------------------------------------------------------------===//
/// Given the strides together with a linear index in the dimension
/// space, returns the vector-space offsets in each dimension for a
/// de-linearized index.
SmallVector<int64_t> delinearize(ArrayRef<int64_t> strides,
int64_t linearIndex);
/// Given a set of sizes, return the suffix product.
///
/// When applied to slicing, this is the calculation needed to derive the
/// strides (i.e. the number of linear indices to skip along the (k-1) most
/// minor dimensions to get the next k-slice).
///
/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
///
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
///
/// `sizes` elements are asserted to be non-negative.
///
/// Return an empty vector if `sizes` is empty.
SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);
inline SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes) {
return computeSuffixProduct(sizes);
}
/// Given a set of sizes, compute and return the strides (i.e. the number of
/// linear incides to skip along the (k-1) most minor dimensions to get the next
/// k-slice). This is also the basis that one can use to linearize an n-D offset
/// confined to `[0 .. sizes]`.
SmallVector<int64_t> computeStrides(ArrayRef<int64_t> sizes);
/// Return a vector containing llvm::zip of v1 and v2 multiplied elementwise.
/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
///
/// Return an empty vector if `v1` and `v2` are empty.
SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
ArrayRef<int64_t> v2);
/// Compute and return the multi-dimensional integral ratio of `subShape` to
/// the trailing dimensions of `shape`. This represents how many times
/// `subShape` fits within `shape`.
/// If integral division is not possible, return std::nullopt.
/// Return the number of elements of basis (i.e. the max linear index).
/// Return `0` if `basis` is empty.
///
/// `basis` elements are asserted to be non-negative.
///
/// Return `0` if `basis` is empty.
int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
/// Return the linearized index of 'offsets' w.r.t. 'basis'.
///
/// `basis` elements are asserted to be non-negative.
int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
/// Given the strides together with a linear index in the dimension space,
/// return the vector-space offsets in each dimension for a de-linearized index.
/// `strides` elements are asserted to be non-negative.
///
/// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the
/// vector of int64_t
/// `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]`
SmallVector<int64_t> delinearize(int64_t linearIndex,
ArrayRef<int64_t> strides);
/// Return the multi-dimensional integral ratio of `subShape` to the trailing
/// dimensions of `shape`. This represents how many times `subShape` fits
/// within `shape`. If integral division is not possible, return std::nullopt.
/// The trailing `subShape.size()` entries of both shapes are assumed (and
/// enforced) to only contain noonnegative values.
/// enforced) to only contain non-negative values.
///
/// Examples:
/// - shapeRatio({3, 5, 8}, {2, 5, 2}) returns {3, 2, 1}.
/// - shapeRatio({3, 8}, {2, 5, 2}) returns std::nullopt (subshape has higher
/// - shapeRatio({3, 8}, {2, 5, 2}) returns std::nullopt (subshape has
/// higher
/// rank).
/// - shapeRatio({42, 2, 10, 32}, {2, 5, 2}) returns {42, 1, 2, 16} which is
/// derived as {42(leading shape dim), 2/2, 10/5, 32/2}.
@ -60,14 +93,96 @@ SmallVector<int64_t> computeElementwiseMul(ArrayRef<int64_t> v1,
std::optional<SmallVector<int64_t>>
computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape);
//===----------------------------------------------------------------------===//
// Utils that operate on AffineExpr.
//===----------------------------------------------------------------------===//
/// Given a set of sizes, return the suffix product.
///
/// When applied to slicing, this is the calculation needed to derive the
/// strides (i.e. the number of linear indices to skip along the (k-1) most
/// minor dimensions to get the next k-slice).
///
/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
///
/// Assuming `sizes` is `[s0, .. sn]`, return the vector<AffineExpr>
/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
///
/// It is the caller's responsibility to pass proper AffineExpr kind that
/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
/// by an AffineDimExpr).
///
/// `sizes` elements are expected to bind to non-negative values.
///
/// Return an empty vector if `sizes` is empty.
SmallVector<AffineExpr> computeSuffixProduct(ArrayRef<AffineExpr> sizes);
inline SmallVector<AffineExpr> computeStrides(ArrayRef<AffineExpr> sizes) {
return computeSuffixProduct(sizes);
}
/// Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
///
/// It is the caller's responsibility to pass proper AffineExpr kind that
/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
/// by an AffineDimExpr).
///
/// Return an empty vector if `v1` and `v2` are empty.
SmallVector<AffineExpr> computeElementwiseMul(ArrayRef<AffineExpr> v1,
ArrayRef<AffineExpr> v2);
/// Return the number of elements of basis (i.e. the max linear index).
/// Return `0` if `basis` is empty.
int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
///
/// It is the caller's responsibility to pass proper AffineExpr kind that
/// result in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide
/// by an AffineDimExpr).
///
/// `basis` elements are expected to bind to non-negative values.
///
/// Return the `0` AffineConstantExpr if `basis` is empty.
AffineExpr computeMaxLinearIndex(MLIRContext *ctx, ArrayRef<AffineExpr> basis);
/// Return the linearized index of 'offsets' w.r.t. 'basis'.
///
/// Assuming `offsets` is `[o0, .. on]` and `basis` is `[b0, .. bn]`, return the
/// AffineExpr `o0 * b0 + .. + on * bn`.
///
/// It is the caller's responsibility to pass proper AffineExpr kind that result
/// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an
/// AffineDimExpr).
///
/// `basis` elements are expected to bind to non-negative values.
AffineExpr linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
ArrayRef<AffineExpr> basis);
AffineExpr linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
ArrayRef<int64_t> basis);
/// Given the strides together with a linear index in the dimension space,
/// return the vector-space offsets in each dimension for a de-linearized index.
///
/// Let `li = linearIndex`, assuming `strides` are `[s0, .. sn]`, return the
/// vector of AffineExpr
/// `[li % s0, (li / s0) % s1, ..., (li / s0 / .. / sn-1) % sn]`
///
/// It is the caller's responsibility to pass proper AffineExpr kind that result
/// in valid AffineExpr (i.e. cannot multiply 2 AffineDimExpr or divide by an
/// AffineDimExpr).
///
/// `strides` elements are expected to bind to non-negative values.
SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
ArrayRef<AffineExpr> strides);
SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
ArrayRef<int64_t> strides);
//===----------------------------------------------------------------------===//
// Permutation utils.
//===----------------------------------------------------------------------===//
/// Apply the permutation defined by `permutation` to `inVec`.
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation
/// vector `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a',
/// 'b']`.
template <typename T, unsigned N>
void applyPermutationToVector(SmallVector<T, N> &inVec,
ArrayRef<int64_t> permutation) {
@ -83,18 +198,11 @@ SmallVector<int64_t> invertPermutationVector(ArrayRef<int64_t> permutation);
/// Method to check if an interchange vector is a permutation.
bool isPermutationVector(ArrayRef<int64_t> interchange);
/// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
/// Helper to return a subset of `arrayAttr` as a vector of int64_t.
// TODO: Port everything relevant to DenseArrayAttr and drop this util.
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
unsigned dropBack = 0);
/// Computes and returns linearized affine expression w.r.t. `basis`.
mlir::AffineExpr getLinearAffineExpr(ArrayRef<int64_t> basis, mlir::Builder &b);
/// Given the strides in the dimension space, returns the affine expressions for
/// vector-space offsets in each dimension for a de-linearized index.
SmallVector<mlir::AffineExpr>
getDelinearizedAffineExpr(ArrayRef<int64_t> strides, mlir::Builder &b);
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H

View File

@ -321,13 +321,6 @@ void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
bindSymbols<N + 1, AffineExprTy2 &...>(ctx, exprs...);
}
template <typename AffineExprTy>
void bindSymbolsList(MLIRContext *ctx, SmallVectorImpl<AffineExprTy> &exprs) {
int idx = 0;
for (AffineExprTy &e : exprs)
e = getAffineSymbolExpr(idx++, ctx);
}
} // namespace detail
/// Bind a list of AffineExpr references to DimExpr at positions:
@ -337,6 +330,13 @@ void bindDims(MLIRContext *ctx, AffineExprTy &...exprs) {
detail::bindDims<0>(ctx, exprs...);
}
template <typename AffineExprTy>
void bindDimsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
int idx = 0;
for (AffineExprTy &e : exprs)
e = getAffineDimExpr(idx++, ctx);
}
/// Bind a list of AffineExpr references to SymbolExpr at positions:
/// [0 .. sizeof...(exprs)]
template <typename... AffineExprTy>
@ -344,6 +344,13 @@ void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs) {
detail::bindSymbols<0>(ctx, exprs...);
}
template <typename AffineExprTy>
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
int idx = 0;
for (AffineExprTy &e : exprs)
e = getAffineSymbolExpr(idx++, ctx);
}
} // namespace mlir
namespace llvm {

View File

@ -103,7 +103,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
loc, DenseElementsAttr::get(vecType, initValueAttr));
SmallVector<int64_t> strides = computeStrides(shape);
for (int64_t linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(strides, linearIndex);
SmallVector<int64_t> positions = delinearize(linearIndex, strides);
SmallVector<Value> operands;
for (Value input : op->getOperands())
operands.push_back(

View File

@ -89,7 +89,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
vecType, FloatAttr::get(vecType.getElementType(), 0.0)));
SmallVector<int64_t> strides = computeStrides(shape);
for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
SmallVector<int64_t> positions = delinearize(strides, linearIndex);
SmallVector<int64_t> positions = delinearize(linearIndex, strides);
SmallVector<Value> operands;
for (auto input : op->getOperands())
operands.push_back(

View File

@ -134,7 +134,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
SmallVector<Value> results(maxIndex);
for (int64_t i = 0; i < maxIndex; ++i) {
auto offsets = delinearize(strides, i);
auto offsets = delinearize(i, strides);
SmallVector<Value> extracted(expandedOperands.size());
for (const auto &tuple : llvm::enumerate(expandedOperands))
@ -152,7 +152,7 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
for (int64_t i = 0; i < maxIndex; ++i)
result = builder.create<vector::InsertOp>(results[i], result,
delinearize(strides, i));
delinearize(i, strides));
// Reshape back to the original vector shape.
return builder.create<vector::ShapeCastOp>(

View File

@ -75,7 +75,7 @@ public:
SmallVector<OpFoldResult> values(2 * sourceRank + 1);
SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
detail::bindSymbolsList(rewriter.getContext(), symbols);
bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
AffineExpr expr = symbols.front();
values[0] = ShapedType::isDynamic(sourceOffset)
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
@ -262,10 +262,9 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
auto sourceType = source.getType().cast<MemRefType>();
auto [strides, offset] = getStridesAndOffset(sourceType);
OpFoldResult origStride =
ShapedType::isDynamic(strides[groupId])
? origStrides[groupId]
: builder.getIndexAttr(strides[groupId]);
OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
? origStrides[groupId]
: builder.getIndexAttr(strides[groupId]);
// Apply the original stride to all the strides.
int64_t doneStrideIdx = 0;

View File

@ -18,6 +18,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallBitVector.h"
@ -54,24 +55,25 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
memref::ExpandShapeOp expandShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
for (SmallVector<int64_t, 2> groups :
expandShapeOp.getReassociationIndices()) {
MLIRContext *ctx = rewriter.getContext();
for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
assert(!groups.empty() && "association indices groups cannot be empty");
unsigned groupSize = groups.size();
SmallVector<int64_t> suffixProduct(groupSize);
// Calculate suffix product of dimension sizes for all dimensions of expand
// shape op result.
suffixProduct[groupSize - 1] = 1;
for (unsigned i = groupSize - 1; i > 0; i--)
suffixProduct[i - 1] =
suffixProduct[i] *
expandShapeOp.getType().cast<MemRefType>().getDimSize(groups[i]);
SmallVector<Value> dynamicIndices(groupSize);
for (unsigned i = 0; i < groupSize; i++)
dynamicIndices[i] = indices[groups[i]];
int64_t groupSize = groups.size();
// Construct the expression for the index value w.r.t to expand shape op
// source corresponding the indices wrt to expand shape op result.
AffineExpr srcIndexExpr = getLinearAffineExpr(suffixProduct, rewriter);
SmallVector<int64_t> sizes(groupSize);
for (int64_t i = 0; i < groupSize; ++i)
sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
SmallVector<AffineExpr> dims(groupSize);
bindDimsList(ctx, MutableArrayRef{dims});
AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
/// Apply permutation and create AffineApplyOp.
SmallVector<Value> dynamicIndices(groupSize);
for (int64_t i = 0; i < groupSize; i++)
dynamicIndices[i] = indices[groups[i]];
sourceIndices.push_back(rewriter.create<AffineApplyOp>(
loc,
AffineMap::get(/*numDims=*/groupSize, /*numSymbols=*/0, srcIndexExpr),
@ -98,35 +100,39 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
memref::CollapseShapeOp collapseShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
unsigned cnt = 0;
int64_t cnt = 0;
SmallVector<Value> tmp(indices.size());
SmallVector<Value> dynamicIndices;
for (SmallVector<int64_t, 2> groups :
collapseShapeOp.getReassociationIndices()) {
for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
assert(!groups.empty() && "association indices groups cannot be empty");
dynamicIndices.push_back(indices[cnt++]);
unsigned groupSize = groups.size();
SmallVector<int64_t> suffixProduct(groupSize);
int64_t groupSize = groups.size();
// Calculate suffix product for all collapse op source dimension sizes.
suffixProduct[groupSize - 1] = 1;
for (unsigned i = groupSize - 1; i > 0; i--)
suffixProduct[i - 1] =
suffixProduct[i] * collapseShapeOp.getSrcType().getDimSize(groups[i]);
SmallVector<int64_t> sizes(groupSize);
for (int64_t i = 0; i < groupSize; ++i)
sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
// Derive the index values along all dimensions of the source corresponding
// to the index wrt to collapsed shape op output.
SmallVector<AffineExpr, 4> srcIndexExpr =
getDelinearizedAffineExpr(suffixProduct, rewriter);
for (unsigned i = 0; i < groupSize; i++)
auto d0 = rewriter.getAffineDimExpr(0);
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
// Construct the AffineApplyOp for each delinearizingExpr.
for (int64_t i = 0; i < groupSize; i++)
sourceIndices.push_back(rewriter.create<AffineApplyOp>(
loc, AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, srcIndexExpr[i]),
loc,
AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
delinearizingExprs[i]),
dynamicIndices));
dynamicIndices.clear();
}
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
unsigned srcRank =
int64_t srcRank =
collapseShapeOp.getViewSource().getType().cast<MemRefType>().getRank();
for (unsigned i = 0; i < srcRank; i++)
for (int64_t i = 0; i < srcRank; i++)
sourceIndices.push_back(
rewriter.create<AffineApplyOp>(loc, zeroAffineMap, dynamicIndices));
}
@ -157,9 +163,9 @@ resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter,
SmallVector<Value> useIndices;
// Check if this is rank-reducing case. Then for every unit-dim size add a
// zero to the indices.
unsigned resultDim = 0;
int64_t resultDim = 0;
llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
for (auto dim : llvm::seq<int64_t>(0, subViewOp.getSourceType().getRank())) {
if (unusedDims.test(dim))
useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
else
@ -171,7 +177,7 @@ resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter,
for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
SmallVector<Value> dynamicOperands;
AffineExpr expr = rewriter.getAffineDimExpr(0);
unsigned numSymbols = 0;
int64_t numSymbols = 0;
dynamicOperands.push_back(useIndices[index]);
// Multiply the stride;
@ -378,7 +384,7 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, loadOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value, 4> sourceIndices;
SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp,
indices, sourceIndices)))
return failure();
@ -424,7 +430,7 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, loadOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value, 4> sourceIndices;
SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesExpandShape(
loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
return failure();
@ -456,7 +462,7 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, loadOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value, 4> sourceIndices;
SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesCollapseShape(
loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
return failure();
@ -488,7 +494,7 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, storeOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value, 4> sourceIndices;
SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp,
indices, sourceIndices)))
return failure();
@ -533,7 +539,7 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, storeOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value, 4> sourceIndices;
SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesExpandShape(
storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
return failure();
@ -566,7 +572,7 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
affineMap, indices, storeOp.getLoc(), rewriter);
indices.assign(expandedIndices.begin(), expandedIndices.end());
}
SmallVector<Value, 4> sourceIndices;
SmallVector<Value> sourceIndices;
if (failed(resolveSourceIndicesCollapseShape(
storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
return failure();

View File

@ -11,27 +11,100 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/STLExtras.h"
#include <numeric>
#include <optional>
using namespace mlir;
SmallVector<int64_t> mlir::computeStrides(ArrayRef<int64_t> sizes) {
SmallVector<int64_t> strides(sizes.size(), 1);
template <typename ExprType>
SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
ExprType unit) {
if (sizes.empty())
return {};
SmallVector<ExprType> strides(sizes.size(), unit);
for (int64_t r = strides.size() - 2; r >= 0; --r)
strides[r] = strides[r + 1] * sizes[r + 1];
return strides;
}
SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
ArrayRef<int64_t> v2) {
SmallVector<int64_t> result;
for (auto it : llvm::zip(v1, v2))
template <typename ExprType>
SmallVector<ExprType> computeElementwiseMulImpl(ArrayRef<ExprType> v1,
ArrayRef<ExprType> v2) {
// Early exit if both are empty, let zip_equal fail if only 1 is empty.
if (v1.empty() && v2.empty())
return {};
SmallVector<ExprType> result;
for (auto it : llvm::zip_equal(v1, v2))
result.push_back(std::get<0>(it) * std::get<1>(it));
return result;
}
template <typename ExprType>
ExprType linearizeImpl(ArrayRef<ExprType> offsets, ArrayRef<ExprType> basis,
ExprType zero) {
assert(offsets.size() == basis.size());
ExprType linearIndex = zero;
for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
linearIndex = linearIndex + offsets[idx] * basis[idx];
return linearIndex;
}
template <typename ExprType, typename DivOpTy>
SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
ArrayRef<ExprType> strides,
DivOpTy divOp) {
int64_t rank = strides.size();
SmallVector<ExprType> offsets(rank);
for (int64_t r = 0; r < rank; ++r) {
offsets[r] = divOp(linearIndex, strides[r]);
linearIndex = linearIndex % strides[r];
}
return offsets;
}
//===----------------------------------------------------------------------===//
// Utils that operate on static integer values.
//===----------------------------------------------------------------------===//
SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
assert(llvm::all_of(sizes, [](int64_t s) { return s > 0; }) &&
"sizes must be nonnegative");
int64_t unit = 1;
return ::computeSuffixProductImpl(sizes, unit);
}
SmallVector<int64_t> mlir::computeElementwiseMul(ArrayRef<int64_t> v1,
ArrayRef<int64_t> v2) {
return computeElementwiseMulImpl(v1, v2);
}
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
"basis must be nonnegative");
if (basis.empty())
return 0;
return std::accumulate(basis.begin(), basis.end(), 1,
std::multiplies<int64_t>());
}
int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
"basis must be nonnegative");
int64_t zero = 0;
return linearizeImpl(offsets, basis, zero);
}
SmallVector<int64_t> mlir::delinearize(int64_t linearIndex,
ArrayRef<int64_t> strides) {
assert(llvm::all_of(strides, [](int64_t s) { return s > 0; }) &&
"strides must be nonnegative");
return delinearizeImpl(linearIndex, strides,
[](int64_t e1, int64_t e2) { return e1 / e2; });
}
std::optional<SmallVector<int64_t>>
mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
if (shape.size() < subShape.size())
@ -60,35 +133,67 @@ mlir::computeShapeRatio(ArrayRef<int64_t> shape, ArrayRef<int64_t> subShape) {
return SmallVector<int64_t>{result.rbegin(), result.rend()};
}
int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
assert(offsets.size() == basis.size());
int64_t linearIndex = 0;
for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
linearIndex += offsets[idx] * basis[idx];
return linearIndex;
//===----------------------------------------------------------------------===//
// Utils that operate on AffineExpr.
//===----------------------------------------------------------------------===//
SmallVector<AffineExpr> mlir::computeSuffixProduct(ArrayRef<AffineExpr> sizes) {
if (sizes.empty())
return {};
AffineExpr unit = getAffineConstantExpr(1, sizes.front().getContext());
return ::computeSuffixProductImpl(sizes, unit);
}
llvm::SmallVector<int64_t> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
int64_t index) {
int64_t rank = sliceStrides.size();
SmallVector<int64_t> vectorOffsets(rank);
for (int64_t r = 0; r < rank; ++r) {
assert(sliceStrides[r] > 0);
vectorOffsets[r] = index / sliceStrides[r];
index %= sliceStrides[r];
}
return vectorOffsets;
SmallVector<AffineExpr> mlir::computeElementwiseMul(ArrayRef<AffineExpr> v1,
ArrayRef<AffineExpr> v2) {
return computeElementwiseMulImpl(v1, v2);
}
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
AffineExpr mlir::computeMaxLinearIndex(MLIRContext *ctx,
ArrayRef<AffineExpr> basis) {
if (basis.empty())
return 0;
return std::accumulate(basis.begin(), basis.end(), 1,
std::multiplies<int64_t>());
return getAffineConstantExpr(0, ctx);
return std::accumulate(basis.begin(), basis.end(),
getAffineConstantExpr(1, ctx),
std::multiplies<AffineExpr>());
}
llvm::SmallVector<int64_t>
AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
ArrayRef<AffineExpr> basis) {
AffineExpr zero = getAffineConstantExpr(0, ctx);
return linearizeImpl(offsets, basis, zero);
}
AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
ArrayRef<int64_t> basis) {
SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
basis, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
return linearize(ctx, offsets, basisExprs);
}
SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
ArrayRef<AffineExpr> strides) {
return delinearizeImpl(
linearIndex, strides,
[](AffineExpr e1, AffineExpr e2) { return e1.floorDiv(e2); });
}
SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
ArrayRef<int64_t> strides) {
MLIRContext *ctx = linearIndex.getContext();
SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
strides, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
return delinearize(linearIndex, ArrayRef<AffineExpr>{basisExprs});
}
//===----------------------------------------------------------------------===//
// Permutation utils.
//===----------------------------------------------------------------------===//
SmallVector<int64_t>
mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
assert(llvm::all_of(permutation, [](int64_t s) { return s >= 0; }) &&
"permutation must be non-negative");
SmallVector<int64_t> inversion(permutation.size());
for (const auto &pos : llvm::enumerate(permutation)) {
inversion[pos.value()] = pos.index();
@ -97,6 +202,8 @@ mlir::invertPermutationVector(ArrayRef<int64_t> permutation) {
}
bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
assert(llvm::all_of(interchange, [](int64_t s) { return s >= 0; }) &&
"permutation must be non-negative");
llvm::SmallDenseSet<int64_t, 4> seenVals;
for (auto val : interchange) {
if (seenVals.count(val))
@ -106,9 +213,9 @@ bool mlir::isPermutationVector(ArrayRef<int64_t> interchange) {
return seenVals.size() == interchange.size();
}
llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront,
unsigned dropBack) {
SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront,
unsigned dropBack) {
assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
auto range = arrayAttr.getAsRange<IntegerAttr>();
SmallVector<int64_t> res;
@ -118,26 +225,3 @@ llvm::SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
res.push_back((*it).getValue().getSExtValue());
return res;
}
mlir::AffineExpr mlir::getLinearAffineExpr(ArrayRef<int64_t> basis,
mlir::Builder &b) {
AffineExpr resultExpr = b.getAffineDimExpr(0);
resultExpr = resultExpr * basis[0];
for (unsigned i = 1; i < basis.size(); i++)
resultExpr = resultExpr + b.getAffineDimExpr(i) * basis[i];
return resultExpr;
}
llvm::SmallVector<mlir::AffineExpr>
mlir::getDelinearizedAffineExpr(mlir::ArrayRef<int64_t> strides, Builder &b) {
AffineExpr resultExpr = b.getAffineDimExpr(0);
int64_t rank = strides.size();
SmallVector<AffineExpr> vectorOffsets(rank);
vectorOffsets[0] = resultExpr.floorDiv(strides[0]);
resultExpr = resultExpr % strides[0];
for (unsigned i = 1; i < rank; i++) {
vectorOffsets[i] = resultExpr.floorDiv(strides[i]);
resultExpr = resultExpr % strides[i];
}
return vectorOffsets;
}

View File

@ -1558,7 +1558,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
}
std::reverse(newStrides.begin(), newStrides.end());
SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp->setAttr(ExtractOp::getPositionAttrStrName(),

View File

@ -457,7 +457,7 @@ public:
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
++linearIdx) {
auto extractIdxs = delinearize(prunedInStrides, linearIdx);
auto extractIdxs = delinearize(linearIdx, prunedInStrides);
SmallVector<int64_t> insertIdxs(extractIdxs);
applyPermutationToVector(insertIdxs, prunedTransp);
Value extractOp =
@ -588,8 +588,7 @@ public:
loc, resType, rewriter.getZeroAttr(resType));
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
auto pos = rewriter.getI64ArrayAttr(d);
Value x =
rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
Value x = rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
Value r = nullptr;
if (acc)

View File

@ -31,7 +31,7 @@ using namespace mlir::vector;
static SmallVector<int64_t> getVectorOffset(ArrayRef<int64_t> ratioStrides,
int64_t index,
ArrayRef<int64_t> targetShape) {
return computeElementwiseMul(delinearize(ratioStrides, index), targetShape);
return computeElementwiseMul(delinearize(index, ratioStrides), targetShape);
}
/// A functor that accomplishes the same thing as `getVectorOffset` but