mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2024-12-04 03:44:59 +00:00
[mlir][vector] Cleanup VectorUnroll and create a generic tile iteration utility
This change refactors some of the utilities used to unroll larger vector
computations into smaller vector computations. In fact, the indexing
computations used here are rather generic and are useful in other dialects or
downstream projects. Therefore, a utility for iterating over all possible tile
offsets for a particular pair of static (shape, tiled shape) is introduced in
IndexingUtils and replaces the existing computations in the vector unrolling
transformations. This builds off of the refactoring of IndexingUtils introduced
in 203fad476b
.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D150000
This commit is contained in:
parent
ed4daeaa13
commit
831041be79
@ -18,7 +18,9 @@
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/iterator.h"
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
namespace mlir {
|
||||
class ArrayAttr;
|
||||
@ -195,6 +197,23 @@ SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
|
||||
// Permutation utils.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename T>
|
||||
SmallVector<T> applyPermutation(ArrayRef<T> input,
|
||||
ArrayRef<int64_t> permutation) {
|
||||
assert(input.size() == permutation.size() &&
|
||||
"expected input rank to equal permutation rank");
|
||||
auto permutationRange = llvm::map_range(
|
||||
llvm::seq<unsigned>(0, input.size()),
|
||||
[&](int64_t idx) -> T { return input[permutation[idx]]; });
|
||||
return llvm::to_vector(permutationRange);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
SmallVector<T> applyPermutation(const SmallVectorImpl<T> &input,
|
||||
ArrayRef<int64_t> permutation) {
|
||||
return applyPermutation(ArrayRef(input), permutation);
|
||||
}
|
||||
|
||||
/// Apply the permutation defined by `permutation` to `inVec`.
|
||||
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
|
||||
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation
|
||||
@ -203,10 +222,7 @@ SmallVector<AffineExpr> delinearize(AffineExpr linearIndex,
|
||||
template <typename T, unsigned N>
|
||||
void applyPermutationToVector(SmallVector<T, N> &inVec,
|
||||
ArrayRef<int64_t> permutation) {
|
||||
SmallVector<T, N> auxVec(inVec.size());
|
||||
for (const auto &en : enumerate(permutation))
|
||||
auxVec[en.index()] = inVec[en.value()];
|
||||
inVec = auxVec;
|
||||
inVec = applyPermutation(inVec, permutation);
|
||||
}
|
||||
|
||||
/// Helper method to apply to inverse a permutation.
|
||||
@ -239,6 +255,138 @@ std::pair<AffineExpr, SmallVector<OpFoldResult>>
|
||||
computeLinearIndex(OpFoldResult sourceOffset, ArrayRef<OpFoldResult> strides,
|
||||
ArrayRef<OpFoldResult> indices);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities for decomposing larger shapes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace detail {
|
||||
/// Encapsulates the set of parameters that are used to make tile offset
|
||||
/// calculations in the TileOffsetRangeIterator.
|
||||
class TileOffsetRangeImpl {
|
||||
public:
|
||||
TileOffsetRangeImpl(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
|
||||
ArrayRef<int64_t> loopOrder);
|
||||
|
||||
int64_t getMaxLinearIndex() const { return maxLinearIndex; }
|
||||
|
||||
SmallVector<int64_t> getStaticTileOffsets(int64_t linearIndex) const;
|
||||
|
||||
SmallVector<AffineExpr> getDynamicTileOffsets(AffineExpr linearIndex) const;
|
||||
|
||||
template <typename T>
|
||||
SmallVector<T> getTileOffsets(T linearIndex) const {
|
||||
if constexpr (std::is_same_v<T, int64_t>)
|
||||
return getStaticTileOffsets(linearIndex);
|
||||
else
|
||||
return getDynamicTileOffsets(linearIndex);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The sub-shape that divides the larger outer shape (which is provided to
|
||||
/// the constructor).
|
||||
SmallVector<int64_t> tileShape;
|
||||
/// The inverse permutation to the `loopOrder` permutation provided in the
|
||||
/// constructor.
|
||||
SmallVector<int64_t> inverseLoopOrder;
|
||||
/// The strides for the basis 'div(shape, tileShape)' permuted by `loopOrder`.
|
||||
SmallVector<int64_t> sliceStrides;
|
||||
/// The maximum linear index in the iteration space given by basis 'div(shape,
|
||||
/// tileShape)'.
|
||||
int64_t maxLinearIndex;
|
||||
};
|
||||
|
||||
/// The STL-style iterator implementation for StaticTileOffsetRange.
|
||||
template <typename ElementType>
|
||||
class TileOffsetRangeIterator
|
||||
: public llvm::iterator_facade_base<TileOffsetRangeIterator<ElementType>,
|
||||
std::forward_iterator_tag,
|
||||
SmallVector<ElementType>> {
|
||||
public:
|
||||
TileOffsetRangeIterator(const TileOffsetRangeImpl ¶ms, ElementType index)
|
||||
: params(params), index(index) {}
|
||||
|
||||
void operator++() { incrementIndex(1); }
|
||||
TileOffsetRangeIterator operator++(int) {
|
||||
const auto copy = *this;
|
||||
++*this;
|
||||
return copy;
|
||||
}
|
||||
|
||||
bool operator==(const TileOffsetRangeIterator &other) const {
|
||||
return index == other.index;
|
||||
}
|
||||
bool operator!=(const TileOffsetRangeIterator &other) const {
|
||||
return index != other.index;
|
||||
}
|
||||
|
||||
SmallVector<ElementType> operator*() const {
|
||||
return params.getTileOffsets(index);
|
||||
}
|
||||
void operator+=(int64_t offset) { incrementIndex(offset); }
|
||||
|
||||
private:
|
||||
void incrementIndex(int64_t offset) { index = index + offset; }
|
||||
const TileOffsetRangeImpl params;
|
||||
int64_t index;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// A range-style iterator that allows for iterating over the offsets of all
|
||||
/// potential tiles of size `tileShape` within the larger shape `shape`, using
|
||||
/// an ordering specified by `loopOrder`. The `loopOrder` specifies the order of
|
||||
/// unrolling by numbering the dimensions in order from "outer most for loop"
|
||||
/// (slowest changing) to "inner most for loop" (fastest changing).
|
||||
///
|
||||
/// For example, for `shape = {10, 20, 30}`, `tileShape = {5, 10, 15}`, and
|
||||
/// `loopOrder={2, 0, 1}`, the iterating over this range will yield offsets:
|
||||
///
|
||||
/// ```
|
||||
/// {0, 0, 0}, {0, 10, 0}, {5, 0, 0}, {5, 10, 0}, {0, 0, 15},
|
||||
/// {0, 10, 15}, {5, 0, 15}, {0, 10, 15}, {5, 10, 15}
|
||||
/// ```
|
||||
///
|
||||
/// This is useful in contexts where a vector computation over a larger shape
|
||||
/// needs to be unrolled to a set of operations on subsets of the original
|
||||
/// operands, such as during the "vector unrolling" transformations.
|
||||
///
|
||||
/// The size of `tileShape` must be less-than-or-equal-to the size of `shape`.a
|
||||
/// If the rank of `tileShape` is smaller than `shape`, then `tileShape`
|
||||
/// elements correspond to the trailing dimensions of `shape`, and the leading
|
||||
/// dimensions are considered untiled and `tileShape` is effectively prepended
|
||||
/// with the leading dims of `shape`.
|
||||
class StaticTileOffsetRange {
|
||||
public:
|
||||
using IteratorTy = detail::TileOffsetRangeIterator<int64_t>;
|
||||
using ParamsTy = detail::TileOffsetRangeImpl;
|
||||
|
||||
StaticTileOffsetRange(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
|
||||
ArrayRef<int64_t> loopOrder)
|
||||
: params(shape, tileShape, loopOrder), beginValue(params, 0),
|
||||
pastEndValue(params, params.getMaxLinearIndex()) {
|
||||
assert(shape.size() >= tileShape.size());
|
||||
assert(loopOrder.size() == shape.size());
|
||||
}
|
||||
|
||||
/// Create the range with identity loop order.
|
||||
StaticTileOffsetRange(ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape)
|
||||
: params(shape, tileShape,
|
||||
llvm::to_vector(llvm::seq<int64_t>(0, shape.size()))),
|
||||
beginValue(params, 0),
|
||||
pastEndValue(params, params.getMaxLinearIndex()) {
|
||||
assert(shape.size() >= tileShape.size());
|
||||
}
|
||||
|
||||
IteratorTy begin() const { return beginValue; }
|
||||
IteratorTy end() const { return pastEndValue; }
|
||||
|
||||
/// Returns the total number of tiles that fit in the larger shape.
|
||||
size_t size() const { return params.getMaxLinearIndex(); }
|
||||
|
||||
private:
|
||||
const ParamsTy params;
|
||||
IteratorTy beginValue;
|
||||
IteratorTy pastEndValue;
|
||||
};
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
@ -250,6 +251,8 @@ inline AffineExpr operator-(int64_t val, AffineExpr expr) {
|
||||
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
|
||||
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
|
||||
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
|
||||
SmallVector<AffineExpr> getAffineConstantExprs(ArrayRef<int64_t> constants,
|
||||
MLIRContext *context);
|
||||
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
|
||||
AffineExpr rhs);
|
||||
|
||||
|
@ -181,9 +181,8 @@ AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
|
||||
|
||||
AffineExpr mlir::linearize(MLIRContext *ctx, ArrayRef<AffineExpr> offsets,
|
||||
ArrayRef<int64_t> basis) {
|
||||
SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
|
||||
basis, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
|
||||
return linearize(ctx, offsets, basisExprs);
|
||||
|
||||
return linearize(ctx, offsets, getAffineConstantExprs(basis, ctx));
|
||||
}
|
||||
|
||||
SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
|
||||
@ -196,9 +195,7 @@ SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
|
||||
SmallVector<AffineExpr> mlir::delinearize(AffineExpr linearIndex,
|
||||
ArrayRef<int64_t> strides) {
|
||||
MLIRContext *ctx = linearIndex.getContext();
|
||||
SmallVector<AffineExpr> basisExprs = llvm::to_vector(llvm::map_range(
|
||||
strides, [ctx](int64_t v) { return getAffineConstantExpr(v, ctx); }));
|
||||
return delinearize(linearIndex, ArrayRef<AffineExpr>{basisExprs});
|
||||
return delinearize(linearIndex, getAffineConstantExprs(strides, ctx));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -302,3 +299,56 @@ mlir::computeLinearIndex(OpFoldResult sourceOffset,
|
||||
|
||||
return {expr, values};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TileOffsetRange
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Apply left-padding by 1 to the tile shape if required.
|
||||
static SmallVector<int64_t> padTileShapeToSize(ArrayRef<int64_t> tileShape,
|
||||
unsigned paddedSize) {
|
||||
assert(tileShape.size() <= paddedSize &&
|
||||
"expected tileShape to <= paddedSize");
|
||||
if (tileShape.size() == paddedSize)
|
||||
return to_vector(tileShape);
|
||||
SmallVector<int64_t> result(paddedSize - tileShape.size(), 1);
|
||||
llvm::append_range(result, tileShape);
|
||||
return result;
|
||||
}
|
||||
|
||||
mlir::detail::TileOffsetRangeImpl::TileOffsetRangeImpl(
|
||||
ArrayRef<int64_t> shape, ArrayRef<int64_t> tileShape,
|
||||
ArrayRef<int64_t> loopOrder)
|
||||
: tileShape(padTileShapeToSize(tileShape, shape.size())),
|
||||
inverseLoopOrder(invertPermutationVector(loopOrder)),
|
||||
sliceStrides(shape.size()) {
|
||||
// Divide the shape by the tile shape.
|
||||
std::optional<SmallVector<int64_t>> shapeRatio =
|
||||
mlir::computeShapeRatio(shape, tileShape);
|
||||
assert(shapeRatio && shapeRatio->size() == shape.size() &&
|
||||
"target shape does not evenly divide the original shape");
|
||||
assert(isPermutationVector(loopOrder) && loopOrder.size() == shape.size() &&
|
||||
"expected loop order to be a permutation of rank equal to outer "
|
||||
"shape");
|
||||
|
||||
maxLinearIndex = mlir::computeMaxLinearIndex(*shapeRatio);
|
||||
mlir::applyPermutationToVector(*shapeRatio, loopOrder);
|
||||
sliceStrides = mlir::computeStrides(*shapeRatio);
|
||||
}
|
||||
|
||||
SmallVector<int64_t> mlir::detail::TileOffsetRangeImpl::getStaticTileOffsets(
|
||||
int64_t linearIndex) const {
|
||||
SmallVector<int64_t> tileCoords = applyPermutation(
|
||||
delinearize(linearIndex, sliceStrides), inverseLoopOrder);
|
||||
return computeElementwiseMul(tileCoords, tileShape);
|
||||
}
|
||||
|
||||
SmallVector<AffineExpr>
|
||||
mlir::detail::TileOffsetRangeImpl::getDynamicTileOffsets(
|
||||
AffineExpr linearIndex) const {
|
||||
MLIRContext *ctx = linearIndex.getContext();
|
||||
SmallVector<AffineExpr> tileCoords = applyPermutation(
|
||||
delinearize(linearIndex, sliceStrides), inverseLoopOrder);
|
||||
return mlir::computeElementwiseMul(tileCoords,
|
||||
getAffineConstantExprs(tileShape, ctx));
|
||||
}
|
||||
|
@ -29,77 +29,6 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
||||
/// During unrolling from `originalShape` to `targetShape` return the offset for
|
||||
/// the slice `index`.
|
||||
static SmallVector<int64_t> getVectorOffset(ArrayRef<int64_t> ratioStrides,
|
||||
int64_t index,
|
||||
ArrayRef<int64_t> targetShape) {
|
||||
return computeElementwiseMul(delinearize(index, ratioStrides), targetShape);
|
||||
}
|
||||
|
||||
/// A functor that accomplishes the same thing as `getVectorOffset` but
|
||||
/// allows for reordering the traversal of the dimensions. The order of
|
||||
/// traversal is given in "for loop order" (outer to inner).
|
||||
namespace {
|
||||
class DecomposeShapeIterator {
|
||||
private:
|
||||
SmallVector<int64_t> vectorShape;
|
||||
SmallVector<int64_t> loopOrder;
|
||||
SmallVector<int64_t> sliceStrides;
|
||||
int64_t maxIndexVal{1};
|
||||
|
||||
public:
|
||||
DecomposeShapeIterator(ArrayRef<int64_t> originalShape,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
ArrayRef<int64_t> loopOrder)
|
||||
: vectorShape(targetShape.begin(), targetShape.end()),
|
||||
loopOrder(loopOrder.begin(), loopOrder.end()),
|
||||
sliceStrides(originalShape.size()) {
|
||||
assert(originalShape.size() >= targetShape.size());
|
||||
assert(loopOrder.size() == originalShape.size());
|
||||
|
||||
// Compute the count for each dimension.
|
||||
auto maybeShapeRatio = computeShapeRatio(originalShape, targetShape);
|
||||
assert(maybeShapeRatio && "Shape does not evenly divide");
|
||||
// Pad `sliceDimCounts` with leading 1s so that all sizes match.
|
||||
SmallVector<int64_t> sliceDimCounts = *maybeShapeRatio;
|
||||
maxIndexVal = computeMaxLinearIndex(sliceDimCounts);
|
||||
|
||||
// Reversing "loop order" gives dimensions from fastest varying to slowest
|
||||
// varying (smallest stride to largest stride).
|
||||
int64_t accum = 1;
|
||||
for (auto idx : llvm::reverse(loopOrder)) {
|
||||
sliceStrides[idx] = accum;
|
||||
accum *= sliceDimCounts[idx];
|
||||
}
|
||||
}
|
||||
|
||||
// Turn the linear index into a d-tuple based on units of vectors of size
|
||||
// `vectorShape`. The linear index is assumed to represent traversal of the
|
||||
// dimensions based on `order`.
|
||||
SmallVector<int64_t> delinearize(int64_t index) const {
|
||||
// Traverse in for loop order (largest stride to smallest stride).
|
||||
SmallVector<int64_t> vectorOffsets(sliceStrides.size());
|
||||
for (auto idx : loopOrder) {
|
||||
vectorOffsets[idx] = index / sliceStrides[idx];
|
||||
index %= sliceStrides[idx];
|
||||
}
|
||||
return vectorOffsets;
|
||||
}
|
||||
|
||||
int64_t maxIndex() const { return maxIndexVal; }
|
||||
|
||||
/// Return the offset within d-tuple based on the ordering given by
|
||||
/// `loopOrder`.
|
||||
SmallVector<int64_t> getVectorOffset(int64_t index) const {
|
||||
SmallVector<int64_t> vectorOffsets = delinearize(index);
|
||||
SmallVector<int64_t> elementOffsets =
|
||||
computeElementwiseMul(vectorShape, vectorOffsets);
|
||||
return elementOffsets;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Compute the indices of the slice `index` for a tranfer op.
|
||||
static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
|
||||
ArrayRef<Value> indices,
|
||||
@ -232,13 +161,10 @@ struct UnrollTransferReadPattern
|
||||
VectorType::get(*targetShape, sourceVectorType.getElementType());
|
||||
SmallVector<Value> originalIndices(readOp.getIndices().begin(),
|
||||
readOp.getIndices().end());
|
||||
|
||||
SmallVector<int64_t> loopOrder =
|
||||
getUnrollOrder(originalSize.size(), readOp, options);
|
||||
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
|
||||
loopOrder);
|
||||
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
|
||||
SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
|
||||
for (SmallVector<int64_t> elementOffsets :
|
||||
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
|
||||
SmallVector<Value> indices =
|
||||
sliceTransferIndices(elementOffsets, originalIndices,
|
||||
readOp.getPermutationMap(), loc, rewriter);
|
||||
@ -283,14 +209,11 @@ struct UnrollTransferWritePattern
|
||||
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
|
||||
SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
|
||||
writeOp.getIndices().end());
|
||||
|
||||
SmallVector<int64_t> loopOrder =
|
||||
getUnrollOrder(originalSize.size(), writeOp, options);
|
||||
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
|
||||
loopOrder);
|
||||
Value resultTensor;
|
||||
for (int64_t i = 0; i < indexToOffsets.maxIndex(); i++) {
|
||||
SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
|
||||
for (SmallVector<int64_t> elementOffsets :
|
||||
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
|
||||
Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
|
||||
SmallVector<Value> indices =
|
||||
@ -355,11 +278,9 @@ struct UnrollContractionPattern
|
||||
|
||||
SmallVector<int64_t> loopOrder = getUnrollOrder(
|
||||
contractOp.getIteratorTypes().size(), contractOp, options);
|
||||
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
|
||||
loopOrder);
|
||||
const int64_t sliceCount = indexToOffsets.maxIndex();
|
||||
for (int64_t i = 0; i < sliceCount; i++) {
|
||||
SmallVector<int64_t> offsets = indexToOffsets.getVectorOffset(i);
|
||||
|
||||
for (SmallVector<int64_t> offsets :
|
||||
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
|
||||
SmallVector<Value> slicesOperands(contractOp.getNumOperands());
|
||||
|
||||
// Helper to compute the new shape of each operand and extract the slice.
|
||||
@ -439,22 +360,16 @@ struct UnrollMultiReductionPattern
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
|
||||
SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
|
||||
llvm::MapVector<
|
||||
SmallVector<int64_t>, Value,
|
||||
llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
|
||||
accCache;
|
||||
// Compute shape ratio of 'shape' and 'sizes'.
|
||||
int64_t sliceCount = computeMaxLinearIndex(ratio);
|
||||
Location loc = reductionOp.getLoc();
|
||||
|
||||
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
|
||||
// of multiples of the targetShape.
|
||||
auto ratioStrides = computeStrides(ratio);
|
||||
for (int64_t i = 0; i < sliceCount; i++) {
|
||||
SmallVector<int64_t> offsets =
|
||||
getVectorOffset(ratioStrides, i, *targetShape);
|
||||
|
||||
for (SmallVector<int64_t> offsets :
|
||||
StaticTileOffsetRange(originalSize, *targetShape)) {
|
||||
SmallVector<Value> operands;
|
||||
SmallVector<int64_t> operandStrides(offsets.size(), 1);
|
||||
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
@ -520,8 +435,6 @@ struct UnrollElementwisePattern : public RewritePattern {
|
||||
auto dstVecType = cast<VectorType>(op->getResult(0).getType());
|
||||
SmallVector<int64_t> originalSize =
|
||||
*cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
|
||||
SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
|
||||
int64_t sliceCount = computeMaxLinearIndex(ratio);
|
||||
Location loc = op->getLoc();
|
||||
// Prepare the result vector.
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
@ -530,12 +443,9 @@ struct UnrollElementwisePattern : public RewritePattern {
|
||||
VectorType newVecType =
|
||||
VectorType::get(*targetShape, dstVecType.getElementType());
|
||||
|
||||
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
|
||||
// of multiples of the targetShape.
|
||||
auto ratioStrides = computeStrides(ratio);
|
||||
for (int64_t i = 0; i < sliceCount; i++) {
|
||||
SmallVector<int64_t> offsets =
|
||||
getVectorOffset(ratioStrides, i, *targetShape);
|
||||
// Create the unrolled computation.
|
||||
for (SmallVector<int64_t> offsets :
|
||||
StaticTileOffsetRange(originalSize, *targetShape)) {
|
||||
SmallVector<Value> extractOperands;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto vecType = dyn_cast<VectorType>(operand.get().getType());
|
||||
@ -574,19 +484,12 @@ struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
|
||||
auto ratio = *computeShapeRatio(originalSize, *targetShape);
|
||||
int64_t sliceCount = ratio[0];
|
||||
|
||||
// Create unrolled vector reduction.
|
||||
Location loc = reductionOp.getLoc();
|
||||
Value accumulator = nullptr;
|
||||
|
||||
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
|
||||
// of multiples of the targetShape.
|
||||
auto ratioStrides = computeStrides(ratio);
|
||||
for (int64_t i = 0; i < sliceCount; ++i) {
|
||||
SmallVector<int64_t> offsets =
|
||||
getVectorOffset(ratioStrides, i, *targetShape);
|
||||
for (SmallVector<int64_t> offsets :
|
||||
StaticTileOffsetRange(originalSize, *targetShape)) {
|
||||
SmallVector<int64_t> strides(offsets.size(), 1);
|
||||
Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, reductionOp.getVector(), offsets, *targetShape, strides);
|
||||
@ -630,20 +533,16 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
|
||||
SmallVector<int64_t> strides(targetShape->size(), 1);
|
||||
Location loc = transposeOp.getLoc();
|
||||
ArrayRef<int64_t> originalSize = originalVectorType.getShape();
|
||||
SmallVector<int64_t> ratio = *computeShapeRatio(originalSize, *targetShape);
|
||||
int64_t sliceCount = computeMaxLinearIndex(ratio);
|
||||
|
||||
// Prepare the result vector;
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
|
||||
SmallVector<int64_t> permutation;
|
||||
transposeOp.getTransp(permutation);
|
||||
|
||||
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
|
||||
// of multiples of the targetShape.
|
||||
auto ratioStrides = computeStrides(ratio);
|
||||
for (int64_t i = 0; i < sliceCount; i++) {
|
||||
SmallVector<int64_t> elementOffsets =
|
||||
getVectorOffset(ratioStrides, i, *targetShape);
|
||||
// Unroll the computation.
|
||||
for (SmallVector<int64_t> elementOffsets :
|
||||
StaticTileOffsetRange(originalSize, *targetShape)) {
|
||||
SmallVector<int64_t> permutedOffsets(elementOffsets.size());
|
||||
SmallVector<int64_t> permutedShape(elementOffsets.size());
|
||||
// Compute the source offsets and shape.
|
||||
@ -694,13 +593,11 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
|
||||
|
||||
SmallVector<int64_t> loopOrder =
|
||||
getUnrollOrder(originalSize.size(), gatherOp, options);
|
||||
DecomposeShapeIterator indexToOffsets(originalSize, *targetShape,
|
||||
loopOrder);
|
||||
for (int64_t i = 0, e = indexToOffsets.maxIndex(); i < e; ++i) {
|
||||
for (SmallVector<int64_t> elementOffsets :
|
||||
StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
|
||||
// To get the unrolled gather, extract the same slice based on the
|
||||
// decomposed shape from each of the index, mask, and pass-through
|
||||
// vectors.
|
||||
SmallVector<int64_t> elementOffsets = indexToOffsets.getVectorOffset(i);
|
||||
Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
|
||||
Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
|
||||
|
@ -533,6 +533,14 @@ AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
|
||||
return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
|
||||
}
|
||||
|
||||
SmallVector<AffineExpr>
|
||||
mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
|
||||
MLIRContext *context) {
|
||||
return llvm::to_vector(llvm::map_range(constants, [&](int64_t constant) {
|
||||
return getAffineConstantExpr(constant, context);
|
||||
}));
|
||||
}
|
||||
|
||||
/// Simplify add expression. Return nullptr if it can't be simplified.
|
||||
static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
|
||||
auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
|
||||
|
@ -1,5 +1,6 @@
|
||||
add_mlir_unittest(MLIRDialectUtilsTests
|
||||
StructuredOpsUtilsTest.cpp
|
||||
IndexingUtilsTest.cpp
|
||||
)
|
||||
target_link_libraries(MLIRDialectUtilsTests
|
||||
PRIVATE
|
||||
|
71
mlir/unittests/Dialect/Utils/IndexingUtilsTest.cpp
Normal file
71
mlir/unittests/Dialect/Utils/IndexingUtilsTest.cpp
Normal file
@ -0,0 +1,71 @@
|
||||
//===- IndexingUtilsTest.cpp - IndexingUtils unit tests -------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
TEST(StaticTileOffsetRange, checkIteratorCanonicalOrder) {
|
||||
// Tile <4x8> by <2x4> with canonical row-major order.
|
||||
std::vector<SmallVector<int64_t>> expected = {{0, 0}, {0, 4}, {2, 0}, {2, 4}};
|
||||
for (auto [idx, tileOffset] :
|
||||
llvm::enumerate(StaticTileOffsetRange({4, 8}, {2, 4}, {0, 1})))
|
||||
EXPECT_EQ(tileOffset, expected[idx]);
|
||||
|
||||
// Check the constructor for default order and test use with zip iterator.
|
||||
for (auto [tileOffset, tileOffsetDefault] :
|
||||
llvm::zip(StaticTileOffsetRange({4, 8}, {2, 4}, {0, 1}),
|
||||
StaticTileOffsetRange({4, 8}, {2, 4})))
|
||||
EXPECT_EQ(tileOffset, tileOffsetDefault);
|
||||
}
|
||||
|
||||
TEST(StaticTileOffsetRange, checkIteratorRowMajorOrder) {
|
||||
// Tile <4x8> by <2x4> with canonical row-major order.
|
||||
std::vector<SmallVector<int64_t>> expected = {{0, 0}, {2, 0}, {0, 4}, {2, 4}};
|
||||
for (auto [idx, tileOffset] :
|
||||
llvm::enumerate(StaticTileOffsetRange({4, 8}, {2, 4}, {1, 0})))
|
||||
EXPECT_EQ(tileOffset, expected[idx]);
|
||||
}
|
||||
|
||||
TEST(StaticTileOffsetRange, checkLeadingOneFill) {
|
||||
// Tile <4x8> by <4>. A smaller tile shape gets right-aligned to the shape.
|
||||
for (auto [idx, tileOffset] :
|
||||
llvm::enumerate(StaticTileOffsetRange({4, 8}, {4}))) {
|
||||
SmallVector<int64_t> expected = {static_cast<int64_t>(idx) / 2,
|
||||
static_cast<int64_t>(idx) % 2 * 4};
|
||||
EXPECT_EQ(tileOffset, expected);
|
||||
}
|
||||
for (auto [idx, tileOffset] :
|
||||
llvm::enumerate(StaticTileOffsetRange({1, 4, 8}, {4}, {2, 1, 0}))) {
|
||||
SmallVector<int64_t> expected = {0, static_cast<int64_t>(idx) % 4,
|
||||
(static_cast<int64_t>(idx) / 4) * 4};
|
||||
EXPECT_EQ(tileOffset, expected);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StaticTileOffsetRange, checkIterator3DPermutation) {
|
||||
// Tile <8x4x2> by <4x2x1> with permutation [1, 0, 2]
|
||||
for (auto [idx, tileOffset] : llvm::enumerate(
|
||||
StaticTileOffsetRange({8, 4, 2}, {4, 2, 1}, {1, 0, 2}))) {
|
||||
SmallVector<int64_t> expected = {((static_cast<int64_t>(idx) / 2) % 2) * 4,
|
||||
((static_cast<int64_t>(idx) / 4) % 2) * 2,
|
||||
static_cast<int64_t>(idx) % 2};
|
||||
EXPECT_EQ(tileOffset, expected);
|
||||
}
|
||||
|
||||
// Tile <10x20x30> by <5x10x16> with permutation [2, 0, 1]
|
||||
for (auto [idx, tileOffset] : llvm::enumerate(
|
||||
StaticTileOffsetRange({10, 20, 30}, {5, 10, 15}, {2, 0, 1}))) {
|
||||
SmallVector<int64_t> expected = {((static_cast<int64_t>(idx) / 2) % 2) * 5,
|
||||
(static_cast<int64_t>(idx) % 2) * 10,
|
||||
(static_cast<int64_t>(idx) / 4) % 2 * 15};
|
||||
EXPECT_EQ(tileOffset, expected);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user