mirror of
https://github.com/capstone-engine/llvm-capstone.git
synced 2025-02-10 11:23:52 +00:00
[mlir][Vector] Modernize default lowering of vector transpose
This patch removes an old recursive implementation to lower vector.transpose to extract/insert operations and replaces it with a iterative approach that leverages newer linearization/delinearization utilities. The patch should be NFC except by the order in which the extract/insert ops are generated. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D121321
This commit is contained in:
parent
3c9e849943
commit
f71f9958b9
@ -32,19 +32,6 @@ class LinalgDependenceGraph;
|
||||
/// `[0, permutation.size())`.
|
||||
bool isPermutation(ArrayRef<int64_t> permutation);
|
||||
|
||||
/// Apply the permutation defined by `permutation` to `inVec`.
|
||||
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
|
||||
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
|
||||
/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
|
||||
template <typename T, unsigned N>
|
||||
void applyPermutationToVector(SmallVector<T, N> &inVec,
|
||||
ArrayRef<int64_t> permutation) {
|
||||
SmallVector<T, N> auxVec(inVec.size());
|
||||
for (const auto &en : enumerate(permutation))
|
||||
auxVec[en.index()] = inVec[en.value()];
|
||||
inVec = auxVec;
|
||||
}
|
||||
|
||||
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
|
||||
/// the type of `source`.
|
||||
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
|
||||
|
@ -30,6 +30,19 @@ int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
|
||||
SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> strides,
|
||||
int64_t linearIndex);
|
||||
|
||||
/// Apply the permutation defined by `permutation` to `inVec`.
|
||||
/// Element `i` in `inVec` is mapped to location `j = permutation[i]`.
|
||||
/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector
|
||||
/// `permutation = [2, 0, 1]`, this function leaves `inVec = ['c', 'a', 'b']`.
|
||||
template <typename T, unsigned N>
|
||||
void applyPermutationToVector(SmallVector<T, N> &inVec,
|
||||
ArrayRef<int64_t> permutation) {
|
||||
SmallVector<T, N> auxVec(inVec.size());
|
||||
for (const auto &en : enumerate(permutation))
|
||||
auxVec[en.index()] = inVec[en.value()];
|
||||
inVec = auxVec;
|
||||
}
|
||||
|
||||
/// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
|
||||
SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
|
||||
unsigned dropFront = 0,
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Transforms.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
||||
#include "mlir/IR/ImplicitLocOpBuilder.h"
|
||||
@ -300,16 +301,18 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Return the number of leftmost dimensions from the first rightmost transposed
|
||||
/// dimension found in 'transpose'.
|
||||
size_t getNumDimsFromFirstTransposedDim(ArrayRef<int64_t> transpose) {
|
||||
/// Given a 'transpose' pattern, prune the rightmost dimensions that are not
|
||||
/// transposed.
|
||||
void pruneNonTransposedDims(ArrayRef<int64_t> transpose,
|
||||
SmallVectorImpl<int64_t> &result) {
|
||||
size_t numTransposedDims = transpose.size();
|
||||
for (size_t transpDim : llvm::reverse(transpose)) {
|
||||
if (transpDim != numTransposedDims - 1)
|
||||
break;
|
||||
numTransposedDims--;
|
||||
}
|
||||
return numTransposedDims;
|
||||
|
||||
result.append(transpose.begin(), transpose.begin() + numTransposedDims);
|
||||
}
|
||||
|
||||
/// Progressive lowering of TransposeOp.
|
||||
@ -334,6 +337,8 @@ public:
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
|
||||
Value input = op.vector();
|
||||
VectorType inputType = op.getVectorType();
|
||||
VectorType resType = op.getResultType();
|
||||
|
||||
// Set up convenience transposition table.
|
||||
@ -354,7 +359,7 @@ public:
|
||||
Type flattenedType =
|
||||
VectorType::get(resType.getNumElements(), resType.getElementType());
|
||||
auto matrix =
|
||||
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, op.vector());
|
||||
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
|
||||
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
|
||||
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
|
||||
Value trans = rewriter.create<vector::FlatTransposeOp>(
|
||||
@ -365,54 +370,40 @@ public:
|
||||
|
||||
// Generate unrolled extract/insert ops. We do not unroll the rightmost
|
||||
// (i.e., highest-order) dimensions that are not transposed and leave them
|
||||
// in vector form to improve performance.
|
||||
size_t numLeftmostTransposedDims = getNumDimsFromFirstTransposedDim(transp);
|
||||
|
||||
// The type of the extract operation will be scalar if all the dimensions
|
||||
// are unrolled. Otherwise, it will be a vector with the shape of the
|
||||
// dimensions that are not transposed.
|
||||
Type extractType =
|
||||
numLeftmostTransposedDims == transp.size()
|
||||
? resType.getElementType()
|
||||
: VectorType::Builder(resType).setShape(
|
||||
resType.getShape().drop_front(numLeftmostTransposedDims));
|
||||
// in vector form to improve performance. Therefore, we prune those
|
||||
// dimensions from the shape/transpose data structures used to generate the
|
||||
// extract/insert ops.
|
||||
SmallVector<int64_t, 4> prunedTransp;
|
||||
pruneNonTransposedDims(transp, prunedTransp);
|
||||
size_t numPrunedDims = transp.size() - prunedTransp.size();
|
||||
auto prunedInShape = inputType.getShape().drop_back(numPrunedDims);
|
||||
SmallVector<int64_t, 4> ones(prunedInShape.size(), 1);
|
||||
auto prunedInStrides = computeStrides(prunedInShape, ones);
|
||||
|
||||
// Generates the extract/insert operations for every scalar/vector element
|
||||
// of the leftmost transposed dimensions. We traverse every transpose
|
||||
// element using a linearized index that we delinearize to generate the
|
||||
// appropriate indices for the extract/insert operations.
|
||||
Value result = rewriter.create<arith::ConstantOp>(
|
||||
loc, resType, rewriter.getZeroAttr(resType));
|
||||
SmallVector<int64_t, 4> lhs(numLeftmostTransposedDims, 0);
|
||||
SmallVector<int64_t, 4> rhs(numLeftmostTransposedDims, 0);
|
||||
rewriter.replaceOp(op, expandIndices(loc, resType, extractType, 0,
|
||||
numLeftmostTransposedDims, transp, lhs,
|
||||
rhs, op.vector(), result, rewriter));
|
||||
int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
|
||||
|
||||
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;
|
||||
++linearIdx) {
|
||||
auto extractIdxs = delinearize(prunedInStrides, linearIdx);
|
||||
SmallVector<int64_t, 4> insertIdxs(extractIdxs);
|
||||
applyPermutationToVector(insertIdxs, prunedTransp);
|
||||
Value extractOp =
|
||||
rewriter.create<vector::ExtractOp>(loc, input, extractIdxs);
|
||||
result =
|
||||
rewriter.create<vector::InsertOp>(loc, extractOp, result, insertIdxs);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Builds the indices arrays for the lhs and rhs. Generates the extract/insert
|
||||
// operations when all the ranks go over the last dimension being transposed.
|
||||
Value expandIndices(Location loc, VectorType resType, Type extractType,
|
||||
int64_t pos, int64_t numLeftmostTransposedDims,
|
||||
SmallVector<int64_t, 4> &transp,
|
||||
SmallVector<int64_t, 4> &lhs,
|
||||
SmallVector<int64_t, 4> &rhs, Value input, Value result,
|
||||
PatternRewriter &rewriter) const {
|
||||
if (pos >= numLeftmostTransposedDims) {
|
||||
auto ridx = rewriter.getI64ArrayAttr(rhs);
|
||||
auto lidx = rewriter.getI64ArrayAttr(lhs);
|
||||
Value e =
|
||||
rewriter.create<vector::ExtractOp>(loc, extractType, input, ridx);
|
||||
return rewriter.create<vector::InsertOp>(loc, resType, e, result, lidx);
|
||||
}
|
||||
for (int64_t d = 0, e = resType.getDimSize(pos); d < e; ++d) {
|
||||
lhs[pos] = d;
|
||||
rhs[transp[pos]] = d;
|
||||
result = expandIndices(loc, resType, extractType, pos + 1,
|
||||
numLeftmostTransposedDims, transp, lhs, rhs, input,
|
||||
result, rewriter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Options to control the vector patterns.
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
};
|
||||
|
@ -8,14 +8,14 @@
|
||||
// ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
|
||||
// ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2, 0] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [0, 1] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32>
|
||||
// ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32>
|
||||
// ELTWISE: return %[[T11]] : vector<3x2xf32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user